文章目录
参考资料
文章地址:https://arxiv.org/pdf/1911.07067.pdf
代码地址:https://github.com/DebeshJha/ResUNetPlusPlus
1. preprocess.py
前言:
可能由于显卡内存不够的原因,导致尺寸很大的图片进行训练时,导致GPU显存不够的情况,一个简单的方法:对图片进行切片操作。对图片进行切片处理:将尺寸很大的图片裁剪成尺寸固定且大小适中的图片,方便后续进行训练。
该部分代码的功能:将训练集和测试集分别进行224×224裁剪,存储到新的文件夹中
1.1. 参数声明
1.1.1. 执行命令的形参
python preprocess.py --config "configs/default.yaml" --train ./DataSet_png512/train --valid ./DataSet_png512/test
--train
:训练集路径
--vaild
:验证集路径
--config
:配置文件,具体内容如下:
train: "./DataPreprocess/train" # 训练数据文件夹路径
valid: "./DataPreprocess/test" # 验证数据文件夹路径
log: "logs" # tensorboard的events存储路径: ./logs
logging_step: 100
validation_interval: 20 # Save and valid have same interval
checkpoints: "checkpoints"
batch_size: 4
lr: 0.001
RESNET_PLUS_PLUS: True # 使用ResUNet++模型;若该值为False则使用ResUNet模型
IMAGE_SIZE: 512 # 1500
CROP_SIZE: 224 # 224
1.1.2. 代码中的参数声明
if __name__ == '__main__':
# 这部分在上面已经赋值过
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True,
help="yaml file for configuration")
parser.add_argument('-t', '--train', type=str, required=True,
help="Training Folder.")
parser.add_argument('-v', '--valid', type=str, required=True,
help="Validation Folder")
args = parser.parse_args()
# 将--config参数赋值给hp,由hp来调用其中的参数
hp = HParam(args.config)
with open(args.config, 'r') as f:
hp_str = ''.join(f.readlines())
参数赋值:
# 数据集路径
train_dir = args.train # './DataSet_png512/train'
valid_dir = args.valid # './DataSet_png512/test'
#start_points这个函数具体作用下面介绍
X_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0) # [0,192,288]
Y_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0) # [0,192,288]
## 训练集图片和掩码的文件夹路径
train_img_dir = os.path.join(train_dir, "images") # './DataSet_png512/train/images'
train_mask_dir = os.path.join(train_dir, "masks") # './DataSet_png512/train/masks'
# 经过preprocess处理后图片的保存路径(如果事先没创建文件夹现在创建)
train_img_crop_dir = os.path.join(hp.train, "images_crop") # './DataPreprocess/train/images_crop'
os.makedirs(train_img_crop_dir, exist_ok=True)
train_mask_crop_dir = os.path.join(hp.train, "masks_crop") # './DataPreprocess/train/masks_crop'
os.makedirs(train_mask_crop_dir, exist_ok=True)
# 遍历所有图片,然后打印图片数量
img_files = glob.glob(os.path.join(train_img_dir, '**', '*.png'), recursive=True)
mask_files = glob.glob(os.path.join(train_mask_dir, '**', '*.png'), recursive=True)
print("Length of image :", len(img_files))
print("Length of mask :", len(mask_files))
上面代码中出现的start_points()
函数,得到X_points和Y_points都为0,192,288,这三个点是图片裁剪的起始点,裁剪图片大小为224×224,具体实现方法见下面crop_image_mask()
函数。
def crop_image_mask(image_dir, mask_dir, mask_path, X_points, Y_points, split_height=224, split_width=224):
img_id = os.path.basename(mask_path).split(".")[0]
mask = load_image(mask_path)
img = load_image(mask_path.replace("masks", "images"))
count = 0
num_skipped = 1
for i in Y_points:
for j in X_points:
# img[0:224,0:244],[0:224,192:416],[0:224,288:512]
# img[192:416,0:244],[192:416,192:416],[192:416,288:512]
# img[288:512,0:244],[288:512,192:416],[288:512,288:512]
new_image = img[i:i + split_height, j:j + split_width]
new_mask = mask[i:i + split_height, j:j + split_width]
new_mask[new_mask > 100] = 255
new_mask[new_mask <= 100] = 0
# 如果白色像素点/黑色像素点<0.01,就将图片设置成全黑。
# 这种方式不适合用作小目标分割(眼底渗出物分割不适用)
if np.any(new_mask):
num_black_pixels, num_white_pixels = np.unique(new_mask, return_counts=True)[1]
if num_white_pixels / num_black_pixels < 0.01:
num_skipped += 1
continue
mask_ = Image.fromarray(new_mask.astype(np.uint8))
mask_.save("{}/{}_{}.jpg".format(mask_dir, img_id, count), "JPEG")
im = Image.fromarray(new_image.astype(np.uint8))
im.save("{}/{}_{}.jpg".format(image_dir, img_id, count), "JPEG")
count = count + 1
到这里图片预处理便完成了,将训练集和测试集分别进行224×224裁剪,存储到新的文件夹中,后面train.py
就是在这个新的文件夹中读取数据的。
2. train.py
2.1. 参数声明
python train.py --name "default" --config "configs/default.yaml"
--name
:1.保存权重的文件夹名称;2.保存events的文件夹名称
--config
:配置文件,具体内容如下:
train: "./DataPreprocess/train" # 训练数据文件夹路径
valid: "./DataPreprocess/test" # 验证数据文件夹路径
log: "logs" # tensorboard的events存储路径: ./logs
logging_step: 100
validation_interval: 20 # Save and valid have same interval
checkpoints: "checkpoints"
batch_size: 4
lr: 0.001
RESNET_PLUS_PLUS: True # 使用ResUNet++模型;若该值为False则使用ResUNet模型
IMAGE_SIZE: 512 # 1500
CROP_SIZE: 224 # 224
参数声明完成后,跳到main
主函数
2.2. main函数(不包括训练阶段)
2.2.1 参数说明
main(hp, num_epochs=args.epochs, resume=args.resume, name=args.name)
hp
:就是configs/default.yaml里面的参数
num_epochs
:默认为 75
resume
:默认空字符串‘ ’
name
:字符串:‘default’
def main(hp, num_epochs, resume, name):
checkpoint_dir:'checkpoint/default' # 保存的权重路径
writer = MyWriter("{}/{}".format(hp.log, name)) # logdir: 'log/default'
model = ResUnetPlusPlus(3).cuda()
criterion = metrics.BCEDiceLoss() # 采用binary cross entropy 和 dice 损失
optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr) # Adam优化器
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
2.2.2. 读取数据部分
mass_dataset_train = dataloader.ImageDataset( # 这里没有False表示对验证集进行处理
hp, transform=transforms.Compose([dataloader.ToTensorTarget()]))
mass_dataset_val = dataloader.ImageDataset( # 这里False表示对验证集进行处理
hp, False, transform=transforms.Compose([dataloader.ToTensorTarget()]))
调用dataloader.ImageDataset
类,要注意的是这里读取的是经过数据预处理的图片,对应文件夹名称为DataPreprocess。
class ImageDataset(Dataset):
该代码实现功能:读取图片和掩码,将其放入sample,如果self.transform==Ture,则对sample进行self.transform。
最后返回值为sample。
2.2.3. 创建 loaders
train_dataloader = DataLoader(
mass_dataset_train, batch_size=hp.batch_size, num_workers=2, shuffle=True)
val_dataloader = DataLoader(
mass_dataset_val, batch_size=1, num_workers=2, shuffle=False)
2.3. 训练阶段
step = 0
for epoch in range(start_epoch, num_epochs):
lr_scheduler.step() # 更新学习率
# 记录准确度和损失,后面会调用update来更新值。
train_acc = metrics.MetricTracker()
train_loss = metrics.MetricTracker()
载入数据,模型进行训练:
loader = tqdm(train_dataloader, desc="training")
for idx, data in enumerate(loader):
# 获取输入图像和掩码
inputs = data["sat_img"].cuda()
labels = data["map_img"].cuda()
# zero the parameter gradients
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels) # 采用binary cross entropy 和 dice 损失,前面声明过
# 后向传播
loss.backward()
optimizer.step()
# 更新acc和loss值
train_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
train_loss.update(loss.data.item(), outputs.size(0))
紧接着,tensorboard可视化训练阶段
# tensorboard logging:其中,hp.logging_step=100
if step % hp.logging_step == 0: #每100step更新一次
writer.log_training(train_loss.avg, train_acc.avg, step)
# 每隔100step,进度条打印一次(tqdm)
loader.set_description(
"Training Loss: {:.4f} Acc: {:.4f}".format(
train_loss.avg, train_acc.avg ) )
2.4. validation阶段
这部分中的validation()
函数是核心:
# hp.validation=20
if step % hp.validation_interval == 0:
# 进入validation()函数,验证阶段
valid_metrics = validation(
val_dataloader, model, criterion, writer, step )
# checkpoint_dir:'checkpoint/default/default_checkpoint_xx.pt' # 保存的权重文件路径
save_path = os.path.join(
checkpoint_dir, "%s_checkpoint_%04d.pt" % (name, step) )
# get最小损失,后面进行保存
best_loss = min(valid_metrics["valid_loss"], best_loss)
# 保存参数,保存在上面save_path中
torch.save(
{
"step": step,
"epoch": epoch,
"arch": "ResUnet++",
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
},
save_path, )
print("Saved checkpoint to: %s" % save_path)
step += 1
validation()
的实现代码:
def validation(valid_loader, model, criterion, logger, step):
# 同上
valid_acc = metrics.MetricTracker()
valid_loss = metrics.MetricTracker()
# 进入验证模式
model.eval()
# Iterate over data.
for idx, data in enumerate(tqdm(valid_loader, desc="validation")):
# get the inputs and wrap in Variable
inputs = data["sat_img"].cuda()
labels = data["map_img"].cuda()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 更新acc和loss参数
valid_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
valid_loss.update(loss.data.item(), outputs.size(0))
if idx == 0:
logger.log_images(inputs.cpu(), labels.cpu(), outputs.cpu(), step)
# 将验证阶段的acc和loss写入tensorboard
logger.log_validation(valid_loss.avg, valid_acc.avg, step)
print("Validation Loss: {:.4f} Acc: {:.4f}".format(valid_loss.avg, valid_acc.avg))
#
model.train()
return {"valid_loss": valid_loss.avg, "valid_acc": valid_acc.avg}
这部分代码倒数第二行model.train()的作用是:
在验证阶段结束后调用 model.train() 是为了将模型切换回训练模式。
在深度学习中,有些层(例如 Dropout、Batch Normalization 等)在训练模式和评估模式下具有不同的行为。在训练模式下,这些层会执行特定的操作来增强模型的泛化能力和稳定性。而在评估模式下,这些层的行为会发生变化,以保持一致性和可重复性。
总之,加上 model.train() 是为了确保模型在验证阶段结束后切换回训练模式,以保持训练和评估的行为一致。
3. 其他相关代码
3.1. model.py
ResUNet++模型框架:
具体实现如下:
3.1.1. res_unet_plus.py
import torch.nn as nn
import torch
from core.modules import (
ResidualConv,
ASPP,
AttentionBlock,
Upsample_,
Squeeze_Excite_Block,
)
class ResUnetPlusPlus(nn.Module):
def __init__(self, channel, filters=[32, 64, 128, 256, 512]):
super(ResUnetPlusPlus, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
)
self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])
self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)
self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])
self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)
self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])
self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)
self.aspp_bridge = ASPP(filters[3], filters[4])
self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
self.upsample1 = Upsample_(2)
self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)
self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
self.upsample2 = Upsample_(2)
self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)
self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
self.upsample3 = Upsample_(2)
self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)
self.aspp_out = ASPP(filters[1], filters[0])
self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid())
def forward(self, x):
x1 = self.input_layer(x) + self.input_skip(x)
x2 = self.squeeze_excite1(x1)
x2 = self.residual_conv1(x2)
x3 = self.squeeze_excite2(x2)
x3 = self.residual_conv2(x3)
x4 = self.squeeze_excite3(x3)
x4 = self.residual_conv3(x4)
x5 = self.aspp_bridge(x4)
x6 = self.attn1(x3, x5)
x6 = self.upsample1(x6)
x6 = torch.cat([x6, x3], dim=1)
x6 = self.up_residual_conv1(x6)
x7 = self.attn2(x2, x6)
x7 = self.upsample2(x7)
x7 = torch.cat([x7, x2], dim=1)
x7 = self.up_residual_conv2(x7)
x8 = self.attn3(x1, x7)
x8 = self.upsample3(x8)
x8 = torch.cat([x8, x1], dim=1)
x8 = self.up_residual_conv3(x8)
x9 = self.aspp_out(x8)
out = self.output_layer(x9)
return out
3.1.1.1. Squeeze and Excitation Units
该模块的输入是上一层的通道数
,一个可设置参数reduction
class Squeeze_Excite_Block(nn.Module):
def __init__(self, channel, reduction=16):
super(Squeeze_Excite_Block, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
这个模块有什么作用呢?文献中是这样解释的:
squeeze and excitation block与residual block堆叠在一起,以增加对不同数据集的有效泛化并提高网络的性能。