1.NJU-CPOL双偏振雷达数据:
https://box.nju.edu.cn/f/16bbb37458d3443dbf9f/?dl=1
2.降水格点数据:
https://box.nju.edu.cn/f/076f5aeb2ec64b87bde8/?dl=1
问题:
根据附件数据以及收集到的资料,通过数学建模完成下列问题:
问题1:建立可提取用于强对流临近预报双偏振雷达资料中微物理特性的数学模型
针对问题1,结合文献资料,本文提出了深U形网络-长短期记忆(U-Net-LSTM)框架,利用U-Net网络模型可以捕捉雷达数据的空间模式和特征的优势以及LSTM擅长提取时间信息的特性。采用BConvLSTM将从相应的编码路径和先前解码的上行卷积层提取的特征映射以非线性的方式进行联合。输入为前面一小时(10帧)的雷达观测量(ZH、ZDR、KDP),输出为后续一小时(10帧)的ZH预报,每一层学习不同层次的特征。将附件中将数据集按照9:1的比例拆分成测试集和验证集,验证了构建的U-Net-LSTM模型的可靠性。
问题2:建立缓解预报的模糊效应的数学模型
针对问题2,在问题1的基础上,为缓解预报的模糊效应,时预报出的雷达回波细节更充分,更真实。在U-Net-LSTM模型框架里,增加模型深度,使用注意力机制。也是采用BConvLSTM将从相应的编码路径和先前解码的上行卷积层提取的特征映射以非线性的方式进行联合。在数据方面,我们整合水平极化反射率和微分反射率因子R(ZH,ZDR)做为通道之一,以及特定微分相位KDP,通过组合ZH、R(ZH,ZDR)和KDP,构造了一种另一种的复合雨量估计器,产生效果更好强对流降水临近预报。增加数据全面性导出稳定的分离关系。最后将附件中将数据集拆1/10进行模型检验,验证了构建的U-Net-LSTM模型效果更好 。
问题3:建立定量降水估计的数学模型
针对问题3,定量降水估计(QPE)利用双偏振雷达数据提供的信息来估算降水量是现代气象学的重要组成部分。对使用和双偏振雷达的主要参数是水平反射率因子ZH和差分反射率ZDR。这对于区分雨、雪和冰粒子非常有用。传统的QPE方法主要基于ZH和降水强度R之间的经验关系(即Z-R关系),为了考虑ZDR的影响,我们可以将上述关系式扩展为R( ZH,ZDR)关系。为了提高预报模型的可靠性和鲁棒性,采用多源数据融合策略,输入为双偏振雷达数据、观测数据、卫星云图,采用加权融合的方法,输出为降水量R,将附件中将数据集按照9:1的比例划分为测试集和验证集,验证集用于评估模型的性能就可以使用交叉验证或其他技术在独立的测试集上验证模型的性能。
问题4:建立评估双偏振雷达资料在强对流降水临近预报贡献的数学模型
针对问题4,双偏振雷达资料提供了关于降水粒子形状、大小、相态、含水量等信息,这些信息对于准确预测强对流天气,如雷暴、龙卷风和冰雹等非常有用。传统雷达则无法提供这些信息。我们分别使用传统雷达和双偏振雷达得到预报,以比较两者之间的差异来评估双偏振雷达资料在强对流降水临近预报中的贡献。用预报准确性、预报时效性和预报空间准确性来作为评估标准。再通过优化数据融合策略更好地去应对突发性和局地性强地强对流天气。最后用历史数据和最近数据进行模型验证和比较,根据模型的验证结果,优化模型参数和数据融合策略。
技术路线图:
网络结构图:
U-Net网络虽然能有效提取多元数据的空间特征,但忽略了时间维度上的信息.而LSTM擅长提取时间信息,具有长期记忆能力长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,由于梯度消失问题导致长程记忆能力有限,LSTM是解决该问题的经典方法之一,LSTM结构包含隐状态。本文提出了一种新的U-Net-LSTM框架来预测复杂的时间序列流体力学特征。图7展示了所提出的深度U-Net-LSTM框架的结构。它是通过结合深U形网络、两个LSTM层和一个跳过连接部分。该框架由编码部分、LSTM部分、解码部分和跳过连接部分组成。编码部分用于提取复杂特征,LSTM层用于学习时间信息,解码部分恢复LSTM层的矢量输出,其大小与输入形状相同。kipconnection部分可以向解码部分发送更多的高级语义,并克服尽可能多的丢失信息(由池化引起)。共设计51层:36个卷积层、6个最大池层、1个全局平均池(GAP)层、2个LSTM层和6个上采样层。
数据流图:
代码:
代码基于
https://github.com/milesial/Pytorch-UNet
修改
train.py
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import wandb
from evaluate import evaluate
from unet import UNet, UNetLSTM
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.loss import smooth_l1_loss
dir_img = "./data/NJU_CPOL_update2308/"
dir_mask = "./data/NJU_CPOL_kdpRain/"
dir_checkpoint = Path('./checkpoints/')
def train_model(
model,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 1e-5,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False,
weight_decay: float = 1e-8,
momentum: float = 0.999,
gradient_clipping: float = 1.0,
):
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError, IndexError):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
# 3. Create data loaders
train_loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **train_loader_args)
val_loader_args = dict(batch_size=1, num_workers=os.cpu_count(), pin_memory=True)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **val_loader_args)
# (Initialize logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(
dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
)
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(model.parameters(),lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10) # goal: maximize Dice score
criterion = smooth_l1_loss
global_step = 0
# 5. Begin training
for epoch in range(1, epochs + 1):
model.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images, true_masks = batch['image'], batch['mask'] #torch.Size([1, 10, 9, 256, 256]) torch.Size([1, 10, 1, 256, 256])
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.float32)
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
masks_pred = model(images) #torch.Size([1, 10, 1, 256, 256])
loss = criterion(masks_pred.squeeze(1), true_masks.float())
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': epoch_loss/len(train_loader)})
# Evaluation round
val_mae_score, val_mse_score = evaluate(model, val_loader, device, amp)
scheduler.step(val_mae_score)
logging.info("\t\nepoch:{} loss:{} Validation MAE score: {} MSE score: {}".format(epoch, epoch_loss/len(train_loader), val_mae_score, val_mse_score))
division_step = (n_train // (5 * batch_size))
if division_step > 0:
if global_step % division_step == 0:
try:
experiment.log({
'learning rate': optimizer.param_groups[0]['lr'],
'validation MAE': val_mae_score,
'validation MSE': val_mse_score,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(masks_pred[0].float().cpu()),
},
'step': global_step,
'epoch': epoch
})
except:
pass
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
logging.info(f'Checkpoint {epoch} saved!')
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=100, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=128, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str, default="checkpoints/checkpoint_epoch26.pth", help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=1, help='Number of classes')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
#model = UNet(n_channels=9, n_classes=args.classes, bilinear=args.bilinear)
model = UNetLSTM(n_channels=9, n_classes=args.classes, bilinear=args.bilinear)
logging.info(f'Network:\n'
f'\t{model.n_channels} input channels\n'
f'\t{model.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
if args.load:
state_dict = torch.load(args.load, map_location=device)
model.load_state_dict(state_dict)
logging.info(f'Model loaded from {args.load}')
model.to(device=device)
try:
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp
)
except torch.cuda.OutOfMemoryError:
logging.error('Detected OutOfMemoryError! '
'Enabling checkpointing to reduce memory usage, but this slows down training. '
'Consider enabling AMP (--amp) for fast and memory efficient training')
torch.cuda.empty_cache()
model.use_checkpointing()
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp
)
predict.py
import argparse
import logging
import os
import numpy as np
import torch
from unet import UNet, UNetLSTM
import random
def predict_img(net,
full_img,
device,
scale_factor=1):
net.eval()
img = torch.from_numpy(full_img)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img).cpu()
return output.numpy()
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images')
parser.add_argument('--model', '-m', default='./unet_epoch20.pth', metavar='FILE',
help='Specify the file in which the model is stored')
parser.add_argument('--scale', '-s', type=float, default=0.5,
help='Scale factor for the input images')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=1, help='Number of classes')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
images_dir = "./data/NJU_CPOL_update2308/"
mask_dir = "./data/NJU_CPOL_kdpRain/"
image_mask_items = "data_dir_000"
time_gap = 10
#load NJU_CPOL_update2308
input_numpy = np.empty((1,10,9,256,256), np.float32)
class_name_lists = ["dBZ", "KDP", "ZDR"]
kilometers_name_lists = ["1.0km", "3.0km", "7.0km"]
for class_name in class_name_lists:
for kilometers_name in kilometers_name_lists:
fore_dir = os.path.join(images_dir, class_name, kilometers_name, image_mask_items)
index = class_name_lists.index(class_name)*3 + kilometers_name_lists.index(kilometers_name)
for num,name in enumerate(os.listdir(fore_dir)):
#默认取前10个进行测试
numpy_tmp = np.load(os.path.join(fore_dir, name))
input_numpy[:,num,index,:,:] = numpy_tmp
if num >= time_gap-1:
break
#load NJU_CPOL_kdpRain
target_numpy = np.empty((1,10,1,256,256), np.float32)
for num, name in enumerate(os.listdir(os.path.join(mask_dir, image_mask_items))):
#默认取前10个进行测试
numpy_tmp = np.load(os.path.join(mask_dir, image_mask_items, name))
target_numpy[:,num,:,:,:] = numpy_tmp
if num >= time_gap-1:
break
net = UNet(n_channels=9, n_classes=args.classes, bilinear=args.bilinear)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')
net.to(device=device)
state_dict = torch.load(args.model, map_location=device)
net.load_state_dict(state_dict)
logging.info('Model loaded!')
result = predict_img(net=net,
full_img=input_numpy,
device=device,
scale_factor=args.scale)
print("MAE score:",np.abs(result-target_numpy).mean())
#np.save('result.npy',result)
evaluate.py
import torch
import torch.nn.functional as F
from tqdm import tqdm
@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
net.eval()
num_val_batches = len(dataloader)
mae_score = 0
# iterate over the validation set
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.float32)
# predict the mask
mask_pred = net(image)
mae_score = torch.mean(torch.abs(mask_pred - mask_true))
mse_score = torch.mean(0.5 * (mask_pred - mask_true) ** 2)
net.train()
return mae_score / max(num_val_batches, 1), mse_score / max(num_val_batches, 1)
hubconf.py
import torch
from unet import UNet as _UNet
def unet_carvana(pretrained=False, scale=0.5):
"""
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
Set the scale to 0.5 (50%) when predicting.
"""
net = _UNet(n_channels=3, n_classes=2, bilinear=False)
if pretrained:
if scale == 0.5:
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth'
elif scale == 1.0:
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale1.0_epoch2.pth'
else:
raise RuntimeError('Only 0.5 and 1.0 scales are available')
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
if 'mask_values' in state_dict:
state_dict.pop('mask_values')
net.load_state_dict(state_dict)
return net
utils/data_loading.py
import logging
import numpy as np
import torch
from PIL import Image
from functools import lru_cache
from functools import partial
from itertools import repeat
from multiprocessing import Pool
from os import listdir
from os.path import splitext, isfile, join
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import tqdm
import os
import random
class BasicDataset(Dataset):
def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0):
#load NJU_CPOL_update2308
self.image_dict = {}
class_name_lists = ["dBZ", "KDP", "ZDR"]
kilometers_name_lists = ["1.0km", "3.0km", "7.0km"]
for class_name in class_name_lists:
for kilometers_name in kilometers_name_lists:
fore_dir = os.path.join(images_dir, class_name, kilometers_name)
for image_items in os.listdir(fore_dir):
if image_items in self.image_dict:
pass
else:
self.image_dict[image_items] = [[]]*9
for name in os.listdir(os.path.join(fore_dir, image_items)):
index = class_name_lists.index(class_name)*3 + kilometers_name_lists.index(kilometers_name)
self.image_dict[image_items][index].append(os.path.join(fore_dir, image_items, name))
#load NJU_CPOL_kdpRain
self.mask_dict = {}
for mask_items in os.listdir(mask_dir):
self.mask_dict[mask_items] = []
for name in os.listdir(os.path.join(mask_dir, mask_items)):
self.mask_dict[mask_items].append(os.path.join(mask_dir, mask_items, name))
self.items = list(self.mask_dict.keys())
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
name = self.items[idx]
time_gap = 10
time_num = len(self.mask_dict[name])
time_start = random.randint(0,time_num-time_gap-1)
height,width= 256,256
#读取image
image_numpy = np.zeros((time_gap,9,height,width), np.float32)
for i in range(time_gap):
for j in range(len(self.image_dict[name])):
image_tmp = np.load(self.image_dict[name][j][time_start+i])
image_numpy[i,j,:,:] = image_tmp
#读取mask
mask_numpy = np.zeros((time_gap,1,height,width), np.float32)
for i in range(time_gap):
mask_tmp = np.load(self.mask_dict[name][time_start+i])
mask_numpy[i,:,:,:] = mask_tmp
return {
'image': torch.as_tensor(image_numpy.copy()).float().contiguous(),
'mask': torch.as_tensor(mask_numpy.copy()).float().contiguous()
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, mask_dir, scale=1):
super().__init__(images_dir, mask_dir, scale)
utils/loss.py
import torch
from torch import Tensor
def smooth_l1_loss(input, target, sigma=1.0, reduce=True, normalizer=1.0):
beta = 1. / (sigma ** 2)
diff = torch.abs(input - target)
cond = diff < beta
loss = torch.where(cond, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
if reduce:
return torch.mean(loss) / normalizer
return torch.mean(loss, dim=1) / normalizer
utils/utils.py
import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
classes = mask.max() + 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
for i in range(classes):
ax[i + 1].set_title(f'Mask (class {i + 1})')
ax[i + 1].imshow(mask == i)
plt.xticks([]), plt.yticks([])
plt.show()
unet/unet_lstm_model.py
""" Full assembly of the parts to form the complete network """
from .unet_parts import *
class UNetLSTM(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNetLSTM, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 4))
self.down1 = (Down(4, 8))
self.down2 = (Down(8, 16))
self.down3 = (Down(16, 32))
factor = 2 if bilinear else 1
self.down4 = (Down(32, 64 // factor))
self.down5 = (Down(64// factor, 128 // factor))
self.down6 = (Down(128// factor, 256 // factor))
self.up1 = (Up(256 // factor, 128 // factor, bilinear))
self.up2 = (Up(128 // factor, 64 // factor, bilinear))
self.up3 = (Up(64 // factor, 32 // factor, bilinear))
self.up4 = (Up(32, 16 // factor, bilinear))
self.up5 = (Up(16, 8 // factor, bilinear))
self.up6 = (Up(8, 4, bilinear))
self.outc = (OutConv(4, n_classes))
#lstm
self.Lstm = nn.LSTM(input_size=256*4*4, hidden_size=128*4*4, num_layers=2, batch_first=True, bidirectional=True)
def forward(self, x):
#形状变化
batch_size, timesteps, channel_x, h_x, w_x = x.shape
x_viewed = x.view(batch_size * timesteps, channel_x, h_x, w_x) #torch.Size([10, 9, 256, 256])
x1 = self.inc(x_viewed)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)#torch.Size([10, 256, 16, 16])
x6 = self.down5(x5)
x7 = self.down6(x6)#torch.Size([10, 1024, 4, 4])
#lstm
_, channel_x7, h_x7, w_x7 = x7.shape
x7 = x7.view(batch_size, timesteps, -1)
lstm_output, _ = self.Lstm(x7)#
lstm_output = x7.view(batch_size*timesteps, channel_x7, h_x7, w_x7)
x = self.up1(lstm_output, x6)
x = self.up2(x, x5)
x = self.up3(x, x4)
x = self.up4(x, x3)
x = self.up5(x, x2)
x = self.up6(x, x1)
logits = self.outc(x).view(batch_size, timesteps, 1, h_x, w_x) #torch.Size([1, 10, 1, 256, 256])
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
unet/unet_model.py
""" Full assembly of the parts to form the complete network """
from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 16))
self.down1 = (Down(16, 32))
self.down2 = (Down(32, 64))
self.down3 = (Down(64, 128))
factor = 2 if bilinear else 1
self.down4 = (Down(128, 256 // factor))
self.up1 = (Up(256, 128 // factor, bilinear))
self.up2 = (Up(128, 64 // factor, bilinear))
self.up3 = (Up(64, 32 // factor, bilinear))
self.up4 = (Up(32, 16, bilinear))
self.outc = (OutConv(16, n_classes))
def forward(self, x):
#形状变化
batch_size, timesteps, channel_x, h_x, w_x = x.shape
x_viewed = x.view(batch_size * timesteps, channel_x, h_x, w_x) #torch.Size([10, 9, 256, 256])
x1 = self.inc(x_viewed)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x).view(batch_size, timesteps, 1, h_x, w_x) #torch.Size([1, 10, 1, 256, 256])
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
unet/__init__.py
from .unet_model import UNet
from .unet_lstm_model import UNetLSTM
unet/unet_parts.py
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
实验结果: