基于Pytorch框架构建VGG-19模型

一、训练模型

1.导入资源包

import torch.utils.data: 导入了PyTorch的数据工具模块,这个模块提供了用于数据加载和处理的工具,如Dataset和DataLoader。
from torchvision import models: 导入了PyTorch的预训练模型模块,这个模块提供了多种预训练的型,如ResNet、VGG、AlexNet等,可以直接用于迁移学习或特征提取。

from sched import scheduler
import torch.optim as optim
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import  os
from torchvision import models

2.定义数据预处理

这些预处理操作的目的是为了增强模型的泛化能力,并确保模型在训练和验证时输入数据的格式一致。通过这些操作,模型能够接受不同尺寸、角度和方向的图像,从而提高其在实际应用中的表现。同时,归一化处理有助于稳定训练过程,加速模型收敛。,这些预处理操作的目的是为了增强模型的泛化能力,并确保模型在训练和验证时输入数据的格式一致。通过这些操作,模型能够接受不同尺寸、角度和方向的图像,从而提高其在实际应用中的表现。同时,归一化处理有助于稳定训练过程,加速模型收敛。

# 定义数据预处理
transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
}

3.读取数据

读取和准备图像数据集,以便用于训练和验证深度学习模型,这段代码设置了数据加载器,它们将在训练和验证过程中提供经过预处理的图像数据。这些数据加载器是PyTorch中用于批量加载数据并使其易于迭代的重要工具。

# 读取数据
dataset = './dataset'
train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'val')

batch_size = 32
num_classes = 2  # 修改为您的分类数

data = {
'train': datasets.ImageFolder(root=train_directory, transform=transform['train']),
'val': datasets.ImageFolder(root=valid_directory, transform=transform['val'])
}

train_loader = DataLoader(data['train'], batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(data['val'], batch_size=batch_size, shuffle=False, num_workers=8)

二、定义VGG19模型

1.定义自定义的 VGG19 模型

这段代码定义了一个自定义的VGG-19模型,并将其适用于一个新的二分类任务。然后,它设置了损失函数和优化器,并检查了是否有可用的GPU以决定在哪个设备上进行训练。

from torchvision.models import vgg19
# 定义自定义的 VGG-19 模型
class CustomVGG19(nn.Module):
def __init__(self):
super(CustomVGG19, self).__init__()
self.vgg19_model = vgg19(pretrained=True)
for param in self.vgg19_model.features.parameters():
param.requires_grad = False
num_features = self.vgg19_model.classifier[6].in_features
self.vgg19_model.classifier[6] = nn.Sequential(
nn.Linear(num_features, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 2)
)

def forward(self, x):
return self.vgg19_model(x)

# 创建 CustomVGG19 模型实例
vgg19_model = CustomVGG19()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg19_model.parameters(), lr=0.001, weight_decay=1e-4)

# 首先,检查是否有可用的 GPU
if torch.cuda.is_available():
# 定义 GPU 设备
device = torch.device('cuda')
print("CUDA is available! Using GPU for training.")
else:
# 如果没有可用的 GPU,则使用 CPU
device = torch.device('cpu')
print("CUDA is not available. Using CPU for training.")

# 将模型移动到 GPU
vgg19_model.to(device)

# 如果你有优化器和其他需要移动到 GPU 的参数,例如梯度
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg19_model.parameters(), lr=0.001, weight_decay=1e-4)

运行结果:

在这里插入图片描述

四、验证模型

1. 定义验证过程

这个验证函数是评估分类模型性能的基本框架,您可以根据需要调整打印频率或其他参数。在实际使用中,您需要确保在调用这个函数之前已经定义了模型、数据加载器和损失函数。

# 定义验证过程
def val(model, device, test_loader, criterion):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
running_loss += loss.item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

print(f'Validation, Loss: {running_loss / len(test_loader)}, Accuracy: {100 * correct / total}%')

2.用于训练模型并应用学习率调整策略的循环

总的来说,这段代码将训练模型10个周期,并在每个周期结束后进行验证,同时使用学习率调度器来调整学习率。这种学习率调整策略可以帮助模型在训练过程中更好地收敛。在实际应用中,您可能需要根据您的具体任务和数据集调整周期数和学习率调度器的参数。

# 定义学习率调整策略
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# 训练模型
EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
train(vgg19_model, device, train_loader, optimizer, epoch)
val(vgg19_model, device, test_loader, criterion)
scheduler.step()  # 调整学习率

运行结果:

在这里插入图片描述

3.保存模型的状态字典

