栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

编码器-解码器

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

编码器-解码器

一、编码器的基本架构
forward层:传入一个x,输出一个out

from torch import nn
class Encoder(nn.Module):
    def __init__(self,**kwargs):
        super(Encoder, self).__init__(**kwargs)
    def forward(self,x,*args):
        raise NotImplementedError

二、解码器的基本架构
def init_state(self,enc_outputs,*args):利用encoder的输出建立中间状态(decoder的初始状态)
forward()内部传入encoder的输入和decoder的初始状态(随后会不断变化)
这里decoder的初始状态也就是:encoder压缩成的向量,也就是编好的码

class Decoder(nn.Module):
    def __init__(self,**kwargs):
        super(Decoder, self).__init__(**kwargs)
    #enc_outputs是encoder的输出,初始化状态,就是用encoder的东西来转化成直接想要的状态
    def init_state(self,enc_outputs,*args):
        raise NotImplementedError
    #可以有额外的输入,state:一开始从encoder那翻译过来,之后随着forward可以不断变化
    def forward(self,x,state):
        raise NotImplementedError

三、编码器-解码器架构
将上面定义的编码解码器加载进来

class EncoderDecoder(nn.Module):
    def __init__(self,encoder,decoder,**kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder=encoder
        self.decoder=decoder
    #输入编号的码,输入编码器和解码器的输入
    def forward(self,enc_x,dec_x,*args):
        enc_outputs = self.encoder(enc_x,*args)
        dec_state=self.decoder.init_state(enc_outputs,*args)
        return self.decoder(dec_x,dec_state)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/348318.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号