本篇代码、数据集来源于李宏毅老师的HW1
本文通过李老师的第一作业以及提供的参考代码来进行Pytorch入门。本文为入门文,不会涉及具体的网络设计。
当我们想使用数据训练一个模型的时候,其实主要分为两个步骤:读取数据、训练模型。那么我们就按照这个步骤进行pytorch使用入门。
读取模型
1、使用dataset和dataloader来进行数据读取
这是我在参考代码中看到的使用方法,应该也是比较推荐的使用方法。(以下读取已经经过简化,去掉了一些特殊的数据处理)
from torch.utils.data import Dataset, DataLoader
import numpy as np
class COVID19Dataset(Dataset):
''' Dataset for loading and preprocessing the COVID19 dataset '''
def __init__(self,
path):
#根据路径读取所需数据(使用pandas)
df = pd.read_csv(path)
#需要将数据转化为pytorch所需的格式
data = torch.tensor(df.values, dtype=torch.float)
#第一列为ID,无用数据,去除
data = data[:,1:]
#这里可以取所有列,也可以经过一些筛选,只使用有用的列
feats = list(range(93))
self.target = data[:, -1]
self.data = data[:, feats]
def __getitem__(self, index):
# 必须要实现的魔术方法,用于训练模型时返回数据
return self.data[index], self.target[index]
def __len__(self):
# 返回数据长度,后面有使用到这个方法
return len(self.data)
#再使用dataloader来实现打乱数据,批次读取等效果
batch_size = 100
train_ds = DataLoader(ds, batch_size=batch_size, shuffle=True)
dev_ds = DataLoader(ds, batch_size=batch_size, shuffle=True)
2、更直接地方法
参考文章
也可以不使用,直接自己实现打乱,按批次读取的效果
import pandas as pd
import torch
from torch import nn
path