栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

神经网络层和块代码(李沐动手学)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

神经网络层和块代码(李沐动手学)

import torch
from torch import nn
from torch.nn import functional as F
net=nn.Sequential(nn.Linear(8,16),nn.ReLU(),nn.Linear(16,1))
X=torch.rand(size=(5,8))
print(net(X))
print(net[2].state_dict())
print(type(net[2].bias),net[2].bias,net[2].bias.data)
print(net[2].weight.grad==None)
print(*[(name,parm.shape)for name,parm in net[0].named_parameters()])
print(*[(name,parm.shape)for name,parm in net.named_parameters()])#8*16,16*8?
print(net.state_dict()['2.bias'].data)

def block1():
    return nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,4),nn.ReLU())

def block2():
    net=nn.Sequential()
    for i in range(4):
        net.add_module(f'block{i}',block1())
    return net

rgnet=nn.Sequential(nn.Linear(8,4),nn.ReLU(),block2(),nn.Linear(4,1))
print(rgnet(X))

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden=nn.Linear(8,16)
        self.out=nn.Linear(16,1)
    def forward(self,X):
        return self.out(F.relu(self.hidden(X)))

net=MLP()
print(net(X))

class MySequential(nn.Module):
    def __init__(self,*args):
        super().__init__()
        for idx,module in enumerate(args):
            self._modules[str(idx)]=module

    def forward(self,X):
        for block in self._modules.values():
            X=block(X)
        return X

net=MySequential(nn.Linear(8,20),nn.ReLU(),nn.Linear(20,1))
net(X)
class FixHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight=torch.rand((8,8),requires_grad=False)
        self.linear=nn.Linear(8,8)
    def forward(self,X):
        X=self.linear(X)
        X=F.relu(torch.mm(X,self.rand_weight)+1)
        X=self.linear(X)
        while(X.abs().sum()>1):
            X=X/2
        return X.sum()

net=FixHiddenMLP()
print(net(X))

class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(8,20),nn.ReLU(),nn.Linear(20,8),nn.ReLU())
        self.linear=nn.Linear(8,8)

    def forward(self,X):
        return torch.cat([self.linear(self.net(X)),self.net(X)],1)

net=nn.Sequential(NestMLP(),nn.Linear(16,8),FixHiddenMLP())
print(net(X))
print(NestMLP()(X))



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/768211.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号