import numpy
np.random.seed(42)
def MSELoss(x, y):
assert x.shape == y.shape
return np.linalg.norm( x - y) ** 2
class LinearLayer:
def __init__(self, input_dim, output_dim):
# w,b初始值一定不能是全0,否则梯度永远是0,无法更新
self.W = np.random.normal(0, 0.1, (input_dim, output_dim))
self.b = np.random.normal(0, 0.1, (1, output_dim))
self.dW = np.zeros((input_dim, output_dim))
self.db = np.zeros((1, output_dim))
def forward(self, X):
return np.matmul(X, self.W) + self.b
def backward(self, X, grad):
self.dW = np.matmul(X.T, grad)
self.db = np.matmul(grad.T, np.ones(X.shape[0]))
return np.matmul(grad, self.W.T)
def update(self, lr):
# 梯度下降更新参数
self.W = self.W - self.dW * lr
self.b = self.b - self.db * lr
class Relu:
def __init__(self):
pass
def forward(self, X):
return np.where(X < 0, 0, X)
def backward(self, X, grad):
return np.where(X > 0, 1, 0) * grad
2. Train:
#训练数据:经典的异或分类问题
train_X = np.array([[0,0],[0,1],[1,0],[1,1]])
train_y = np.array([0,1,1,0])
#初始化网络,总共2层,输入数据是2维,第一层3个节点,第二层1个节点作为输出层,激活函数使用Relu
fc1 = LinearLayer(2,3)
relu1 = Relu()
fc2 = LinearLayer(3,1)
#学习率
learn_rate = 0.01
#开始训练网络
for i in range(10000):
#前向传播Forward,获取网络输出
input_x = train_X
fc1_out = fc1.forward(input_x)
relu1_out = relu1.forward(fc1_out)
fc2_out = fc2.forward(relu1_out)
output_y = fc2_out
#获得网络当前输出,计算损失loss
result = output_y.reshape(output_y.shape[0])
# (4,1) => (4,)
loss = MSELoss(train_y, result) # mean squared error loss
#将梯度反向逐层传播,获取要更新参数的梯度
grad = (result - train_y).reshape(result.shape[0],1)
grad = fc2.backward(relu1_out, grad)
grad = relu1.backward(fc1_out, grad)
grad = fc1.backward(input_x, grad)
#更新网络中线性层的参数
fc1.update(learn_rate)
fc2.update(learn_rate)
#判断学习是否完成
if i % 100 == 0:
print(loss)
if loss < 0.001:
print("train over! 第%d次迭代" %(i))
break
3. Predict:
#将训练好的层堆叠组合成model
model = [fc1, relu1, fc2]
#预测
def predict(model, X):
tmp = X
for layer in model:
tmp = layer.forward(tmp)
return np.where(tmp > 0.5, 1, 0)
print("*"*20)
X = np.array([[0,0],[0,1],[1,0],[1,1]])
result = predict(model, X)
print("预测数据")
print(X)
print("*"*20)
print("预测结果")
print(result)



