栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch深度学习入门 || 系列(六)——多元分类_学pytorch的基础?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch深度学习入门 || 系列(六)——多元分类_学pytorch的基础?

文章目录

0 写在前面1 softmax函数2 数据预处理

2.1 scatter()函数的cmap属性 3 激活函数4 模型搭建5 完整代码6 输出分析

6.1 目标6.2 运行过程7 总结

0 写在前面

二分类问题是多分类问题的一种特殊情况,区别在于多分类用softmax代替sigmoid函数。softmax函数将所有分类的分数值转化为概率,且各概率的和为1。 1 softmax函数

softmax函数首先对所有的输出值通过指数函数,将实数输出映射到正无穷然后将所有的结果相加作为分母
2 数据预处理

首先下面这段代码可以直接运行!

首先cluster数据,形状为500×2,各元素值为1。然后用normal()函数,以4为均值,2 为标准差生产data0
其他同理
import torch
import matplotlib.pyplot as plt

cluster = torch.ones(500, 2)
data0 = torch.normal(4*cluster, 2)
data1 = torch.normal(-4*cluster, 1)
data2 = torch.normal(-8*cluster, 1)
label0 = torch.zeros(500)
label1 = torch.ones(500)
label2 = 2*label1

x = torch.cat((data0, data1, data2), ).type(torch.FloatTensor)
y = torch.cat((label0, label1, label2), ).type(torch.LongTensor)

plt.scatter(x.numpy()[:, 0], x.numpy()[:, 1], c=y.numpy(), s=10, lw=0, cmap='Accent')
plt.show()

根据均值,可以看出灰色的数据群是data2,蓝色数据群是data1,绿色的数据群是data0。符合正态分布的特点——中心点聚集的数据点较多,四周的数据点分散。
2.1 scatter()函数的cmap属性

1 可以不写,有默认值的2 如果想花里胡哨一点的话,可以看下图
3 激活函数

隐藏层激活函数采用relu(),最后分类激活函数采用softmax 4 模型搭建

一个隐藏层(维度变换从input_figure到num_hidden)一个输出层(维度变换从num_hidden到outputs)

class Net(nn.Module):
    def __init__(self, input_feature, num_hidden, outputs):
        super(Net, self).__init__()
        self.hidden = nn.Linear(input_feature, num_hidden)
        self.out = nn.Linear(num_hidden, outputs)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        x = self.out(x)
        x = F.softmax(x, dim=1)
        return x

这里初始化模型,输入是2d,中间是20d,输出是3d。输入输出维度不能改变,中间的维度可以随意设定。

net = Net(input_feature=2, num_hidden=20, outputs=3).cuda()
inputs = x.cuda()
target = y.cuda()
5 完整代码

可以直接运行!如果你觉得本文对你有帮助的话,感谢点赞收藏+评论哦!

import torch
import matplotlib.pyplot as plt


cluster = torch.ones(500, 2)
data0 = torch.normal(4*cluster, 2)
data1 = torch.normal(-4*cluster, 1)
data2 = torch.normal(-8*cluster, 1)
label0 = torch.zeros(500)
label1 = torch.ones(500)
label2 = 2*label1

x = torch.cat((data0, data1, data2), ).type(torch.FloatTensor)
y = torch.cat((label0, label1, label2), ).type(torch.LongTensor)

# plt.scatter(x.numpy()[:, 0], x.numpy()[:, 1], c=y.numpy(), s=10, lw=0, cmap='Accent')
# plt.show()

import torch.nn.functional as F
from torch import nn, optim


class Net(nn.Module):
    def __init__(self, input_feature, num_hidden, outputs):
        super(Net, self).__init__()
        self.hidden = nn.Linear(input_feature, num_hidden)
        self.out = nn.Linear(num_hidden, outputs)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        x = self.out(x)
        x = F.softmax(x, dim=1)
        return x


net = Net(input_feature=2, num_hidden=20, outputs=3).cuda()
inputs = x.cuda()
target = y.cuda()

optimizer = optim.SGD(net.parameters(), lr=0.02)
criterion = nn.CrossEntropyLoss()


def draw(output):
    output = output.cpu()
    plt.cla()

    output = torch.max((output), 1)[1]
    pred_y = output.data.numpy().squeeze()
    target_y = y.numpy()

    plt.scatter(x.numpy()[:, 0], x.numpy()[:, 1], c=pred_y, s=10, lw=0, cmap='RdYlGn')
    accuracy = sum(pred_y == target_y) / 1500.0
    plt.text(1.5, -4, 'Accuracy=%s' % (accuracy), fontdict={'size':20, 'color':'red'})
    plt.show()


def train(model, criterion, optimizer, epochs):
    for epoch in range(epochs):
        output = model(inputs)
        loss = criterion(output, target)

        optimizer.zero_grad()  # 梯度清零
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            draw(output)

train(net, criterion, optimizer, 1000)
6 输出分析 6.1 目标

目标就是向第二节的图像靠拢,三个数据群分别是不一样的颜色,然后accuracy接近1.

6.2 运行过程

训练100epochs的情况,效果不是很好,这里只分了2类

训练200epochs的情况,效果和100差不多

训练300epochs的情况

训练400epochs的情况

训练500epochs的情况,可以看到,data1的数据群一半已经变成黄色

训练600epochs的情况

训练700epochs的情况

训练800epochs的情况

训练900epochs的情况

训练00epochs的情况

7 总结

训练到1000epochs的时候,accuracy已经达到0.99,能够非常明显地看出图片被分为了三类。目前所有的操作都是在训练集上完成的,以后会学习测试集上训练的情况!
- 如果有用的话麻烦三连,点点关注哦!

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/783300.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号