python 代码实现了一个带有空洞卷积(Atrous Convolution)的注意力U型网络(Attention U-Net),用于图像分割任务

部署运行你感兴趣的模型镜像
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets
from PIL import Image
import numpy as np


# 定义Atrous卷积块
class AtrousConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dilation_rate=2):
        super(AtrousConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=dilation_rate, dilation=dilation_rate, bias=True)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# 定义Attention U-Net模型
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi


class AttentionUNetWithAtrous(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, dilation_rate=2):
        super(AttentionUNetWithAtrous, self).__init__()

        # 使用Atrous卷积进行降维和扩展感受野
        self.AtrousConv1 = AtrousConvBlock(img_ch, 64, dilation_rate)
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.Conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.Conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.Conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.Conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )

        self.Up5 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, padding=0, bias=True)
        self.Att5 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.Up_conv5 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.Up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0, bias=True)
        self.Att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.Up_conv4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.Up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True)
        self.Att3 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.Up_conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.Up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0, bias=True)
        self.Att2 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.Up_conv2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        e1 = self.AtrousConv1(x)  # 使用空洞卷积进行降维
        e2 = self.Maxpool(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        x4 = self.Att5(g=d5, x=e4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=e3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=e2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=e1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv_1x1(d2)

        return out


def compute_miou(pred, target, smooth=1e-10):

    intersection = (pred & target).float().sum((1, 2))  # 计算交集
    union = (pred | target).float().sum((1, 2))  # 计算并集

    iou = (intersection + smooth) / (union + smooth)  # 计算IoU
    miou = iou.mean()  # 计算MIoU
    return miou.item()

def compute_pixel_accuracy(pred, target):

    correct = (pred == target).float().sum()
    total = target.numel()

    pa = correct / total
    return pa.item()
def dice(pred, target):
    pred = pred.contiguous().view(-1)
    target = target.contiguous().view(-1)
    intersection = torch.sum(pred * target)
    pred_sum = torch.sum(pred)
    target_sum = torch.sum(target)
    dice_score = (2. * intersection + 1e-5) / (pred_sum + target_sum + 1e-5)
    return dice_score

# 数据集类定义
class CustomDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(img_dir))
        self.label_files = sorted(os.listdir(label_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.image_files[idx])
        label_path = os.path.join(self.label_dir, self.label_files[idx])
        image = Image.open(img_path).convert("RGB")
        label = Image.open(label_path).convert("L")

        if self.transform:
            # 数据增强时,将图像和标签同步变换
            seed = np.random.randint(2147483647)
            torch.manual_seed(seed)
            image = self.transform(image)
            torch.manual_seed(seed)
            label = self.transform(label)

        return image, label


# 定义数据增强操作
data_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.ToTensor(),
])



def load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.8):
    # 加载数据集

    img_dir = os.path.join(data_dir, 'img')
    label_dir = os.path.join(data_dir, 'label')
    dataset = CustomDataset(img_dir, label_dir, transform=data_transforms)
    # 划分训练集和验证集
    train_size = int(scala * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    return train_loader,val_loader



train_loader, val_loader = load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.7)

# 模型定义与训练
model = AttentionUNetWithAtrous(img_ch=3, output_ch=1)
model = model.cuda() if torch.cuda.is_available() else model

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)  # L2正则化

# 早停法配置
patience = 3
best_loss = float('inf')
counter = 0

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images = images.cuda() if torch.cuda.is_available() else images
        labels = labels.cuda() if torch.cuda.is_available() else labels

        optimizer.zero_grad()

        # 前向传播
        outputs = model(images)

        # 计算损失
        loss = criterion(outputs, labels)
        train_loss += loss.item() * images.size(0)

        # 反向传播
        loss.backward()
        optimizer.step()

    train_loss = train_loss / len(train_loader.dataset)

    # 验证阶段
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.cuda() if torch.cuda.is_available() else images
            labels = labels.cuda() if torch.cuda.is_available() else labels

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)

    val_loss = val_loss / len(val_loader.dataset)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    # 早停法
    if val_loss < best_loss:
        best_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), 'best_model.pth')  # 保存最佳模型
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered")
            break

# 加载最佳模型进行评估

