pytorch房价预测练习

任务:基于 pytorch 实现房价预测

  1. 收集数据,对数据的属性进行介绍
  2. 编程实现数据预处理并保存
  3. 数据统计分析并绘制效果图

gitee仓库地址

个人博客地址

数据来源

导入包

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

读取数据

train_data = pd.read_csv("../data/train.csv")
test_data = pd.read_csv("../data/test.csv")

查看数据

train_data.head(20)
IdMSSubClassMSZoningLotFrontageLotAreaStreetAlleyLotShapeLandContourUtilities...PoolAreaPoolQCFenceMiscFeatureMiscValMoSoldYrSoldSaleTypeSaleConditionSalePrice
0160RL65.08450PaveNaNRegLvlAllPub...0NaNNaNNaN022008WDNormal208500
1220RL80.09600PaveNaNRegLvlAllPub...0NaNNaNNaN052007WDNormal181500
2360RL68.011250PaveNaNIR1LvlAllPub...0NaNNaNNaN092008WDNormal223500
3470RL60.09550PaveNaNIR1LvlAllPub...0NaNNaNNaN022006WDAbnorml140000
4560RL84.014260PaveNaNIR1LvlAllPub...0NaNNaNNaN0122008WDNormal250000
5650RL85.014115PaveNaNIR1LvlAllPub...0NaNMnPrvShed700102009WDNormal143000
6720RL75.010084PaveNaNRegLvlAllPub...0NaNNaNNaN082007WDNormal307000
7860RLNaN10382PaveNaNIR1LvlAllPub...0NaNNaNShed350112009WDNormal200000
8950RM51.06120PaveNaNRegLvlAllPub...0NaNNaNNaN042008WDAbnorml129900
910190RL50.07420PaveNaNRegLvlAllPub...0NaNNaNNaN012008WDNormal118000
101120RL70.011200PaveNaNRegLvlAllPub...0NaNNaNNaN022008WDNormal129500
111260RL85.011924PaveNaNIR1LvlAllPub...0NaNNaNNaN072006NewPartial345000
121320RLNaN12968PaveNaNIR2LvlAllPub...0NaNNaNNaN092008WDNormal144000
131420RL91.010652PaveNaNIR1LvlAllPub...0NaNNaNNaN082007NewPartial279500
141520RLNaN10920PaveNaNIR1LvlAllPub...0NaNGdWoNaN052008WDNormal157000
151645RM51.06120PaveNaNRegLvlAllPub...0NaNGdPrvNaN072007WDNormal132000
161720RLNaN11241PaveNaNIR1LvlAllPub...0NaNNaNShed70032010WDNormal149000
171890RL72.010791PaveNaNRegLvlAllPub...0NaNNaNShed500102006WDNormal90000
181920RL66.013695PaveNaNRegLvlAllPub...0NaNNaNNaN062008WDNormal159000
192020RL70.07560PaveNaNRegLvlAllPub...0NaNMnPrvNaN052009CODAbnorml139000

20 rows × 81 columns

train_data.describe()
IdMSSubClassLotFrontageLotAreaOverallQualOverallCondYearBuiltYearRemodAddMasVnrAreaBsmtFinSF1...WoodDeckSFOpenPorchSFEnclosedPorch3SsnPorchScreenPorchPoolAreaMiscValMoSoldYrSoldSalePrice
count1460.0000001460.0000001201.0000001460.0000001460.0000001460.0000001460.0000001460.0000001452.0000001460.000000...1460.0000001460.0000001460.0000001460.0000001460.0000001460.0000001460.0000001460.0000001460.0000001460.000000
mean730.50000056.89726070.04995810516.8280826.0993155.5753421971.2678081984.865753103.685262443.639726...94.24452146.66027421.9541103.40958915.0609592.75890443.4890416.3219182007.815753180921.195890
std421.61000942.30057124.2847529981.2649321.3829971.11279930.20290420.645407181.066207456.098091...125.33879466.25602861.11914929.31733155.75741540.177307496.1230242.7036261.32809579442.502883
min1.00000020.00000021.0000001300.0000001.0000001.0000001872.0000001950.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000001.0000002006.00000034900.000000
25%365.75000020.00000059.0000007553.5000005.0000005.0000001954.0000001967.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000005.0000002007.000000129975.000000
50%730.50000050.00000069.0000009478.5000006.0000005.0000001973.0000001994.0000000.000000383.500000...0.00000025.0000000.0000000.0000000.0000000.0000000.0000006.0000002008.000000163000.000000
75%1095.25000070.00000080.00000011601.5000007.0000006.0000002000.0000002004.000000166.000000712.250000...168.00000068.0000000.0000000.0000000.0000000.0000000.0000008.0000002009.000000214000.000000
max1460.000000190.000000313.000000215245.00000010.0000009.0000002010.0000002010.0000001600.0000005644.000000...857.000000547.000000552.000000508.000000480.000000738.00000015500.00000012.0000002010.000000755000.000000

8 rows × 38 columns

数据中各属性在data/data_description.txt中已详细介绍,再次不再赘述

数据预处理

train_data.shape, test_data.shape
((1460, 81), (1459, 80))

Id属性对房价预测没有影响,去除

