栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

一些Attention代码解释

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

一些Attention代码解释

。这个注意力的应用方式我也没在其他地方遇到过 等遇到再来填坑。

class SingleLayerAttention(nn.Module):

def __init__(self, d_model, d_k, attn_dropout 0.1):
 super(SingleLayerAttention, self).__init__()
 self.dropout nn.Dropout(attn_dropout)
 self.softmax nn.Softmax(dim 2)
 # self.linear nn.Linear(2*d_k, d_k)
 self.weight nn.Parameter(torch.FloatTensor(d_k, 1)) # 利用批处理机制对元素进行处理
 self.act nn.LeakyReLU()
 init.xavier_normal(self.weight)
def forward(self, q, k, v, attn_mask None):
 # q (mb_size x len_q x d_k)
 # k (mb_size x len_k x d_k)
 # v (mb_size x len_v x d_v)
 mb_size, len_q, d_k q.size()
 mb_size, len_k, d_k k.size()
 q q.unsqueeze(2).expand(-1, -1, len_k, -1) 
 k k.unsqueeze(1).expand(-1, len_q, -1, -1)
 x q - k
 attn self.act(torch.matmul(x, self.weight).squeeze(3)) # mb_size * len_q * len_k
 if attn_mask is not None: # mb_size * len_q * len_k
 assert attn_mask.size() attn.size()
 attn_mask attn_mask.eq(0).data
 attn.data.masked_fill_(attn_mask, -float( inf )) # 广播掩码
 attn self.softmax(attn)
 attn.data.masked_fill_(attn_mask, 0)
 attn self.dropout(attn)
 output torch.bmm(attn, v)
 return output, attn
4. MultiHeadAttention层

多头注意力机制 调用点乘注意力机制

class MultiHeadAttention(nn.Module):
 Multi-Head Attention module 
 def __init__(self, n_head, d_input, d_model, d_input_v None, dropout 0.1):
 super(MultiHeadAttention, self).__init__()
 self.n_head n_head
 d_k, d_v d_model//n_head, d_model//n_head
 self.d_k d_k
 self.d_v d_v
 if d_input_v is None:
 d_input_v d_input
 self.w_qs nn.Parameter(torch.FloatTensor(n_head, d_input, d_k))
 self.w_ks nn.Parameter(torch.FloatTensor(n_head, d_input, d_k))
 self.w_vs nn.Parameter(torch.FloatTensor(n_head, d_input_v, d_v))
 self.attention DotProductAttention(d_model)
 self.proj Linear(n_head*d_v, d_model)
 self.dropout nn.Dropout(dropout)
 init.xavier_normal(self.w_qs)
 init.xavier_normal(self.w_ks)
 init.xavier_normal(self.w_vs)
 def forward(self, q, k, v, attn_mask None):
 d_k, d_v self.d_k, self.d_v
 n_head self.n_head
 # residual q
 mb_size, len_q, d_input q.size()
 mb_size, len_k, d_input k.size()
 mb_size, len_v, d_input_v v.size()
 # treat as a (n_head) size batch. 依照多头数量对数据形式进行处理 - n_head x (mb_size*len_q) x d_model
 q_s q.repeat(n_head, 1, 1).view(n_head, -1, d_input) # n_head x (mb_size*len_q) x d_model
 k_s k.repeat(n_head, 1, 1).view(n_head, -1, d_input) # n_head x (mb_size*len_k) x d_model
 v_s v.repeat(n_head, 1, 1).view(n_head, -1, d_input_v) # n_head x (mb_size*len_v) x d_model
 # treat the result as a (n_head * mb_size) size batch - 理解 d_model//n_head 结果最后一维是 d_k
 q_s torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k) # (n_head*mb_size) x len_q x d_k
 k_s torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k) # (n_head*mb_size) x len_k x d_k
 v_s torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v) # (n_head*mb_size) x len_v x d_v
 # perform attention, result size (n_head * mb_size) x len_q x d_v 处理完多头操作 进行注意力机制计算
 outputs, attns self.attention(q_s, k_s, v_s, attn_mask attn_mask.repeat(n_head, 1, 1))
 # back to original mb_size batch, result size mb_size x len_q x (n_head*d_v)
 outputs torch.cat(torch.split(outputs, mb_size, dim 0), dim -1) 
 # project back to residual size 返回原始维度
 outputs self.proj(outputs)
 outputs self.dropout(outputs)
 # return self.layer_norm(outputs residual), attns 返回进行层正则化(outputs residual)
 return outputs, attns
5. BiAttention层
class BiAttention(nn.Module):
 def __init__(self, input_size, dropout):
 super().__init__()
 self.dropout nn.Dropout(dropout)
 self.input_linear nn.Linear(input_size, 1, bias False)
 self.memory_linear nn.Linear(input_size, 1, bias False)
 self.dot_scale nn.Parameter(torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5)))
 self.softmax nn.Softmax(dim -1)
def forward(self, input, memory, q_mask):
 Args:
 input: batch_size * doc_word_len * emb_size
 memory: h_question_word batch_size * ques_len * emb_size
 q_mask:
 Returns:
 bsz, input_len, memory_len input.size(0), input.size(1), memory.size(1)
 input self.dropout(input)
 memory self.dropout(memory)
 input_dot self.input_linear(input)
 memory_dot self.memory_linear(memory).view(bsz, 1, memory_len)
 cross_dot torch.bmm(input * self.dot_scale, memory.permute(0, 2, 1).contiguous())
 # input先进行缩放 -- [batch_size * doc_word_len * ques_len]
 att input_dot memory_dot cross_dot # 注意力矩阵
 att att - 1e30 * (1 - q_mask[:, None]) # None可以在所处维度中多一维 处理问题中padding字符
 weight_one self.softmax(att) # 对查询做归一化, 获得文档对问题注意力权重矩阵
 output_one torch.bmm(weight_one, memory) # 获得文档单词对问题的权重
 weight_two self.softmax(att.max(dim -1)[0]).view(bsz, 1, input_len) # 获得问题对文档注意力权重矩阵
 output_two torch.bmm(weight_two, input) # 问题在每个向量上的权重
 return torch.cat([input, output_one, input*output_one, output_two*output_one], dim -1)
 # input*output_one 每个单词的权重*单词 output_two*output_one 问题在每个单词的权重*每个单词
 # 拼接 原始文档 每个单词的权重 每个单词的权重*单词 问题在每个单词的权重*每个单词
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267549.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号