Datawhale AI 夏令营 -- baseline拓展

根据赛题我们可以得知:

  • 输入数据大小为(1, 72, 24, W, H)- 长为72的时序,每个step中为24个通道的数组
  • 输出数据大小为(1, 72, W, H)

在baseline代码中,使用:

x = x.reshape(B, -1, W, H)

也就是将S*C,视作新的通道数,带入CNN模型中。但这样显然忽略了时序信息,自然想到可以融入RNN思想。我们第一部分先基于baseline改进一下模型,第二部分尝试兼顾时序信息

(LSTM在MIM论文中提及说遗忘门总是处于饱和状态,对于非平稳数据,MIM效果更优,但暂时没找到现成的预训练的MIM模型)

一、扩大模型规模

在baseline-enhance文件中,作者增加了几层CNN和Act,我们不妨在这个思想上更进一步,使用更大的模型。

1.模型的选择

首先自然想到的是ResNet,但考虑到本次数据量较大(3年*2部分*7G),而ResNet的参数量随着残差块的增多而增加,并不适合在博主的笔记本上运行。因此考虑使用EfficientNet。

关于两种模型的准确度、性能、效果等,可以查看往期文章,这里简单列个对比图

ResNet50 in Deep Fake Dataset with 5 epochs

EfficientNet50 in Deep Fake Dataset with 5 epochs

2.加载数据、模型

数据加载部分略,详见baseline。不过有一点注意:数据集文件路径不能包含中文,不然xarray读取nc文件时会报错No Such File

首先加载模型,这里使用了timm的efficientnet_b0,相比b1,b2更加轻量化,且性能还处在不错的水平

effnet = timm.create_model('timm/efficientnet_b0.ra_in1k', pretrained=True, features_only=True)
data_config = timm.data.resolve_model_data_config(effnet)
val_trans = timm.data.create_transform(**data_config, is_training=False)
train_trans = timm.data.create_transform(**data_config, is_training=True)

因为我们这里不需要进行分类,最后输出的数据大小为(1, 72, W, H),因此将最后的分类层(Flatten, Linear等)去除

观察effnet结构发现,模型要求的输入的数据通道为3,通过取出一个batch的数据代入,发现最终输出的大小为(1, 320, 2, 3),因此我们需要对模型结构进行一定的修改。

3.模型修改

对于首层结构,最简单的修改就是直接更改Conv2d的输入通道数,当然这里直接从72*24个channels压缩到32个channels是不太妥当的,可以逐步缩小,每层加入BatchNorm, Dropout, Activation等。本文没有深究这一方面。

B, S, C, W, H = tuple(ft_item.shape)
input_f = S * C
x = ft_item.reshape(B, -1, W, H)
feedforward = nn.Conv2d(
    input_f, 32, 
    kernel_size=3, 
    padding=(1, 1), 
    stride=(2, 2), 
    bias=False
)
effnet.conv_stem = feedforward

对于尾层结构,需要增加其他层,而不是简单修改。观察effnet结构,发现主要三个部分:conv_stem, btn1, blocks,blocks里面为一个个的efficientnet block的堆叠,不便进行修改,而我们目标是将(1, 320, 2, 3)的张量放大到(1, 72, 57, 81),于是自然想到转置卷积

那接下来就很好办了,从320个通道逐步缩小到72,同时将宽高逐步放大到(57, 81)        

covt = nn.Sequential(
    nn.ConvTranspose2d(320, 256, kernel_size=5, stride=(5, 5), bias=False),
    nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
    nn.ConvTranspose2d(256, 128, kernel_size=6, stride=(5, 5), bias=False),
    nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
    nn.ConvTranspose2d(128, 72, kernel_size=(7, 6), stride=(1, 1), bias=False),
)

这其中的BatchNorm2d参数直接使用了blocks里面的,整个covt的结构还需斟酌,本文只做引导

最终输出大小变为

那后续就方便了,读取数据然后模型训练,这里就省略了

4.后续

这里面要不要进行数据增强?在图像分类里面,数据增强的效果很直观,很容易理解数据增强的原理和作用。但是这里用的是观测到的数据,具有实际意义——降水、风速……好像这里使用数据增强有点没必要。个人觉得还是应该增强一下的,但是方法要限制一下,比如旋转就不太妥当(个人观点),观测数据数值上会有误差,位置上不应该吧……

二、考虑时序信息

数据每次获取一个Batch,每个Batch为(72, 24, 57, 81)。数据中24, 57, 81,表示24个特征,每个特征包含(57, 81)的数据。将72视作time_step,即用72个数据去预测下一个数据。

那么最后进行test时,前面的部分纯用CNN预测,从第73个开始融入GRU,那么我们每次就有两个预测值——CNN的输出与GRU的输出,那这两个输出怎么融合,进行综合考虑呢?最简单的思想自然是加权融合,然后让这个权重W1, W2也是个可学习的参数,然后就成了Multi-Stream Networks

或者直接将RNN的思想在CNN部分层中体现,直接上TCN或者CNN-RNN?直觉上觉得是可以的。查询资料后,觉得可以试试如下网络:

  • 3D卷积神经网络(3D CNN)

3D CNN可以处理具有时间维度的数据,通过在三个维度上应用卷积操作来提取特征。这种网络可以捕捉到时间序列数据中的动态变化和空间特征。

  • 时空卷积网络(Spatio-Temporal CNN, ST-CNN)

ST-CNN是一种专门为视频数据设计的网络,它结合了空间卷积和时间卷积。在这种网络中,通常首先使用2D卷积来提取空间特征,然后使用1D卷积来提取时间特征。

在Pytorch提供的GRU类中,输入要求是(B, T, F),或者(T, B, F),咱们这里五维的数据好像不太支持。但是仔细想想,我们相当于是每个数具有24个features,而每个feature大小为(57, 81)的二维数组,在NLP中每个feature多为单个数值,看来需要手写一下GRU了……不想写一点

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值