train_data = train_data.drop(['Id'], axis=1)
train_data
MSSubClassMSZoningLotFrontageLotAreaStreetAlleyLotShapeLandContourUtilitiesLotConfig...PoolAreaPoolQCFenceMiscFeatureMiscValMoSoldYrSoldSaleTypeSaleConditionSalePrice
060RL65.08450PaveNaNRegLvlAllPubInside...0NaNNaNNaN022008WDNormal208500
120RL80.09600PaveNaNRegLvlAllPubFR2...0NaNNaNNaN052007WDNormal181500
260RL68.011250PaveNaNIR1LvlAllPubInside...0NaNNaNNaN092008WDNormal223500
370RL60.09550PaveNaNIR1LvlAllPubCorner...0NaNNaNNaN022006WDAbnorml140000
460RL84.014260PaveNaNIR1LvlAllPubFR2...0NaNNaNNaN0122008WDNormal250000
..................................................................
145560RL62.07917PaveNaNRegLvlAllPubInside...0NaNNaNNaN082007WDNormal175000
145620RL85.013175PaveNaNRegLvlAllPubInside...0NaNMnPrvNaN022010WDNormal210000
145770RL66.09042PaveNaNRegLvlAllPubInside...0NaNGdPrvShed250052010WDNormal266500
145820RL68.09717PaveNaNRegLvlAllPubInside...0NaNNaNNaN042010WDNormal142125
145920RL75.09937PaveNaNRegLvlAllPubInside...0NaNNaNNaN062008WDNormal147500

1460 rows × 80 columns

# 拆分数据与标签
# tx = train_data.drop(['SalePrice'],axis=1)
# ty = train_data['SalePrice']
# tx, ty
tx = train_data
tx
MSSubClassMSZoningLotFrontageLotAreaStreetAlleyLotShapeLandContourUtilitiesLotConfig...PoolAreaPoolQCFenceMiscFeatureMiscValMoSoldYrSoldSaleTypeSaleConditionSalePrice
060RL65.08450PaveNaNRegLvlAllPubInside...0NaNNaNNaN022008WDNormal208500
120RL80.09600PaveNaNRegLvlAllPubFR2...0NaNNaNNaN052007WDNormal181500
260RL68.011250PaveNaNIR1LvlAllPubInside...0NaNNaNNaN092008WDNormal223500
370RL60.09550PaveNaNIR1LvlAllPubCorner...0NaNNaNNaN022006WDAbnorml140000
460RL84.014260PaveNaNIR1LvlAllPubFR2...0NaNNaNNaN0122008WDNormal250000
..................................................................
145560RL62.07917PaveNaNRegLvlAllPubInside...0NaNNaNNaN082007WDNormal175000
145620RL85.013175PaveNaNRegLvlAllPubInside...0NaNMnPrvNaN022010WDNormal210000
145770RL66.09042PaveNaNRegLvlAllPubInside...0NaNGdPrvShed250052010WDNormal266500
145820RL68.09717PaveNaNRegLvlAllPubInside...0NaNNaNNaN042010WDNormal142125
145920RL75.09937PaveNaNRegLvlAllPubInside...0NaNNaNNaN062008WDNormal147500

1460 rows × 80 columns

# 标准化后,房价也同样落在了[-1, 1]的区间里,预测出来的值不是真实值,为解决这个问题,计算出 mean max min后反算即可
d_mean = tx['SalePrice'].mean()
d_max = tx['SalePrice'].max()
d_min = tx['SalePrice'].min()

连续型属性处理

数据属性有些是连续型,有些是离散型,统计这些属性如下

continuous_colmuns = []
continuous_colmuns.extend(list(tx.dtypes[train_data.dtypes == np.int64].index))
continuous_colmuns.extend(list(tx.dtypes[train_data.dtypes == np.float64].index))
continuous_colmuns
['MSSubClass',
 'LotArea',
 'OverallQual',
 'OverallCond',
 'YearBuilt',
 'YearRemodAdd',
 'BsmtFinSF1',
 'BsmtFinSF2',
 'BsmtUnfSF',
 'TotalBsmtSF',
 '1stFlrSF',
 '2ndFlrSF',
 'LowQualFinSF',
 'GrLivArea',
 'BsmtFullBath',
 'BsmtHalfBath',
 'FullBath',
 'HalfBath',
 'BedroomAbvGr',
 'KitchenAbvGr',
 'TotRmsAbvGrd',
 'Fireplaces',
 'GarageCars',
 'GarageArea',
 'WoodDeckSF',
 'OpenPorchSF',
 'EnclosedPorch',
 '3SsnPorch',
 'ScreenPorch',
 'PoolArea',
 'MiscVal',
 'MoSold',
 'YrSold',
 'SalePrice',
 'LotFrontage',
 'MasVnrArea',
 'GarageYrBlt']
