最近在学b站刘二大人的pytorch教程。将第八讲的课后作业写了下,在此记录。
我也是初学者,写代码常常在写bug,代码也不够简洁,参考了其他同学的课后作业。
如果有幸被您看到,希望可以轻喷,我们一起讨论~
0.数据集准备
科学上网,如果可以的话,上面有一些如何简单处理数据集的代码。
https://www.kaggle.com/competitions/titanic/overview
1.思路提纲
1.数据处理
①性别(sex)的字符串转为整数
male用1表示,female用0表示
df = pd.read_csv(filepath, header=0)
df.replace('male', 1, inplace=True)
df.replace('female', 0, inplace=True)
方法来自这里
②年龄(age)数据有缺失值
使用该列均值填充
df = df.fillna(df.mean())
2.数据格式的转换
pandas的dataframe→numpy的array(float32)→pytorch提供的tensor
【数据集你的转换搞的我代码好复杂好乱QAQ】
xy = df.iloc[:, [1,2,4,5,6,7,9]] #取所需的整数数据列
xy = (np.array(xy)).astype(np.float32)
2.代码呈现
#0.引用库
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import matplotlib.pyplot as plt
#1.准备数据集
class Train_Dataset(Dataset):
def __init__(self, filepath):
df = pd.read_csv(filepath, header=0)#读取数据
df.replace