【Pytorch实现】——nn.Sequential()
import torch
import torch.nn as nn
class MySequential(nn.Module):
def __init__(self,*args):
super().__init__()
# 将args中的层存入有顺序的dict中
for block in args:
self._modules[block] = block
def forward(self,X):
# 从有顺序的dict中逐个拿出字典的values
for block in self._modules.values():
X = block(X)
return X
net = MySequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))
X = torch.randn(2,20)
Y = net(X)
print(Y)
- 经过nn.Sequential()包装后,其实就相当于一个list
- print(net[0])
- print(net[0][0])
- print(net[0][0][0])