# 查找连续值是否存在缺失值
tx[continuous_colmuns].isnull().sum(),tx[continuous_colmuns].isna().sum()
(MSSubClass         0
 LotArea            0
 OverallQual        0
 OverallCond        0
 YearBuilt          0
 YearRemodAdd       0
 BsmtFinSF1         0
 BsmtFinSF2         0
 BsmtUnfSF          0
 TotalBsmtSF        0
 1stFlrSF           0
 2ndFlrSF           0
 LowQualFinSF       0
 GrLivArea          0
 BsmtFullBath       0
 BsmtHalfBath       0
 FullBath           0
 HalfBath           0
 BedroomAbvGr       0
 KitchenAbvGr       0
 TotRmsAbvGrd       0
 Fireplaces         0
 GarageCars         0
 GarageArea         0
 WoodDeckSF         0
 OpenPorchSF        0
 EnclosedPorch      0
 3SsnPorch          0
 ScreenPorch        0
 PoolArea           0
 MiscVal            0
 MoSold             0
 YrSold             0
 SalePrice          0
 LotFrontage      259
 MasVnrArea         8
 GarageYrBlt       81
 dtype: int64,
 MSSubClass         0
 LotArea            0
 OverallQual        0
 OverallCond        0
 YearBuilt          0
 YearRemodAdd       0
 BsmtFinSF1         0
 BsmtFinSF2         0
 BsmtUnfSF          0
 TotalBsmtSF        0
 1stFlrSF           0
 2ndFlrSF           0
 LowQualFinSF       0
 GrLivArea          0
 BsmtFullBath       0
 BsmtHalfBath       0
 FullBath           0
 HalfBath           0
 BedroomAbvGr       0
 KitchenAbvGr       0
 TotRmsAbvGrd       0
 Fireplaces         0
 GarageCars         0
 GarageArea         0
 WoodDeckSF         0
 OpenPorchSF        0
 EnclosedPorch      0
 3SsnPorch          0
 ScreenPorch        0
 PoolArea           0
 MiscVal            0
 MoSold             0
 YrSold             0
 SalePrice          0
 LotFrontage      259
 MasVnrArea         8
 GarageYrBlt       81
 dtype: int64)
# 标准化连续值
tx[continuous_colmuns] = tx[continuous_colmuns].apply(lambda x: (x - x.mean())/(x.std()))
tx[continuous_colmuns]
MSSubClassLotAreaOverallQualOverallCondYearBuiltYearRemodAddBsmtFinSF1BsmtFinSF2BsmtUnfSFTotalBsmtSF...3SsnPorchScreenPorchPoolAreaMiscValMoSoldYrSoldSalePriceLotFrontageMasVnrAreaGarageYrBlt
00.073350-0.2070710.651256-0.5170231.0506340.8783670.575228-0.288554-0.944267-0.459145...-0.116299-0.270116-0.068668-0.087658-1.5985630.1387300.347154-0.2079480.5098400.992066
1-0.872264-0.091855-0.0718122.1788810.156680-0.4294301.171591-0.288554-0.6410080.466305...-0.116299-0.270116-0.068668-0.087658-0.488943-0.6142280.0072860.409724-0.572637-0.101506
20.0733500.0734550.651256-0.5170230.9844150.8299300.092875-0.288554-0.301540-0.313261...-0.116299-0.270116-0.068668-0.0876580.9905520.1387300.535970-0.0844130.3220630.911061
30.309753-0.0968640.651256-0.517023-1.862993-0.720051-0.499103-0.288554-0.061648-0.687089...-0.116299-0.270116-0.068668-0.087658-1.598563-1.367186-0.515105-0.413838-0.5726370.789553
40.0733500.3750201.374324-0.5170230.9513060.7330560.463410-0.288554-0.1748050.199611...-0.116299-0.270116-0.068668-0.0876582.1001730.1387300.8695450.5744361.3603570.870558
..................................................................
14550.073350-0.260471-0.071812-0.5170230.9181960.733056-0.972685-0.2885540.873022-0.238040...-0.116299-0.270116-0.068668-0.0876580.620678-0.614228-0.074534-0.331482-0.5726370.830055
1456-0.8722640.266316-0.0718120.3816120.2228990.1518130.7593990.7218650.0492451.104547...-0.116299-0.270116-0.068668-0.087658-1.5985631.6446460.3660360.6156140.084581-0.020501
14570.309753-0.1477600.6512563.077516-1.0021491.023678-0.369744-0.2885540.7010250.215567...-0.116299-0.270116-0.0686684.951415-0.4889431.6446461.077242-0.166770-0.572637-1.519100
1458-0.872264-0.080133-0.7948790.381612-0.7041640.539309-0.8652526.090101-1.2837360.046889...-0.116299-0.270116-0.068668-0.087658-0.8588161.644646-0.488356-0.084413-0.572637-1.154576
1459-0.872264-0.058092-0.7948790.381612-0.207523-0.9622360.8470991.509123-0.9759510.452629...-0.116299-0.270116-0.068668-0.087658-0.1190690.138730-0.4206970.203833-0.572637-0.547036

1460 rows × 37 columns

