栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch softmax回归_pytorch中softmax怎么用?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch softmax回归_pytorch中softmax怎么用?

SOFTMAX回归手动实现
import torch
import torchvision
import numpy as np
import sys
import d2l
batch_size=256
train_iter,test_iter =d2l.load_data_fashion_mnist(batch_size)

num_inputs=784
num_output=10

W=torch.tensor(np.random.normal(0,0.01,(num_inputs,num_outputs)),dtype=torch.float)
b=torch.zeros(num_outputs,dtype=torch.float)

W.required_grad_(requires_grad=True)
b.required_grad_(required_grad=True)

def softmax(X):
	X_exp=X.exp()
	partition=X_exp.sum(dim=1,keepdim=True)
	return X_exp/partition # 这里使用了广播的方法

'''
# 测试
X=torch.rand((2,5))
X_prob=softmax(x)
print(X_prob,X_prob.sum(dim=1))
'''
定义模型
def net(X):
	return softmax(torch.mm(X.view((-1,num_inputs)),W)+b)
定义损失函数
y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
y=torch.LongTensor([0,2])
print(".gather()函数的解释")
print(y_hat.gather(1,y.view(-1,1)))

'''
tensor([[0.1000],
 [0.5000]])'''
#下⾯实现了3.4节(softmax回归)中介绍的交叉熵损失函数。
def cross_entropy(y_hat, y):
    return -torch.log(y_hat.gather(1, y.view(-1, 1)))
 
'''
给定⼀个类别的预测概率分布 y_hat ,我们把预测概率最⼤的类别作为输出类别。如果它与真实类
别 y ⼀致,说明这次预测是正确的。分类准确率即正确预测数量与总预测数量之⽐'''
def accuracy(y_hat,y):
    return (y_hat.argmax(dim=1)==y).float().mean().item()

'''
训练softmax回归的实现跟“线性回归的从零开始实现” ⼀节介绍的线性回归中的实现⾮常相似。我们同
样使⽤⼩批量随机梯度下降来优化模型的损失函数。在训练模型时,迭代周期数 num_epochs 和学习
率 lr 都是可以调的超参数。改变它们的值可能会得到分类更准确的模型'''

num_epochs,lr=5,0.1

d2l.train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,[W,b],lr)

#3.6.8 预测
#给定⼀系列图像(第三⾏图像输出),我们⽐较⼀下它们的真实标签(第⼀⾏⽂本输出)和模型预测结果(第⼆⾏⽂本输出)
x,y=iter(test_iter).next()
true_labels=d2l.get_fashion_mnist_labels(y.numpy())
pred_lables=d2l.get_fashion_mnist_labels(net(x).argmax(dim=1).numpy())
titles=[true+'n'+pred for true,pred in zip(true_labels,pred_lables)]
d2l.show_fashion_mnist(x[0:9],titles[0:9])
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/783555.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号