天池AI Earth挑战赛 PyTorch深度学习模型改进记录(代码版)

0、前言

文字说明参见链接

1、截至2月23日14:40(score=-19.6)

1.1、基本思路

1.2、数据预处理

代码段1.2.1

import pickle
from netCDF4 import Dataset
import numpy as np
import gc
import torch

with open('pkl/population_stats.pkl', 'rb') as pf0:
    pop_mean, pop_std = pickle.load(pf0)
file_path = 'C:\\Users\\13372\\Downloads\\Compressed\\enso_round1_train_20210201\\'
SODA_sample_path = file_path + 'SODA_train.nc'
CMIP_sample_path = file_path + 'CMIP_train.nc'
SODA_label_path = file_path + 'SODA_label.nc'
CMIP_label_path = file_path + 'CMIP_label.nc'
random_seed = 623


def train_val_split(sample, label, val_size, seed=random_seed):
    """
    :param sample: tuple,(year,12,24,72,4)
    :param label: tuple,(year,24)
    :param val_size: fraction of validation set, between 0 and 1
    :param seed: random seed
    """
    assert sample.ndim == 5 and label.ndim == 2 and val_size <= 1
    np.random.seed(seed)
    sample_num = sample.shape[0]
    shuffle_idx = np.arange(sample_num)
    np.random.shuffle(shuffle_idx)
    val_size = int(val_size * sample_num)
    x_train, x_val = sample[shuffle_idx[:-val_size]], \
                     sample[shuffle_idx[-val_size:]]
    t_train, t_val = label[shuffle_idx[:-val_size]], \
                     label[shuffle_idx[-val_size:]]
    ts = lambda x: torch.from_numpy(x).type(torch.float32)
    return {'x_train': ts(x_train), 'x_val': ts(x_val),
            't_train': ts(t_train), 't_val': ts(t_val)}


def sample_extract(sample_name, batch_id=None):
    extract = lambda x: np.expand_dims(np.array(x[:, :12, :, :]), axis=4)
    assert sample_name in ['SODA', 'CMIP', 'CMIP_batch']
    if sample_name == 'SODA':
        sample_path = SODA_sample_path
    else:
        sample_path = CMIP_sample_path
        if sample_name == 'CMIP_batch':
            extract = lambda x: np.expand_dims(np.array(x[batch_id, :12, :, :]), axis=4)
    sample_nc = Dataset(sample_path, 'r')
    sst = extract(sample_nc.variables['sst'])
    t300 = extract(sample_nc.variables['t300'])
    ua = extract(sample_nc.variables['ua'])
    va = extract(sample_nc.variables['va'])
    sample = np.concatenate((sst, t300, ua, va), axis=-1)
    return sample


def population_stats():
    population = np.concatenate((sample_extract('SODA'), sample_extract('CMIP')), axis=0)
    mean = np.nanmean(population.reshape(-1, 4), axis=0).reshape((1, 1, 1, 1, 4))
    std = np.nanstd(population.reshape(-1, 4), axis=0).reshape((1, 1, 1, 1, 4))
    with open('pkl/population_stats.pkl', 'wb') as pf1:
        pickle.dump((mean, std), pf1)


def nan_fill(data):
    nan = np.isnan(data)
    if nan.any():
        mean = pop_mean.ravel()
        for ft in range(4):
            data[nan[:, :, :, :, ft], ft] = mean[ft]
    return data


def SODA_gen():
    SODA_sample = sample_extract('SODA')
    SODA_sample = nan_fill(SODA_sample)
    SODA_sample = (SODA_sample - pop_mean) / pop_std

    SODA_label_nc = Dataset(SODA_label_path, 'r')
    SODA_label = (np.array(SODA_label_nc.variables['nino'])[:, 12:])
    assert not np.isnan(SODA_label).any()  # label无缺失值
    SODA_data = train_val_split(SODA_sample, SODA_label, val_size=0.3)
    del SODA_label, SODA_label_nc, SODA_sample
    gc.collect()

    with open('pkl/SODA_scaled.pkl', 'wb') as pf2:
        pickle.dump(SODA_data, pf2)

    del SODA_data
    gc.collect()


