栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

《PyTorch深度学习实战》第五讲

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

《PyTorch深度学习实战》第五讲

Linear Regression with PyTorch 传送门:https://www.bilibili.com/video/BV1Y7411d7Ys?p=5

目标:PyTorch实现线性模型预测 详细标注
import torch
import matplotlib.pyplot as plt

# 如果运行图像时报错,请添加如下两行代码可解决问题
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# Step:
# 1. Prepare dataset
# 2. Design model using Class
# 3. Construct loss and optimizer
# 4. Training cycle

# 1.准备数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

# 2.设计模型
# torch.nn.Module为torch父类,在使用过程中,我们创建子类继承父类,可实现多功能,方便快捷
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__() # 继承父类
        self.linear = torch.nn.Linear(1, 1) # 构造对象, 包含权重和偏置,参数:输入参数维度和输出参数维度
	# 前向传播
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
# 运行模型
model = LinearModel()

# 3.构造损失函数和训练循环
criterion = torch.nn.MSELoss(size_average=False)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)

# 4.训练周期:前馈反馈更新
res = []
for epoch in range(100): # 训练100次
    y_pred = model(x_data) # 模型预测
    loss = criterion(y_pred, y_data) # loss计算
    res.append(loss.item()) # 绘图,loss数据收集
    print(epoch, loss) # 打印,训练次数和loss

    optimizer.zero_grad() # 梯度清零 # 将模块中的梯度值清零
    loss.backward() #反向传播
    optimizer.step() # 权重更新

print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

# 绘图
plt.plot(list(range(len(res))), res)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
可选择多个优化器: torch.optim.Adagrad torch.optim.Adam torch.optim.Adamax torch.optim.ASGD torch.optim.LBFGS torch.optim.RMSprop torch.optim.Rprop torch.optim.SGD 尝试不同的优化器得到不同的结果 运行结果:

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/325836.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号