# 缺失值默认设置为均值,标准化后均值为0
tx[continuous_colmuns] = tx[continuous_colmuns].fillna(0)
tx
MSSubClassMSZoningLotFrontageLotAreaStreetAlleyLotShapeLandContourUtilitiesLotConfig...PoolAreaPoolQCFenceMiscFeatureMiscValMoSoldYrSoldSaleTypeSaleConditionSalePrice
00.073350RL-0.207948-0.207071PaveNaNRegLvlAllPubInside...-0.068668NaNNaNNaN-0.087658-1.5985630.138730WDNormal0.347154
1-0.872264RL0.409724-0.091855PaveNaNRegLvlAllPubFR2...-0.068668NaNNaNNaN-0.087658-0.488943-0.614228WDNormal0.007286
20.073350RL-0.0844130.073455PaveNaNIR1LvlAllPubInside...-0.068668NaNNaNNaN-0.0876580.9905520.138730WDNormal0.535970
30.309753RL-0.413838-0.096864PaveNaNIR1LvlAllPubCorner...-0.068668NaNNaNNaN-0.087658-1.598563-1.367186WDAbnorml-0.515105
40.073350RL0.5744360.375020PaveNaNIR1LvlAllPubFR2...-0.068668NaNNaNNaN-0.0876582.1001730.138730WDNormal0.869545
..................................................................
14550.073350RL-0.331482-0.260471PaveNaNRegLvlAllPubInside...-0.068668NaNNaNNaN-0.0876580.620678-0.614228WDNormal-0.074534
1456-0.872264RL0.6156140.266316PaveNaNRegLvlAllPubInside...-0.068668NaNMnPrvNaN-0.087658-1.5985631.644646WDNormal0.366036
14570.309753RL-0.166770-0.147760PaveNaNRegLvlAllPubInside...-0.068668NaNGdPrvShed4.951415-0.4889431.644646WDNormal1.077242
1458-0.872264RL-0.084413-0.080133PaveNaNRegLvlAllPubInside...-0.068668NaNNaNNaN-0.087658-0.8588161.644646WDNormal-0.488356
1459-0.872264RL0.203833-0.058092PaveNaNRegLvlAllPubInside...-0.068668NaNNaNNaN-0.087658-0.1190690.138730WDNormal-0.420697

1460 rows × 80 columns

离散型属性处理

统计各属性属性值个数

discrete_colmuns = []
discrete_colmuns.extend(list(tx.dtypes[train_data.dtypes  == 'object'].index))
discrete_colmuns
['MSZoning',
 'Street',
 'Alley',
 'LotShape',
 'LandContour',
 'Utilities',
 'LotConfig',
 'LandSlope',
 'Neighborhood',
 'Condition1',
 'Condition2',
 'BldgType',
 'HouseStyle',
 'RoofStyle',
 'RoofMatl',
 'Exterior1st',
 'Exterior2nd',
 'MasVnrType',
 'ExterQual',
 'ExterCond',
 'Foundation',
 'BsmtQual',
 'BsmtCond',
 'BsmtExposure',
 'BsmtFinType1',
 'BsmtFinType2',
 'Heating',
 'HeatingQC',
 'CentralAir',
 'Electrical',
 'KitchenQual',
 'Functional',
 'FireplaceQu',
 'GarageType',
 'GarageFinish',
 'GarageQual',
 'GarageCond',
 'PavedDrive',
 'PoolQC',
 'Fence',
 'MiscFeature',
 'SaleType',
 'SaleCondition']
# 缺失值
tx[discrete_colmuns].isnull().sum(),tx[discrete_colmuns].isna().sum()
(MSZoning            0
 Street              0
 Alley            1369
 LotShape            0
 LandContour         0
 Utilities           0
 LotConfig           0
 LandSlope           0
 Neighborhood        0
 Condition1          0
 Condition2          0
 BldgType            0
 HouseStyle          0
 RoofStyle           0
 RoofMatl            0
 Exterior1st         0
 Exterior2nd         0
 MasVnrType          8
 ExterQual           0
 ExterCond           0
 Foundation          0
 BsmtQual           37
 BsmtCond           37
 BsmtExposure       38
 BsmtFinType1       37
 BsmtFinType2       38
 Heating             0
 HeatingQC           0
 CentralAir          0
 Electrical          1
 KitchenQual         0
 Functional          0
 FireplaceQu       690
 GarageType         81
 GarageFinish       81
 GarageQual         81
 GarageCond         81
 PavedDrive          0
 PoolQC           1453
 Fence            1179
 MiscFeature      1406
 SaleType            0
 SaleCondition       0
 dtype: int64,
 MSZoning            0
 Street              0
 Alley            1369
 LotShape            0
 LandContour         0
 Utilities           0
 LotConfig           0
 LandSlope           0
 Neighborhood        0
 Condition1          0
 Condition2          0
 BldgType            0
 HouseStyle          0
 RoofStyle           0
 RoofMatl            0
 Exterior1st         0
 Exterior2nd         0
 MasVnrType          8
 ExterQual           0
 ExterCond           0
 Foundation          0
 BsmtQual           37
 BsmtCond           37
 BsmtExposure       38
 BsmtFinType1       37
 BsmtFinType2       38
 Heating             0
 HeatingQC           0
 CentralAir          0
 Electrical          1
 KitchenQual         0
 Functional          0
 FireplaceQu       690
 GarageType         81
 GarageFinish       81
 GarageQual         81
 GarageCond         81
 PavedDrive          0
 PoolQC           1453
 Fence            1179
 MiscFeature      1406
 SaleType            0
 SaleCondition       0
 dtype: int64)

将离散值转化为 one hot 编码