def CMIP_gen(i):
    total_idx = np.arange(4645)
    np.random.shuffle(total_idx)
    assert 1 <= i <= 13
    if i == 12:
        batch_id = total_idx[4400:4645]
    elif i == 13:
        batch_id = total_idx
    else:
        batch_id = total_idx[400 * (i - 1):400 * i]
    CMIP_sample = sample_extract('CMIP_batch', batch_id=batch_id)
    CMIP_sample = nan_fill(CMIP_sample)
    CMIP_sample = (CMIP_sample - pop_mean) / pop_std

    CMIP_label_nc = Dataset(CMIP_label_path, 'r')
    CMIP_label = (np.array(CMIP_label_nc.variables['nino'])[batch_id, 12:])

    CMIP_data = train_val_split(CMIP_sample, CMIP_label, val_size=0.3)
    del CMIP_label, CMIP_label_nc, CMIP_sample
    gc.collect()

    if i == 13:
        with open('pkl/SODA_scaled.pkl', 'rb') as pf1:
            SODA_data = pickle.load(pf1)
        for key, val in SODA_data.items():
            CMIP_data[key] = torch.cat((val, CMIP_data[key]), dim=0)

    with open('pkl/CMIP_scaled_' + str(i) + '.pkl', 'wb') as pf2:
        pickle.dump(CMIP_data, pf2)
    del CMIP_data
    gc.collect()


# population_stats()
CMIP_gen(13)

代码段1.2.2

import torch
import pickle
import gc
import numpy as np
from matplotlib import pyplot as plt

with open('pkl/CMIP_scaled_13.pkl', 'rb') as pf1:
    data = pickle.load(pf1)
labels = torch.cat((data['t_train'], data['t_val']), dim=0).numpy()
SODA_labels = np.vstack((labels[:70],labels[-1423:-1423+30]))
del data
gc.collect()
labels = labels.ravel()
SODA_labels = SODA_labels.ravel()
print('max:',labels.max(),'  min:',labels.min(),
      '  mean:',labels.mean(), '  std:', labels.std())
print('SODA_max:'+str(SODA_labels.max()), '  SODA_min:',
      SODA_labels.min(), '  mean:',SODA_labels.mean(), '  std:', SODA_labels.std())

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.hist(SODA_labels,color='orange',label='SODA',alpha=0.5)
ax1.set_ylabel('SODA')
ax2.hist(labels,color='blue',label='Total',alpha=0.5)
ax2.set_ylabel('Total')
handles1, labels1 = ax1.get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
plt.legend(handles1+handles2, labels1+labels2, loc='best')
plt.show()

1.3、网络结构

代码段1.3.1

