Datawhale AI夏令营
Task 2 抽丝剥茧——降水预测baseline详解
分析数据集特征
- feature_path和gt_path是官方提供的train.xxx数据和gt.xxx数据存放的路径,挑选想尝试的数据集,并更改列表中相应的字符串,Feature类和GroundTruth类是数据集的定义 方便后续自定义数据集和数据加载类.
class Feature:
def __init__(self):
self.path = feature_path
self.years = years
self.fcst_steps = fcst_steps
self.features_paths_dict = self.get_features_paths()
def get_features_paths(self):
init_time_path_dict = {}
for year in self.years:
init_time_dir_year = os.listdir(os.path.join(self.path, year))
for init_time in sorted(init_time_dir_year):
init_time_path_dict[pd.to_datetime(init_time)] = os.path.join(self.path, year, init_time)
return init_time_path_dict
def get_fts(self, init_time):
return xr.open_mfdataset(self.features_paths_dict.get(init_time) + '/*').sel(lead_time=self.fcst_steps).isel(
time=0)
class GT:
def __init__(self):
self.path = gt_path
self.years = years
self.fcst_steps = fcst_steps
self.gt_paths = [os.path.join(self.path, f'{year}.nc') for year in self.years]
self.gts = xr.open_mfdataset(self.gt_paths)
def parser_gt_timestamps(self, init_time):
return [init_time + pd.Timedelta(f'{fcst_step}h') for fcst_step in self.fcst_steps]
def get_gts(self, init_time):
return self.gts.sel(time=self.parser_gt_timestamps(init_time))
feature_path = '/mnt/workspace/AICamp_earth_baseline/feature/'
gt_path = '/mnt/workspace/AICamp_earth_baseline/truth'
years = ['2020']
fcst_steps = list(range(1, 73, 1))
baseline结构解读
- 定义数据集, 建立起训练数据和标签之间的关系;定义数据加载器(DataLoader), 方便取数据进行训练
- 定义模型, 利用PyTorch搭建网络,根据输入输出数据维度实例化模型
- 定义损失函数, 优化器, 训练周期, 训练模型并保存模型参数
- 模型加载及推理(模型预测),输入测试数据输出要提交的文件
问题总结
- 数据导入不足 导致梯度消失;因为平台无法下载数据集,本地上传很慢,所以先上传了小部分数据验证脚本可行性,随后出现梯度消失的问题。
- 第二个问题是继续上传特征数据,编写上传脚本f1运行,发现再次训练模型时出错,是脚本运行产生了隐藏文件,藏在特征文件夹中导致无法识别。