- 1. 数据描述
- 2.导入数据以及数据预处理
- 3. 构建学习网络
- 4. 模型的训练与评估
- 4.1 定义画图方法
- 4.2 训练并保存模型
- 5. 模型预测
- 5.1 模型的可视化真实值与预测值方法定义
- 5.2 预测模型
- 写在最后
本文基于百度飞浆Paddle平台
项目地址:
用PaddlePaddle做房价预测
1. 数据描述
波士顿数据框有 506 行和 14 列
对应特征:
-
crim:犯罪率
-
zn:划分为超过25,000平方英尺地段的住宅用地所占比例
-
indus:每镇非零售商铺面积比例
-
chas:是否临河
-
nox:氮氧化物浓度(千万分之一)
-
rm: 每个住宅的平均房间数
-
age:一九四年以前业主自住单位比例
-
dis:波士顿五个商业中心的加权平均距离
-
rad:放射状公路的可达性指数
-
tax:每$10,000的全价值物业税税率
-
ptratio: 学生-教师比例按城镇划分
-
black:1000(Bk - 0.63)^2其中Bk是按城镇划分的黑人比例
-
lstat:低收入阶层人口占比
-
medv:自住房屋价值中位数,以千元计
2.导入数据以及数据预处理
# 导入波士顿房价数据 import os import paddle import numpy as np # 设置训练Batch大小 BATCH_SIZE = 20 # 训练集 train_datasets = paddle.text.datasets.UCIHousing(mode= 'train') # 验证集 valid_datasets = paddle.text.datasets.UCIHousing(mode= 'test') # 用于训练的额数据集加载器,每次随机读取batch大小的数据,剩余不足的批次大小的数据将被丢弃 train_loader = paddle.io.DataLoader(train_datasets, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) # 测试集加载器,每次读取随机批次大小的数据 valid_loader = paddle.io.DataLoader(valid_datasets, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
# 打印数据类型 print(type(train_datasets))
# 打印查看uci_housing数据 print(train_datasets[0]) # 每一行是一个样本,每个样本有14个特征 # print(train_datasets)
(array([-0.0405441 , 0.06636363, -0.32356226, -0.06916996, -0.03435197,
0.05563625, -0.03475696, 0.02682186, -0.37171334, -0.21419305,
-0.33569506, 0.10143217, -0.21172912], dtype=float32), array([24.], dtype=float32))
3. 构建学习网络
# 定义网络结构 net = paddle.nn.Linear(13, 1) # 定义优化函数 optimizer = paddle.optimizer.SGD(learning_rate = 0.001, parameters = net.parameters())
4. 模型的训练与评估
4.1 定义画图方法
# 定义绘图函数
import matplotlib.pyplot as plt
iter = 0
iters = []
train_costs = []
def draw_train_process(iters, train_costs):
title = 'training costs'
plt.title(title, fontsize = 24)
plt.xlabel('iter', fontsize = 14)
plt.ylabel('cost', fontsize = 14)
plt.plot(iters, train_costs, color = 'red', label = 'training cost')
plt.grid()
plt.show()
4.2 训练并保存模型
# 定义训练轮次
EPOCH_NUM = 50
# 训练EPOCH_NUM轮
for pass_id in range(EPOCH_NUM):
# 开始训练并输出最后一个batch的损失值
train_cost = 0
# 遍历train_loader迭代器
for batch_id, data in enumerate(train_loader()):
# 分别提取训练集和标签
inputs = paddle.to_tensor(data[0])
labels = paddle.to_tensor(data[1])
# 计算输出
out = net(inputs)
# 计算损失函数(均方差)
train_loss = paddle.mean(paddle.nn.functional.square_error_cost(out, labels))
# 反向迭代
train_loss.backward()
# 优化并清空dw
optimizer.step()
optimizer.clear_grad()
# 每40步输出信息,
# 从0batch开始:0, 40, 80
if batch_id % 40 == 0:
print("Pass id: %d, cost: %0.5f" % (pass_id, train_loss))
iter = iter + BATCH_SIZE
iters.append(iter)
train_costs.append(train_loss.numpy()[0])
# 开始测试并输出最后一个batch的缺失值
test_loss = 0
# 遍历test_reader迭代器
for batch_id, data in enumerate(valid_loader()):
# 分别提取训练集和标签
inputs = paddle.to_tensor(data[0])
labels = paddle.to_tensor(data[1])
# 计算输出
out = net(inputs)
# 计算损失函数(均方差)
train_loss = paddle.mean(paddle.nn.functional.square_error_cost(out, labels))
# 打印最后一个batch的损失值
print("Pass id: %d, cost: %0.5f" % (pass_id, train_loss))
# 保存模型
paddle.save(net.state_dict(), 'fit_a_line.pdparams')
draw_train_process(iters, train_costs)
Pass id: 0, cost: 697.82361 Pass id: 0, cost: 244.72691 Pass id: 1, cost: 552.90552 Pass id: 1, cost: 221.64818 Pass id: 2, cost: 429.48627 Pass id: 2, cost: 192.41925 Pass id: 3, cost: 484.38715 Pass id: 3, cost: 173.44687 Pass id: 4, cost: 387.63263 Pass id: 4, cost: 221.82217 Pass id: 5, cost: 402.83636 Pass id: 5, cost: 128.57277 Pass id: 6, cost: 547.67670 Pass id: 6, cost: 173.92017 Pass id: 7, cost: 308.60843 Pass id: 7, cost: 147.69878 Pass id: 8, cost: 219.78215 Pass id: 8, cost: 107.03541 Pass id: 9, cost: 208.08089 Pass id: 9, cost: 112.56085 Pass id: 10, cost: 311.31699 Pass id: 10, cost: 123.64721 Pass id: 11, cost: 363.99536 Pass id: 11, cost: 105.76099 Pass id: 12, cost: 420.03595 Pass id: 12, cost: 72.11946 Pass id: 13, cost: 327.58844 Pass id: 13, cost: 88.04731 Pass id: 14, cost: 390.53604 Pass id: 14, cost: 85.30426 Pass id: 15, cost: 230.41187 Pass id: 15, cost: 63.41622 Pass id: 16, cost: 166.12924 Pass id: 16, cost: 63.62786 Pass id: 17, cost: 236.01692 Pass id: 17, cost: 77.35274 Pass id: 18, cost: 290.79498 Pass id: 18, cost: 33.24523 Pass id: 19, cost: 97.46697 Pass id: 19, cost: 51.36712 Pass id: 20, cost: 340.30243 Pass id: 20, cost: 32.06834 Pass id: 21, cost: 168.25182 Pass id: 21, cost: 38.49080 Pass id: 22, cost: 175.69730 Pass id: 22, cost: 35.95297 Pass id: 23, cost: 205.52931 Pass id: 23, cost: 71.39227 Pass id: 24, cost: 207.87970 Pass id: 24, cost: 28.49346 Pass id: 25, cost: 116.07450 Pass id: 25, cost: 35.36184 Pass id: 26, cost: 182.95099 Pass id: 26, cost: 33.28909 Pass id: 27, cost: 181.40346 Pass id: 27, cost: 40.69643 Pass id: 28, cost: 53.31166 Pass id: 28, cost: 23.08777 Pass id: 29, cost: 159.42133 Pass id: 29, cost: 49.56973 Pass id: 30, cost: 148.92783 Pass id: 30, cost: 34.58585 Pass id: 31, cost: 98.45925 Pass id: 31, cost: 22.08068 Pass id: 32, cost: 64.72221 Pass id: 32, cost: 18.82707 Pass id: 33, cost: 39.10216 Pass id: 33, cost: 41.70573 Pass id: 34, cost: 86.37087 Pass id: 34, cost: 27.22211 Pass id: 35, cost: 30.80962 Pass id: 35, cost: 32.63113 Pass id: 36, cost: 112.65536 Pass id: 36, cost: 16.17883 Pass id: 37, cost: 69.60898 Pass id: 37, cost: 19.16402 Pass id: 38, cost: 39.25970 Pass id: 38, cost: 22.66752 Pass id: 39, cost: 19.43631 Pass id: 39, cost: 24.29722 Pass id: 40, cost: 118.90907 Pass id: 40, cost: 17.93452 Pass id: 41, cost: 63.60981 Pass id: 41, cost: 21.98477 Pass id: 42, cost: 48.37024 Pass id: 42, cost: 42.80917 Pass id: 43, cost: 120.80444 Pass id: 43, cost: 16.80975 Pass id: 44, cost: 67.58469 Pass id: 44, cost: 19.20858 Pass id: 45, cost: 19.36415 Pass id: 45, cost: 31.67330 Pass id: 46, cost: 67.76525 Pass id: 46, cost: 19.26246 Pass id: 47, cost: 91.39543 Pass id: 47, cost: 30.07190 Pass id: 48, cost: 120.14481 Pass id: 48, cost: 17.26798 Pass id: 49, cost: 89.36503 Pass id: 49, cost: 21.93978 /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data5. 模型预测
5.1 模型的可视化真实值与预测值方法定义
infer_results = []
groud_truths = []
# 绘制真实值和预测值对比图
def draw_infer_result(groud_truths, infer_results):
title = 'Boston'
plt.title(title, fontsize = 24)
x = np.arange(1, 20)
y = x
plt.plot(x, y)
plt.xlabel('ground truth', fontsize = 14)
plt.ylabel('infer result', fontsize = 14)
plt.scatter(groud_truths, infer_results, color = 'green', label = 'training costs')
plt.legend()
plt.grid()
plt.show()
5.2 预测模型
import paddle
import numpy as np
import matplotlib.pyplot as plt
valid_datasets = paddle.text.UCIHousing(mode='test')
infer_loader = paddle.io.DataLoader(valid_datasets, batch_size=200)
# 先创建一个架构
infer_net = paddle.nn.Linear(13, 1)
# 加载模型
param = paddle.load('fit_a_line.pdparams')
infer_net.set_dict(param)
data = next(infer_loader())
inputs = paddle.to_tensor(data[0])
results = infer_net(inputs)
for idx, item in enumerate(zip(results, data[1])):
print("Index:%d, Infer Result: %.2f, Ground Truth: %.2f" % (idx, item[0], item[1]))
infer_results.append(item[0].numpy()[0])
groud_truths.append(item[1].numpy()[0])
draw_infer_result(groud_truths, infer_results)
Index:0, Infer Result: 12.86, Ground Truth: 8.50 Index:1, Infer Result: 12.75, Ground Truth: 5.00 Index:2, Infer Result: 12.67, Ground Truth: 11.90 Index:3, Infer Result: 14.19, Ground Truth: 27.90 Index:4, Infer Result: 13.26, Ground Truth: 17.20 Index:5, Infer Result: 13.88, Ground Truth: 27.50 Index:6, Infer Result: 13.41, Ground Truth: 15.00 Index:7, Infer Result: 13.40, Ground Truth: 17.20 Index:8, Infer Result: 11.47, Ground Truth: 17.90 Index:9, Infer Result: 13.08, Ground Truth: 16.30 Index:10, Infer Result: 10.83, Ground Truth: 7.00 Index:11, Infer Result: 12.48, Ground Truth: 7.20 Index:12, Infer Result: 13.14, Ground Truth: 7.50 Index:13, Infer Result: 12.54, Ground Truth: 10.40 Index:14, Infer Result: 12.28, Ground Truth: 8.80 Index:15, Infer Result: 13.67, Ground Truth: 8.40 Index:16, Infer Result: 14.19, Ground Truth: 16.70 Index:17, Infer Result: 14.12, Ground Truth: 14.20 Index:18, Infer Result: 14.37, Ground Truth: 20.80 Index:19, Infer Result: 13.34, Ground Truth: 13.40 Index:20, Infer Result: 13.96, Ground Truth: 11.70 Index:21, Infer Result: 12.72, Ground Truth: 8.30 Index:22, Infer Result: 14.42, Ground Truth: 10.20 Index:23, Infer Result: 13.74, Ground Truth: 10.90 Index:24, Infer Result: 13.72, Ground Truth: 11.00 Index:25, Infer Result: 13.13, Ground Truth: 9.50 Index:26, Infer Result: 14.08, Ground Truth: 14.50 Index:27, Infer Result: 13.92, Ground Truth: 14.10 Index:28, Infer Result: 14.86, Ground Truth: 16.10 Index:29, Infer Result: 14.02, Ground Truth: 14.30 Index:30, Infer Result: 13.76, Ground Truth: 11.70 Index:31, Infer Result: 13.28, Ground Truth: 13.40 Index:32, Infer Result: 13.44, Ground Truth: 9.60 Index:33, Infer Result: 12.45, Ground Truth: 8.70 Index:34, Infer Result: 12.15, Ground Truth: 8.40 Index:35, Infer Result: 13.51, Ground Truth: 12.80 Index:36, Infer Result: 13.52, Ground Truth: 10.50 Index:37, Infer Result: 14.01, Ground Truth: 17.10 Index:38, Infer Result: 14.17, Ground Truth: 18.40 Index:39, Infer Result: 14.03, Ground Truth: 15.40 Index:40, Infer Result: 13.11, Ground Truth: 10.80 Index:41, Infer Result: 13.01, Ground Truth: 11.80 Index:42, Infer Result: 14.04, Ground Truth: 14.90 Index:43, Infer Result: 14.22, Ground Truth: 12.60 Index:44, Infer Result: 14.10, Ground Truth: 14.10 Index:45, Infer Result: 13.93, Ground Truth: 13.00 Index:46, Infer Result: 13.73, Ground Truth: 13.40 Index:47, Infer Result: 14.30, Ground Truth: 15.20 Index:48, Infer Result: 14.39, Ground Truth: 16.10 Index:49, Infer Result: 14.67, Ground Truth: 17.80 Index:50, Infer Result: 13.58, Ground Truth: 14.90 Index:51, Infer Result: 13.85, Ground Truth: 14.10 Index:52, Infer Result: 13.45, Ground Truth: 12.70 Index:53, Infer Result: 13.72, Ground Truth: 13.50 Index:54, Infer Result: 14.41, Ground Truth: 14.90 Index:55, Infer Result: 14.69, Ground Truth: 20.00 Index:56, Infer Result: 14.41, Ground Truth: 16.40 Index:57, Infer Result: 14.75, Ground Truth: 17.70 Index:58, Infer Result: 14.88, Ground Truth: 19.50 Index:59, Infer Result: 15.11, Ground Truth: 20.20 Index:60, Infer Result: 15.39, Ground Truth: 21.40 Index:61, Infer Result: 15.43, Ground Truth: 19.90 Index:62, Infer Result: 13.83, Ground Truth: 19.00 Index:63, Infer Result: 14.07, Ground Truth: 19.10 Index:64, Infer Result: 14.76, Ground Truth: 19.10 Index:65, Infer Result: 15.33, Ground Truth: 20.10 Index:66, Infer Result: 14.93, Ground Truth: 19.90 Index:67, Infer Result: 15.19, Ground Truth: 19.60 Index:68, Infer Result: 15.38, Ground Truth: 23.20 Index:69, Infer Result: 15.82, Ground Truth: 29.80 Index:70, Infer Result: 14.06, Ground Truth: 13.80 Index:71, Infer Result: 13.74, Ground Truth: 13.30 Index:72, Infer Result: 14.54, Ground Truth: 16.70 Index:73, Infer Result: 13.26, Ground Truth: 12.00 Index:74, Infer Result: 14.30, Ground Truth: 14.60 Index:75, Infer Result: 14.83, Ground Truth: 21.40 Index:76, Infer Result: 15.92, Ground Truth: 23.00 Index:77, Infer Result: 16.14, Ground Truth: 23.70 Index:78, Infer Result: 16.29, Ground Truth: 25.00 Index:79, Infer Result: 16.34, Ground Truth: 21.80 Index:80, Infer Result: 15.95, Ground Truth: 20.60 Index:81, Infer Result: 16.18, Ground Truth: 21.20 Index:82, Infer Result: 15.11, Ground Truth: 19.10 Index:83, Infer Result: 15.84, Ground Truth: 20.60 Index:84, Infer Result: 15.65, Ground Truth: 15.20 Index:85, Infer Result: 14.94, Ground Truth: 7.00 Index:86, Infer Result: 14.31, Ground Truth: 8.10 Index:87, Infer Result: 15.73, Ground Truth: 13.60 Index:88, Infer Result: 16.46, Ground Truth: 20.10 Index:89, Infer Result: 20.31, Ground Truth: 21.80 Index:90, Infer Result: 20.51, Ground Truth: 24.50 Index:91, Infer Result: 20.40, Ground Truth: 23.10 Index:92, Infer Result: 19.09, Ground Truth: 19.70 Index:93, Infer Result: 19.87, Ground Truth: 18.30 Index:94, Infer Result: 20.13, Ground Truth: 21.20 Index:95, Infer Result: 19.60, Ground Truth: 17.50 Index:96, Infer Result: 19.72, Ground Truth: 16.80 Index:97, Infer Result: 21.10, Ground Truth: 22.40 Index:98, Infer Result: 20.79, Ground Truth: 20.60 Index:99, Infer Result: 21.10, Ground Truth: 23.90 Index:100, Infer Result: 21.01, Ground Truth: 22.00 Index:101, Infer Result: 20.78, Ground Truth: 11.90
写在最后
各位看官,都看到这里了,麻烦动动手指头给博主来个点赞8,您的支持作者最大的创作动力哟!
<(^-^)>
才疏学浅,若有纰漏,恳请斧正
本文章仅用于各位同志作为学习交流之用,不作任何商业用途,若涉及版权问题请速与作者联系,望悉知



