在本文中,我们将介绍如何使用Python的PyTorch库进行迁移学习。迁移学习是一种机器学习技术,它允许我们利用一个预先训练好的模型对新任务进行训练,从而节省了大量的计算资源和时间。
1. 安装并导入库
首先,我们需要安装PyTorch库。在终端或命令提示符中输入以下命令进行安装:
pip install torch torchvision
然后导入所需的库:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
2. 数据预处理
在本例中,我们将使用CIFAR-10数据集进行迁移学习。首先,我们需要对数据进行预处理:
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True