tx = pd.get_dummies(tx, dummy_na=True)
tx
MSSubClassLotFrontageLotAreaOverallQualOverallCondYearBuiltYearRemodAddMasVnrAreaBsmtFinSF1BsmtFinSF2...SaleType_OthSaleType_WDSaleType_nanSaleCondition_AbnormlSaleCondition_AdjLandSaleCondition_AllocaSaleCondition_FamilySaleCondition_NormalSaleCondition_PartialSaleCondition_nan
00.073350-0.207948-0.2070710.651256-0.5170231.0506340.8783670.5098400.575228-0.288554...0100000100
1-0.8722640.409724-0.091855-0.0718122.1788810.156680-0.429430-0.5726371.171591-0.288554...0100000100
20.073350-0.0844130.0734550.651256-0.5170230.9844150.8299300.3220630.092875-0.288554...0100000100
30.309753-0.413838-0.0968640.651256-0.517023-1.862993-0.720051-0.572637-0.499103-0.288554...0101000000
40.0733500.5744360.3750201.374324-0.5170230.9513060.7330561.3603570.463410-0.288554...0100000100
..................................................................
14550.073350-0.331482-0.260471-0.071812-0.5170230.9181960.733056-0.572637-0.972685-0.288554...0100000100
1456-0.8722640.6156140.266316-0.0718120.3816120.2228990.1518130.0845810.7593990.721865...0100000100
14570.309753-0.166770-0.1477600.6512563.077516-1.0021491.023678-0.572637-0.369744-0.288554...0100000100
1458-0.872264-0.084413-0.080133-0.7948790.381612-0.7041640.539309-0.572637-0.8652526.090101...0100000100
1459-0.8722640.203833-0.058092-0.7948790.381612-0.207523-0.962236-0.5726370.8470991.509123...0100000100

1460 rows × 332 columns

到此数据处理完成,补充空值缺失值,数据标准化。

导出处理后训练数据数据

tx.to_csv("../data_process.csv",index=False)

线性回归训练

y = torch.tensor(pd.DataFrame(tx['SalePrice']).values, dtype=torch.float)
tx.drop(['SalePrice'],axis=1)
x =  torch.tensor(tx.values, dtype=torch.float)
x,y
(tensor([[ 0.0733, -0.2079, -0.2071,  ...,  1.0000,  0.0000,  0.0000],
         [-0.8723,  0.4097, -0.0919,  ...,  1.0000,  0.0000,  0.0000],
         [ 0.0733, -0.0844,  0.0735,  ...,  1.0000,  0.0000,  0.0000],
         ...,
         [ 0.3098, -0.1668, -0.1478,  ...,  1.0000,  0.0000,  0.0000],
         [-0.8723, -0.0844, -0.0801,  ...,  1.0000,  0.0000,  0.0000],
         [-0.8723,  0.2038, -0.0581,  ...,  1.0000,  0.0000,  0.0000]]),
 tensor([[ 0.3472],
         [ 0.0073],
         [ 0.5360],
         ...,
         [ 1.0772],
         [-0.4884],
         [-0.4207]]))
class Net(nn.Module):
    def __init__(self, data_in, l1, l2, l3, data_out):
        super(Net, self).__init__()
        
        self.linear1 = nn.Linear(data_in, l1)
        self.linear2 = nn.Linear(l1, l2)
        self.linear3 = nn.Linear(l2, l3)
        self.linear4 = nn.Linear(l3, data_out)
        
    def forward(self, x):
        y_pred = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(y_pred).clamp(min=0)
        y_pred = self.linear3(y_pred).clamp(min=0)
        y_pred = self.linear4(y_pred)
        return y_pred

# def get_net(feature_num):
#     net = nn.Linear(feature_num, 1)
#     for param in net.parameters():
#         nn.init.normal_(param, mean=0, std=0.01)
#     return net

