栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

实现线性回归《PyTorch深度学习实践》

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

实现线性回归《PyTorch深度学习实践》

pytorch写神经网络:
1.构造数据集 -----x,y必须是矩阵
2.设计模型#计算y_hat-------构造计算图
3.构造损失函数和优化器(API)
4.写训练循环(前馈,反馈,更新)
在minibatch里
注:
1.向量是一维张量,矩阵是二维张量,
2.标量l,才能backward
3.一个样本表示一个及一个以上的feature,行表示样本,列表示feature
4.对象():说明是可调用对象,callable
forward 来自torch.nn.Module是callable,重写的forward也是callable
5.继承自Module可以进行反向传播
6.增加epoch次数,注意,训练集上loss减少,测试集上loss增加(过拟合)

import torch
# x,y是矩阵,3行1列----3个样本,特征feature为1
# minibatch 为3
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
 
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)#构造linear对象,参数包括w,b (1,1)为输入,输出样本维度(特征数量)
    def forward(self, x): #重写,覆盖Module里的forward
        y_pred = self.linear(x)
        return y_pred
 
model = LinearModel() #实例化
 
 #损失,优化器实例化
# criterion = torch.nn.MSELoss(size_average = False) #size_average 取均值1/n
criterion = torch.nn.MSELoss(reduction = 'sum') #继承自Module
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # model.parameters()自动检查并拿到其所有成员需要训练的权重
 	#优化器对象,可对权重进行优化			    优化器不继承自Module,不会构造计算图
# 1.前馈(y_hat,loss),2.反馈(backward),3.更新
for epoch in range(100):
    y_pred = model(x_data) # y_hat
    loss = criterion(y_pred, y_data) # loss
    print(epoch, loss.item())
 
    optimizer.zero_grad() # 梯度清零
    loss.backward() # backward
    optimizer.step() # 更新w,b
 #打印出w,b
print('w = ', model.linear.weight.item())#w,b是矩阵
print('b = ', model.linear.bias.item())
 
 #对训练出的模型进行测试
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

补充:

def func(*args,x,y):
	print(args) #args里是传入的元组(1,2,3,4)
func(1,2,3,4,x=3,y=5)
# (1,2,3,4)
def func(*args,**kwargs):
	print(kwargs) #字典
func(1,2,3,4,x=3,y=5)
#{'x':3,'y':5}

__call __:使对象有类似函数的功能

class Foobar:
	def __init__(self):
		pass
	def __call__(self,*args,**kwargs):
		print("hello"+str(args[0]))
foobar =Foobar()
foobar(1,2,3)
#hello1
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/840658.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号