【代码复现】ResUNet++进行语义分割(含图像切片预处理)


参考资料

文章地址:https://arxiv.org/pdf/1911.07067.pdf
代码地址:https://github.com/DebeshJha/ResUNetPlusPlus

1. preprocess.py

前言:
可能由于显卡内存不够的原因,导致尺寸很大的图片进行训练时,导致GPU显存不够的情况,一个简单的方法:对图片进行切片操作。对图片进行切片处理:将尺寸很大的图片裁剪成尺寸固定且大小适中的图片,方便后续进行训练。

该部分代码的功能:将训练集和测试集分别进行224×224裁剪,存储到新的文件夹中

1.1. 参数声明

1.1.1. 执行命令的形参

python preprocess.py --config "configs/default.yaml" --train ./DataSet_png512/train --valid ./DataSet_png512/test

--train:训练集路径
--vaild:验证集路径
--config:配置文件,具体内容如下:

train: "./DataPreprocess/train" # 训练数据文件夹路径
valid: "./DataPreprocess/test"  # 验证数据文件夹路径
log: "logs"                     # tensorboard的events存储路径: ./logs
logging_step: 100
validation_interval: 20 # Save and valid have same interval
checkpoints: "checkpoints"

batch_size: 4
lr: 0.001
RESNET_PLUS_PLUS: True  # 使用ResUNet++模型;若该值为False则使用ResUNet模型
IMAGE_SIZE: 512         # 1500
CROP_SIZE: 224          # 224

1.1.2. 代码中的参数声明

if __name__ == '__main__':
    # 这部分在上面已经赋值过
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, required=True,
                        help="yaml file for configuration")
    parser.add_argument('-t', '--train', type=str, required=True,
                        help="Training Folder.")
    parser.add_argument('-v', '--valid', type=str, required=True,
                        help="Validation Folder")
    args = parser.parse_args()


    # 将--config参数赋值给hp,由hp来调用其中的参数
    hp = HParam(args.config)
    with open(args.config, 'r') as f:
        hp_str = ''.join(f.readlines())

参数赋值:

    # 数据集路径
    train_dir = args.train   # './DataSet_png512/train'
    valid_dir = args.valid   # './DataSet_png512/test'
    
    #start_points这个函数具体作用下面介绍 
    X_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0)  # [0,192,288]
    Y_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0)  # [0,192,288]

    ## 训练集图片和掩码的文件夹路径
    train_img_dir = os.path.join(train_dir, "images") # './DataSet_png512/train/images'
    train_mask_dir = os.path.join(train_dir, "masks") # './DataSet_png512/train/masks'

    # 经过preprocess处理后图片的保存路径(如果事先没创建文件夹现在创建)
    train_img_crop_dir = os.path.join(hp.train, "images_crop") # './DataPreprocess/train/images_crop'
    os.makedirs(train_img_crop_dir, exist_ok=True)
    train_mask_crop_dir = os.path.join(hp.train, "masks_crop") # './DataPreprocess/train/masks_crop'
    os.makedirs(train_mask_crop_dir, exist_ok=True)

    # 遍历所有图片,然后打印图片数量
    img_files = glob.glob(os.path.join(train_img_dir, '**', '*.png'), recursive=True)
    mask_files = glob.glob(os.path.join(train_mask_dir, '**', '*.png'), recursive=True)
    print("Length of image :", len(img_files))
    print("Length of mask :", len(mask_files))

     上面代码中出现的start_points()函数,得到X_points和Y_points都为0,192,288,这三个点是图片裁剪的起始点,裁剪图片大小为224×224,具体实现方法见下面crop_image_mask()函数。

def crop_image_mask(image_dir, mask_dir, mask_path, X_points, Y_points, split_height=224, split_width=224):
    img_id = os.path.basename(mask_path).split(".")[0]
    mask = load_image(mask_path)
    img = load_image(mask_path.replace("masks", "images"))

    count = 0
    num_skipped = 1
    for i in Y_points:
        for j in X_points:
        
            # img[0:224,0:244],[0:224,192:416],[0:224,288:512]
            # img[192:416,0:244],[192:416,192:416],[192:416,288:512]
            # img[288:512,0:244],[288:512,192:416],[288:512,288:512]
            new_image = img[i:i + split_height, j:j + split_width] 
            new_mask = mask[i:i + split_height, j:j + split_width]
            new_mask[new_mask > 100] = 255
            new_mask[new_mask <= 100] = 0

            # 如果白色像素点/黑色像素点<0.01,就将图片设置成全黑。
            # 这种方式不适合用作小目标分割(眼底渗出物分割不适用)
            if np.any(new_mask):
                num_black_pixels, num_white_pixels = np.unique(new_mask, return_counts=True)[1]

                if num_white_pixels / num_black_pixels < 0.01:
                    num_skipped += 1
                    continue

            mask_ = Image.fromarray(new_mask.astype(np.uint8))
            mask_.save("{}/{}_{}.jpg".format(mask_dir, img_id, count), "JPEG")
            im = Image.fromarray(new_image.astype(np.uint8))
            im.save("{}/{}_{}.jpg".format(image_dir, img_id, count), "JPEG")
            count = count + 1

     到这里图片预处理便完成了,将训练集和测试集分别进行224×224裁剪,存储到新的文件夹中,后面train.py就是在这个新的文件夹中读取数据的。

