import xarray as xr
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import ndimage
import matplotlib.pyplot as plt
ROOT_DIR='你的训练数据路径'
LEANING_RATE=1e-4
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE=16
SHUFFLE=False
NUM_WORKERS=0
EPOCHS=1000
class PPTDataset(Dataset):
def __init__(self,root_dir,train=True):
self.train_list=[2010,2011,2012,2013,2014,2015]
self.test_list=[2016]
self.root_dir=root_dir
if train:
self.arr=np.zeros(len(self.train_list)*365, 301, 620)
for idx,year in enumerate(self.train_list):
data=xr.open_dataset(f'{
root_dir}\\imerg_ppt_10km2_{
year}.nc')['ppt'].values[:365]
data[np.isnan(data)] = 0
self.arr