一、baseline代码的大致结构
-
导入库:首先,代码导入了需要用到的库,包括
torch
和处理气象数据必要的xarray
,用于处理结构数据的pandas和常用操作系统库os -
数据集构建:代码通过使用
class Feature
类和class GT
类定义了从气象.nc文件中读取数据,同时通过class Dataset
类将训练数据和对应标签建立起对应关系, 最后使用torch.utils.data
中的DataLoader
定义数据加载工具, 方便我们在训练过程中获取数据。 -
定义模型和使用的损失函数:定义了只含有一层卷积的简单网络, 使用
MSE
作为损失函数, 特别注意模型的输入输出要根据赛题要求设计。 -
模型训练:完成优化器和训练周期的定义后, 我们就可以开始训练模型以便在数据上得到一个拟合程度最好的训练模型,模型训练时不要忘了保存模型参数文件
.pth
-
加载训练好的模型进行预测输出:加载第4步中的训练参数以后, 需要用model.eval()将模型置于推理模式, 然后我们就有了一个拟合程度相对不错的降水预测模型, 再把测试数据输入, 就可以生成预测结果~
伪代码:
# 1. 导入需要用到的相关库
import os
import torch
import pandas as pd
import xarray as xr
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# 2. 定义数据集
path = "" # 配置路径的设置..
class Feature:
pass
class GT:
pass
class mydataset(Dataset):
def __init__(self):
...
def __getitem__(self, index):
# 获取训练数据的方法, 同时将训练数据和真值建立联系
...
def __len__(self):
#获取数据集长度的方法
...
my_data = mydataset() # 初始化dataset
train_loader = DataLoader(my_data, batch_size=1, shuffle=True) # 定义dataloader
# 3. 定义模型和损失函数
class Model(nn.Module):
# 模型初始化
def __init__(self, *args):
...
# 定义前向传播函数
def forward(self, *args):
...
# 模型具体输入输出定义
input_chnl = ...
# loss定义
loss = nn.MSELoss()
# 4. 模型训练
num_epochs = 1 # 定义模型训练轮数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 定义优化器
# 从dataloader中取数据训练
for epoch in range(num_epochs):
for index, (ft_item, gt_item) in enumerate(train_loader):
...
torch.save(model.state_dict(), "xxx.pth") # 保存模型参数
# 5. 模型推理
model.load_state_dict(torch.load('model_weights.pth')) # 加载模型
model.eval() # 将模型置于推理状态
test_data_path = "xxx"
# 模型推理
for index, test_data_file in enumerate(os.listdir(test_data_path)):
...
二、pytorch概述
PyTorch是由Meta AI(Facebook)人工智能研究小组开发的一种基于Lua编写的Torch库的Python实现的深度学习库,目前被广泛应用于学术界和工业界。通过pytroch, 我们可以自由的搭建神经网络模型, 使之适配我们所需要的任务。
重要的工具类:
-
构建数据集的类Dataset
-
数据加载类的Dataloader
-
定义模型的类Model
pytorch学习网站:深入浅出PyTorch — 深入浅出PyTorch (datawhalechina.github.io)
深度学习:GitHub - datawhalechina/unusual-deep-learning: 水很深的深度学习