构造神经网络一般步骤:
1prepare dataset
2design model using Class (计算y_hat)
3consturct loss and optimizer (using PyTorch API构造损失函数和优化器)
4training cycle (forwward backward update 前馈算损失,后馈算梯度,更新权重)
广播机制:
用mini-batch的方式来做线性回归
第一步
X、Y是3x1的tensor(张量),w是3x1的矩阵(这里可以理解为矩阵)(对应位置相乘,不是矩阵相乘)
第二步
重点目标:构造计算图,让pytorch自动求梯度
仿射模型y_hat=x*w+b,在pytorch中近似为线性单元z=wx+b,要确定w、b的大小需要知道x,y_hat的维度,
loss经过计算最终是一个标量,向量没办法backward
linearmodule会自动backward
默认bias=True,有偏置,
self.linear(x),对象后面加括号,实现一个collable可调用的对象,python中常用
举例:
class Foobar():
def __init__(self):
pass
def __call__(self, *args, **kwargs):
print("Hello"+str(args[0]))
def func(*args,**kwargs):#这两个是python中的可变参数。*args 表示任何多个无名参数,它是一个tuple;**kwargs 表示关键字参数,它是一个dict
print(args)
print(kwargs)
func(1,2,4,3,x=3,y=5)
foobar=Foobar()
foobar(1,2,3)
输出:
(1, 2, 4, 3)
{'x': 3, 'y': 5}
Hello1
第三步:构造损失函数和优化器
损失函数loss=(y_hat-y)**2, 用MSE,size_average=True 求均值,等于 false不求均值,求不求均值都一样,在nimi-batch中,若某一批样本数较少需要求均值,criterion需要y_hat,y就可以求损失(需要构建计算图,继承自nn.module),reduce表示要不要降维,一般只考虑size_average
优化器不是module,不会构建计算图,SDG是一个类,实例化SDG,第一个参数para是权重,
lr是学习率,pytorch支持对模型不同部分使用不同学习率
model.parameters ()不管模型多复杂都能找到他们的参数
第四步:训练轮数
一共四步:计算y_hat,loss,backard,更新
weight是矩阵,只需要打印值用item()方法
本节完整代码:
import torch
from matplotlib import pyplot as plt
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])
class LinearModel(torch.nn.Module):#our module class inherit from nn.module(neural network module)
def __init__(self):#构造函数,初始化
super().__init__()#继承父类的init方法
self.linear=torch.nn.Linear(1,1)#Class nn.Linear包含权重和偏置两个tensor
def forward(self,x):#重写父类中forward函数
y_pred=self.linear(x)
return y_pred
model=LinearModel() # creat a instance of linearmodel.model是collable,即可以model(x)
#nn.MSELoss继承自nn.Module
criterion=torch.nn.MSELoss(size_average=False)
#选择优化器,lr是learing rate
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
for epoch in range(100):
y_pred=model(x_data) #fforwward:predict 算出y_hat
loss=criterion(y_pred,y_data)#forward :loss
print(epoch,loss)
optimizer.zero_grad()#before backard,梯度归零
loss.backward()#backward:autograd
optimizer.step()#用step函数update更新
#output weight and bias
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())
#test model
x_test=torch.tensor([4.0])
y_test=model(x_test)
print('y_pred=',y_test.data)
输出:
D:Anacodaenvspytorch-py36python.exe C:/Users/hp/Desktop/python_work/PyTorch/Lesson1/LinearRegression.py 0 tensor(48.8937, grad_fn=) D:Anacodaenvspytorch-py36libsite-packagestorchnn_reduction.py:44: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. warnings.warn(warning.format(ret)) 1 tensor(21.9066, grad_fn= ) 2 tensor(9.8907, grad_fn= ) 3 tensor(4.5396, grad_fn= ) 4 tensor(2.1555, grad_fn= ) 5 tensor(1.0922, grad_fn= ) 6 tensor(0.6169, grad_fn= ) 7 tensor(0.4035, grad_fn= ) 8 tensor(0.3066, grad_fn= ) 9 tensor(0.2617, grad_fn= ) 10 tensor(0.2399, grad_fn= ) 11 tensor(0.2284, grad_fn= ) 12 tensor(0.2215, grad_fn= ) 13 tensor(0.2167, grad_fn= ) 14 tensor(0.2129, grad_fn= ) 15 tensor(0.2095, grad_fn= ) 16 tensor(0.2064, grad_fn= ) 17 tensor(0.2034, grad_fn= ) 18 tensor(0.2004, grad_fn= ) 19 tensor(0.1975, grad_fn= ) 20 tensor(0.1947, grad_fn= ) 21 tensor(0.1919, grad_fn= ) 22 tensor(0.1891, grad_fn= ) 23 tensor(0.1864, grad_fn= ) 24 tensor(0.1837, grad_fn= ) 25 tensor(0.1811, grad_fn= ) 26 tensor(0.1785, grad_fn= ) 27 tensor(0.1759, grad_fn= ) 28 tensor(0.1734, grad_fn= ) 29 tensor(0.1709, grad_fn= ) 30 tensor(0.1684, grad_fn= ) 31 tensor(0.1660, grad_fn= ) 32 tensor(0.1636, grad_fn= ) 33 tensor(0.1613, grad_fn= ) 34 tensor(0.1590, grad_fn= ) 35 tensor(0.1567, grad_fn= ) 36 tensor(0.1544, grad_fn= ) 37 tensor(0.1522, grad_fn= ) 38 tensor(0.1500, grad_fn= ) 39 tensor(0.1479, grad_fn= ) 40 tensor(0.1457, grad_fn= ) 41 tensor(0.1436, grad_fn= ) 42 tensor(0.1416, grad_fn= ) 43 tensor(0.1395, grad_fn= ) 44 tensor(0.1375, grad_fn= ) 45 tensor(0.1356, grad_fn= ) 46 tensor(0.1336, grad_fn= ) 47 tensor(0.1317, grad_fn= ) 48 tensor(0.1298, grad_fn= ) 49 tensor(0.1279, grad_fn= ) 50 tensor(0.1261, grad_fn= ) 51 tensor(0.1243, grad_fn= ) 52 tensor(0.1225, grad_fn= ) 53 tensor(0.1207, grad_fn= ) 54 tensor(0.1190, grad_fn= ) 55 tensor(0.1173, grad_fn= ) 56 tensor(0.1156, grad_fn= ) 57 tensor(0.1139, grad_fn= ) 58 tensor(0.1123, grad_fn= ) 59 tensor(0.1107, grad_fn= ) 60 tensor(0.1091, grad_fn= ) 61 tensor(0.1075, grad_fn= ) 62 tensor(0.1060, grad_fn= ) 63 tensor(0.1045, grad_fn= ) 64 tensor(0.1030, grad_fn= ) 65 tensor(0.1015, grad_fn= ) 66 tensor(0.1000, grad_fn= ) 67 tensor(0.0986, grad_fn= ) 68 tensor(0.0972, grad_fn= ) 69 tensor(0.0958, grad_fn= ) 70 tensor(0.0944, grad_fn= ) 71 tensor(0.0930, grad_fn= ) 72 tensor(0.0917, grad_fn= ) 73 tensor(0.0904, grad_fn= ) 74 tensor(0.0891, grad_fn= ) 75 tensor(0.0878, grad_fn= ) 76 tensor(0.0865, grad_fn= ) 77 tensor(0.0853, grad_fn= ) 78 tensor(0.0841, grad_fn= ) 79 tensor(0.0829, grad_fn= ) 80 tensor(0.0817, grad_fn= ) 81 tensor(0.0805, grad_fn= ) 82 tensor(0.0793, grad_fn= ) 83 tensor(0.0782, grad_fn= ) 84 tensor(0.0771, grad_fn= ) 85 tensor(0.0760, grad_fn= ) 86 tensor(0.0749, grad_fn= ) 87 tensor(0.0738, grad_fn= ) 88 tensor(0.0727, grad_fn= ) 89 tensor(0.0717, grad_fn= ) 90 tensor(0.0707, grad_fn= ) 91 tensor(0.0697, grad_fn= ) 92 tensor(0.0686, grad_fn= ) 93 tensor(0.0677, grad_fn= ) 94 tensor(0.0667, grad_fn= ) 95 tensor(0.0657, grad_fn= ) 96 tensor(0.0648, grad_fn= ) 97 tensor(0.0639, grad_fn= ) 98 tensor(0.0629, grad_fn= ) 99 tensor(0.0620, grad_fn= ) w= 1.834191083908081 b= 0.3769226372241974 y_pred= tensor([7.7137]) Process finished with exit code 0
迭代100次结果不是很理想,迭代1000词之后的结果:
D:Anacodaenvspytorch-py36python.exe C:/Users/hp/Desktop/python_work/PyTorch/Lesson1/LinearRegression.py D:Anacodaenvspytorch-py36libsite-packagestorchnn_reduction.py:44: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. warnings.warn(warning.format(ret)) 0 tensor(111.9742, grad_fn=) 1 tensor(49.9149, grad_fn= ) 2 tensor(22.2869, grad_fn= ) 3 tensor(9.9867, grad_fn= ) 4 tensor(4.5101, grad_fn= ) 5 tensor(2.0711, grad_fn= ) 6 tensor(0.9845, grad_fn= ) 7 tensor(0.4998, grad_fn= ) 8 tensor(0.2832, grad_fn= ) 9 tensor(0.1859, grad_fn= ) 10 tensor(0.1417, grad_fn= ) 11 tensor(0.1212, grad_fn= ) 12 tensor(0.1112, grad_fn= ) 13 tensor(0.1060, grad_fn= ) 14 tensor(0.1028, grad_fn= ) 15 tensor(0.1006, grad_fn= ) 16 tensor(0.0988, grad_fn= ) 17 tensor(0.0973, grad_fn= ) 18 tensor(0.0958, grad_fn= ) 19 tensor(0.0944, grad_fn= ) 20 tensor(0.0930, grad_fn= ) 21 tensor(0.0917, grad_fn= ) 22 tensor(0.0904, grad_fn= ) 23 tensor(0.0891, grad_fn= ) 24 tensor(0.0878, grad_fn= ) 25 tensor(0.0865, grad_fn= ) 26 tensor(0.0853, grad_fn= ) 27 tensor(0.0841, grad_fn= ) 28 tensor(0.0828, grad_fn= ) 29 tensor(0.0817, grad_fn= ) 30 tensor(0.0805, grad_fn= ) 31 tensor(0.0793, grad_fn= ) 32 tensor(0.0782, grad_fn= ) 33 tensor(0.0771, grad_fn= ) 34 tensor(0.0760, grad_fn= ) 35 tensor(0.0749, grad_fn= ) 36 tensor(0.0738, grad_fn= ) 37 tensor(0.0727, grad_fn= ) 38 tensor(0.0717, grad_fn= ) 39 tensor(0.0707, grad_fn= ) 40 tensor(0.0696, grad_fn= ) 41 tensor(0.0686, grad_fn= ) 42 tensor(0.0677, grad_fn= ) 43 tensor(0.0667, grad_fn= ) 44 tensor(0.0657, grad_fn= ) 45 tensor(0.0648, grad_fn= ) 46 tensor(0.0638, grad_fn= ) 47 tensor(0.0629, grad_fn= ) 48 tensor(0.0620, grad_fn= ) 49 tensor(0.0611, grad_fn= ) 50 tensor(0.0603, grad_fn= ) 51 tensor(0.0594, grad_fn= ) 52 tensor(0.0585, grad_fn= ) 53 tensor(0.0577, grad_fn= ) 54 tensor(0.0569, grad_fn= ) 55 tensor(0.0560, grad_fn= ) 56 tensor(0.0552, grad_fn= ) 57 tensor(0.0544, grad_fn= ) 58 tensor(0.0537, grad_fn= ) 59 tensor(0.0529, grad_fn= ) 60 tensor(0.0521, grad_fn= ) 61 tensor(0.0514, grad_fn= ) 62 tensor(0.0506, grad_fn= ) 63 tensor(0.0499, grad_fn= ) 64 tensor(0.0492, grad_fn= ) 65 tensor(0.0485, grad_fn= ) 66 tensor(0.0478, grad_fn= ) 67 tensor(0.0471, grad_fn= ) 68 tensor(0.0464, grad_fn= ) 69 tensor(0.0458, grad_fn= ) 70 tensor(0.0451, grad_fn= ) 71 tensor(0.0445, grad_fn= ) 72 tensor(0.0438, grad_fn= ) 73 tensor(0.0432, grad_fn= ) 74 tensor(0.0426, grad_fn= ) 75 tensor(0.0420, grad_fn= ) 76 tensor(0.0414, grad_fn= ) 77 tensor(0.0408, grad_fn= ) 78 tensor(0.0402, grad_fn= ) 79 tensor(0.0396, grad_fn= ) 80 tensor(0.0390, grad_fn= ) 81 tensor(0.0385, grad_fn= ) 82 tensor(0.0379, grad_fn= ) 83 tensor(0.0374, grad_fn= ) 84 tensor(0.0368, grad_fn= ) 85 tensor(0.0363, grad_fn= ) 86 tensor(0.0358, grad_fn= ) 87 tensor(0.0353, grad_fn= ) 88 tensor(0.0348, grad_fn= ) 89 tensor(0.0343, grad_fn= ) 90 tensor(0.0338, grad_fn= ) 91 tensor(0.0333, grad_fn= ) 92 tensor(0.0328, grad_fn= ) 93 tensor(0.0323, grad_fn= ) 94 tensor(0.0319, grad_fn= ) 95 tensor(0.0314, grad_fn= ) 96 tensor(0.0310, grad_fn= ) 97 tensor(0.0305, grad_fn= ) 98 tensor(0.0301, grad_fn= ) 99 tensor(0.0296, grad_fn= ) 100 tensor(0.0292, grad_fn= ) 101 tensor(0.0288, grad_fn= ) 102 tensor(0.0284, grad_fn= ) 103 tensor(0.0280, grad_fn= ) 104 tensor(0.0276, grad_fn= ) 105 tensor(0.0272, grad_fn= ) 106 tensor(0.0268, grad_fn= ) 107 tensor(0.0264, grad_fn= ) 108 tensor(0.0260, grad_fn= ) 109 tensor(0.0256, grad_fn= ) 110 tensor(0.0253, grad_fn= ) 111 tensor(0.0249, grad_fn= ) 112 tensor(0.0246, grad_fn= ) 113 tensor(0.0242, grad_fn= ) 114 tensor(0.0239, grad_fn= ) 115 tensor(0.0235, grad_fn= ) 116 tensor(0.0232, grad_fn= ) 117 tensor(0.0228, grad_fn= ) 118 tensor(0.0225, grad_fn= ) 119 tensor(0.0222, grad_fn= ) 120 tensor(0.0219, grad_fn= ) 121 tensor(0.0216, grad_fn= ) 122 tensor(0.0212, grad_fn= ) 123 tensor(0.0209, grad_fn= ) 124 tensor(0.0206, grad_fn= ) 125 tensor(0.0203, grad_fn= ) 126 tensor(0.0201, grad_fn= ) 127 tensor(0.0198, grad_fn= ) 128 tensor(0.0195, grad_fn= ) 129 tensor(0.0192, grad_fn= ) 130 tensor(0.0189, grad_fn= ) 131 tensor(0.0187, grad_fn= ) 132 tensor(0.0184, grad_fn= ) 133 tensor(0.0181, grad_fn= ) 134 tensor(0.0179, grad_fn= ) 135 tensor(0.0176, grad_fn= ) 136 tensor(0.0174, grad_fn= ) 137 tensor(0.0171, grad_fn= ) 138 tensor(0.0169, grad_fn= ) 139 tensor(0.0166, grad_fn= ) 140 tensor(0.0164, grad_fn= ) 141 tensor(0.0161, grad_fn= ) 142 tensor(0.0159, grad_fn= ) 143 tensor(0.0157, grad_fn= ) 144 tensor(0.0155, grad_fn= ) 145 tensor(0.0152, grad_fn= ) 146 tensor(0.0150, grad_fn= ) 147 tensor(0.0148, grad_fn= ) 148 tensor(0.0146, grad_fn= ) 149 tensor(0.0144, grad_fn= ) 150 tensor(0.0142, grad_fn= ) 151 tensor(0.0140, grad_fn= ) 152 tensor(0.0138, grad_fn= ) 153 tensor(0.0136, grad_fn= ) 154 tensor(0.0134, grad_fn= ) 155 tensor(0.0132, grad_fn= ) 156 tensor(0.0130, grad_fn= ) 157 tensor(0.0128, grad_fn= ) 158 tensor(0.0126, grad_fn= ) 159 tensor(0.0124, grad_fn= ) 160 tensor(0.0123, grad_fn= ) 161 tensor(0.0121, grad_fn= ) 162 tensor(0.0119, grad_fn= ) 163 tensor(0.0117, grad_fn= ) 164 tensor(0.0116, grad_fn= ) 165 tensor(0.0114, grad_fn= ) 166 tensor(0.0112, grad_fn= ) 167 tensor(0.0111, grad_fn= ) 168 tensor(0.0109, grad_fn= ) 169 tensor(0.0108, grad_fn= ) 170 tensor(0.0106, grad_fn= ) 171 tensor(0.0105, grad_fn= ) 172 tensor(0.0103, grad_fn= ) 173 tensor(0.0102, grad_fn= ) 174 tensor(0.0100, grad_fn= ) 175 tensor(0.0099, grad_fn= ) 176 tensor(0.0097, grad_fn= ) 177 tensor(0.0096, grad_fn= ) 178 tensor(0.0094, grad_fn= ) 179 tensor(0.0093, grad_fn= ) 180 tensor(0.0092, grad_fn= ) 181 tensor(0.0090, grad_fn= ) 182 tensor(0.0089, grad_fn= ) 183 tensor(0.0088, grad_fn= ) 184 tensor(0.0087, grad_fn= ) 185 tensor(0.0085, grad_fn= ) 186 tensor(0.0084, grad_fn= ) 187 tensor(0.0083, grad_fn= ) 188 tensor(0.0082, grad_fn= ) 189 tensor(0.0081, grad_fn= ) 190 tensor(0.0079, grad_fn= ) 191 tensor(0.0078, grad_fn= ) 192 tensor(0.0077, grad_fn= ) 193 tensor(0.0076, grad_fn= ) 194 tensor(0.0075, grad_fn= ) 195 tensor(0.0074, grad_fn= ) 196 tensor(0.0073, grad_fn= ) 197 tensor(0.0072, grad_fn= ) 198 tensor(0.0071, grad_fn= ) 199 tensor(0.0070, grad_fn= ) 200 tensor(0.0069, grad_fn= ) 201 tensor(0.0068, grad_fn= ) 202 tensor(0.0067, grad_fn= ) 203 tensor(0.0066, grad_fn= ) 204 tensor(0.0065, grad_fn= ) 205 tensor(0.0064, grad_fn= ) 206 tensor(0.0063, grad_fn= ) 207 tensor(0.0062, grad_fn= ) 208 tensor(0.0061, grad_fn= ) 209 tensor(0.0060, grad_fn= ) 210 tensor(0.0059, grad_fn= ) 211 tensor(0.0059, grad_fn= ) 212 tensor(0.0058, grad_fn= ) 213 tensor(0.0057, grad_fn= ) 214 tensor(0.0056, grad_fn= ) 215 tensor(0.0055, grad_fn= ) 216 tensor(0.0054, grad_fn= ) 217 tensor(0.0054, grad_fn= ) 218 tensor(0.0053, grad_fn= ) 219 tensor(0.0052, grad_fn= ) 220 tensor(0.0051, grad_fn= ) 221 tensor(0.0051, grad_fn= ) 222 tensor(0.0050, grad_fn= ) 223 tensor(0.0049, grad_fn= ) 224 tensor(0.0049, grad_fn= ) 225 tensor(0.0048, grad_fn= ) 226 tensor(0.0047, grad_fn= ) 227 tensor(0.0046, grad_fn= ) 228 tensor(0.0046, grad_fn= ) 229 tensor(0.0045, grad_fn= ) 230 tensor(0.0044, grad_fn= ) 231 tensor(0.0044, grad_fn= ) 232 tensor(0.0043, grad_fn= ) 233 tensor(0.0043, grad_fn= ) 234 tensor(0.0042, grad_fn= ) 235 tensor(0.0041, grad_fn= ) 236 tensor(0.0041, grad_fn= ) 237 tensor(0.0040, grad_fn= ) 238 tensor(0.0040, grad_fn= ) 239 tensor(0.0039, grad_fn= ) 240 tensor(0.0039, grad_fn= ) 241 tensor(0.0038, grad_fn= ) 242 tensor(0.0037, grad_fn= ) 243 tensor(0.0037, grad_fn= ) 244 tensor(0.0036, grad_fn= ) 245 tensor(0.0036, grad_fn= ) 246 tensor(0.0035, grad_fn= ) 247 tensor(0.0035, grad_fn= ) 248 tensor(0.0034, grad_fn= ) 249 tensor(0.0034, grad_fn= ) 250 tensor(0.0033, grad_fn=