保存模型的状态字典是一个重要的步骤,因为它允许您在训练完成后保存模型的结果,以便将来使用或进行分析。在实际应用中,您可能需要根据您的具体需求选择不同的文件名和保存路径。

# 保存模型的状态字典
torch.save(vgg19_model.state_dict(), 'vgg19_model_weights.pth')

三、训练模型

1. 定义训练函数

这个训练函数是训练分类模型的基本框架,您可以根据需要调整打印频率或其他参数。在实际使用中,您需要确保在调用这个函数之前已经定义了模型、数据加载器、优化器和损失函数。

def train(model, device, train_loader, optimizer, epoch):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

if batch_idx % 10 == 0:  # 每10个批次打印一次
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

print(f'Epoch {epoch}, Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%')

五、创建 CustomVGG19 模型实例

1. 导入资源包

from torch.autograd import Variable: 导入了PyTorch的自动求导变量模块。在早期的PyTorch版本中,Variable是用于封装张量并记录计算图的工具。但在最新的PyTorch版本中,Variable已经不再推荐使用,因为PyTorch自动将普通张量转换为Variable。如果您使用的是最新版本的PyTorch,这行代码可能是不必要的。

import torch
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models
from torch.autograd import Variable

2.定义数据预处理

这些步骤是图像识别任务中的常见操作,用于准备数据和选择计算设备。在实际应用中,您可能需要根据您的具体任务和数据集调整这些参数。

# 定义数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 定义类别
classes = ['cat', 'dog']  # 替换为您的实际类别名称

# 检查是否有可用的 GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3.定义自定义的 VGG-19 模型
这个自定义的VGG-19模型是用于二分类任务的,它保留了VGG-19的卷积层和池化层不变,只修改了最后的全连接层以适应新的类别数。在实际应用中,您可能需要根据您的具体任务调整最后的全连接层,以匹配您的类别数。

# 定义自定义的 VGG-19 模型
class CustomVGG19(nn.Module):
def __init__(self):
super(CustomVGG19, self).__init__()
self.vgg19_model = models.vgg19(pretrained=True)
for param in self.vgg19_model.features.parameters():
param.requires_grad = False
num_features = self.vgg19_model.classifier[6].in_features
self.vgg19_model.classifier[6] = nn.Sequential(
nn.Linear(num_features, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, len(classes))
)

def forward(self, x):
return self.vgg19_model(x)

4.创建 CustomVGG19 模型实例

用于创建自定义的VGG-19模型实例,加载模型的权重,并将模型移动到指定的设备上。

# 创建 CustomVGG19 模型实例
model = CustomVGG19()
# 加载权重
model.load_state_dict(torch.load("vgg19_model_weights.pth"))
model.to(DEVICE)
model.eval()

5.定义预测函数

这个预测函数是使用模型进行图像分类的基本框架,我们可以根据需要调整打印频率或其他参数。在实际应用中,您需要确保在调用这个函数之前已经定义了模型、数据预处理和类别列表。

# 定义预测函数
def predict_image(image_path):
# 打开图片
image = Image.open(image_path)
# 应用预处理
image = transform(image).unsqueeze(0)  # 添加batch维度
# 转换为Variable(如果模型需要)
image = Variable(image).to(DEVICE)
# 获取模型预测
output = model(image)
_, prediction = torch.max(output.data, 1)
return classes[prediction.item()]

# 上传的图片路径
uploaded_image_path = '1111.jpg'
# 进行预测
predicted_class = predict_image(uploaded_image_path)

print(f"The uploaded image is predicted as: {predicted_class}")

运行结果:
在这里插入图片描述

6.定义了一个可视化函数

当我们运行这个脚本时,它会打开一个Matplotlib窗口,显示上传的图片,并在图片上添加一个标题显示预测的类别。同时,脚本会打印出预测结果。这个脚本是用于可视化图像分类结果的基本框架,您可以根据需要调整打印频率或其他参数。在实际应用中,您需要确保在调用这个函数之前已经定义了模型、数据预处理和类别列表。

import matplotlib.pyplot as plt

# 定义可视化函数
def visualize_prediction(image_path, predicted_class):
# 打开图片
image = Image.open(image_path)
# 显示图片
plt.imshow(image)
plt.axis('off')
plt.title(f'Predicted: {predicted_class}')
plt.show()
# 上传的图片路径
uploaded_image_path = '44.jpg'
# 进行预测
predicted_class = predict_image(uploaded_image_path)

# 可视化预测结果
visualize_prediction(uploaded_image_path, predicted_class)

print(f"The uploaded image is predicted as: {predicted_class}")

运行结果:

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值