前言
刚学习深度学习没多久,受之前一篇论文《Satellite Image Prediction Relying on GAN and LSTM Neural Networks》启发Satellite Image Prediction Relying on GAN and LSTM Neural Networks | IEEE Conference Publication | IEEE Xplorehttps://ieeexplore.ieee.org/document/8761462自己勉强写了一个Self-Attention Convlstm与GAN结合的神经网络系统,还与普通的ConvLSTM和Self-Attention ConvLSTM进行了对比。
ConvLSTM参考论文了《Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting》
Self-Attention Convlstm参考了论文《Self-Attention ConvLSTM for Spatiotemporal Prediction》
本文使用的框架为Pytorch,版本为2.2.1+cu121;所使用的Python版本为3.11
数据集准备
本文所使用的数据集来自阿里天池云之前举办的比赛:【追风少年】台风图像时间序列预测。
卫星云图数据来自与葵花8号,比赛将其重新整理成npy格式的文件。由于原比赛提供的图像数据的分辨率为1999x1999,为了减小存储占用空间和计算负担,将图像重新缩放为128x128的分辨率,下载链接会在文章末尾提供。
数据包括了葵花8号三个水汽通道的数据:Band 8,Band 9,Band 10,这里我选择Band 9通道作为训练数据输入。
训练与验证数据集包含了A、B、C、E四个类型,测试数据集包含了U、V、W、X、Y五个类型。在每个类型中,均是每隔一小时就采集一次图片。
数据图片示例:
本文的任务是使用前六个小时的卫星云图,预测未来六个小时的卫星云图图像。
图片数据总计9402张,将这些数据组成一个个序列(即每12张组成一个序列,前6张用于训练,后六张作为标签)。最后组成7638个训练序列,965个验证序列,704个测试序列。
编写Dataset函数
为了方便再训练的时候加载数据,这里我们重新编写一下Dataset函数:
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import zipfile
import fnmatch
import io
import re
import numpy as np
# 训练集数据加载
class loaded_train_Dataset(Dataset):
def __init__(self, zip_file, sid, frame_num, time_freq, ps):
self.zip_file = zipfile.ZipFile(zip_file, 'r')
self.sid = str(sid)
self.fnum = frame_num
self.freq = time_freq
self.ps = ps
# 列出所有匹配的文件名,即待加载数据
filename = self.sid + '_Hour_*_Band_09.png'
all_files = [name for name in self.zip_file.namelist() if fnmatch.fnmatch(name, filename)]
if self.sid == 'A':
max_limit = 1865
else:
max_limit = 1987
# 过滤文件,只保留文件名中的数字在1到max_limit之间的文件
self.files = []
for file in all_files:
# 使用正则表达式从文件名中提取数字
match = re.search(r'Hour_(\d+)', file)
if match:
num = int(match.group(1))
if 1 <= num <= max_limit:
self.files.append(file)
# 用于将加载的数据保存在内存中。因为每次要加载12张图像,会花费比较长的时间。
self.allframes = [None] * len(self.files)
tids = []
for i in range(len(self.allframes)):
self.allframes[i] = []
tids.append(int(re.findall(r'Hour_(\d+)', self.files[i])[0]))
print(tids[:10])
self.first_tid = min(tids)
self.last_tid = max(tids)
self.tids = sorted(tids)
def __len__(self):
return len(self.files) - (self.fnum // 2) * (self.freq + 1) + 1
def __getitem__(self, idx):
# 列出12帧数据的时间点ID,比如 A_Hour_189 的时间点ID为189
tid = self.tids[idx]
# 从当前时间点开始,取连续的6帧数据
ids1 = list(range(tid, tid + (self.fnum // 2)))
# 从当前时间点开始,取连续的6帧数据,间隔为self.freq
ids2 = list(range(tid + (self.fnum // 2) + self.freq - 1,
tid + (self.fnum // 2) + self.freq + (self.fnum // 2) * self.freq - 1, self.freq))
ids = ids1 + ids2
frames = [] # 使用list保存12帧数据
for i, ctid in enumerate(ids):
# 如果已保存在内存中,那么直接将保存的数据加入list,如果第一次读取,那么读取数据文件
if not len(self.allframes[ctid - self.first_tid]):
with self.zip_file.open(self.sid + '_Hour_' + str(ctid) + '_Band_09.png') as file:
frame = np.array(Image.open(io.BytesIO(file.read()))) / 255.0 # 注意归一化
frame = frame[:, :, np.newaxis]
frame = np.transpose(frame.astype(np.float32), axes=[2, 0, 1]) # CUDA要求数据为CHW,而读取到的矩阵为HWC,故进行转置
self.allframes[ctid - self.first_tid] = frame
else:
frame = self.allframes[ctid - self.first_tid]
frames.append(frame)
return ids, frames
因为我这里使用的数据集是zip压缩包文件里的,所以需要在代码中用zipfile库先解压。如果需要的话,也可以先将压缩包解压,然后修改一下读取文件这一部分的代码就可以了。
这个是加载训练数据集的函数,加载验证数据集和测试数据集的函数也类似。
损失函数
这里我使用了联合损失函数,主要由三部分组成:重构损失(L1损失+L2损失)、ssim(结构相似性)、对抗损失。
重构损失
L1损失与L2损失都不陌生,两者都经常被用于图像处理任务。但由于两者在处理误差时具有不同的特效,所以二者在任务中的表现也不同。
L1损失函数是误差的绝对值的总和,而L2损失函数则是误差平方的总和。所以这往往意味着L1损失函数对异常值等大误差不会那么敏感,因为每一项误差与总误差之间是线性关系的,他们对总误差影响都是相同的;而L2损失函数由于使用了平方项,所以其会对那些误差较大的项特别敏感,因为他们会显著增加误差。
因此L1损失通常有助于保留图像的边缘与细节部分,因为边缘部分的像素值变化较大;而L2损失因为取平均值的缘故,会使得模型在训练过程中尽量避免出现较大像素变化,所以它通常有助于使图像整体变得更加平滑。这里我将两个损失结合,共同发挥他们优势。
结构相似性
结构相似性(structural similarity)是一种衡量两幅图像相似度的指标,其范围为-1到
,当两张图像一模一样时,SSIM的值等于1。最早由论文《Image quality assessment: from error visibility to structural similarity》提出,具体的算法原理这里不再阐述。
对抗损失
本文的生成对抗网络使用的是WGAN-GP算法。WGAN-GP算法的核心是使用Wasserstein距离作为损失函数,并且在训练判别器时添加一个梯度惩罚项,对真实样本和生成样本的判别器输出梯度范数的期望进行惩罚,以避免出现“梯度爆炸”或“梯度消失”。
而对于生成器的损失函数:
其中Pg是生成器的分布。D(x^)即为判别器对于生成样本的得分,总的对抗损失即为该得分的期望的负数。
总而言之,重构损失主要用于确保生成的图像在像素级别上与真实图像保持一致性,结构相似性则主要从纹理和局部对比度方面生成视觉上与真实图像相似图形,而对抗损失可以显著提高合成图像的清晰度。这三种损失函数也被广泛的用于处理各种图像迁移或是超分辨率问题上。
其主要实现代码如下:
# 生成器损失函数
class generator_loss_function(nn.Module):
def __init__(self):
super().__init__()
self.l1_loss = nn.L1Loss()
self.l2_loss = nn.MSELoss()
self.loss_ssim = ssim
def forward(self, gen_img, target, gen_D):
L_rec = self.l1_loss(gen_img, target) + self.l2_loss(gen_img, target)
output_np = gen_img.detach().cpu().numpy()
target_np = target.detach().cpu().numpy()
ssim_value = 0
for i in range(output_np.shape[0]):
ssim_seq = 0
for k in range(output_np.shape[1]):
result = self.loss_ssim(output_np[i, k, 0, :, :] * 255, target_np[i, k, 0, :, :] * 255, data_range=255)
ssim_seq += result
ssim_value += ssim_seq / 6
L_ssim = ssim_value / output_np.shape[0]
L_adv = -torch.mean(gen_D)
return L_rec + 1e-2 * (1 - L_ssim) + 1e-4 * L_adv, L_rec, L_ssim, L_adv
这里我将ssim与对抗损失权重分别设置为0.01与0.0001。
自注意力ConvLSTM与生成对抗网络模型搭建
ConvLSTM神经网络
LSTM(长短时记忆神经网络)广泛运用于各种时间序列预测任务中,而ConvLstm则将卷积操作替代了LSTM中一部分全连接操作,使模型能更好的通过图片来获得空间上的特征。
对于基础的ConvLstm,其单元架构如下:
class ConvLSTMCell(nn.Module):
def __init__(self, params):
super(ConvLSTMCell, self).__init__()
self.input_dim = params['hidden_dim']
self.hidden_dim = params['hidden_dim']
self.kernel_size = params['kernel_size']
self.padding = params['kernel_size'] // 2, params['kernel_size'] // 2
self.bias = params['bias']
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)
return h_next, (h_next, c_next)
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
ConvLSTM单元最后会返回下一个时间步的隐藏层状态与单元状态。
然后开始搭建完整的ConvLSTM网络:
class ConvLstmEncode2Decode(nn.Module):
def __init__(self, params):
super(ConvLstmEncode2Decode, self).__init__()
# 超参数
self.batch_size = params['batch_size']
self.img_size = params['img_size']
self.cells, self.bns, self.decoderCells = [], [], []
self.n_layers = params['n_layers']
self.img_encode = nn.Sequential(
nn.Conv2d(in_channels=params['input_dim'], kernel_size=1, stride=1, padding=0,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1)
)
self.img_decode = nn.Sequential(
nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=1, stride=1, padding=0,
out_channels=params['input_dim'])
)
for i in range(params['n_layers']):
params['input_dim'] == params['hidden_dim'] if i == 0 else params['hidden_dim']
params['hidden_dim'] == params['hidden_dim']
self.cells.append(ConvLSTMCell(params)) # 需要参数:input_dim, hidden_dim, kernel_size, bias
self.bns.append(nn.LayerNorm((params['hidden_dim'], 32, 32)))
self.cells = nn.ModuleList(self.cells)
self.bns = nn.ModuleList(self.bns)
self.decoderCells = nn.ModuleList(self.decoderCells)
self.decoder_predict = nn.Conv2d(in_channels=params['hidden_dim'],
out_channels=1,
kernel_size=(1, 1),
padding=(0, 0))
def forward(self, frames, hidden=None):
if hidden == None:
hidden = self._init_hidden(batch_size=frames.size(0), img_size=self.img_size)
predict_temp_de = []
for t in range(6):
x = frames[:, t, :, :, :]
x = self.img_encode(x)
for i, cell in enumerate(self.cells):
out, hidden[i] = cell(x, hidden[i])
out = self.bns[i](out)
out = self.img_decode(out)
predict_temp_de.append(out)
predict_temp_de = torch.stack(predict_temp_de, dim=1)
return predict_temp_de
def _init_hidden(self, batch_size, img_size):
states = []
for i in range(self.n_layers):
states.append(self.cells[i].init_hidden(batch_size, img_size))
return states
模型最后会返回未来6帧的图像内容。
Self-Attention ConvLSTM
Self-Attention ConvLSTM则是在ConvLSTM的基础上引入了一个基于记忆的自我注意模块(memory-based self-attention module, SAM),以在预测过程中记忆全局时空依赖性。
SA-ConvLSTM的重点在于自注意力记忆模块,其实现代码如下:
class self_attention_memory_module(nn.Module):
def __init__(self, input_dim, hidden_dim, device):
super().__init__()
self.layer_q = nn.Conv2d(input_dim, hidden_dim, 1)
self.layer_k = nn.Conv2d(input_dim, hidden_dim, 1)
self.layer_k2 = nn.Conv2d(input_dim, hidden_dim, 1)
self.layer_v = nn.Conv2d(input_dim, input_dim, 1)
self.layer_v2 = nn.Conv2d(input_dim, input_dim, 1)
self.layer_z = nn.Conv2d(input_dim * 2, input_dim * 2, 1)
self.layer_m = nn.Conv2d(input_dim * 3, input_dim * 3, 1)
self.hidden_dim = hidden_dim
self.input_dim = input_dim
def forward(self, h, m):
batch_size, channel, H, W = h.shape
K_h = self.layer_k(h)
Q_h = self.layer_q(h)
K_h = K_h.view(batch_size, self.hidden_dim, H * W)
Q_h = Q_h.view(batch_size, self.hidden_dim, H * W)
Q_h = Q_h.transpose(1, 2)
A_h = torch.softmax(torch.bmm(Q_h, K_h), dim=-1) # batch_size, H*W, H*W
V_h = self.layer_v(h)
V_h = V_h.view(batch_size, self.input_dim, H * W)
Z_h = torch.matmul(A_h, V_h.permute(0, 2, 1))
K_m = self.layer_k2(m)
V_m = self.layer_v2(m)
K_m = K_m.view(batch_size, self.hidden_dim, H * W)
V_m = V_m.view(batch_size, self.input_dim, H * W)
A_m = torch.softmax(torch.bmm(Q_h, K_m), dim=-1)
V_m = self.layer_v2(m)
V_m = V_m.view(batch_size, self.input_dim, H * W)
Z_m = torch.matmul(A_m, V_m.permute(0, 2, 1))
Z_h = Z_h.transpose(1, 2).view(batch_size, self.input_dim, H, W)
Z_m = Z_m.transpose(1, 2).view(batch_size, self.input_dim, H, W)
W_z = torch.cat([Z_h, Z_m], dim=1)
Z = self.layer_z(W_z)
combined = self.layer_m(torch.cat([Z, h], dim=1)) # 3 * input_dim
mo, mg, mi = torch.split(combined, self.input_dim, dim=1)
mi = torch.sigmoid(mi)
new_m = (1 - mi) * m + mi * torch.tanh(mg)
new_h = torch.sigmoid(mo) * new_m
return new_h, new_m
再将其嵌入到ConvLSTM中:
class SA_Convlstm_cell(nn.Module):
def __init__(self, params):
super().__init__()
self.input_channels = params['hidden_dim']
self.hidden_dim = params['hidden_dim']
self.kernel_size = params['kernel_size']
self.padding = params['padding']
self.device = params['device']
self.attention_layer = self_attention_memory_module(params['hidden_dim'], params['att_hidden_dim'],
self.device) # 32, 16
self.conv2d = nn.Sequential(
nn.Conv2d(in_channels=self.input_channels + self.hidden_dim, out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size, padding=self.padding)
, nn.GroupNorm(4 * self.hidden_dim, 4 * self.hidden_dim)) # (num_groups, num_channels)
def forward(self, x, hidden):
h, c, m = hidden
device = x.device
h = h.to(device)
c = c.to(device)
m = m.to(device)
combined = torch.cat([x, h], dim=1) # (batch_size, input_dim + hidden_dim, img_size[0], img_size[1])
combined_conv = self.conv2d(combined) # (batch_size, 4 * hidden_dim, img_size[0], img_size[1])
i, f, o, g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
o = torch.sigmoid(o)
g = torch.tanh(g)
c_next = f * c + i * g
h_next = o * torch.tanh(c_next)
h_next, m_next = self.attention_layer(h_next, m)
return h_next, (h_next, c_next, m_next)
def init_hidden(self, batch_size, img_size): # h, c, m initalize
h, w = img_size
return (torch.zeros(batch_size, self.hidden_dim, h, w),
torch.zeros(batch_size, self.hidden_dim, h, w),
torch.zeros(batch_size, self.hidden_dim, h, w))
便得到了一个完整的SA-ConvLSTM单元模块,随后在该基础上搭建完整的SA-ConvLSTM;
class Encode2Decode(nn.Module): # 自注意力convlstm模型
def __init__(self, params):
super(Encode2Decode, self).__init__()
# 超参数
self.batch_size = params['batch_size']
self.img_size = params['img_size']
self.cells, self.bns, self.decoderCells = [], [], []
self.n_layers = params['n_layers']
self.input_window_size = params['input_window_size']
self.output_window_size = params['output_dim']
# 使用seq2seq模型
self.img_encode = nn.Sequential(
nn.Conv2d(in_channels=params['input_dim'], kernel_size=1, stride=1, padding=0,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1)
)
self.img_decode = nn.Sequential(
nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.ConvTranspose2d(in_channels=params['hidden_dim'], kernel_size=3, stride=2, padding=1, output_padding=1,
out_channels=params['hidden_dim']),
nn.LeakyReLU(0.1),
nn.Conv2d(in_channels=params['hidden_dim'], kernel_size=1, stride=1, padding=0,
out_channels=params['input_dim'])
)
for i in range(params['n_layers']):
params['input_dim'] == params['hidden_dim'] if i == 0 else params['hidden_dim']
params['hidden_dim'] == params['hidden_dim']
self.cells.append(SA_Convlstm_cell(params))
self.bns.append(nn.LayerNorm((params['hidden_dim'], 32, 32))) # Use layernorm
self.cells = nn.ModuleList(self.cells)
self.bns = nn.ModuleList(self.bns)
self.decoderCells = nn.ModuleList(self.decoderCells)
self.decoder_predict = nn.Conv2d(in_channels=params['hidden_dim'],
out_channels=1,
kernel_size=(1, 1),
padding=(0, 0))
def forward(self, frames, target, mask_true, is_training, hidden=None):
if hidden == None:
hidden = self.init_hidden(batch_size=frames.size(0), img_size=self.img_size)
if is_training:
frames = torch.cat([frames, target], dim=1)
predict_temp_de = []
for t in range(11):
if is_training:
if t < self.input_window_size:
x = frames[:, t, :, :, :]
else: # 使用一个掩码来实行计划采样策略
x = mask_true[:, t - self.input_window_size, :, :, :] * frames[:, t] + \
(1 - mask_true[:, t - self.input_window_size, :, :, :]) * out
else:
if t < self.input_window_size:
x = frames[:, t, :, :, :]
else: # 如果是测试模式, 则使用预测的帧作为输入
x = out
x = self.img_encode(x)
for i, cell in enumerate(self.cells):
out, hidden[i] = cell(x, hidden[i])
out = self.bns[i](out)
out = self.img_decode(out)
predict_temp_de.append(out)
predict_temp_de = torch.stack(predict_temp_de, dim=1)
predict_temp_de = predict_temp_de[:, 5:, :, :, :]
return predict_temp_de
def init_hidden(self, batch_size, img_size):
states = []
for i in range(self.n_layers):
states.append(self.cells[i].init_hidden(batch_size, img_size))
return states
这便是一个完整的Self-Attention ConvLSTM神经网络的实现。
可以看到无论是ConvLSTM还是Self-Attention ConvLSTM均使用了seq2seq模型,这种序列映射到序列的模型在序列任务中的应用十分广泛,例如机器翻译、自然语言模型等等,在这里我将其用于处理这种时空序列任务。
同时,我还在Self-Attention ConvLSTM模型中使用了计划采样(Scheduled Sampling)策略,这是一种在序列生成任务中的训练技巧,简单来说就是对于每一个时间步,都有一定概率使用上一个时间步的输出作为当前时间步的输入。
生成对抗网络
在生成对抗网络中,生成器与判别器相互对抗、更新,促使生成器逐渐生成与真实图片相似的图片,以最终混淆判别器对真实图片与生成图片的判断。
前文以及提到我们使用的是WGAN-GP算法,其重点主要在于生成器与判别器的损失函数,生成器的损失函数上文已经讲过。而对于判别器的损失函数,则为:
前两项的差值即为Wasserstein距离,后一项则是梯度惩罚项,其计算方式如下:
def compute_gradient_penalty(D, real_samples, fake_samples):
"""计算WGAN GP的梯度下降惩罚项"""
# Random weight term for interpolation between real and fake samples
alpha = torch.rand(real_samples.size(0), 1, 1, 1, 1).cuda().expand_as(real_samples)
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones(d_interpolates.size()).cuda(),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
grad_l2norm = gradients.norm(2, dim=[1,2,3,4])
gradient_penalty = torch.mean((grad_l2norm - 1) ** 2)
return gradient_penalty
对于判别器,其架构为:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.lay1 = nn.Sequential(
nn.Conv3d(1, 16, (3, 5, 5), (1, 2, 2), (1, 2, 2)),
nn.LeakyReLU(0.2))
self.lay2 = nn.Sequential(
nn.Conv3d(16, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2)),
nn.LeakyReLU(0.2))
self.lay3 = nn.Sequential(
nn.Conv3d(32, 64, (3, 3, 3), (1, 2, 2), 1),
nn.LeakyReLU(0.2))
self.lay4 = nn.Sequential(
nn.Conv3d(64, 128, (3, 3, 3), (1, 2, 2), 1),
nn.LeakyReLU(0.2))
self.lay5 = nn.Sequential(
nn.Conv3d(128, 256, (3, 3, 3), (1, 2, 2), 1),
nn.LeakyReLU(0.2))
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.dense1 = nn.Linear(256, 1)
def forward(self, x):
# input [batch_size, 12(seq_length), channel, 128, 128]
x = x.view(-1, 1, 12, 128, 128)
# input [batch_size, channel, 12(seq_length), 128, 128]
x = self.lay1(x)
x = self.lay2(x)
x = self.lay3(x)
x = self.lay4(x)
x = self.lay5(x)
x = self.avgpool(x)
dense_input = x.contiguous().view(x.size(0), -1)
dense_output_1 = self.dense1(dense_input) # 不使用sigmoid激活函数,避免出现梯度爆炸
return dense_output_1
主要有五个卷积层,一个平均池化层和一个全连接层组成。
在具体的训练过程中,判别器与生成器的更新比例为2:1,即每训练两次判别器,就训练一次生成器,这一部分的代码如下:
#############开始训练判别器和生成器#######################
optimizer_D.zero_grad()
gen_img = generator(input, target, real_input_flag, is_training=True)
pred_discri_input = torch.cat([input, gen_img], 1)
true_discri_input = torch.cat([input, target], 1)
gen_D = discriminator(pred_discri_input.detach())
true_D = discriminator(true_discri_input)
gradient_penalty = compute_gradient_penalty(discriminator, true_discri_input.data,
pred_discri_input.data)
loss_D = -torch.mean(true_D) + torch.mean(gen_D) + lambda_gp * gradient_penalty
loss_D.backward()
optimizer_D.step()
total_loss_D.append(loss_D.item())
optimizer_G.zero_grad()
# 开始训练生成器
if ind % 2 == 0:
output = discriminator(pred_discri_input)
loss_G, L_rec, L_ssim, L_adv = criterion_G(gen_img, target, output)
loss_G.backward()
optimizer_G.step()
total_loss_G.append(loss_G.item())
关于WGAN-GP的具体内容,可以查看论文《Improved Training of Wasserstein GANs》。
模型评估
模型评估指标采用了mse、ssim、psnr(peak signal-to-noise ratio)和sharpness,前两者上文以及提及过。而psnr(峰值信噪比)则是另一种基于MSE(均方误差)定义的衡量图像质量的指标,其定义如下:
其中MAXI为图像的最大像素。
而Sharpness(锐度)则是衡量真实图像与预测图像之间的清晰度损失的指标,计算方式参考论文《No-Reference Image Sharpness Assessment Based on Maximum Gradient and Variability of Gradients》,具体为:
def sharpness_calculate(img1):
# 定义滤波器
F1 = np.array([[0, 0], [-1, 1]])
F2 = F1.T
# 计算水平和垂直方向上的梯度
H1 = convolve2d(img1, F1, mode='valid')
H2 = convolve2d(img1, F2, mode='valid')
g = np.sqrt(H1 ** 2 + H2 ** 2)
row, col = g.shape
B = round(min(row, col) / 16)
g_center = g[B + 1: -B, B + 1: -B]
MaxG = np.max(g_center)
MinG = np.min(g_center)
CVG = (MaxG - MinG) / np.mean(g_center)
re = MaxG ** 0.61 * CVG ** 0.39
return re
模型结果
分别将三种模型:ConvLSTM,无对抗训练SA ConvLSTM和有对抗训练SA ConvLSTM进行训练,然后使用一组测试数据进行预测。
首先ConvLSTM与SA ConvLSTM生成的图像对比:
ConvLSTM生成图像
SA ConvLSTM生成图像
真实图像
可以看到引入了自注意力模块的SA ConvLSTM模型生成的图像明显更清楚一点。
再来看有使用对抗训练的SA ConvLSTM模型生成的图像:
两者在评估指标结果上:
可以看到引入对抗训练后,虽然MSE与PSNR两个指标有所变差,但SSIM与Sharpness均有所提高,且Sharpness提高最为明显,这是因为引入对抗训练有助于模型生成更为清晰的图片。
总结
由于我只是简单的搭建了一个神经网络,模型还有很大的改进空间,大家可以对模型进行调整,并用其对更长时间跨度的序列数据进行预测(比如每三个小时的图像数据)。
完整的代码和数据下载方式已上传至github:GitHub - LEOMMM1/Typhoon-satellite-Image-prediction-based-on-SA-ConvLstm-and-GAN