model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# 在验证集上计算MIoU和PA
miou_list = []
pa_list = []
dice_list = []
with torch.no_grad():
    for images, labels in val_loader:
        images = images.cuda() if torch.cuda.is_available() else images
        labels = labels.cuda() if torch.cuda.is_available() else labels

        outputs = model(images)
        predicted = (outputs > 0.5).int()

        miou = compute_miou(predicted, labels.int())
        pa = compute_pixel_accuracy(predicted, labels.int())
        dice = dice(outputs, labels)
        miou_list.append(miou)
        pa_list.append(pa)
        dice_list.append(dice)
print(f'Mean MIoU: {np.mean(miou_list):.4f}')
print(f'Mean Dice: {np.mean(dice_list):.4f}')
print(f'Mean Pixel Accuracy: {np.mean(pa_list):.4f}')

这段代码实现了一个带有空洞卷积(Atrous Convolution)的注意力U型网络(Attention U-Net),用于图像分割任务。下面详细介绍其功能:

1. 导入必要的库

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets
from PIL import Image
import numpy as np

导入了操作系统操作、PyTorch深度学习框架、数据处理和图像操作等相关的库。

2. 定义空洞卷积块(Atrous Conv Block)

class AtrousConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dilation_rate=2):
        super(AtrousConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=dilation_rate, dilation=dilation_rate, bias=True)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

定义了一个空洞卷积块,包含一个空洞卷积层、一个批归一化层和一个ReLU激活函数。空洞卷积可以在不增加参数和计算量的情况下扩大感受野。

3. 定义注意力块(Attention Block)

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

定义了一个注意力块,用于在U型网络的跳跃连接中引入注意力机制,帮助网络更好地聚焦于重要特征。

4. 定义带有空洞卷积的注意力U型网络(Attention UNet With Atrous)

class AttentionUNetWithAtrous(nn.Module):
    def __init__(self, img_ch=3, output_ch=1, dilation_rate=2):
        super(AttentionUNetWithAtrous, self).__init__()
        # 省略部分代码
    def forward(self, x):
        # 省略部分代码
        return out

定义了整个网络结构,结合了空洞卷积和注意力机制。网络由下采样、上采样和跳跃连接组成,最终输出分割结果。

5. 定义评估指标计算函数

def compute_miou(pred, target, smooth=1e-10):
    # 省略部分代码
    return miou.item()

def compute_pixel_accuracy(pred, target):
    # 省略部分代码
    return pa.item()

def dice(pred, target):
    # 省略部分代码
    return dice_score

定义了计算平均交并比(MIoU)、像素准确率(PA)和Dice系数的函数,用于评估模型的性能。

6. 定义自定义数据集类

class CustomDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        # 省略部分代码
    def __len__(self):
        return len(self.image_files)
    def __getitem__(self, idx):
        # 省略部分代码
        return image, label

定义了一个自定义数据集类,用于加载图像和对应的标签数据,并支持数据增强。

7. 数据加载和预处理

data_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.ToTensor(),
])

def load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.8):
    # 省略部分代码
    return train_loader,val_loader

train_loader, val_loader = load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.7)

定义了数据增强操作,包括随机水平翻转、垂直翻转、旋转和随机裁剪等。然后定义了一个函数来加载数据集,并将其划分为训练集和验证集。

8. 模型训练和评估

model = AttentionUNetWithAtrous(img_ch=3, output_ch=1)
model = model.cuda() if torch.cuda.is_available() else model

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

# 早停法配置
patience = 3
best_loss = float('inf')
counter = 0

num_epochs = 20

# 训练循环
for epoch in range(num_epochs):
    # 省略部分代码

# 加载最佳模型进行评估
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# 计算评估指标
miou_list = []
pa_list = []
dice_list = []
with torch.no_grad():
    # 省略部分代码
print(f'Mean MIoU: {np.mean(miou_list):.4f}')
print(f'Mean Dice: {np.mean(dice_list):.4f}')
print(f'Mean Pixel Accuracy: {np.mean(pa_list):.4f}')

初始化模型、损失函数和优化器,使用早停法来防止过拟合。在训练过程中,模型在每个epoch上进行训练和验证,并记录损失值。训练结束后,加载最佳模型并在验证集上计算MIoU、PA和Dice系数等评估指标。

总结

这段代码实现了一个完整的图像分割流程,包括数据加载、数据增强、模型定义、训练和评估。通过结合空洞卷积和注意力机制,提高了U型网络在图像分割任务中的性能。

您可能感兴趣的与本文相关的镜像

Python3.9

Python3.9

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值