栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch简单案例:y=wx+b参数训练

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch简单案例:y=wx+b参数训练

简单练习一下Pytorch:目标方程是y=2x+3,使用六个数据样本进行100次迭代

// An highlighted block
import torch
import numpy
import matplotlib.pyplot as plt

x_data=[1.0,2.0,3.0,4.0,5.0,6.0]
y_data=[5.0,7.0,9.0,11.0,13.0,15.0]
w=torch.tensor([1.0],requires_grad=True)
b=torch.tensor([1.0],requires_grad=True)
w_list=[]
b_list=[]

def pred(x):
    return x*w+b

def loss(x,y):
    l=(pred(x)-y)**2
    return l

for epoch in range(100):
    w_list.append(w.data.item())
    b_list.append(b.data.item())
    for xs,ys in zip(x_data,y_data):
        lo=loss(xs,ys)
        lo.backward()
        w.data=w.data-0.01*w.grad.item()
        b.data=b.data-0.01*b.grad.item()
        w.grad.data.zero_()
        b.grad.data.zero_()
    print("epoch:",epoch,"loss:",lo.item())

print("预测值为:",pred(5).item())

plt.plot(numpy.arange(100),w_list,color="blue",label='parameter_w')
plt.plot(numpy.arange(100),b_list,color="red",label='parameter_b')
plt.show()

测试结果:蓝线是w,最终逼近2,红线是b,最终逼近3

上面的模型是手动进行梯度下降,下面这个是自动实现的梯度下降

import torch
import numpy
import matplotlib.pyplot as plt

x_data=torch.tensor([[1.0],[2.0],[3.0],[4.0]])
y_data=torch.tensor([[6.0],[9.0],[12.0],[15.0]])
loss_list=[]
#创建模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear=torch.nn.Linear(1,1)

    def forward(self,x):
        y_pred=self.linear(x)
        return y_pred
#损失函数和优化器
model=LinearModel()
criterion=torch.nn.MSELoss(reduction='mean')
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
#模型迭代,3000次
for epoch in range(3000):
    y_pre=model(x_data)
    loss=criterion(y_pre,y_data)
    loss_list.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epoch_list=numpy.arange(3000)
print(model.linear.weight.item(),model.linear.bias.item())#输出权值和偏差
test_set=torch.tensor([[8.0]])
print(model(test_set).item())
#可视化训练过程
plt.plot(epoch_list,loss_list,label='loss',color='blue')
plt.show()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/331103.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号