# 导入所需的库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# 定义PSPNet类
class PSPNet(nn.Module):
def __init__(self, num_classes):
super(PSPNet, self).__init__()
self.num_classes = num_classes # 设置类别数
self.backbone = models.resnet50(pretrained=True) # 使用预训练的ResNet50作为特征提取器
self.layer5a = _PSPModule(2048, 512) # 定义金字塔池化模块
self.layer5c = _PSPModule(2048, 512) # 定义金字塔池化模块
self.layer5e = _PSPModule(2048, 512) # 定义金字塔池化模块
self.layer5b = _PSPModule(2048, 512) # 定义金字塔池化模块
self.fc = nn.Linear(512 * 4, num_classes) # 定义全连接层
def forward(self, x):
x = self.backbone(x) # 使用ResNet50提取特征
x1 = self.layer5a(x) # 通过金字塔池化模块
x2 = self.layer5b(x) # 通过金字塔池化模块
x3 = self.layer5c(x) # 通过金字塔池化模块
x4 = self.layer5e(x) # 通过金字塔池化模块
x = torch.cat((x1, x2, x3, x4), dim=1) # 沿通道维度拼接特征图
x = self.fc(x) # 通过全连接层输出预测结果
return x
# 定义金字塔池化模块
class _PSPModule(nn.Module):
def __init__(self, in_channels, out_channels, pool_size):
super(_PSPModule, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(pool_size) # 自适应平均池化层
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) # 1x1卷积层
self.bn = nn.BatchNorm2d(out_channels) # 批标准化层
self.relu = nn.ReLU(inplace=True) # ReLU激活函数
def forward(self, x):
x = self.pool(x) # 应用自适应平均池化
x = self.conv(x) # 应用1x1卷积
x = self.bn(x) # 应用批标准化
x = self.relu(x) # 应用ReLU激活函数
return x
# 训练PSPNet
# 加载数据集,设置超参数等
# ...
# 初始化模型
num_classes = 21 # 根据具体任务调整类别数
model = PSPNet(num_classes)
model = model.to(device) # 将模型移到设备(GPU或CPU)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 使用交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # 使用Adam优化器
# 训练模型
for epoch in range(num_epochs): # 对每个epoch进行迭代
for i, (images, labels) in enumerate(train_loader): # 从数据加载器中取出一批数据
images = images.to(device) # 将图像移到设备
labels = labels.to(device) # 将标签移到设备
# 前向传播
outputs = model(images) # 通过模型计算输出
# 计算损失
loss = criterion(outputs, labels) # 计算损失
# 反向传播
optimizer.zero_grad() # 清空梯度缓存
loss.backward() # 计算梯度
optimizer.step() # 更新模型参数
# 输出训练信息
if (i + 1) % print_freq == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_step}], Loss: {loss.item()}')
# 保存模型
torch.save(model.state_dict(), 'pspnet_model.pth') # 保存模型权重
# 模型评估
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 关闭梯度计算
correct = 0
total = 0
for images, labels in val_loader: # 对验证集中的每一批数据进行迭代
images = images.to(device) # 将图像移到设备
labels = labels.to(device) # 将标签移到设备
outputs = model(images) # 通过模型计算输出
_, predicted = torch.max(outputs.data, 1) # 获得预测结果
total += labels.size(0) # 更新样本总数
correct += (predicted == labels).sum().item() # 更新正确预测数
print(f'Validation accuracy of the model: {100 * correct / total}%') # 输出验证准确率
# 使用模型进行预测
def predict(model, img):
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 关闭梯度计算
img = img.to(device) # 将图像移到设备
output = model(img) # 通过模型计算输出
_, predicted = torch.max(output.data, 1) # 获得预测结果
return predicted.cpu().numpy() # 将预测结果转换为numpy数组并返回
# 加载测试数据
test_image, test_label = test_dataset[0] # 从测试数据集中获取一张图片
plt.imshow(test_image.permute(1, 2, 0).numpy()) # 显示原始图像
plt.show()
# 使用模型
prediction = predict(model, test_image.unsqueeze(0)) # 对测试图像进行预测
# 可视化预测结果
prediction_mask = np.squeeze(prediction) # 去掉多余的维度
plt.imshow(prediction_mask) # 显示预测结果
plt.show()
PSPNet的pytorch实现
于 2023-03-17 21:01:52 首次发布