l1, l2, l3 = 500, 1000, 200
data_in = x.shape[1]
data_out = y.shape[1]
model = Net(data_in,l1,l2,l3,data_out)
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4*2, weight_decay=0)
losses1 = []
y_p = y
for t in range(500):
    y_pred = model(x)
    
    loss = criterion(y_pred, y)
    print(t, loss.item())
    losses1.append(loss.item())
    
    if torch.isnan(loss):
        break
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    y_p=y_pred
0 1459.40771484375
1 1392.9884033203125
2 1327.1253662109375
3 1258.0968017578125
4 1184.120361328125
5 1104.348388671875
6 1018.5840454101562
7 927.036865234375
8 830.722412109375
9 731.9678344726562
10 634.2273559570312
11 542.0256958007812
12 461.3277587890625
13 398.6611328125
14 359.0272216796875
15 342.24737548828125
16 341.33612060546875
17 342.6305847167969
18 331.4631042480469
19 303.4851989746094
20 266.21112060546875
21 230.264892578125
22 202.28729248046875
23 185.4599609375
24 179.8956298828125
25 179.50994873046875
26 177.8726043701172
27 173.0084686279297
28 163.7674560546875
29 150.3192596435547
30 136.2794647216797
31 125.10523986816406
32 117.6074447631836
33 114.28300476074219
34 113.25989532470703
35 110.12255859375
36 103.99020385742188
37 95.72050476074219
38 87.1082992553711
39 80.95310974121094
40 77.13907623291016
41 74.54013061523438
42 71.68026733398438
43 66.95227813720703
44 61.12464904785156
45 55.19034957885742
46 50.59487533569336
47 48.036521911621094
48 46.413204193115234
49 44.63623046875
50 41.55665969848633
51 38.00105285644531
52 34.835899353027344
53 32.89207458496094
54 31.564334869384766
55 30.114761352539062
56 28.005346298217773
57 25.78139877319336
58 24.080860137939453
59 23.243757247924805
60 22.58363151550293
61 21.547937393188477
62 19.998720169067383
63 18.64838409423828
64 17.81045913696289
65 17.302536010742188
66 16.604515075683594
67 15.583803176879883
68 14.638494491577148
69 13.992101669311523
70 13.54787826538086
71 12.961628913879395
72 12.189724922180176
73 11.536569595336914
74 11.120109558105469
75 10.790014266967773
76 10.336910247802734
77 9.772478103637695
78 9.308484077453613
79 8.999449729919434
80 8.685918807983398
81 8.281786918640137
82 7.883256435394287
83 7.582124710083008
84 7.327266216278076
85 7.02606725692749
86 6.69525671005249
87 6.422881126403809
88 6.215458393096924
89 5.992396831512451
90 5.7315449714660645
91 5.494955539703369
92 5.311194896697998
93 5.127244472503662
94 4.918134689331055
95 4.730138301849365
96 4.582120418548584
97 4.431356430053711
98 4.260013103485107
99 4.102842807769775
100 3.9737141132354736
101 3.843905210494995
102 3.707719326019287
103 3.601248025894165
104 3.551821231842041
105 3.592027187347412
106 3.8266773223876953
107 4.379158973693848
108 4.993258476257324
109 4.778003692626953
110 3.5360450744628906
111 2.769289970397949
112 3.2857789993286133
113 3.8392574787139893
114 3.2963366508483887
115 2.522071599960327
116 2.681736946105957
117 3.1372625827789307
118 2.859166145324707
119 2.2857251167297363
120 2.32077693939209
121 2.6420400142669678
122 2.495882511138916
123 2.0810556411743164
124 2.0285420417785645
125 2.243170738220215
126 2.2124433517456055
127 1.9245238304138184
128 1.7894601821899414
129 1.8995769023895264
130 1.9591234922409058
131 1.8102812767028809
132 1.6313799619674683
133 1.6106034517288208
134 1.6833714246749878
135 1.6782467365264893
136 1.560412883758545
137 1.4432872533798218
138 1.4171345233917236
139 1.4503400325775146
140 1.4559720754623413
141 1.3967015743255615
142 1.3085331916809082
143 1.246904969215393
144 1.2313358783721924
145 1.2399450540542603
146 1.2392830848693848
147 1.2113351821899414
148 1.1614681482315063
149 1.1072932481765747
150 1.0640703439712524
151 1.037292242050171
152 1.0241785049438477
153 1.0183836221694946
154 1.0139071941375732
155 1.007814645767212
156 0.9994260668754578
157 0.990032434463501
158 0.9811416864395142
159 0.9755876660346985
160 0.9764101505279541
161 0.9885083436965942
162 1.0174137353897095
163 1.073725700378418
164 1.1688017845153809
165 1.3198115825653076
166 1.5327341556549072
167 1.7964661121368408
168 2.026193380355835
169 2.090864419937134
170 1.843663215637207
171 1.3451727628707886
172 0.854713499546051
173 0.6517972946166992
174 0.7802258729934692
175 1.0613033771514893
176 1.2591156959533691
177 1.2204216718673706
178 0.978543758392334
179 0.7041733860969543
180 0.573451042175293
181 0.6265032887458801
182 0.773800253868103
183 0.8895548582077026
184 0.8911064863204956
185 0.7840784788131714
186 0.6346603035926819
187 0.5257025957107544
188 0.4978765845298767
189 0.5393790602684021
190 0.6085287928581238
191 0.6624920964241028
192 0.6769289970397949
193 0.647126317024231
194 0.5881399512290955
195 0.5204046368598938
196 0.4636395573616028
197 0.4281044602394104
198 0.41511157155036926
199 0.41996967792510986
200 0.4363104999065399
201 0.45899271965026855
202 0.485185444355011
203 0.5146241188049316
204 0.5484604239463806
205 0.5900295376777649
206 0.6420221924781799
207 0.7098633646965027
208 0.7947096824645996
209 0.9000582695007324
210 1.0158016681671143
211 1.1308093070983887
212 1.2065943479537964
213 1.2124741077423096
214 1.1087749004364014
215 0.9090127944946289
216 0.6591584086418152
217 0.4416056275367737
218 0.31704744696617126
219 0.3027245104312897
220 0.3727879524230957
221 0.4814433455467224
222 0.584940493106842
223 0.6494030952453613
224 0.6602392196655273
225 0.6150671243667603
226 0.531480073928833
227 0.43209004402160645
228 0.3428885340690613
229 0.2804076671600342
230 0.25067248940467834
231 0.2501675486564636
232 0.27077382802963257
233 0.30417877435684204
234 0.34413743019104004
235 0.38745465874671936
236 0.43279358744621277
237 0.4812193214893341
238 0.5328287482261658
239 0.5897232294082642
240 0.6497997045516968
241 0.7126964330673218
242 0.7697558999061584
243 0.8151849508285522
244 0.8324537873268127
245 0.814405083656311
246 0.7496241331100464
247 0.6463999152183533
248 0.5170392394065857
249 0.3879045844078064
250 0.28106746077537537
251 0.2118593156337738
252 0.18318529427051544
253 0.18897591531276703
254 0.21899700164794922
255 0.26325783133506775
256 0.31462690234184265
257 0.36845412850379944
258 0.4235222041606903
259 0.47801244258880615
260 0.5328558683395386
261 0.5848822593688965
262 0.6335201859474182
263 0.671372652053833
264 0.6951640844345093
265 0.6943939924240112
266 0.6672643423080444
267 0.6086594462394714
268 0.5265052318572998
269 0.42876332998275757
270 0.3318851888179779
271 0.24785339832305908
272 0.18642649054527283
273 0.15023303031921387
274 0.13721497356891632
275 0.14248286187648773
276 0.16075286269187927
277 0.18802016973495483
278 0.2220512330532074
279 0.26306045055389404
280 0.31244951486587524
281 0.3743334412574768
282 0.4523894190788269
283 0.5534863471984863
284 0.6802131533622742
285 0.8368046879768372
286 1.0104289054870605
287 1.1839443445205688
288 1.3037998676300049
289 1.322584629058838
290 1.1820666790008545
291 0.9010692834854126
292 0.5517466068267822
293 0.26055294275283813
294 0.1182369738817215
295 0.14159783720970154
296 0.27602139115333557
297 0.435191810131073
298 0.5431021451950073
299 0.5544558763504028
300 0.47308725118637085
301 0.3364395797252655
302 0.20292562246322632
303 0.11799780279397964
304 0.09933719784021378
305 0.1355438232421875
306 0.1990634649991989
307 0.2611197829246521
308 0.3003968298435211
309 0.3084506690502167
310 0.2861405909061432
311 0.2437874674797058
312 0.19322504103183746
313 0.14659611880779266
314 0.1114949956536293
315 0.09117703139781952
316 0.08490884304046631
317 0.08974863588809967
318 0.10214374959468842
319 0.11904925853013992
320 0.1385415494441986
321 0.15968437492847443
322 0.1828000545501709
323 0.20865221321582794
324 0.23935003578662872
325 0.2769160866737366
326 0.32513898611068726
327 0.38705378770828247
328 0.46830886602401733
329 0.5713444948196411
330 0.7010533213615417
331 0.8512529730796814
332 1.0150606632232666
333 1.1586813926696777
334 1.2495845556259155
335 1.2290078401565552
336 1.078914999961853
337 0.8049017786979675
338 0.4874662458896637
339 0.22077317535877228
340 0.08165990561246872
341 0.08536580204963684
342 0.1916998028755188
343 0.3331620395183563
344 0.4439804255962372
345 0.48394155502319336
346 0.4413498044013977
347 0.3397776782512665
348 0.21818627417087555
349 0.11873447149991989
350 0.06674043834209442
351 0.06599300354719162
352 0.10284645855426788
353 0.1559557020664215
354 0.2054419368505478
355 0.23710842430591583
356 0.24594929814338684
357 0.2320968061685562
358 0.20245851576328278
359 0.1643735021352768
360 0.12634292244911194
361 0.09397175163030624
362 0.07068207114934921
363 0.057136908173561096
364 0.05230645462870598
365 0.05429523438215256
366 0.06113283708691597
367 0.07127126306295395
368 0.08378728479146957
369 0.09850577265024185
370 0.11586176604032516
371 0.1371409147977829
372 0.1641264408826828
373 0.20001165568828583
374 0.24851566553115845
375 0.3162808120250702
376 0.4099084436893463
377 0.5412055850028992
378 0.7180109620094299
379 0.9529534578323364
380 1.2352896928787231
381 1.5471041202545166
382 1.8024611473083496
383 1.9113022089004517
384 1.7411528825759888
385 1.2997612953186035
386 0.7103318572044373
387 0.22983340919017792
388 0.04464929923415184
389 0.1671920120716095
390 0.4477957785129547
391 0.6815035343170166
392 0.7275214195251465
393 0.5641263127326965
394 0.3039877712726593
395 0.09819929301738739
396 0.04379488155245781
397 0.12985694408416748
398 0.2676863670349121
399 0.3594537675380707
400 0.351449191570282
401 0.2587491571903229
402 0.13907583057880402
403 0.05620124191045761
404 0.04084179550409317
405 0.0822826400399208
406 0.14384561777114868
407 0.1873161792755127
408 0.19239096343517303
409 0.16035300493240356
410 0.1101120337843895
411 0.06403486430644989
412 0.03829789161682129
413 0.03703799843788147
414 0.054030437022447586
415 0.07800006121397018
416 0.09811682999134064
417 0.10786034911870956
418 0.10537803173065186
419 0.09330643713474274
420 0.0760207325220108
421 0.05843013897538185
422 0.044070057570934296
423 0.03486529737710953
424 0.03107019141316414
425 0.03183189406991005
426 0.03578619286417961
427 0.04155302420258522
428 0.048052363097667694
429 0.054569993168115616
430 0.0608607642352581
431 0.06691505014896393
432 0.07308138161897659
433 0.07974386215209961
434 0.08765038102865219
435 0.09747633337974548
436 0.11049649864435196
437 0.12794305384159088
438 0.152177631855011
439 0.18567615747451782
440 0.23310738801956177
441 0.2993691861629486
442 0.3934844434261322
443 0.5234649777412415
444 0.7036969661712646
445 0.9385809898376465
446 1.2359614372253418
447 1.561197280883789
448 1.8691719770431519
449 2.028472900390625
450 1.9459717273712158
451 1.5387744903564453
452 0.9330716133117676
453 0.35589325428009033
454 0.05003536865115166
455 0.0889483094215393
456 0.3566819131374359
457 0.6413065791130066
458 0.7551127672195435
459 0.6417924165725708
460 0.37786367535591125
461 0.13001208007335663
462 0.026230057701468468
463 0.08645960688591003
464 0.23000115156173706
465 0.3440704345703125
466 0.357770711183548
467 0.2699103057384491
468 0.1423608809709549
469 0.047258615493774414
470 0.02545809932053089
471 0.06855352222919464
472 0.13506144285202026
473 0.18072697520256042
474 0.18125350773334503
475 0.1416153907775879
476 0.08529476821422577
477 0.039876788854599
478 0.0217688400298357
479 0.031408119946718216
480 0.05699583888053894
481 0.08281167596578598
482 0.09696558117866516
483 0.09485352784395218
484 0.07945433259010315
485 0.057723499834537506
486 0.03747238963842392
487 0.024086352437734604
488 0.019536178559064865
489 0.022679146379232407
490 0.0305526964366436
491 0.039793986827135086
492 0.04764045029878616
493 0.05252092704176903
494 0.05387616530060768
495 0.05217081680893898
496 0.04817085713148117
497 0.04296367987990379
498 0.03738382086157799
499 0.03214520961046219

