pytorch模型导入、修改、保存、读取

pytorch对模型的操作

导入

这里以vgg11的classifier层为例
导入模型:

vgg11 = torchvision.models.vgg11(pretrained=False)
print(vgg11)

其中classifier的输出如下
在这里插入图片描述

修改

主要的修改方式:

# 在某层中添加层
vgg11.classifier.add_module('new_linear', nn.Linear(1000, 10))
# 修改某层
vgg11.classifier[6] = nn.Linear(4096, 10)
print(vgg11)

修改后:
在这里插入图片描述

保存与读取

# 模型的保存与读取
# 方式1,保存模型和参数
torch.save(vgg11, 'vgg11_method1.pth')
# 在读取时需要保证原模型已经引入
model = torch.load('vgg11_method1.pth')

# 方式2,只保存模型参数,一个字典形式(官方推荐)
torch.save(vgg11.state_dict(), 'vgg11_method2.pth')

vgg11_new = torchvision.models.vgg11(pretrained=False)
vgg11_new.load_state_dict(torch.load('vgg11_method2.pth'))
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
你可以按照以下步骤在PyTorch中实现CNN手写数字识别,包括使用CSV文件进行数据读取保存和加载模型: 1. 导入所需的库和模块: ```python import torch import torch.nn as nn import torch.optim as optim import pandas as pd from torch.utils.data import DataLoader, Dataset ``` 2. 创建一个自定义的数据集类,用于读取CSV文件中的数据: ```python class DigitDataset(Dataset): def __init__(self, csv_file): self.data = pd.read_csv(csv_file) def __len__(self): return len(self.data) def __getitem__(self, idx): image = self.data.iloc[idx, 1:].values.reshape(28, 28).astype('float32') / 255.0 label = self.data.iloc[idx, 0] return image, label ``` 3. 定义CNN模型: ```python class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(kernel_size=2) self.fc = nn.Linear(7*7*32, 10) def forward(self, x): x = self.conv1(x) x = self.relu1(x) x = self.pool1(x) x = self.conv2(x) x = self.relu2(x) x = self.pool2(x) x = x.view(x.size(0), -1) x = self.fc(x) return x ``` 4. 定义训练函数和测试函数: ```python def train(model, train_loader, criterion, optimizer): model.train() for images, labels in train_loader: optimizer.zero_grad() outputs = model(images.unsqueeze(1)) loss = criterion(outputs, labels) loss.backward() optimizer.step() def test(model, test_loader): model.eval() correct = 0 with torch.no_grad(): for images, labels in test_loader: outputs = model(images.unsqueeze(1)) _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() accuracy = correct / len(test_loader.dataset) return accuracy ``` 5. 加载数据集并创建数据加载器: ```python train_dataset = DigitDataset('train.csv') test_dataset = DigitDataset('test.csv') train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) ``` 6. 创建CNN模型实例、损失函数和优化器: ```python model = CNN() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) ``` 7. 进行训练和测试: ```python num_epochs = 10 for epoch in range(num_epochs): train(model, train_loader, criterion, optimizer) accuracy = test(model, test_loader) print(f'Epoch {epoch+1}, Test Accuracy: {accuracy}') torch.save(model.state_dict(), 'digit_model.pt') ``` 8. 加载保存模型并进行预测: ```python model = CNN() model.load_state_dict(torch.load('digit_model.pt')) # 假设有一个名为image的张量用于预测 output = model(image.unsqueeze(0).unsqueeze(0)) _, predicted = torch.max(output.data, 1) print(f'Predicted digit: {predicted.item()}') ``` 这就是使用CSV文件进行手写数字识别的基本步骤。你可以根据自己的需求进行修改和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值