栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch搭建一个简单的神经网络(模板)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch搭建一个简单的神经网络(模板)

导库
import torch
import torch.nn as nn
import torch.nn.functional as F
搭建网络
class Model(nn.Module):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
    //可以在这你定义你的链接层
        pass
    def reset_parameters(self):
    //在这里初始化权重
        pass
    def forward(self, inputs):
    //在这里定义前向传播
        pass

推理
model = Model()

model(inputs) # 或model.forward(inputs)
优化器和损失函数定义
optimize=torch.optim.SGD(model.parameters(),lr=0.1)
loss_fn=nn.MSELoss()
训练
for epoch in range(10):
    //清空梯度
    optimize.zero_grad()
    //前向传播
    pred=model(features)
    //损失值
    loss=loss_fn(pred,labels)
    //后向传播
    loss.backward()
    //优化器更新参数
    optimize.step()
    //打印损失
    print(loss)
实例

产生数据集

%matplotlib inline
import random
import torch

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

训练网络

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        self.dense=nn.Linear(2,1)
    def forward(self, inputs):
        outputs=self.dense(inputs)
        return outputs

model=Model()

optimize=torch.optim.SGD(model.parameters(),lr=0.1)
loss_fn=nn.MSELoss()

for epoch in range(10):
    optimize.zero_grad()
    pred=model(features)
    loss=loss_fn(pred,labels)
    loss.backward()
    optimize.step()
    print(loss)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/767370.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号