栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch中forward的调用

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch中forward的调用

  • forward不是在模型创建的时候调用,而是在模型调用的时候调用
model = NeuralNet(tr_set.dataset.dim).to(device)  # Construct model and move to device

pred = model(x)  # model是一个可调用对象,即会自动调用__call__函数,这里会调用__call_impl函数,
                 # 其中的forward_call会调用NeuralNet中实现的forward,后续导致的其他调用在下面类中说明
class NeuralNet(nn.Module):
    ''' A simple fully-connected deep neural network '''

    def __init__(self, input_dim):
        super(NeuralNet, self).__init__()

        # Define your neural network here
        # TODO: How to modify this model to achieve better performance?
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        # Mean squared error loss
        self.criterion = nn.MSELoss(reduction='mean')

    def forward(self, x):  # 这里forward不是override的方法,因为编译器这一行没有o图标
        ''' Given input of size (batch_size x input_dim), compute output of the network '''
        return self.net(x).squeeze(1)  
        # 这里self.net是Sequential对象,Sequential类实现了Module类,Module类实现了__call__函![请添加图片描述](https://img-blog.csdnimg.cn/b08b20342690454a88761183bdb4b07a.png)
数,故Module、Sequential实例都是可调用的 从module.py的forward_call(*input, **kwargs)调用过来
        # 这里self.net(x)会调用__call_impl,(self是Sequential类型,会自动调用__call_impl)从而会导致Sequential中的forward被调用,
        # 在Sequential的forward中对每个module(这里指Linear、RELU、Liner)循环调用,因此分别会调用Liner、RELU、Linear的__call_impl,从而它们各自的forward会被调用
        # 这里impl都是一个,但对不同的类通过impl中的forward_call会分别调用每个类自己的forward
    def cal_loss(self, pred, target):
        ''' Calculate loss '''
        # TODO: you may implement L1/L2 regularization here
        return self.criterion(pred, target)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/870154.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号