栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

《Attention is all you need》Pytorch实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

《Attention is all you need》Pytorch实现

《Attention is all you need》Pytorch实现代码
  • transformer_from_scratch.py
Self Attention 示意图

代码实现
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
				"""
				
				:param embed_size: int
				:param heads: int
				
				        """
				super(SelfAttention, self).__init__()
				
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads."

        # Embedding Layer
        self.values = nn.Linear(self.embed_size, self.head_dim, bias=False)
        self.keys = nn.Linear(self.embed_size, self.head_dim, bias=False)
        self.queries = nn.Linear(self.embed_size, self.head_dim, bias=False)

        # Out Layer
        self.fc_out = nn.Linear(self.head_dim * heads, embed_size)

    def forward(self, values, keys, queries, mask):
				"""
				
				:param values:  (N,value_len,heads,head_dim)
				:param keys:    (N,key_len,heads,head_dim)
				:param queries: (N,query_len,heads,head_dim)
				:param mask:    (N,heads,query_len,key_len)
				:return out:    (N,query_len,heads,head_dim)
				
				        """
				N = queries.shape[0]
        values_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, values_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # attention shape: (N,heads,query_len,key_len)
        attention = nn.Softmax(dim=3)(energy / (self.embed_size ** (1 / 2)))

        # value_len always equals key_len
        out = torch.einsum("nhql,nlhd->nqhd", attention, values).reshape(N,query_len,self.heads*self.head_dim)
        # out shape: (N,query_len,heads,head_dim)
        out = self.fc_out(out)

        return out

注意事项
energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)

if mask is not None:
    energy = energy.masked_fill(mask == 0, float("-1e20"))

❗ mask的维度必须和engery保持一致或者最后的特征维度相等

# Demo
a=torch.arange(12).reshape(3,4)
mask = torch.tril(torch.ones((3,4)))
# mask shape: (3,4)
a.masked_fill(mask==0,-1)
Out[11]:
	tensor([[ 0, -1, -1, -1],
			[ 4,  5, -1, -1],
			[ 8,  9, 10, -1]])
# mask shape: (4)
m = torch.arange(4)
a.masked_fill(m==0,-1)
Out[15]:
	tensor([[-1,  1,  2,  3],
			[-1,  5,  6,  7],
			[-1,  9, 10, 11]])

TransformerBlock 示意图

代码实现
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
				"""
				
				:param embed_size:  int
				:param heads:       int
				:param dropout:     float
				:param forward_expansion:   Coefficient of expansion in Feed Forward Module
				        """
				super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        # Norm
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask=None):
        # Multi-Head Attention
        attention = self.attention(value, key, query, mask)
        # Add & Norm
        x = self.dropout(self.norm1(attention + query))
        # Feed Forward
        forward = self.feed_forward(x)
        # Add & Norm
        out = self.dropout(self.norm2(forward + x))

        return out

Encoder 示意图

代码实现
class Encoder(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
    ):
				"""
				
				:param src_vocab_size:  size of source vocabulary
				:param embed_size:      
				:param num_layers:
				:param heads:
				:param forward_expansion:
				:param dropout:
				:param device:
				:param max_length:      max length of sequence
				        """
				super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.dropout = nn.Dropout(dropout)
        self.device = device
        # Embedding Layer
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Transformer Layer
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout,
                    forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, x, mask=None):
        N, seq_length = x.shape
        # positions = [0,1,2,...,N-1]
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # get input for TransformerBlock
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # N x TransformerBlock
        for layer in self.layers:
            # Encoder -> value=key=query
            out = layer(out, out, out, mask)

        return out

DecoderBlock 示意图

代码实现
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, device):
"""

:param embed_size:
:param heads:
:param dropout:
:param forward_expansion:
:param device:
        """
super(DecoderBlock, self).__init__()
        # Masked Multi-Head Attention
        self.attention = SelfAttention(embed_size, heads)
        # Norm
        self.norm = nn.LayerNorm(embed_size)
        # TransformerBlock
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout=dropout, forward_expansion=forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        # get query for TransformerBlock
        attention = self.attention(x, x, x, trg_mask)
        # Add & Norm
        query = self.dropout(self.norm(attention + x))
        # TransformerBlock with query from target
        out = self.transformer_block(value, key, query, src_mask)

        return out
Decoder 示意图

代码实现
class Decoder(nn.Module):
    def __init__(
            self,
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            dropout,
            forward_expansion,
            device,
            max_length,

    ):
				"""
				:param trg_vocab_size:  size of target vocabulary
				:param embed_size:
				:param num_layers:
				:param heads:
				:param dropout:
				:param forward_expansion:
				:param device:
				:param max_length:
				        """
				super(Decoder, self).__init__()
        self.device = device

        # Embedding Layer
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # DecoderBlock Layer
        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embed_size,
                    heads,
                    dropout,
                    forward_expansion,
                    device,
                )
                for _ in range(num_layers)
            ]
        )

        # Other Layers (Dropout Layer, Linear Layer)
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape

        # positions = [0,1,2,...,N-1]
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # get input for DecoderBlock
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # N x DecoderBlock
        for layer in self.layers:
            out = layer(out, enc_out, enc_out, src_mask, trg_mask)

        # Linear Layer
        out = self.fc_out(out)

        return out

Transformer 示意图

代码实现
class Transformer(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            trg_vocab_size,
            src_pad_idx,
            trg_pad_idx,
            embed_size=256,
            num_layers=2,
            heads=8,
            dropout=0.1,
            forward_expansion=4,
            device="cuda",
            max_length=100,
    ):
        """
        :param src_vocab_size:
        :param trg_vocab_size:
        :param src_pad_idx:
        :param trg_pad_idx:
        :param embed_size:
        :param num_layers:
        :param heads:
        :param dropout:
        :param forward_expansion:
        :param device:
        :param max_length:
        """
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            dropout,
            forward_expansion,
            device,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        """

        :param src: (N,src_len)
        :return:
        src_mask:   (N,1,1,src_len)
        """
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        """
        :param trg: (N,trg_len)
        :return:    (N, 1, trg_len, trg_len)
        """
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask.to(self.device)

    def forward(self, src, trg):
        # Get Mask for source and target
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        # Encoder and Decoder
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)

        return out
Demo 代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.tensor(
    [
        [1, 5, 6, 4, 3, 9, 5, 2, 0],
        [1, 8, 7, 3, 4, 5, 6, 7, 2]
    ]
).to(device)

trg = torch.tensor(
    [
        [1, 7, 4, 3, 5, 9, 2, 0],
        [1, 5, 6, 2, 4, 7, 6, 2]
    ]
).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to(device)

out = model(x, trg)
print(out.shape)
输出

torch.Size([2, 8, 10])

参考文献

Attention is all you need:https://arxiv.org/abs/1706.03762
A good blogpost on Transformers:http://peterbloem.nl/blog/transformers
Einsum Is All You Need: NumPy, PyTorch and TensorFlow:https://youtu.be/pkVwUVEHmfI
Pytorch Transformers from Scratch (Attention is all you need):https://youtu.be/U0s0f995w14

本文仅作为技术交流和分享,严禁未经授权挪作他用。如果对上述实现存在问题,或者想进一步沟通可以联系邮箱1377157216@qq.com

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/272901.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号