栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

DenseNet:Pytorch简易实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

DenseNet:Pytorch简易实现

原文:Densely Connected Convolutional Networks 原文代码:https://github.com/liuzhuang13/DenseNet

(原文)网络结构:

(原文)详细结构

(原文)插图:

"""
2021年09月27日21:56:08:DenseNet完成, 细节其实可以修改下,
尤其是一个block中, 维度的合并,以及1*1降维的维数可以指定一下
简易版本:自己按照论文以自己的理解写的,有不足之处还请见谅,可在评论区交流指正
"""
import os

import torch
import torch.nn as nn

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


class DenseNet(nn.Module):
    def __init__(self, n_class):
        super(DenseNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 3, 7, stride=2, padding=4),
            nn.MaxPool2d(3, 2),
        )

        self.layer_dense2 = DenseBlock(3)
        self.transition2 = self.Transition_Layer(15, 6)

        self.layer_dense3 = DenseBlock(6)
        self.transition3 = self.Transition_Layer(30, 12)

        self.layer_dense4 = DenseBlock(12)
        self.transition4 = self.Transition_Layer(60, 120)

        self.layer_pool5 = nn.AvgPool2d(7, 7)

        self.Linear = nn.Linear(120, n_class)

    def Transition_Layer(self, in_, out):
        """
        控制升维和降维
        :param in_:
        :param out:
        :return:
        """
        transition = nn.Sequential(
            nn.BatchNorm2d(in_),
            nn.ReLU(),
            nn.Conv2d(in_, out, 1),
            nn.AvgPool2d(2, 2),
        )
        return transition

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer_dense2(x)
        x = self.transition2(x)

        x = self.layer_dense3(x)
        x = self.transition3(x)

        x = self.layer_dense4(x)
        x = self.transition4(x)

        x = self.layer_pool5(x)

        x = x.view(x.size(0), -1)

        x = self.Linear(x)
        return x


class DenseBlock(nn.Module):
    def __init__(self, in_channel):
        """
        每次卷积输入输出的模块维度是相同的, 最后拼接在一起
        :param in_channel:输入维度, 输出维度相同
        """
        super(DenseBlock, self).__init__()
        self.d1 = self.Conv_Block(in_channel, in_channel)
        self.d2 = self.Conv_Block(2 * in_channel, in_channel)
        self.d3 = self.Conv_Block(4 * in_channel, in_channel)
        self.d4 = self.Conv_Block(8 * in_channel, in_channel)

    @staticmethod
    def Conv_Block(in_channel, out):
        Conv = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(),
            nn.Conv2d(in_channel, out, 1),

            nn.BatchNorm2d(out),
            nn.ReLU(),
            nn.Conv2d(out, out, 3, padding=1),
        )
        return Conv

    def forward(self, x):
        x1 = self.d1(x)
        x_cat1 = torch.cat((x, x1), dim=1)

        x2 = self.d2(x_cat1)
        x_cat2 = torch.cat((x2, x_cat1, x1), dim=1)

        x3 = self.d3(x_cat2)
        x_cat3 = torch.cat((x3, x_cat2, x_cat1, x1), dim=1)

        x4 = self.d4(x_cat3)

        x = torch.cat((x4, x3, x2, x1, x), dim=1)

        return x


if __name__ == '__main__':
    inputs = torch.randn(10, 3, 224, 224)
    model = DenseNet(n_class=6)
    outputs = model(inputs)
    print(model)
    print(outputs.shape)


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/273437.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号