2. train.py

2.1. 参数声明

python train.py --name "default" --config "configs/default.yaml"

--name:1.保存权重的文件夹名称;2.保存events的文件夹名称
--config:配置文件,具体内容如下:

train: "./DataPreprocess/train" # 训练数据文件夹路径
valid: "./DataPreprocess/test"  # 验证数据文件夹路径
log: "logs"                     # tensorboard的events存储路径: ./logs
logging_step: 100
validation_interval: 20 # Save and valid have same interval
checkpoints: "checkpoints"

batch_size: 4
lr: 0.001
RESNET_PLUS_PLUS: True  # 使用ResUNet++模型;若该值为False则使用ResUNet模型
IMAGE_SIZE: 512         # 1500
CROP_SIZE: 224          # 224

参数声明完成后,跳到main主函数

2.2. main函数(不包括训练阶段)

2.2.1 参数说明

main(hp, num_epochs=args.epochs, resume=args.resume, name=args.name)

hp:就是configs/default.yaml里面的参数
num_epochs:默认为 75
resume:默认空字符串‘ ’
name:字符串:‘default’

def main(hp, num_epochs, resume, name):
checkpoint_dir:'checkpoint/default'  # 保存的权重路径
writer = MyWriter("{}/{}".format(hp.log, name)) # logdir: 'log/default'
model = ResUnetPlusPlus(3).cuda()
criterion = metrics.BCEDiceLoss()  # 采用binary cross entropy 和 dice 损失
optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr)  # Adam优化器
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

2.2.2. 读取数据部分

mass_dataset_train = dataloader.ImageDataset(      # 这里没有False表示对验证集进行处理
        hp, transform=transforms.Compose([dataloader.ToTensorTarget()]))
mass_dataset_val = dataloader.ImageDataset(         # 这里False表示对验证集进行处理
        hp, False, transform=transforms.Compose([dataloader.ToTensorTarget()]))

     调用dataloader.ImageDataset类,要注意的是这里读取的是经过数据预处理的图片,对应文件夹名称为DataPreprocess。

class ImageDataset(Dataset):
该代码实现功能:读取图片和掩码,将其放入sample,如果self.transform==Ture,则对sample进行self.transform。
最后返回值为sample。

2.2.3. 创建 loaders

    train_dataloader = DataLoader(
        mass_dataset_train, batch_size=hp.batch_size, num_workers=2, shuffle=True)
        
    val_dataloader = DataLoader(
        mass_dataset_val, batch_size=1, num_workers=2, shuffle=False)

2.3. 训练阶段

    step = 0
    for epoch in range(start_epoch, num_epochs):
    lr_scheduler.step()   # 更新学习率
    
    # 记录准确度和损失,后面会调用update来更新值。
    train_acc = metrics.MetricTracker()   
    train_loss = metrics.MetricTracker()   

载入数据,模型进行训练:

        loader = tqdm(train_dataloader, desc="training")
        for idx, data in enumerate(loader):

            # 获取输入图像和掩码
            inputs = data["sat_img"].cuda()
            labels = data["map_img"].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)  # 采用binary cross entropy 和 dice 损失,前面声明过

            # 后向传播
            loss.backward()
            optimizer.step()

            # 更新acc和loss值
            train_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
            train_loss.update(loss.data.item(), outputs.size(0))

紧接着,tensorboard可视化训练阶段

            # tensorboard logging:其中,hp.logging_step=100
            if step % hp.logging_step == 0:   #每100step更新一次
                writer.log_training(train_loss.avg, train_acc.avg, step)

                # 每隔100step,进度条打印一次(tqdm)
                loader.set_description(
                    "Training Loss: {:.4f} Acc: {:.4f}".format(
                        train_loss.avg, train_acc.avg )   )

2.4. validation阶段

这部分中的validation()函数是核心:

            # hp.validation=20
            if step % hp.validation_interval == 0:
                
                # 进入validation()函数,验证阶段
                valid_metrics = validation(
                    val_dataloader, model, criterion, writer, step )

                # checkpoint_dir:'checkpoint/default/default_checkpoint_xx.pt'  # 保存的权重文件路径
                save_path = os.path.join(
                    checkpoint_dir, "%s_checkpoint_%04d.pt" % (name, step)  )
                    
                # get最小损失,后面进行保存
                best_loss = min(valid_metrics["valid_loss"], best_loss)

                # 保存参数,保存在上面save_path中
                torch.save(
                    {
                        "step": step,
                        "epoch": epoch,
                        "arch": "ResUnet++",
                        "state_dict": model.state_dict(),
                        "best_loss": best_loss,
                        "optimizer": optimizer.state_dict(),
                    },
                    save_path, )
                print("Saved checkpoint to: %s" % save_path)

            step += 1

