栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch 深度学习 第七讲

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch 深度学习 第七讲

处理多维数据特征的输入

关于一维特征的输入(也就是只有一个x)应该如何处理?比如下图这个预测一个人在一年之后得糖尿病的概率的例子,这个时候我们的输入将会有很多的指标。你可以把它看成是我们体检的各种值。最后一排的Y代表了他是否会得糖尿病。

那么多维的特征输入应该怎么办呢?

我们就需要把每一个特征x付以相应的权重。在进行逻辑回归时,把每一个维度的x乘相应的权值的和加上一个偏置量,送入sigmoid函数进行二分类,就像这样:

 

import  torch
import  torch.nn.functional as F
import  numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

xy=np.loadtxt('./data/Diabetes_class.csv.gz',delimiter=',',dtype=np.float32)#加载训练集合
x_data = torch.from_numpy(xy[:,:-1])#取前八列
y_data = torch.from_numpy(xy[:,[-1]])#取最后一列

test =np.loadtxt('./data/test_class.csv.gz',delimiter=',',dtype=np.float32)#加载测试集合,这里我用数据集的最后一个样本做测试,训练集中没有最后一个样本
test_x = torch.from_numpy(test)

class Model(torch.nn.Module):
    def __init__(self):#构造函数
        super(Model,self).__init__()
        self.linear1 = torch.nn.Linear(8,6)#8维到6维
        self.linear2 = torch.nn.Linear(6, 4)#6维到4维
        self.linear3 = torch.nn.Linear(4, 1)#4维到1维
        self.sigmoid = torch.nn.Sigmoid()#因为他里边也没有权重需要更新,所以要一个就行了,单纯的算个数


    def forward(self, x):#构建一个计算图,就像上面图片画的那样
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))#将上面一行的输出作为输入
        x = self.sigmoid(self.linear3(x))
        return  x

model = Model()#实例化模型

criterion = torch.nn.BCELoss(size_average=False)
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)#lr为学习率,因为0.01太小了,我改成了0.1

for epoch in range(1000):
    #Forward
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())
    #Backward
    optimizer.zero_grad()
    loss.backward()
    #update
    optimizer.step()

y_pred = model(x_data)

print(y_pred.detach().numpy())

y_pred2 = model(test_x)
print(y_pred2.data.item())

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/743822.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号