任务:基于 pytorch 实现房价预测
- 收集数据,对数据的属性进行介绍编程实现数据预处理并保存数据统计分析并绘制效果图
gitee仓库地址
个人博客地址
数据来源
导入包import torch import torch.nn as nn import numpy as np import pandas as pd import matplotlib.pyplot as plt读取数据
train_data = pd.read_csv("../data/train.csv")
test_data = pd.read_csv("../data/test.csv")
查看数据
train_data.head(20)
| Id | MSSubClass | MSZoning | LotFrontage | LotArea | Street | Alley | LotShape | LandContour | Utilities | ... | PoolArea | PoolQC | Fence | MiscFeature | MiscVal | MoSold | YrSold | SaleType | SaleCondition | SalePrice | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 60 | RL | 65.0 | 8450 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 2 | 2008 | WD | Normal | 208500 |
| 1 | 2 | 20 | RL | 80.0 | 9600 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 5 | 2007 | WD | Normal | 181500 |
| 2 | 3 | 60 | RL | 68.0 | 11250 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 9 | 2008 | WD | Normal | 223500 |
| 3 | 4 | 70 | RL | 60.0 | 9550 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 2 | 2006 | WD | Abnorml | 140000 |
| 4 | 5 | 60 | RL | 84.0 | 14260 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 12 | 2008 | WD | Normal | 250000 |
| 5 | 6 | 50 | RL | 85.0 | 14115 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | MnPrv | Shed | 700 | 10 | 2009 | WD | Normal | 143000 |
| 6 | 7 | 20 | RL | 75.0 | 10084 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 8 | 2007 | WD | Normal | 307000 |
| 7 | 8 | 60 | RL | NaN | 10382 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | NaN | Shed | 350 | 11 | 2009 | WD | Normal | 200000 |
| 8 | 9 | 50 | RM | 51.0 | 6120 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 4 | 2008 | WD | Abnorml | 129900 |
| 9 | 10 | 190 | RL | 50.0 | 7420 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 1 | 2008 | WD | Normal | 118000 |
| 10 | 11 | 20 | RL | 70.0 | 11200 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 2 | 2008 | WD | Normal | 129500 |
| 11 | 12 | 60 | RL | 85.0 | 11924 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 7 | 2006 | New | Partial | 345000 |
| 12 | 13 | 20 | RL | NaN | 12968 | Pave | NaN | IR2 | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 9 | 2008 | WD | Normal | 144000 |
| 13 | 14 | 20 | RL | 91.0 | 10652 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 8 | 2007 | New | Partial | 279500 |
| 14 | 15 | 20 | RL | NaN | 10920 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | GdWo | NaN | 0 | 5 | 2008 | WD | Normal | 157000 |
| 15 | 16 | 45 | RM | 51.0 | 6120 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | GdPrv | NaN | 0 | 7 | 2007 | WD | Normal | 132000 |
| 16 | 17 | 20 | RL | NaN | 11241 | Pave | NaN | IR1 | Lvl | AllPub | ... | 0 | NaN | NaN | Shed | 700 | 3 | 2010 | WD | Normal | 149000 |
| 17 | 18 | 90 | RL | 72.0 | 10791 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | Shed | 500 | 10 | 2006 | WD | Normal | 90000 |
| 18 | 19 | 20 | RL | 66.0 | 13695 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | NaN | NaN | 0 | 6 | 2008 | WD | Normal | 159000 |
| 19 | 20 | 20 | RL | 70.0 | 7560 | Pave | NaN | Reg | Lvl | AllPub | ... | 0 | NaN | MnPrv | NaN | 0 | 5 | 2009 | COD | Abnorml | 139000 |
20 rows × 81 columns
train_data.describe()
| Id | MSSubClass | LotFrontage | LotArea | OverallQual | OverallCond | YearBuilt | YearRemodAdd | MasVnrArea | BsmtFinSF1 | ... | WoodDeckSF | OpenPorchSF | EnclosedPorch | 3SsnPorch | ScreenPorch | PoolArea | MiscVal | MoSold | YrSold | SalePrice | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 1460.000000 | 1460.000000 | 1201.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1452.000000 | 1460.000000 | ... | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 | 1460.000000 |
| mean | 730.500000 | 56.897260 | 70.049958 | 10516.828082 | 6.099315 | 5.575342 | 1971.267808 | 1984.865753 | 103.685262 | 443.639726 | ... | 94.244521 | 46.660274 | 21.954110 | 3.409589 | 15.060959 | 2.758904 | 43.489041 | 6.321918 | 2007.815753 | 180921.195890 |
| std | 421.610009 | 42.300571 | 24.284752 | 9981.264932 | 1.382997 | 1.112799 | 30.202904 | 20.645407 | 181.066207 | 456.098091 | ... | 125.338794 | 66.256028 | 61.119149 | 29.317331 | 55.757415 | 40.177307 | 496.123024 | 2.703626 | 1.328095 | 79442.502883 |
| min | 1.000000 | 20.000000 | 21.000000 | 1300.000000 | 1.000000 | 1.000000 | 1872.000000 | 1950.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 2006.000000 | 34900.000000 |
| 25% | 365.750000 | 20.000000 | 59.000000 | 7553.500000 | 5.000000 | 5.000000 | 1954.000000 | 1967.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 5.000000 | 2007.000000 | 129975.000000 |
| 50% | 730.500000 | 50.000000 | 69.000000 | 9478.500000 | 6.000000 | 5.000000 | 1973.000000 | 1994.000000 | 0.000000 | 383.500000 | ... | 0.000000 | 25.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 6.000000 | 2008.000000 | 163000.000000 |
| 75% | 1095.250000 | 70.000000 | 80.000000 | 11601.500000 | 7.000000 | 6.000000 | 2000.000000 | 2004.000000 | 166.000000 | 712.250000 | ... | 168.000000 | 68.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 8.000000 | 2009.000000 | 214000.000000 |
| max | 1460.000000 | 190.000000 | 313.000000 | 215245.000000 | 10.000000 | 9.000000 | 2010.000000 | 2010.000000 | 1600.000000 | 5644.000000 | ... | 857.000000 | 547.000000 | 552.000000 | 508.000000 | 480.000000 | 738.000000 | 15500.000000 | 12.000000 | 2010.000000 | 755000.000000 |
8 rows × 38 columns
数据中各属性在data/data_description.txt中已详细介绍,再次不再赘述
数据预处理train_data.shape, test_data.shape
((1460, 81), (1459, 80))
Id属性对房价预测没有影响,去除
train_data = train_data.drop(['Id'], axis=1) train_data
| MSSubClass | MSZoning | LotFrontage | LotArea | Street | Alley | LotShape | LandContour | Utilities | LotConfig | ... | PoolArea | PoolQC | Fence | MiscFeature | MiscVal | MoSold | YrSold | SaleType | SaleCondition | SalePrice | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 60 | RL | 65.0 | 8450 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 2 | 2008 | WD | Normal | 208500 |
| 1 | 20 | RL | 80.0 | 9600 | Pave | NaN | Reg | Lvl | AllPub | FR2 | ... | 0 | NaN | NaN | NaN | 0 | 5 | 2007 | WD | Normal | 181500 |
| 2 | 60 | RL | 68.0 | 11250 | Pave | NaN | IR1 | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 9 | 2008 | WD | Normal | 223500 |
| 3 | 70 | RL | 60.0 | 9550 | Pave | NaN | IR1 | Lvl | AllPub | Corner | ... | 0 | NaN | NaN | NaN | 0 | 2 | 2006 | WD | Abnorml | 140000 |
| 4 | 60 | RL | 84.0 | 14260 | Pave | NaN | IR1 | Lvl | AllPub | FR2 | ... | 0 | NaN | NaN | NaN | 0 | 12 | 2008 | WD | Normal | 250000 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1455 | 60 | RL | 62.0 | 7917 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 8 | 2007 | WD | Normal | 175000 |
| 1456 | 20 | RL | 85.0 | 13175 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | MnPrv | NaN | 0 | 2 | 2010 | WD | Normal | 210000 |
| 1457 | 70 | RL | 66.0 | 9042 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | GdPrv | Shed | 2500 | 5 | 2010 | WD | Normal | 266500 |
| 1458 | 20 | RL | 68.0 | 9717 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 4 | 2010 | WD | Normal | 142125 |
| 1459 | 20 | RL | 75.0 | 9937 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 6 | 2008 | WD | Normal | 147500 |
1460 rows × 80 columns
# 拆分数据与标签 # tx = train_data.drop(['SalePrice'],axis=1) # ty = train_data['SalePrice'] # tx, ty tx = train_data tx
| MSSubClass | MSZoning | LotFrontage | LotArea | Street | Alley | LotShape | LandContour | Utilities | LotConfig | ... | PoolArea | PoolQC | Fence | MiscFeature | MiscVal | MoSold | YrSold | SaleType | SaleCondition | SalePrice | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 60 | RL | 65.0 | 8450 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 2 | 2008 | WD | Normal | 208500 |
| 1 | 20 | RL | 80.0 | 9600 | Pave | NaN | Reg | Lvl | AllPub | FR2 | ... | 0 | NaN | NaN | NaN | 0 | 5 | 2007 | WD | Normal | 181500 |
| 2 | 60 | RL | 68.0 | 11250 | Pave | NaN | IR1 | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 9 | 2008 | WD | Normal | 223500 |
| 3 | 70 | RL | 60.0 | 9550 | Pave | NaN | IR1 | Lvl | AllPub | Corner | ... | 0 | NaN | NaN | NaN | 0 | 2 | 2006 | WD | Abnorml | 140000 |
| 4 | 60 | RL | 84.0 | 14260 | Pave | NaN | IR1 | Lvl | AllPub | FR2 | ... | 0 | NaN | NaN | NaN | 0 | 12 | 2008 | WD | Normal | 250000 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1455 | 60 | RL | 62.0 | 7917 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 8 | 2007 | WD | Normal | 175000 |
| 1456 | 20 | RL | 85.0 | 13175 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | MnPrv | NaN | 0 | 2 | 2010 | WD | Normal | 210000 |
| 1457 | 70 | RL | 66.0 | 9042 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | GdPrv | Shed | 2500 | 5 | 2010 | WD | Normal | 266500 |
| 1458 | 20 | RL | 68.0 | 9717 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 4 | 2010 | WD | Normal | 142125 |
| 1459 | 20 | RL | 75.0 | 9937 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | 0 | NaN | NaN | NaN | 0 | 6 | 2008 | WD | Normal | 147500 |
1460 rows × 80 columns
# 标准化后,房价也同样落在了[-1, 1]的区间里,预测出来的值不是真实值,为解决这个问题,计算出 mean max min后反算即可 d_mean = tx['SalePrice'].mean() d_max = tx['SalePrice'].max() d_min = tx['SalePrice'].min()连续型属性处理
数据属性有些是连续型,有些是离散型,统计这些属性如下
continuous_colmuns = [] continuous_colmuns.extend(list(tx.dtypes[train_data.dtypes == np.int64].index)) continuous_colmuns.extend(list(tx.dtypes[train_data.dtypes == np.float64].index)) continuous_colmuns
['MSSubClass', 'LotArea', 'OverallQual', 'OverallCond', 'YearBuilt', 'YearRemodAdd', 'BsmtFinSF1', 'BsmtFinSF2', 'BsmtUnfSF', 'TotalBsmtSF', '1stFlrSF', '2ndFlrSF', 'LowQualFinSF', 'GrLivArea', 'BsmtFullBath', 'BsmtHalfBath', 'FullBath', 'HalfBath', 'BedroomAbvGr', 'KitchenAbvGr', 'TotRmsAbvGrd', 'Fireplaces', 'GarageCars', 'GarageArea', 'WoodDeckSF', 'OpenPorchSF', 'EnclosedPorch', '3SsnPorch', 'ScreenPorch', 'PoolArea', 'MiscVal', 'MoSold', 'YrSold', 'SalePrice', 'LotFrontage', 'MasVnrArea', 'GarageYrBlt']
# 查找连续值是否存在缺失值 tx[continuous_colmuns].isnull().sum(),tx[continuous_colmuns].isna().sum()
(MSSubClass 0 LotArea 0 OverallQual 0 OverallCond 0 YearBuilt 0 YearRemodAdd 0 BsmtFinSF1 0 BsmtFinSF2 0 BsmtUnfSF 0 TotalBsmtSF 0 1stFlrSF 0 2ndFlrSF 0 LowQualFinSF 0 GrLivArea 0 BsmtFullBath 0 BsmtHalfBath 0 FullBath 0 HalfBath 0 BedroomAbvGr 0 KitchenAbvGr 0 TotRmsAbvGrd 0 Fireplaces 0 GarageCars 0 GarageArea 0 WoodDeckSF 0 OpenPorchSF 0 EnclosedPorch 0 3SsnPorch 0 ScreenPorch 0 PoolArea 0 MiscVal 0 MoSold 0 YrSold 0 SalePrice 0 LotFrontage 259 MasVnrArea 8 GarageYrBlt 81 dtype: int64, MSSubClass 0 LotArea 0 OverallQual 0 OverallCond 0 YearBuilt 0 YearRemodAdd 0 BsmtFinSF1 0 BsmtFinSF2 0 BsmtUnfSF 0 TotalBsmtSF 0 1stFlrSF 0 2ndFlrSF 0 LowQualFinSF 0 GrLivArea 0 BsmtFullBath 0 BsmtHalfBath 0 FullBath 0 HalfBath 0 BedroomAbvGr 0 KitchenAbvGr 0 TotRmsAbvGrd 0 Fireplaces 0 GarageCars 0 GarageArea 0 WoodDeckSF 0 OpenPorchSF 0 EnclosedPorch 0 3SsnPorch 0 ScreenPorch 0 PoolArea 0 MiscVal 0 MoSold 0 YrSold 0 SalePrice 0 LotFrontage 259 MasVnrArea 8 GarageYrBlt 81 dtype: int64)
# 标准化连续值 tx[continuous_colmuns] = tx[continuous_colmuns].apply(lambda x: (x - x.mean())/(x.std())) tx[continuous_colmuns]
| MSSubClass | LotArea | OverallQual | OverallCond | YearBuilt | YearRemodAdd | BsmtFinSF1 | BsmtFinSF2 | BsmtUnfSF | TotalBsmtSF | ... | 3SsnPorch | ScreenPorch | PoolArea | MiscVal | MoSold | YrSold | SalePrice | LotFrontage | MasVnrArea | GarageYrBlt | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.073350 | -0.207071 | 0.651256 | -0.517023 | 1.050634 | 0.878367 | 0.575228 | -0.288554 | -0.944267 | -0.459145 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | -1.598563 | 0.138730 | 0.347154 | -0.207948 | 0.509840 | 0.992066 |
| 1 | -0.872264 | -0.091855 | -0.071812 | 2.178881 | 0.156680 | -0.429430 | 1.171591 | -0.288554 | -0.641008 | 0.466305 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | -0.488943 | -0.614228 | 0.007286 | 0.409724 | -0.572637 | -0.101506 |
| 2 | 0.073350 | 0.073455 | 0.651256 | -0.517023 | 0.984415 | 0.829930 | 0.092875 | -0.288554 | -0.301540 | -0.313261 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | 0.990552 | 0.138730 | 0.535970 | -0.084413 | 0.322063 | 0.911061 |
| 3 | 0.309753 | -0.096864 | 0.651256 | -0.517023 | -1.862993 | -0.720051 | -0.499103 | -0.288554 | -0.061648 | -0.687089 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | -1.598563 | -1.367186 | -0.515105 | -0.413838 | -0.572637 | 0.789553 |
| 4 | 0.073350 | 0.375020 | 1.374324 | -0.517023 | 0.951306 | 0.733056 | 0.463410 | -0.288554 | -0.174805 | 0.199611 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | 2.100173 | 0.138730 | 0.869545 | 0.574436 | 1.360357 | 0.870558 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1455 | 0.073350 | -0.260471 | -0.071812 | -0.517023 | 0.918196 | 0.733056 | -0.972685 | -0.288554 | 0.873022 | -0.238040 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | 0.620678 | -0.614228 | -0.074534 | -0.331482 | -0.572637 | 0.830055 |
| 1456 | -0.872264 | 0.266316 | -0.071812 | 0.381612 | 0.222899 | 0.151813 | 0.759399 | 0.721865 | 0.049245 | 1.104547 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | -1.598563 | 1.644646 | 0.366036 | 0.615614 | 0.084581 | -0.020501 |
| 1457 | 0.309753 | -0.147760 | 0.651256 | 3.077516 | -1.002149 | 1.023678 | -0.369744 | -0.288554 | 0.701025 | 0.215567 | ... | -0.116299 | -0.270116 | -0.068668 | 4.951415 | -0.488943 | 1.644646 | 1.077242 | -0.166770 | -0.572637 | -1.519100 |
| 1458 | -0.872264 | -0.080133 | -0.794879 | 0.381612 | -0.704164 | 0.539309 | -0.865252 | 6.090101 | -1.283736 | 0.046889 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | -0.858816 | 1.644646 | -0.488356 | -0.084413 | -0.572637 | -1.154576 |
| 1459 | -0.872264 | -0.058092 | -0.794879 | 0.381612 | -0.207523 | -0.962236 | 0.847099 | 1.509123 | -0.975951 | 0.452629 | ... | -0.116299 | -0.270116 | -0.068668 | -0.087658 | -0.119069 | 0.138730 | -0.420697 | 0.203833 | -0.572637 | -0.547036 |
1460 rows × 37 columns
# 缺失值默认设置为均值,标准化后均值为0 tx[continuous_colmuns] = tx[continuous_colmuns].fillna(0) tx
| MSSubClass | MSZoning | LotFrontage | LotArea | Street | Alley | LotShape | LandContour | Utilities | LotConfig | ... | PoolArea | PoolQC | Fence | MiscFeature | MiscVal | MoSold | YrSold | SaleType | SaleCondition | SalePrice | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.073350 | RL | -0.207948 | -0.207071 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | -1.598563 | 0.138730 | WD | Normal | 0.347154 |
| 1 | -0.872264 | RL | 0.409724 | -0.091855 | Pave | NaN | Reg | Lvl | AllPub | FR2 | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | -0.488943 | -0.614228 | WD | Normal | 0.007286 |
| 2 | 0.073350 | RL | -0.084413 | 0.073455 | Pave | NaN | IR1 | Lvl | AllPub | Inside | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | 0.990552 | 0.138730 | WD | Normal | 0.535970 |
| 3 | 0.309753 | RL | -0.413838 | -0.096864 | Pave | NaN | IR1 | Lvl | AllPub | Corner | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | -1.598563 | -1.367186 | WD | Abnorml | -0.515105 |
| 4 | 0.073350 | RL | 0.574436 | 0.375020 | Pave | NaN | IR1 | Lvl | AllPub | FR2 | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | 2.100173 | 0.138730 | WD | Normal | 0.869545 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1455 | 0.073350 | RL | -0.331482 | -0.260471 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | 0.620678 | -0.614228 | WD | Normal | -0.074534 |
| 1456 | -0.872264 | RL | 0.615614 | 0.266316 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | -0.068668 | NaN | MnPrv | NaN | -0.087658 | -1.598563 | 1.644646 | WD | Normal | 0.366036 |
| 1457 | 0.309753 | RL | -0.166770 | -0.147760 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | -0.068668 | NaN | GdPrv | Shed | 4.951415 | -0.488943 | 1.644646 | WD | Normal | 1.077242 |
| 1458 | -0.872264 | RL | -0.084413 | -0.080133 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | -0.858816 | 1.644646 | WD | Normal | -0.488356 |
| 1459 | -0.872264 | RL | 0.203833 | -0.058092 | Pave | NaN | Reg | Lvl | AllPub | Inside | ... | -0.068668 | NaN | NaN | NaN | -0.087658 | -0.119069 | 0.138730 | WD | Normal | -0.420697 |
1460 rows × 80 columns
离散型属性处理统计各属性属性值个数
discrete_colmuns = [] discrete_colmuns.extend(list(tx.dtypes[train_data.dtypes == 'object'].index)) discrete_colmuns
['MSZoning', 'Street', 'Alley', 'LotShape', 'LandContour', 'Utilities', 'LotConfig', 'LandSlope', 'Neighborhood', 'Condition1', 'Condition2', 'BldgType', 'HouseStyle', 'RoofStyle', 'RoofMatl', 'Exterior1st', 'Exterior2nd', 'MasVnrType', 'ExterQual', 'ExterCond', 'Foundation', 'BsmtQual', 'BsmtCond', 'BsmtExposure', 'BsmtFinType1', 'BsmtFinType2', 'Heating', 'HeatingQC', 'CentralAir', 'Electrical', 'KitchenQual', 'Functional', 'FireplaceQu', 'GarageType', 'GarageFinish', 'GarageQual', 'GarageCond', 'PavedDrive', 'PoolQC', 'Fence', 'MiscFeature', 'SaleType', 'SaleCondition']
# 缺失值 tx[discrete_colmuns].isnull().sum(),tx[discrete_colmuns].isna().sum()
(MSZoning 0 Street 0 Alley 1369 LotShape 0 LandContour 0 Utilities 0 LotConfig 0 LandSlope 0 Neighborhood 0 Condition1 0 Condition2 0 BldgType 0 HouseStyle 0 RoofStyle 0 RoofMatl 0 Exterior1st 0 Exterior2nd 0 MasVnrType 8 ExterQual 0 ExterCond 0 Foundation 0 BsmtQual 37 BsmtCond 37 BsmtExposure 38 BsmtFinType1 37 BsmtFinType2 38 Heating 0 HeatingQC 0 CentralAir 0 Electrical 1 KitchenQual 0 Functional 0 FireplaceQu 690 GarageType 81 GarageFinish 81 GarageQual 81 GarageCond 81 PavedDrive 0 PoolQC 1453 Fence 1179 MiscFeature 1406 SaleType 0 SaleCondition 0 dtype: int64, MSZoning 0 Street 0 Alley 1369 LotShape 0 LandContour 0 Utilities 0 LotConfig 0 LandSlope 0 Neighborhood 0 Condition1 0 Condition2 0 BldgType 0 HouseStyle 0 RoofStyle 0 RoofMatl 0 Exterior1st 0 Exterior2nd 0 MasVnrType 8 ExterQual 0 ExterCond 0 Foundation 0 BsmtQual 37 BsmtCond 37 BsmtExposure 38 BsmtFinType1 37 BsmtFinType2 38 Heating 0 HeatingQC 0 CentralAir 0 Electrical 1 KitchenQual 0 Functional 0 FireplaceQu 690 GarageType 81 GarageFinish 81 GarageQual 81 GarageCond 81 PavedDrive 0 PoolQC 1453 Fence 1179 MiscFeature 1406 SaleType 0 SaleCondition 0 dtype: int64)将离散值转化为 one hot 编码
tx = pd.get_dummies(tx, dummy_na=True) tx
| MSSubClass | LotFrontage | LotArea | OverallQual | OverallCond | YearBuilt | YearRemodAdd | MasVnrArea | BsmtFinSF1 | BsmtFinSF2 | ... | SaleType_Oth | SaleType_WD | SaleType_nan | SaleCondition_Abnorml | SaleCondition_AdjLand | SaleCondition_Alloca | SaleCondition_Family | SaleCondition_Normal | SaleCondition_Partial | SaleCondition_nan | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.073350 | -0.207948 | -0.207071 | 0.651256 | -0.517023 | 1.050634 | 0.878367 | 0.509840 | 0.575228 | -0.288554 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 1 | -0.872264 | 0.409724 | -0.091855 | -0.071812 | 2.178881 | 0.156680 | -0.429430 | -0.572637 | 1.171591 | -0.288554 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 2 | 0.073350 | -0.084413 | 0.073455 | 0.651256 | -0.517023 | 0.984415 | 0.829930 | 0.322063 | 0.092875 | -0.288554 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 3 | 0.309753 | -0.413838 | -0.096864 | 0.651256 | -0.517023 | -1.862993 | -0.720051 | -0.572637 | -0.499103 | -0.288554 | ... | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 0.073350 | 0.574436 | 0.375020 | 1.374324 | -0.517023 | 0.951306 | 0.733056 | 1.360357 | 0.463410 | -0.288554 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1455 | 0.073350 | -0.331482 | -0.260471 | -0.071812 | -0.517023 | 0.918196 | 0.733056 | -0.572637 | -0.972685 | -0.288554 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 1456 | -0.872264 | 0.615614 | 0.266316 | -0.071812 | 0.381612 | 0.222899 | 0.151813 | 0.084581 | 0.759399 | 0.721865 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 1457 | 0.309753 | -0.166770 | -0.147760 | 0.651256 | 3.077516 | -1.002149 | 1.023678 | -0.572637 | -0.369744 | -0.288554 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 1458 | -0.872264 | -0.084413 | -0.080133 | -0.794879 | 0.381612 | -0.704164 | 0.539309 | -0.572637 | -0.865252 | 6.090101 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 1459 | -0.872264 | 0.203833 | -0.058092 | -0.794879 | 0.381612 | -0.207523 | -0.962236 | -0.572637 | 0.847099 | 1.509123 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
1460 rows × 332 columns
到此数据处理完成,补充空值缺失值,数据标准化。
导出处理后训练数据数据
tx.to_csv("../data_process.csv",index=False)
线性回归训练
y = torch.tensor(pd.Dataframe(tx['SalePrice']).values, dtype=torch.float) tx.drop(['SalePrice'],axis=1) x = torch.tensor(tx.values, dtype=torch.float) x,y
(tensor([[ 0.0733, -0.2079, -0.2071, ..., 1.0000, 0.0000, 0.0000],
[-0.8723, 0.4097, -0.0919, ..., 1.0000, 0.0000, 0.0000],
[ 0.0733, -0.0844, 0.0735, ..., 1.0000, 0.0000, 0.0000],
...,
[ 0.3098, -0.1668, -0.1478, ..., 1.0000, 0.0000, 0.0000],
[-0.8723, -0.0844, -0.0801, ..., 1.0000, 0.0000, 0.0000],
[-0.8723, 0.2038, -0.0581, ..., 1.0000, 0.0000, 0.0000]]),
tensor([[ 0.3472],
[ 0.0073],
[ 0.5360],
...,
[ 1.0772],
[-0.4884],
[-0.4207]]))
class Net(nn.Module):
def __init__(self, data_in, l1, l2, l3, data_out):
super(Net, self).__init__()
self.linear1 = nn.Linear(data_in, l1)
self.linear2 = nn.Linear(l1, l2)
self.linear3 = nn.Linear(l2, l3)
self.linear4 = nn.Linear(l3, data_out)
def forward(self, x):
y_pred = self.linear1(x).clamp(min=0)
y_pred = self.linear2(y_pred).clamp(min=0)
y_pred = self.linear3(y_pred).clamp(min=0)
y_pred = self.linear4(y_pred)
return y_pred
# def get_net(feature_num):
# net = nn.Linear(feature_num, 1)
# for param in net.parameters():
# nn.init.normal_(param, mean=0, std=0.01)
# return net
l1, l2, l3 = 500, 1000, 200 data_in = x.shape[1] data_out = y.shape[1] model = Net(data_in,l1,l2,l3,data_out) criterion = nn.MSELoss(reduction='sum') optimizer = torch.optim.Adam(model.parameters(), lr=1e-4*2, weight_decay=0)
losses1 = []
y_p = y
for t in range(500):
y_pred = model(x)
loss = criterion(y_pred, y)
print(t, loss.item())
losses1.append(loss.item())
if torch.isnan(loss):
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
y_p=y_pred
0 1459.40771484375 1 1392.9884033203125 2 1327.1253662109375 3 1258.0968017578125 4 1184.120361328125 5 1104.348388671875 6 1018.5840454101562 7 927.036865234375 8 830.722412109375 9 731.9678344726562 10 634.2273559570312 11 542.0256958007812 12 461.3277587890625 13 398.6611328125 14 359.0272216796875 15 342.24737548828125 16 341.33612060546875 17 342.6305847167969 18 331.4631042480469 19 303.4851989746094 20 266.21112060546875 21 230.264892578125 22 202.28729248046875 23 185.4599609375 24 179.8956298828125 25 179.50994873046875 26 177.8726043701172 27 173.0084686279297 28 163.7674560546875 29 150.3192596435547 30 136.2794647216797 31 125.10523986816406 32 117.6074447631836 33 114.28300476074219 34 113.25989532470703 35 110.12255859375 36 103.99020385742188 37 95.72050476074219 38 87.1082992553711 39 80.95310974121094 40 77.13907623291016 41 74.54013061523438 42 71.68026733398438 43 66.95227813720703 44 61.12464904785156 45 55.19034957885742 46 50.59487533569336 47 48.036521911621094 48 46.413204193115234 49 44.63623046875 50 41.55665969848633 51 38.00105285644531 52 34.835899353027344 53 32.89207458496094 54 31.564334869384766 55 30.114761352539062 56 28.005346298217773 57 25.78139877319336 58 24.080860137939453 59 23.243757247924805 60 22.58363151550293 61 21.547937393188477 62 19.998720169067383 63 18.64838409423828 64 17.81045913696289 65 17.302536010742188 66 16.604515075683594 67 15.583803176879883 68 14.638494491577148 69 13.992101669311523 70 13.54787826538086 71 12.961628913879395 72 12.189724922180176 73 11.536569595336914 74 11.120109558105469 75 10.790014266967773 76 10.336910247802734 77 9.772478103637695 78 9.308484077453613 79 8.999449729919434 80 8.685918807983398 81 8.281786918640137 82 7.883256435394287 83 7.582124710083008 84 7.327266216278076 85 7.02606725692749 86 6.69525671005249 87 6.422881126403809 88 6.215458393096924 89 5.992396831512451 90 5.7315449714660645 91 5.494955539703369 92 5.311194896697998 93 5.127244472503662 94 4.918134689331055 95 4.730138301849365 96 4.582120418548584 97 4.431356430053711 98 4.260013103485107 99 4.102842807769775 100 3.9737141132354736 101 3.843905210494995 102 3.707719326019287 103 3.601248025894165 104 3.551821231842041 105 3.592027187347412 106 3.8266773223876953 107 4.379158973693848 108 4.993258476257324 109 4.778003692626953 110 3.5360450744628906 111 2.769289970397949 112 3.2857789993286133 113 3.8392574787139893 114 3.2963366508483887 115 2.522071599960327 116 2.681736946105957 117 3.1372625827789307 118 2.859166145324707 119 2.2857251167297363 120 2.32077693939209 121 2.6420400142669678 122 2.495882511138916 123 2.0810556411743164 124 2.0285420417785645 125 2.243170738220215 126 2.2124433517456055 127 1.9245238304138184 128 1.7894601821899414 129 1.8995769023895264 130 1.9591234922409058 131 1.8102812767028809 132 1.6313799619674683 133 1.6106034517288208 134 1.6833714246749878 135 1.6782467365264893 136 1.560412883758545 137 1.4432872533798218 138 1.4171345233917236 139 1.4503400325775146 140 1.4559720754623413 141 1.3967015743255615 142 1.3085331916809082 143 1.246904969215393 144 1.2313358783721924 145 1.2399450540542603 146 1.2392830848693848 147 1.2113351821899414 148 1.1614681482315063 149 1.1072932481765747 150 1.0640703439712524 151 1.037292242050171 152 1.0241785049438477 153 1.0183836221694946 154 1.0139071941375732 155 1.007814645767212 156 0.9994260668754578 157 0.990032434463501 158 0.9811416864395142 159 0.9755876660346985 160 0.9764101505279541 161 0.9885083436965942 162 1.0174137353897095 163 1.073725700378418 164 1.1688017845153809 165 1.3198115825653076 166 1.5327341556549072 167 1.7964661121368408 168 2.026193380355835 169 2.090864419937134 170 1.843663215637207 171 1.3451727628707886 172 0.854713499546051 173 0.6517972946166992 174 0.7802258729934692 175 1.0613033771514893 176 1.2591156959533691 177 1.2204216718673706 178 0.978543758392334 179 0.7041733860969543 180 0.573451042175293 181 0.6265032887458801 182 0.773800253868103 183 0.8895548582077026 184 0.8911064863204956 185 0.7840784788131714 186 0.6346603035926819 187 0.5257025957107544 188 0.4978765845298767 189 0.5393790602684021 190 0.6085287928581238 191 0.6624920964241028 192 0.6769289970397949 193 0.647126317024231 194 0.5881399512290955 195 0.5204046368598938 196 0.4636395573616028 197 0.4281044602394104 198 0.41511157155036926 199 0.41996967792510986 200 0.4363104999065399 201 0.45899271965026855 202 0.485185444355011 203 0.5146241188049316 204 0.5484604239463806 205 0.5900295376777649 206 0.6420221924781799 207 0.7098633646965027 208 0.7947096824645996 209 0.9000582695007324 210 1.0158016681671143 211 1.1308093070983887 212 1.2065943479537964 213 1.2124741077423096 214 1.1087749004364014 215 0.9090127944946289 216 0.6591584086418152 217 0.4416056275367737 218 0.31704744696617126 219 0.3027245104312897 220 0.3727879524230957 221 0.4814433455467224 222 0.584940493106842 223 0.6494030952453613 224 0.6602392196655273 225 0.6150671243667603 226 0.531480073928833 227 0.43209004402160645 228 0.3428885340690613 229 0.2804076671600342 230 0.25067248940467834 231 0.2501675486564636 232 0.27077382802963257 233 0.30417877435684204 234 0.34413743019104004 235 0.38745465874671936 236 0.43279358744621277 237 0.4812193214893341 238 0.5328287482261658 239 0.5897232294082642 240 0.6497997045516968 241 0.7126964330673218 242 0.7697558999061584 243 0.8151849508285522 244 0.8324537873268127 245 0.814405083656311 246 0.7496241331100464 247 0.6463999152183533 248 0.5170392394065857 249 0.3879045844078064 250 0.28106746077537537 251 0.2118593156337738 252 0.18318529427051544 253 0.18897591531276703 254 0.21899700164794922 255 0.26325783133506775 256 0.31462690234184265 257 0.36845412850379944 258 0.4235222041606903 259 0.47801244258880615 260 0.5328558683395386 261 0.5848822593688965 262 0.6335201859474182 263 0.671372652053833 264 0.6951640844345093 265 0.6943939924240112 266 0.6672643423080444 267 0.6086594462394714 268 0.5265052318572998 269 0.42876332998275757 270 0.3318851888179779 271 0.24785339832305908 272 0.18642649054527283 273 0.15023303031921387 274 0.13721497356891632 275 0.14248286187648773 276 0.16075286269187927 277 0.18802016973495483 278 0.2220512330532074 279 0.26306045055389404 280 0.31244951486587524 281 0.3743334412574768 282 0.4523894190788269 283 0.5534863471984863 284 0.6802131533622742 285 0.8368046879768372 286 1.0104289054870605 287 1.1839443445205688 288 1.3037998676300049 289 1.322584629058838 290 1.1820666790008545 291 0.9010692834854126 292 0.5517466068267822 293 0.26055294275283813 294 0.1182369738817215 295 0.14159783720970154 296 0.27602139115333557 297 0.435191810131073 298 0.5431021451950073 299 0.5544558763504028 300 0.47308725118637085 301 0.3364395797252655 302 0.20292562246322632 303 0.11799780279397964 304 0.09933719784021378 305 0.1355438232421875 306 0.1990634649991989 307 0.2611197829246521 308 0.3003968298435211 309 0.3084506690502167 310 0.2861405909061432 311 0.2437874674797058 312 0.19322504103183746 313 0.14659611880779266 314 0.1114949956536293 315 0.09117703139781952 316 0.08490884304046631 317 0.08974863588809967 318 0.10214374959468842 319 0.11904925853013992 320 0.1385415494441986 321 0.15968437492847443 322 0.1828000545501709 323 0.20865221321582794 324 0.23935003578662872 325 0.2769160866737366 326 0.32513898611068726 327 0.38705378770828247 328 0.46830886602401733 329 0.5713444948196411 330 0.7010533213615417 331 0.8512529730796814 332 1.0150606632232666 333 1.1586813926696777 334 1.2495845556259155 335 1.2290078401565552 336 1.078914999961853 337 0.8049017786979675 338 0.4874662458896637 339 0.22077317535877228 340 0.08165990561246872 341 0.08536580204963684 342 0.1916998028755188 343 0.3331620395183563 344 0.4439804255962372 345 0.48394155502319336 346 0.4413498044013977 347 0.3397776782512665 348 0.21818627417087555 349 0.11873447149991989 350 0.06674043834209442 351 0.06599300354719162 352 0.10284645855426788 353 0.1559557020664215 354 0.2054419368505478 355 0.23710842430591583 356 0.24594929814338684 357 0.2320968061685562 358 0.20245851576328278 359 0.1643735021352768 360 0.12634292244911194 361 0.09397175163030624 362 0.07068207114934921 363 0.057136908173561096 364 0.05230645462870598 365 0.05429523438215256 366 0.06113283708691597 367 0.07127126306295395 368 0.08378728479146957 369 0.09850577265024185 370 0.11586176604032516 371 0.1371409147977829 372 0.1641264408826828 373 0.20001165568828583 374 0.24851566553115845 375 0.3162808120250702 376 0.4099084436893463 377 0.5412055850028992 378 0.7180109620094299 379 0.9529534578323364 380 1.2352896928787231 381 1.5471041202545166 382 1.8024611473083496 383 1.9113022089004517 384 1.7411528825759888 385 1.2997612953186035 386 0.7103318572044373 387 0.22983340919017792 388 0.04464929923415184 389 0.1671920120716095 390 0.4477957785129547 391 0.6815035343170166 392 0.7275214195251465 393 0.5641263127326965 394 0.3039877712726593 395 0.09819929301738739 396 0.04379488155245781 397 0.12985694408416748 398 0.2676863670349121 399 0.3594537675380707 400 0.351449191570282 401 0.2587491571903229 402 0.13907583057880402 403 0.05620124191045761 404 0.04084179550409317 405 0.0822826400399208 406 0.14384561777114868 407 0.1873161792755127 408 0.19239096343517303 409 0.16035300493240356 410 0.1101120337843895 411 0.06403486430644989 412 0.03829789161682129 413 0.03703799843788147 414 0.054030437022447586 415 0.07800006121397018 416 0.09811682999134064 417 0.10786034911870956 418 0.10537803173065186 419 0.09330643713474274 420 0.0760207325220108 421 0.05843013897538185 422 0.044070057570934296 423 0.03486529737710953 424 0.03107019141316414 425 0.03183189406991005 426 0.03578619286417961 427 0.04155302420258522 428 0.048052363097667694 429 0.054569993168115616 430 0.0608607642352581 431 0.06691505014896393 432 0.07308138161897659 433 0.07974386215209961 434 0.08765038102865219 435 0.09747633337974548 436 0.11049649864435196 437 0.12794305384159088 438 0.152177631855011 439 0.18567615747451782 440 0.23310738801956177 441 0.2993691861629486 442 0.3934844434261322 443 0.5234649777412415 444 0.7036969661712646 445 0.9385809898376465 446 1.2359614372253418 447 1.561197280883789 448 1.8691719770431519 449 2.028472900390625 450 1.9459717273712158 451 1.5387744903564453 452 0.9330716133117676 453 0.35589325428009033 454 0.05003536865115166 455 0.0889483094215393 456 0.3566819131374359 457 0.6413065791130066 458 0.7551127672195435 459 0.6417924165725708 460 0.37786367535591125 461 0.13001208007335663 462 0.026230057701468468 463 0.08645960688591003 464 0.23000115156173706 465 0.3440704345703125 466 0.357770711183548 467 0.2699103057384491 468 0.1423608809709549 469 0.047258615493774414 470 0.02545809932053089 471 0.06855352222919464 472 0.13506144285202026 473 0.18072697520256042 474 0.18125350773334503 475 0.1416153907775879 476 0.08529476821422577 477 0.039876788854599 478 0.0217688400298357 479 0.031408119946718216 480 0.05699583888053894 481 0.08281167596578598 482 0.09696558117866516 483 0.09485352784395218 484 0.07945433259010315 485 0.057723499834537506 486 0.03747238963842392 487 0.024086352437734604 488 0.019536178559064865 489 0.022679146379232407 490 0.0305526964366436 491 0.039793986827135086 492 0.04764045029878616 493 0.05252092704176903 494 0.05387616530060768 495 0.05217081680893898 496 0.04817085713148117 497 0.04296367987990379 498 0.03738382086157799 499 0.03214520961046219训练结果展示
import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # 上面这个是处理jupyter notebook内核挂掉的,不必在意 # 横坐标为迭代次数,纵坐标为loss值 plt.figure(figsize=(12, 10)) plt.plot(range(len(losses1)), losses1) plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YjR40Hbg-1646578210620)(file/output_30_0.png)]
y_p.shape,x.shape
(torch.Size([1460, 1]), torch.Size([1460, 332]))模型预测
时间有限,目前仅仅只简单训练,暂时不考虑优化或者其他操作
同样先处理数据
test_data.isnull().sum()
Id 0
MSSubClass 0
MSZoning 4
LotFrontage 227
LotArea 0
...
MiscVal 0
MoSold 0
YrSold 0
SaleType 1
SaleCondition 0
Length: 80, dtype: int64
continuous_colmuns.remove('SalePrice')
test_data[continuous_colmuns] = test_data[continuous_colmuns].apply(lambda x: (x - x.mean())/(x.std()))
test_data[continuous_colmuns]=test_data[continuous_colmuns].fillna(0)
test_data[continuous_colmuns].isnull().sum()
MSSubClass 0 LotArea 0 OverallQual 0 OverallCond 0 YearBuilt 0 YearRemodAdd 0 BsmtFinSF1 0 BsmtFinSF2 0 BsmtUnfSF 0 TotalBsmtSF 0 1stFlrSF 0 2ndFlrSF 0 LowQualFinSF 0 GrLivArea 0 BsmtFullBath 0 BsmtHalfBath 0 FullBath 0 HalfBath 0 BedroomAbvGr 0 KitchenAbvGr 0 TotRmsAbvGrd 0 Fireplaces 0 GarageCars 0 GarageArea 0 WoodDeckSF 0 OpenPorchSF 0 EnclosedPorch 0 3SsnPorch 0 ScreenPorch 0 PoolArea 0 MiscVal 0 MoSold 0 YrSold 0 LotFrontage 0 MasVnrArea 0 GarageYrBlt 0 dtype: int64
# 这一步为了让网络参数统一 td=test_data td = pd.get_dummies(test_data, dummy_na=True) for col in tx.columns: if col not in td: td[col] = 0 td = td.drop(['Id'],axis=1)
# 预测 pred_y = model(torch.tensor(td.values, dtype=torch.float)) pred_y
tensor([[-0.3723],
[-0.1471],
[-0.0315],
...,
[-0.1320],
[-0.4483],
[-0.0833]], grad_fn=)
res = pd.Dataframe(pred_y.data.numpy(), columns=['SalePrice']) res['SalePrice']
0 -0.372343
1 -0.147053
2 -0.031502
3 0.103289
4 -0.258128
...
1454 -0.568579
1455 -0.618316
1456 -0.132024
1457 -0.448310
1458 -0.083297
Name: SalePrice, Length: 1459, dtype: float32
# 计算房价 res['SalePrice'] = res['SalePrice'] * (d_max - d_min) + d_mean res # 预测值有点离谱,之后有时间找找问题再重新修正
| SalePrice | |
|---|---|
| 0 | -87203.301559 |
| 1 | 75028.274554 |
| 2 | 158236.460120 |
| 3 | 255299.261917 |
| 4 | -4957.003893 |
| ... | ... |
| 1454 | -228512.597976 |
| 1455 | -264328.452371 |
| 1456 | 85850.581086 |
| 1457 | -141906.568785 |
| 1458 | 120939.053495 |
1459 rows × 1 columns



