首先需要说明这是一次失败的尝试。最初在云平台上运行代码,结果在解压数据时由于空间不足,导致无法解压。之后将数据集都下载到本地,又因为环境、版本冲突导致库的安装一直存在问题,尝试许久无法解决。因此仅在云平台,采用部分数据,进行了模型的训练。
第一步:安装所需库和下载数据集
需要注意链接的变化
第二步:导入所需库
import os
import pandas as pd
import xarray as xr
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
第三步:数据集路径配置
feature_path = 'E:\\comp3\\earth_baseline\\feature' # 修改为自己的路径
gt_path = 'E:\\comp3\\earth_baseline\\groundtruth' # 修改为自己的路径
years = ['2021']
fcst_steps = list(range(1, 73, 1))
第四步:定义数据集类
class Feature:
def __init__(self):
self.path = feature_path
self.years = years
self.fcst_steps = fcst_steps
self.features_paths_dict = self.get_features_paths()
def get_features_paths(self):
init_time_path_dict = {}
for year in self.years:
init_time_dir_year = os.listdir(os.path.join(self.path, year))
for init_time in sorted(init_time_dir_year):
init_time_path_dict[pd.to_datetime(init_time)] = os.path.join(self.path, year, init_time)
return init_time_path_dict
def get_fts(self, init_time):
return xr.open_mfdataset(self.features_paths_dict.get(init_time) + '/*').sel(lead_time=self.fcst_steps).isel(time=0)
class GT:
def __init__(self):
self.path = gt_path
self.years = years
self.fcst_steps = fcst_steps
self.gt_paths = [os.path.join(self.path, f'{year}.nc') for year in self.years]
self.gts = xr.open_mfdataset(self.gt_paths)
def parser_gt_timestamps(self, init_time):
return [init_time + pd.Timedelta(f'{fcst_step}h') for fcst_step in self.fcst_steps]
def get_gts(self, init_time):
timestamps = self.parser_gt_timestamps(init_time)
try:
return self.gts.sel(time=timestamps)
except KeyError as e:
print(f"KeyError: {e}")
print(f"Adjusting timestamps for init_time: {init_time}")
adjusted_timestamps = self.adjust_time(timestamps)
return self.gts.sel(time=adjusted_timestamps)
def adjust_time(self, timestamps):
available_times = self.gts.time.values
adjusted_timestamps = []
for ts in timestamps:
if ts in available_times:
adjusted_timestamps.append(ts)
else:
adjusted_timestamps.append(available_times[-1])
return adjusted_timestamps
第五步:定义自定义数据集类
class mydataset(Dataset):
def __init__(self):
self.ft = Feature()
self.gt = GT()
self.features_paths_dict = self.ft.features_paths_dict
self.init_times = list(self.features_paths_dict.keys())
def __getitem__(self, index):
init_time = self.init_times[index]
try:
ft_item = self.ft.get_fts(init_time).to_array().isel(variable=0).values
gt_item = self.gt.get_gts(init_time).to_array().isel(variable=0).values
except KeyError as e:
print(e)
print(f'init_time: {init_time} not found')
return self.__getitem__((index + 1) % len(self.init_times))
return ft_item, gt_item
def __len__(self):
return len(list(self.init_times))
第六步:查看数据集
my_data = mydataset()
print('sample num:', mydataset().__len__())
train_loader = DataLoader(my_data, batch_size=1, shuffle=True)
第七步:定义模型
class Model(nn.Module):
def __init__(self, num_in_ch, num_out_ch):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1)
def forward(self, x):
B, S, C, W, H = tuple(x.shape)
x = x.reshape(B, -1, W, H)
out = self.conv1(x)
out = out.reshape(B, S, W, H)
return out
in_varibales = 24
in_times = len(fcst_steps)
out_varibales = 1
out_times = len(fcst_steps)
input_size = in_times * in_varibales
output_size = out_times * out_varibales
model = Model(input_size, output_size).cuda()
第八步:定义损失函数
loss_func = nn.MSELoss()
第九步:训练模型
num_epochs = 1
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
for index, (ft_item, gt_item) in enumerate(train_loader):
ft_item = ft_item.cuda().float()
gt_item = gt_item.cuda().float()
# Forward pass
output_item = model(ft_item)
loss = loss_func(output_item, gt_item)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (index+1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{index+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
torch.save(model.state_dict(), 'model_weights.pth')
第十步:模型推理
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
test_data_path = "test/weather.round1.test"
os.makedirs("./output", exist_ok=True)
for index, test_data_file in enumerate(os.listdir(test_data_path)):
test_data = torch.load(os.path.join(test_data_path, test_data_file))
test_data = test_data.cuda().float()
# Forward pass
output_item = model(test_data)
print(f"Output shape for sample {test_data_file.split('.')[0]}: {output_item.shape}")
output_path = f"output/{test_data_file}"
torch.save(output_item.cpu(), output_path)
!zip -r output.zip output
改进方向
-
数据预处理优化:当前代码在数据加载过程中可能会遇到
KeyError
,应增加更好的异常处理机制,避免重复获取失败的数据。 -
模型复杂度:当前模型只有一个卷积层,可能无法充分捕捉复杂的气象数据特征。建议尝试增加卷积层或引入其他更复杂的模型结构,如 U-Net 等。
-
训练过程监控:增加训练过程的详细日志记录和可视化工具,如 TensorBoard,以便更好地监控和分析训练过程中的模型性能。
-
超参数调优:当前训练仅进行了一个 epoch,且学习率固定为 0.001。建议尝试更多的超参数组合,并进行交叉验证,以找到最佳的模型参数配置。
-
数据增强:引入数据增强技术,增加训练数据的多样性,以提高模型的泛化能力。
-
模型评估:增加更多的评估指标和测试集,以全面评估模型的性能,并在实际应用中进行验证。