class simpleSpatailTimeNN(nn.Module):
    def __init__(self, n_cnn_layer: int = 1, n_gru_units: int = 64, dropout=0.1):
        super(simpleSpatailTimeNN, self).__init__()
        self.conv_sst1 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_t3001 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_ua1 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_va1 = nn.Conv2d(12, 12, 3, padding=1)

        self.conv_sst2 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_t3002 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_ua2 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_va2 = nn.Conv2d(12, 12, 3, padding=1)
        self.pool1 = nn.Sequential(nn.MaxPool2d(6, 6),
                                   # nn.BatchNorm2d(12),
                                   nn.ReLU())
        self.pool2 = nn.Sequential(nn.MaxPool2d(4, 4),
                                   # nn.BatchNorm2d(12),
                                   nn.ReLU())
        self.batch_norm = nn.BatchNorm1d(12, affine=True)
        self.dropout = nn.Dropout(p=dropout)
        self.gru = nn.GRU(12, n_gru_units, num_layers=2)
        self.linear = nn.Linear(12 * n_gru_units, 24)

        for m in self.modules():
            if isinstance(m, nn.GRU):
                for layer in range(2):
                    for idx in range(2):
                        nn.init.orthogonal_(m.all_weights[layer][idx])
            elif isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_normal_(m.weight.data)

    def forward(self, x):
        sst = self.conv_sst1(x[:, :, :, :, 0])  # (b,12,24,72)
        t300 = self.conv_t3001(x[:, :, :, :, 1])
        ua = self.conv_ua1(x[:, :, :, :, 2])
        va = self.conv_va1(x[:, :, :, :, 3])  # (b,12,24,72)

        sst = self.pool1(sst)  # (b,12,4,12)
        t300 = self.pool1(t300)
        ua = self.pool1(ua)
        va = self.pool1(va)  # (b,12,4,12)

        sst = self.conv_sst2(sst)  # (b,12,4,12)
        t300 = self.conv_t3002(t300)
        ua = self.conv_ua2(ua)
        va = self.conv_va2(va)  # (b,12,4,12)

        sst = self.pool2(sst)  # (b,12,1,3)
        t300 = self.pool2(t300)
        ua = self.pool2(ua)
        va = self.pool2(va)  # (b,12,1,3)

        tmp = (-1, 12, 3)  # 将经纬度特征拉平
        x = torch.cat([sst.reshape(tmp), t300.reshape(tmp),
                       ua.reshape(tmp), va.reshape(tmp)], dim=-1)  # (b,12,12)
        x = self.dropout(x)
        x, _ = self.gru(x)  # (b,12,n_gru_units)
        x = x.reshape(-1, 12 * n_gru_units)  # (b,12*n_gru_units)
        x = self.linear(x)  # (b,24)
        return x

1.4、训练网络

代码段1.4.1

class MyMSELoss(nn.Module):
    def __init__(self, device):
        super(MyMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='none')
        self.a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6
        self.a = torch.Tensor(self.a).to(device)
        self.i = torch.arange(1, 25).type(torch.float32).to(device)

    def forward(self, y, t):
        loss = self.criterion(y, t)
        return (loss * torch.log(self.i) * self.a).mean()

代码段1.4.2

def time_series_score(y, t):
    """
    https://tianchi.aliyun.com/competition/entrance/531871/information
    """
    accskill_score = 0
    y = y.cpu().detach().numpy()
    t = t.cpu().detach().numpy()
    y_mean = y.mean(axis=0)
    t_mean = t.mean(axis=0)
    a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6
    for idx in range(24):
        d_t = t[:, idx] - t_mean[idx]
        d_y = y[:, idx] - y_mean[idx]
        fenzi = np.sum(d_t * d_y)
        fenmu = np.sqrt(np.sum(d_t ** 2) * np.sum(d_y ** 2))
        cor_i = fenzi / fenmu
        accskill_score += a[idx] * np.log(idx + 1) * cor_i
    rmse_score = np.mean((y - t) ** 2, axis=1).sum()
    return 2 / 3.0 * accskill_score - rmse_score

代码段1.4.3

# hyper-meters setting
batch_size = 128
lr = -2.245
decay, flooding, dropout = -1.9, 0., 0.4
n_gru_units = 46
epoch_num = 400
round_num = 10

device = 'cuda'
model = simpleSpatailTimeNN(n_gru_units=n_gru_units,
                            dropout=dropout).to(device)
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = MyMSELoss(device=device)

weight_p, bias_p = [], []
for name, p in model.named_parameters():
    if 'bias' in name:
        bias_p += [p]
    else:
        weight_p += [p]
optimizer = optim.Adam([
    {'params': weight_p, 'weight_decay': 10 ** decay},
    {'params': bias_p, 'weight_decay': 0}
], lr=10 ** lr, betas=(0.7307, 0.8732), eps=1e-9)

2、截至2月23日23:30(score=6.2)

2.1、模型结构

代码段2.1.1

