用线性回归预测死亡年龄
使用世界卫生组织预期寿命数据集:
https://www.kaggle.com/kumarajarshi/life-expectancy-who/data#
import torch
import torchvision
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torch.utils.data import DataLoader, TensorDataset, random_split
import pickle
DATASET_URL = "https://raw.githubusercontent.com/Federico-abss/pytorch_gans/master/datasets/life_expectancy_data.csv"
DATA_FILENAME = "life_expectancy_data.csv"
input_size=5
output_size=1
# Download the data
download_url(DATASET_URL, '.')
dataframe = pd.read_csv(DATA_FILENAME)
dataframe.describe()
清除数据
有些单元格名称有尾随空格或由两个独立的名称组成,我将修复这一点,
# Renaming some column names as they contain trailing spaces.
dataframe.rename(columns={" BMI ":"BMI","Life expectancy ":"Life_Expectancy","Adult Mortality":"Adult_Mortality",
"infant deaths":"Infant_Deaths","percentage expenditure":"Percentage_Exp","Hepatitis B":"HepatitisB",
"Measles ":"Measles"," BMI ":"BMI","under-five deaths ":"Under_Five_Deaths","Diphtheria ":"Diphtheria",
" HIV/AIDS":"HIV/AIDS"," thinness 1-19 years":"thinness_1to19_years"," thinness 5-9 years":"thinness_5to9_years","Income composition of resources":"Income_Comp_Of_Resources",
"Total expenditure":"Tot_Exp"},inplace=True)
#在列表中保存国家
country_list = dataframe.Country.unique()
#包含不完整列的列表
fill_list = ['Life_Expectancy','Adult_Mortality','Alcohol','HepatitisB','BMI','Polio','Tot_Exp','Diphtheria','GDP','Population','thinness_1to19_years','thinness_5to9_years','Income_Comp_Of_Resources','Schooling']
# 使用插值处理空值.
for country in country_list:
dataframe.loc[dataframe['Country'] == country,fill_list] = dataframe.loc[dataframe['Country'] == country,fill_list].interpolate()
# 插值后删除剩余空值。
dataframe.dropna(inplace=True)
数据分析
我们将寻找相关的相关性,以决定我们将使用哪些列来训练我们的模型
让我们从一些关于我们数据的一般统计开始
plt.figure(figsize=(4,5))
plt.bar(dataframe.groupby('Status')['Status'].count().index,dataframe.groupby('Status')['Life_Expectancy'].mean(),color='blue',alpha=0.50)
plt.xlabel("Country Status",fontsize=12)
plt.ylabel("Avg Life_Expectancy",fontsize=12)
plt.title("Life_Expectancy w.r.t Country Development")
plt.show()
# 使用条形图的预期寿命w.r.t年。
plt.figure(figsize=(7,5))
plt.bar(dataframe.groupby('Year')['Year'].count().index,dataframe.groupby('Year')['Life_Expectancy'].mean(),color='pink',alpha=0.65)
plt.xlabel("Year",fontsize=12)
plt.ylabel("Avg Life_Expectancy",fontsize=12)
plt.title("Life_Expectancy w.r.t Year")
plt.show()
因此,发达国家与其他国家之间存在着明显的差异,但至少预期寿命每年都在提高。
现在让我们试着找出哪些特征对预期寿命影响最大。
plt.figure(figsize=(7,5))
plt.title("LifeExpectancy w.r.t Income_Comp_Of_Resources")
plt.xlabel("Income",fontsize=12)
plt.ylabel("Life_Expectancy",fontsize=12)
plt.scatter(dataframe["Income_Comp_Of_Resources"], dataframe["Life_Expectancy"])
代码下载
有2个文件夹:
1.数据集文件夹由数据和一个自述文件组成,其中提到了数据源。
2.模型文件夹包含一个详细解释一切的Jupyter笔记本。它可以在水蟒身上运行。该文件夹还包含jupyter笔记本中实现的模型的pkl文件。
链接:https://pan.baidu.com/s/1h5KTg3C337GJEzFdAG3kIA
提取码:tkzk
复制这段内容后打开百度网盘手机App,操作更方便哦