validation()的实现代码:

def validation(valid_loader, model, criterion, logger, step):

    # 同上
    valid_acc = metrics.MetricTracker()
    valid_loss = metrics.MetricTracker()

    # 进入验证模式
    model.eval()

    # Iterate over data.
    for idx, data in enumerate(tqdm(valid_loader, desc="validation")):

        # get the inputs and wrap in Variable
        inputs = data["sat_img"].cuda()
        labels = data["map_img"].cuda()

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 更新acc和loss参数
        valid_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
        valid_loss.update(loss.data.item(), outputs.size(0))
        
        if idx == 0:
            logger.log_images(inputs.cpu(), labels.cpu(), outputs.cpu(), step)
            
    # 将验证阶段的acc和loss写入tensorboard
    logger.log_validation(valid_loss.avg, valid_acc.avg, step)

    print("Validation Loss: {:.4f} Acc: {:.4f}".format(valid_loss.avg, valid_acc.avg))

    # 
    model.train()
    return {"valid_loss": valid_loss.avg, "valid_acc": valid_acc.avg}

这部分代码倒数第二行model.train()的作用是:
     在验证阶段结束后调用 model.train() 是为了将模型切换回训练模式
在深度学习中,有些层(例如 Dropout、Batch Normalization 等)在训练模式和评估模式下具有不同的行为。在训练模式下,这些层会执行特定的操作来增强模型的泛化能力和稳定性。而在评估模式下,这些层的行为会发生变化,以保持一致性和可重复性。
     总之,加上 model.train() 是为了确保模型在验证阶段结束后切换回训练模式,以保持训练和评估的行为一致。

3. 其他相关代码

3.1. model.py

ResUNet++模型框架:
在这里插入图片描述

具体实现如下:

3.1.1. res_unet_plus.py

import torch.nn as nn
import torch
from core.modules import (
    ResidualConv,
    ASPP,
    AttentionBlock,
    Upsample_,
    Squeeze_Excite_Block,
)


class ResUnetPlusPlus(nn.Module):
    def __init__(self, channel, filters=[32, 64, 128, 256, 512]):
        super(ResUnetPlusPlus, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(),
            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
        )

        self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])

        self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)

        self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])

        self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)

        self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])

        self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)

        self.aspp_bridge = ASPP(filters[3], filters[4])

        self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
        self.upsample1 = Upsample_(2)
        self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)

        self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
        self.upsample2 = Upsample_(2)
        self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)

        self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
        self.upsample3 = Upsample_(2)
        self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)

        self.aspp_out = ASPP(filters[1], filters[0])

        self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid())

    def forward(self, x):
        x1 = self.input_layer(x) + self.input_skip(x)

        x2 = self.squeeze_excite1(x1)
        x2 = self.residual_conv1(x2)

        x3 = self.squeeze_excite2(x2)
        x3 = self.residual_conv2(x3)

        x4 = self.squeeze_excite3(x3)
        x4 = self.residual_conv3(x4)

        x5 = self.aspp_bridge(x4)

        x6 = self.attn1(x3, x5)
        x6 = self.upsample1(x6)
        x6 = torch.cat([x6, x3], dim=1)
        x6 = self.up_residual_conv1(x6)

        x7 = self.attn2(x2, x6)
        x7 = self.upsample2(x7)
        x7 = torch.cat([x7, x2], dim=1)
        x7 = self.up_residual_conv2(x7)

        x8 = self.attn3(x1, x7)
        x8 = self.upsample3(x8)
        x8 = torch.cat([x8, x1], dim=1)
        x8 = self.up_residual_conv3(x8)

        x9 = self.aspp_out(x8)
        out = self.output_layer(x9)

        return out

3.1.1.1. Squeeze and Excitation Units

该模块的输入是上一层的通道数,一个可设置参数reduction

class Squeeze_Excite_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(Squeeze_Excite_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

这个模块有什么作用呢?文献中是这样解释的:

squeeze and excitation block与residual block堆叠在一起,以增加对不同数据集的有效泛化并提高网络的性能。

下面是一个PyTorch实现ResNet代码的示例: ``` import torch import torch.nn as nn import torch.nn.functional as F class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) 这是ResNet模型的一个PyTorch代码实现。它包了一个基本块(BasicBlock)和ResNet模型(ResNet)的实现。在实现中,用到了卷积层(Conv2d)、批量标准化(BatchNorm2d)和激活函数(ReLU)。
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cpdr

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值