class MyConvMaxout(nn.Module):
    __constants__ = ['k', 'features']
    k: int
    features: int
    weight: torch.Tensor

    def __init__(self, k: int, features: int, bias: bool = True) -> None:
        super(MyConvMaxout, self).__init__()
        self.features = features
        self.k = k
        self.weight = Parameter(torch.Tensor(k, features, features))
        if bias:
            self.bias = Parameter(torch.Tensor(features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_normal_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        """
        :param inp:(batch,channel,height,width)
        :return: (batch,channel,height,width)
        """
        x = inp.permute(0, 2, 3, 1)  # (b,h,w,c)
        b, h, w, c = x.shape
        x = x.reshape(-1, c)  # (b*h*w,c)
        x = x.matmul(self.weight).permute(1, 0, 2) + self.bias  # (b*h*w,k,c)
        x = x.reshape(b, h, w, self.k, c)  # (b,h,w,k,c)
        x = x.permute(0, 3, 4, 1, 2)  # (b,k,c,h,w)
        out = torch.max(x, dim=1).values
        return out

    def extra_repr(self) -> str:
        return 'k={}, features={}, bias={}'.format(
            self.k, self.features, self.bias is not None
        )

3、截至2月24号19:10(best score=6.2)

3.1、数据预处理

代码段3.1.1

def batch_gen(seed, last_seed):
    """
    总训练集70个SODA,4645*0.7=3252个CMIP;总验证集30个SODA,4645*0.3=1393个CMIP,共1423个
    :param seed:随机数种子;
    :param last_seed: 上次的随机数种子,用于得到last_CMIP_ind;
    :return: 全部SODA,130个上次学过的随机的CMIP,400个另外的随机的CMIP(可能有学过的);
    训练集600个样本,验证集1423个样本
    """
    with open('pkl/batch_seed_'+str(last_seed)+'.pkl','rb') as pf0:
        _, last_CMIP_ind = pickle.load(pf0)
    del _
    gc.collect()

    np.random.seed(seed)
    CMIP_ind = np.arange(70, 3252)
    SODA_ind = np.arange(70)
    last_CMIP_ind = np.random.choice(last_CMIP_ind, 130)
    left_CMIP_ind = np.array(list(set(CMIP_ind).difference(set(last_CMIP_ind))))
    new_CMIP_ind = np.random.choice(left_CMIP_ind, 400)
    ind = np.hstack((SODA_ind, last_CMIP_ind, new_CMIP_ind))
    with open('pkl/CMIP_scaled_13.pkl', 'rb') as pf1:
        data = pickle.load(pf1)
    data = {'x_train': data['x_train'][ind],
            'x_val': data['x_val'],
            't_train': data['t_train'][ind],
            't_val': data['t_val']}
    with open('pkl/batch_seed_' + str(seed) + '.pkl', 'wb') as pf2:
        pickle.dump((data, new_CMIP_ind), pf2)

6、截至2月25日23:41(best score=17.89)

6.1、网络结构

代码段6.1.1

class AttentionSTNN(nn.Module):
    def __init__(self, dropout, device, init_t=-0.05,
                 train_mode='schedule_sampling'):
        """
        :param train_mode: str,'schedule_sampling','teacher_forcing','free_running'
        """
        super(AttentionSTNN, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.init_t = init_t
        self.conv_sst1 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_t3001 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_ua1 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_va1 = nn.Conv2d(12, 12, 3, padding=1)

        self.conv_sst2 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_t3002 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_ua2 = nn.Conv2d(12, 12, 3, padding=1)
        self.conv_va2 = nn.Conv2d(12, 12, 3, padding=1)
        self.pool1 = nn.Sequential(nn.MaxPool2d(3, 3),
                                   nn.ReLU())
        self.pool2 = nn.Sequential(nn.MaxPool2d(4, 4),
                                   nn.ReLU()
                                   )
        self.enc_self_attn = nn.MultiheadAttention(
            embed_dim=48, num_heads=2, dropout=dropout)
        self.dec_self_attn = nn.MultiheadAttention(
            embed_dim=48, num_heads=2, dropout=dropout)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=48, num_heads=2, dropout=dropout)
        self.out_linear = nn.Linear(48, 1)
        self.tgt_embeds = nn.Embedding(8000, 48)
        init_p_free_run = {'teacher_forcing': -1.1e9,
                           'schedule_sampling': 0.,
                           'free_running': 1.}
        self.p_free_run = init_p_free_run[train_mode]
        self.device = device

        for m in self.modules():
            if isinstance(m, nn.GRU):
                for layer in range(self.n_gru_layer):
                    for idx in range(2):
                        nn.init.orthogonal_(m.all_weights[layer][idx])
            elif isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_normal_(m.weight.data)

    def attn(self, src_embed, t):
        """
        :param src_embed:(b,12,48)
        :param t:(b,seq), 1<=seq<=24,if teacher forcing,then seq=48
        :return: (b,seq,48)
        """
        # encoder self attention
        src_mask = (torch.tril(torch.ones(12, 12)) != 1).to(self.device)
        src_embed = src_embed.permute(1, 0, 2)
        enc_attn, _ = self.enc_self_attn(src_embed, src_embed, src_embed,
                                         attn_mask=src_mask)
        # decoder self attention
        # 加3.59是为了让数据全是正的,乘1000是取精度为0.001
        t = ((t + 3.59) * 1000).type(torch.int64).to(self.device)
        tgt_embed = self.tgt_embeds(t).permute(1, 0, 2)  # (seq,b,48)
        tgt_mask = (torch.tril(torch.ones(t.size(1), t.size(1))) != 1).to(self.device)
        dec_attn, _ = self.dec_self_attn(tgt_embed, tgt_embed, tgt_embed,
                                         attn_mask=tgt_mask)  # (seq,b,48)
        # cross attention
        out, _ = self.cross_attn(dec_attn, enc_attn, enc_attn)
        return out.permute(1, 0, 2)

    def last_layer(self, x):
        x = self.out_linear(x)  # (b,24,1)
        return x.squeeze(dim=2)

    def forward(self, x, t=None):
        """
        :param x:(b,12,24,72,4)
        :param t: (b,24)
        :return:(b,24)
        """
        sst = self.conv_sst1(x[:, :, :, :, 0])  # (b,12,24,72)
        t300 = self.conv_t3001(x[:, :, :, :, 1])
        ua = self.conv_ua1(x[:, :, :, :, 2])
        va = self.conv_va1(x[:, :, :, :, 3])  # (b,12,24,72)

        sst = self.dropout(self.pool1(sst))  # (b,12,8,24)
        t300 = self.dropout(self.pool1(t300))
        ua = self.dropout(self.pool1(ua))
        va = self.dropout(self.pool1(va))  # (b,12,8,24)

        sst = self.conv_sst2(sst)  # (b,12,8,24)
        t300 = self.conv_t3002(t300)
        ua = self.conv_ua2(ua)
        va = self.conv_va2(va)  # (b,12,8,24)

        sst = self.pool2(sst)  # (b,12,2,6)
        t300 = self.pool2(t300)
        ua = self.pool2(ua)
        va = self.pool2(va)  # (b,12,2,6)

        tmp = (-1, 12, 12)  # 将经纬度特征拉平
        x = torch.cat([sst.reshape(tmp), t300.reshape(tmp),
                       ua.reshape(tmp), va.reshape(tmp)], dim=-1)  # (b,12,48)
        x = self.dropout(x)
        self.p_free_run += 0.005
        if random.random() > self.p_free_run and self.training:
            # teacher forcing
            assert t is not None
            roll_t = torch.zeros_like(t)
            roll_t[:, 0] = self.init_t
            roll_t[:, 1:] = t[:, :-1]  # (b,24)
            x = self.attn(x, roll_t)  # (b,24,48)
            return self.last_layer(x)
        else:
            # free running
            t_first_col = torch.full((x.shape[0], 1), self.init_t).to(self.device)
            changing_t = torch.Tensor([]).to(self.device)
            for month in range(1, 25):
                try:
                    changing_t = torch.cat((t_first_col, changing_t), dim=1)  # 历史(b,month)
                    changing_t = self.attn(x, changing_t)  # 当前(b,month,48)
                    changing_t = self.last_layer(changing_t)  # 当前(b,month)
                except RuntimeError:
                    print(month)
                    return
            return changing_t
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值