栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

基于pytorch简单线性神经网络处理Titanic数据预测笔记

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

基于pytorch简单线性神经网络处理Titanic数据预测笔记

处理好的数据带有索引行,所以不需要第一行

用np.loadtxt读取CSV文件和pd.read_csv区别

用法     pd.read_csv('xxxx.csv')         np.loadtxt(filepath, delimiter=',', dtype=np.float32,skiprows =)

读取后的数据类型  

x_data = np.loadtxt('p_train.csv', delimiter=',', dtype=np.float32,skiprows=1)
print(type(x_data))
#输出

train = pd.read_csv('train.csv')
print(type(train))
#输出

输出格式

train = pd.read_csv('train.csv')
print(train)
#输出

     PassengerId  Survived  Pclass  ...     Fare Cabin  Embarked
0              1         0       3  ...   7.2500   NaN         S
1              2         1       1  ...  71.2833   C85         C
2              3         1       3  ...   7.9250   NaN         S
3              4         1       1  ...  53.1000  C123         S
4              5         0       3  ...   8.0500   NaN         S
..           ...       ...     ...  ...      ...   ...       ...
886          887         0       2  ...  13.0000   NaN         S
887          888         1       1  ...  30.0000   B42         S
888          889         0       3  ...  23.4500   NaN         S
889          890         1       1  ...  30.0000  C148         C
890          891         0       3  ...   7.7500   NaN         Q
x_data = np.loadtxt('p_train.csv', delimiter=',', dtype=np.float32,skiprows=1)
print(x_data)
#输出
[[  0.   0.   0. ...   0.   0.   1.]
 [  1.   0.   0. ...   1.   0.   0.]
 [  2.   0.   1. ...   0.   0.   1.]
 ...
 [888.   0.   1. ...   0.   0.   1.]
 [889.   0.   0. ...   1.   0.   0.]
 [890.   0.   0. ...   0.   1.   0.]]

切片问题

[ : , : ]  前面是多少行到多少行  后面是取几列

单独取出一列时可以用[ : , [-1]]表示取出最后一列

对于训练好的模型保存方式

第一种:保存模型

#----保存----

torch.save(model, 'model_name.pth')

#----加载----

model = torch.load('model_name.pth')

第二种:保存模型参数

#----保存----
torch.save(model.state_dict(), 'params_name.pth') #保存的文件名后缀一般是.pt或.pth
#----加载----
model=Model() #定义模型结构
model.load_state_dict(torch.load('params_name.pth'))  #加载模型参数

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/331468.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号