训练结果展示

import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
# 上面这个是处理jupyter notebook内核挂掉的,不必在意
# 横坐标为迭代次数,纵坐标为loss值
plt.figure(figsize=(12, 10))
plt.plot(range(len(losses1)), losses1)
plt.show()


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YjR40Hbg-1646578210620)(file/output_30_0.png)]

y_p.shape,x.shape
(torch.Size([1460, 1]), torch.Size([1460, 332]))

模型预测

时间有限,目前仅仅只简单训练,暂时不考虑优化或者其他操作

同样先处理数据

test_data.isnull().sum()
Id                 0
MSSubClass         0
MSZoning           4
LotFrontage      227
LotArea            0
                ... 
MiscVal            0
MoSold             0
YrSold             0
SaleType           1
SaleCondition      0
Length: 80, dtype: int64
continuous_colmuns.remove('SalePrice')
test_data[continuous_colmuns] = test_data[continuous_colmuns].apply(lambda x: (x - x.mean())/(x.std()))
test_data[continuous_colmuns]=test_data[continuous_colmuns].fillna(0)
test_data[continuous_colmuns].isnull().sum()
MSSubClass       0
LotArea          0
OverallQual      0
OverallCond      0
YearBuilt        0
YearRemodAdd     0
BsmtFinSF1       0
BsmtFinSF2       0
BsmtUnfSF        0
TotalBsmtSF      0
1stFlrSF         0
2ndFlrSF         0
LowQualFinSF     0
GrLivArea        0
BsmtFullBath     0
BsmtHalfBath     0
FullBath         0
HalfBath         0
BedroomAbvGr     0
KitchenAbvGr     0
TotRmsAbvGrd     0
Fireplaces       0
GarageCars       0
GarageArea       0
WoodDeckSF       0
OpenPorchSF      0
EnclosedPorch    0
3SsnPorch        0
ScreenPorch      0
PoolArea         0
MiscVal          0
MoSold           0
YrSold           0
LotFrontage      0
MasVnrArea       0
GarageYrBlt      0
dtype: int64
# 这一步为了让网络参数统一
td=test_data
td = pd.get_dummies(test_data, dummy_na=True)
for col in tx.columns:
	if col not in td:
		td[col] = 0
td = td.drop(['Id'],axis=1)
# 预测
pred_y = model(torch.tensor(td.values, dtype=torch.float))
pred_y
tensor([[-0.3723],
        [-0.1471],
        [-0.0315],
        ...,
        [-0.1320],
        [-0.4483],
        [-0.0833]], grad_fn=<AddmmBackward0>)
res = pd.DataFrame(pred_y.data.numpy(), columns=['SalePrice'])
res['SalePrice']
0      -0.372343
1      -0.147053
2      -0.031502
3       0.103289
4      -0.258128
          ...   
1454   -0.568579
1455   -0.618316
1456   -0.132024
1457   -0.448310
1458   -0.083297
Name: SalePrice, Length: 1459, dtype: float32
# 计算房价

res['SalePrice'] = res['SalePrice'] * (d_max - d_min) + d_mean
res
# 预测值有点离谱,之后有时间找找问题再重新修正
SalePrice
0-87203.301559
175028.274554
2158236.460120
3255299.261917
4-4957.003893
......
1454-228512.597976
1455-264328.452371
145685850.581086
1457-141906.568785
1458120939.053495

1459 rows × 1 columns

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值