import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets
from PIL import Image
import numpy as np
# 定义Atrous卷积块
class AtrousConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, dilation_rate=2):
super(AtrousConvBlock, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=dilation_rate, dilation=dilation_rate, bias=True)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# 定义Attention U-Net模型
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class AttentionUNetWithAtrous(nn.Module):
def __init__(self, img_ch=3, output_ch=1, dilation_rate=2):
super(AttentionUNetWithAtrous, self).__init__()
# 使用Atrous卷积进行降维和扩展感受野
self.AtrousConv1 = AtrousConvBlock(img_ch, 64, dilation_rate)
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.Conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.Conv3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.Conv4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.Conv5 = nn.Sequential(
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True)
)
self.Up5 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, padding=0, bias=True)
self.Att5 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.Up_conv5 = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.Up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0, bias=True)
self.Att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.Up_conv4 = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.Up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True)
self.Att3 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.Up_conv3 = nn.Sequential(
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.Up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0, bias=True)
self.Att2 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.Up_conv2 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
def forward(self, x):
e1 = self.AtrousConv1(x) # 使用空洞卷积进行降维
e2 = self.Maxpool(e1)
e2 = self.Conv2(e2)
e3 = self.Maxpool(e2)
e3 = self.Conv3(e3)
e4 = self.Maxpool(e3)
e4 = self.Conv4(e4)
e5 = self.Maxpool(e4)
e5 = self.Conv5(e5)
d5 = self.Up5(e5)
x4 = self.Att5(g=d5, x=e4)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
x3 = self.Att4(g=d4, x=e3)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
x2 = self.Att3(g=d3, x=e2)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
x1 = self.Att2(g=d2, x=e1)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
out = self.Conv_1x1(d2)
return out
def compute_miou(pred, target, smooth=1e-10):
intersection = (pred & target).float().sum((1, 2)) # 计算交集
union = (pred | target).float().sum((1, 2)) # 计算并集
iou = (intersection + smooth) / (union + smooth) # 计算IoU
miou = iou.mean() # 计算MIoU
return miou.item()
def compute_pixel_accuracy(pred, target):
correct = (pred == target).float().sum()
total = target.numel()
pa = correct / total
return pa.item()
def dice(pred, target):
pred = pred.contiguous().view(-1)
target = target.contiguous().view(-1)
intersection = torch.sum(pred * target)
pred_sum = torch.sum(pred)
target_sum = torch.sum(target)
dice_score = (2. * intersection + 1e-5) / (pred_sum + target_sum + 1e-5)
return dice_score
# 数据集类定义
class CustomDataset(Dataset):
def __init__(self, img_dir, label_dir, transform=None):
self.img_dir = img_dir
self.label_dir = label_dir
self.transform = transform
self.image_files = sorted(os.listdir(img_dir))
self.label_files = sorted(os.listdir(label_dir))
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.image_files[idx])
label_path = os.path.join(self.label_dir, self.label_files[idx])
image = Image.open(img_path).convert("RGB")
label = Image.open(label_path).convert("L")
if self.transform:
# 数据增强时,将图像和标签同步变换
seed = np.random.randint(2147483647)
torch.manual_seed(seed)
image = self.transform(image)
torch.manual_seed(seed)
label = self.transform(label)
return image, label
# 定义数据增强操作
data_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(15),
transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
transforms.ToTensor(),
])
def load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.8):
# 加载数据集
img_dir = os.path.join(data_dir, 'img')
label_dir = os.path.join(data_dir, 'label')
dataset = CustomDataset(img_dir, label_dir, transform=data_transforms)
# 划分训练集和验证集
train_size = int(scala * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
return train_loader,val_loader
train_loader, val_loader = load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.7)
# 模型定义与训练
model = AttentionUNetWithAtrous(img_ch=3, output_ch=1)
model = model.cuda() if torch.cuda.is_available() else model
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) # L2正则化
# 早停法配置
patience = 3
best_loss = float('inf')
counter = 0
num_epochs = 20
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for images, labels in train_loader:
images = images.cuda() if torch.cuda.is_available() else images
labels = labels.cuda() if torch.cuda.is_available() else labels
optimizer.zero_grad()
# 前向传播
outputs = model(images)
# 计算损失
loss = criterion(outputs, labels)
train_loss += loss.item() * images.size(0)
# 反向传播
loss.backward()
optimizer.step()
train_loss = train_loss / len(train_loader.dataset)
# 验证阶段
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in val_loader:
images = images.cuda() if torch.cuda.is_available() else images
labels = labels.cuda() if torch.cuda.is_available() else labels
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
val_loss = val_loss / len(val_loader.dataset)
print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
# 早停法
if val_loss < best_loss:
best_loss = val_loss
counter = 0
torch.save(model.state_dict(), 'best_model.pth') # 保存最佳模型
else:
counter += 1
if counter >= patience:
print("Early stopping triggered")
break
# 加载最佳模型进行评估
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
# 在验证集上计算MIoU和PA
miou_list = []
pa_list = []
dice_list = []
with torch.no_grad():
for images, labels in val_loader:
images = images.cuda() if torch.cuda.is_available() else images
labels = labels.cuda() if torch.cuda.is_available() else labels
outputs = model(images)
predicted = (outputs > 0.5).int()
miou = compute_miou(predicted, labels.int())
pa = compute_pixel_accuracy(predicted, labels.int())
dice = dice(outputs, labels)
miou_list.append(miou)
pa_list.append(pa)
dice_list.append(dice)
print(f'Mean MIoU: {np.mean(miou_list):.4f}')
print(f'Mean Dice: {np.mean(dice_list):.4f}')
print(f'Mean Pixel Accuracy: {np.mean(pa_list):.4f}')
这段代码实现了一个带有空洞卷积(Atrous Convolution)的注意力U型网络(Attention U-Net),用于图像分割任务。下面详细介绍其功能:
1. 导入必要的库
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets
from PIL import Image
import numpy as np
导入了操作系统操作、PyTorch深度学习框架、数据处理和图像操作等相关的库。
2. 定义空洞卷积块(Atrous Conv Block)
class AtrousConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, dilation_rate=2):
super(AtrousConvBlock, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=dilation_rate, dilation=dilation_rate, bias=True)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
定义了一个空洞卷积块,包含一个空洞卷积层、一个批归一化层和一个ReLU激活函数。空洞卷积可以在不增加参数和计算量的情况下扩大感受野。
3. 定义注意力块(Attention Block)
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
定义了一个注意力块,用于在U型网络的跳跃连接中引入注意力机制,帮助网络更好地聚焦于重要特征。
4. 定义带有空洞卷积的注意力U型网络(Attention UNet With Atrous)
class AttentionUNetWithAtrous(nn.Module):
def __init__(self, img_ch=3, output_ch=1, dilation_rate=2):
super(AttentionUNetWithAtrous, self).__init__()
# 省略部分代码
def forward(self, x):
# 省略部分代码
return out
定义了整个网络结构,结合了空洞卷积和注意力机制。网络由下采样、上采样和跳跃连接组成,最终输出分割结果。
5. 定义评估指标计算函数
def compute_miou(pred, target, smooth=1e-10):
# 省略部分代码
return miou.item()
def compute_pixel_accuracy(pred, target):
# 省略部分代码
return pa.item()
def dice(pred, target):
# 省略部分代码
return dice_score
定义了计算平均交并比(MIoU)、像素准确率(PA)和Dice系数的函数,用于评估模型的性能。
6. 定义自定义数据集类
class CustomDataset(Dataset):
def __init__(self, img_dir, label_dir, transform=None):
# 省略部分代码
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# 省略部分代码
return image, label
定义了一个自定义数据集类,用于加载图像和对应的标签数据,并支持数据增强。
7. 数据加载和预处理
data_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(15),
transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
transforms.ToTensor(),
])
def load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.8):
# 省略部分代码
return train_loader,val_loader
train_loader, val_loader = load_train(data_dir = r'E:\Pycharmproject\AUNet\data', scala=0.7)
定义了数据增强操作,包括随机水平翻转、垂直翻转、旋转和随机裁剪等。然后定义了一个函数来加载数据集,并将其划分为训练集和验证集。
8. 模型训练和评估
model = AttentionUNetWithAtrous(img_ch=3, output_ch=1)
model = model.cuda() if torch.cuda.is_available() else model
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
# 早停法配置
patience = 3
best_loss = float('inf')
counter = 0
num_epochs = 20
# 训练循环
for epoch in range(num_epochs):
# 省略部分代码
# 加载最佳模型进行评估
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
# 计算评估指标
miou_list = []
pa_list = []
dice_list = []
with torch.no_grad():
# 省略部分代码
print(f'Mean MIoU: {np.mean(miou_list):.4f}')
print(f'Mean Dice: {np.mean(dice_list):.4f}')
print(f'Mean Pixel Accuracy: {np.mean(pa_list):.4f}')
初始化模型、损失函数和优化器,使用早停法来防止过拟合。在训练过程中,模型在每个epoch上进行训练和验证,并记录损失值。训练结束后,加载最佳模型并在验证集上计算MIoU、PA和Dice系数等评估指标。
总结
这段代码实现了一个完整的图像分割流程,包括数据加载、数据增强、模型定义、训练和评估。通过结合空洞卷积和注意力机制,提高了U型网络在图像分割任务中的性能。
6857

被折叠的 条评论
为什么被折叠?



