1 Task Description
给出了美国某个州的过去四天的调查结果,然后预测接下来第五天的概率。
- 其中数据在a.csv中
- 每行代表一个数据样本,包含118个特点
- 最后一行元素是标签
2 解决思路
2.1 Load data/Preprocessing
- Load data: 可以利用Pandas找到csv文件
train_data = pd.read_csv('./covid.train.csv').drop(columns=['data']).values
- preprocessing:得到模型输入和标签
x_train, y_train = train_data[:, :-1], train_data[:, -1]
2.2 Dataset
2.3 Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
- Group data into batches
- 如果设置shuffle=True,倒进的数据将自动重新标记
- 在训练中常设置shuffle=True
2.4 Model
- 输入维数是117
- 模型的输出是标量
2.5 Criterion(标准)
- 做一个回归模型,选择均方误差作为损失函数是一个方案
criterion = torch.nn.MSELoss(reduction='mean')
2.6 Optimizer
选用一个适应调节网络参数的优化方式减小误差
exmple:选用随即梯度下降作为优化算法
optomizer = torch.optiom.SGD(model.parameters(), 1r=1e-5,momentum=0.9)
2.7 循环训练(Training loop)
- 得到一个预测模型,计算梯度,更新参数,重新设置模型参数的梯度