栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

MLP-Mixer的Pytorch实现和分析(一)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

MLP-Mixer的Pytorch实现和分析(一)

        MLP-Mixer网络结构分析:MLP-Mixer: An all-MLP Architecture for Vision_HHzdh的博客-CSDN博客

        Mixer的Pytorch代码实现的难点在于矩阵旋转,我们使用einops中的Rearrange实现矩阵旋转,同时使用torchsummary和tensorboardX来查看网络结构。

 1、导入相关库

import torch
import torch.nn as nn
from torchsummary import summary
from torch.nn import Conv2d
from einops.layers.torch import Rearrange, Reduce
from tensorboardX import SummaryWriter

 2、首先实现MLP(多层感知机),即上结构的MLP1和MLP2

        可以看到MLP的结构如下所示:Fully-connected、GELU、Fully-connected。 

# 这个是两层FC加一个激活函数的mlp block
# 因为有两个mixing,进出的维度都不变,只是中间全连接层的神经元数量不同
# 定义多层感知机
class FeedForward(nn.Module):
    def __init__(self,dim,hidden_dim,dropout=0.):
        super().__init__()
        self.net=nn.Sequential(
            #由此可以看出 FeedForward 的输入和输出维度是一致的
            nn.Linear(dim,hidden_dim),
            #激活函数
            nn.GELU(),
            #防止过拟合
            nn.Dropout(dropout),
            #重复上述过程
            nn.Linear(hidden_dim,dim),
            nn.Dropout(dropout)
        )
    def forward(self,x):
        x=self.net(x)
        return x
 

 2、接着来实现Mixer Block

class MixerBlock(nn.Module):
    def __init__(self,dim,num_patch,token_dim,channel_dim,dropout=0.):
        super().__init__()
        self.token_mixer=nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n d -> b d n'),
            FeedForward(num_patch,token_dim,dropout),
            Rearrange('b d n -> b n d')

         )
        self.channel_mixer=nn.Sequential(
            nn.LayerNorm(dim),
            FeedForward(dim,channel_dim,dropout)
        )
    def forward(self,x):
        x = x+self.token_mixer(x)
        x = x+self.channel_mixer(x)
        return x

 3、搭建MLPMixer

class MLPMixer(nn.Module):
    def __init__(self,in_channels,dim,num_classes,patch_size,image_size,depth,token_dim,channel_dim,dropout=0.):
        super().__init__()
        assert image_size%patch_size==0
        self.num_patches=(image_size//patch_size)**2   # (224/16)**2=196
        # embedding 操作,看见没用卷积来分成一小块一小块的
        # 通过embedding可以将这张3*224*224的图片转换为Channel*Patches=512*196,再通过Rearrange转为196*512
        self.to_embedding=nn.Sequential(Conv2d(in_channels=in_channels,out_channels=dim,kernel_size=patch_size,stride=patch_size),
            Rearrange('b c h w -> b (h w) c')
        )

        # 输入为196*512的table
        # 以下为token-mixing MLPs(MLP1)和channel-mixing MLPs(MLP2)各一层
        self.mixer_blocks=nn.ModuleList([])
        for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim,self.num_patches,token_dim,channel_dim,dropout))

        #
        self.layer_normal=nn.LayerNorm(dim)

        #
        self.mlp_head=nn.Sequential(
            nn.Linear(dim,num_classes)
        )
    def forward(self,x):
        x = self.to_embedding(x)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        x = self.layer_normal(x)
        x = x.mean(dim=1)

        x = self.mlp_head(x)
        return x

 4、测试MLPMixer

#测试Mlp-Mixer
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MLPMixer(in_channels=3, dim=512, num_classes=1000, patch_size=16, image_size=224, depth=1, token_dim=256,
                     channel_dim=2048).to(device)
    summary(model,(3,224,224))

    # torch.Tensor([1, 2, 3, 4, 5, 6])
    inputs = torch.Tensor(1, 3, 224, 224)
    inputs = inputs.to(device)
    print(inputs.shape)

    # 将model保存为graph
    with SummaryWriter(log_dir='logs', comment='model') as w:
        w.add_graph(model, (inputs,))
        print("success")

         得到网络结构如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 512, 14, 14]         393,728
         Rearrange-2             [-1, 196, 512]               0
         LayerNorm-3             [-1, 196, 512]           1,024
         Rearrange-4             [-1, 512, 196]               0
            Linear-5             [-1, 512, 256]          50,432
              GELU-6             [-1, 512, 256]               0
           Dropout-7             [-1, 512, 256]               0
            Linear-8             [-1, 512, 196]          50,372
           Dropout-9             [-1, 512, 196]               0
      FeedForward-10             [-1, 512, 196]               0
        Rearrange-11             [-1, 196, 512]               0
        LayerNorm-12             [-1, 196, 512]           1,024
           Linear-13            [-1, 196, 2048]       1,050,624
             GELU-14            [-1, 196, 2048]               0
          Dropout-15            [-1, 196, 2048]               0
           Linear-16             [-1, 196, 512]       1,049,088
          Dropout-17             [-1, 196, 512]               0
      FeedForward-18             [-1, 196, 512]               0
       MixerBlock-19             [-1, 196, 512]               0
        LayerNorm-20             [-1, 196, 512]           1,024
           Linear-21                 [-1, 1000]         513,000

        通过tensorboardX可以查看网络

 5、网络输入输出框图

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/329281.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号