如何在pytorch 中训练一个分类器


1. 前言

如题所示,我希望在PyTorch 中利用public data set 训练一个分类器

2. 准备数据

这里假设我们需要处理图像、文字、视频等信息资源,我们可以用python 标准库来下载并将其转化为numpy array,之后在把这个numpy 转化为torch.* Tensor。
常见的python 标准库有以下:

For images: Pillow, OpenCV are useful

For audio: scipy and librosa

For text: NLTK and SpaCy 

我的研究方向是计算机视觉,所以一定会用到torchvision这个包,它包含了常见的数据集,比如ImageNet, CIFAR10, MNIST等等。同时还有torchvision.datasets 和torch.utils.data.DataLoader 这些已经编辑好的工具。

import torch
from torch.utils.data import DataLoader

#-----1. using built-in datasets in PyTorch-----------------------------
from torchvision import datasets, transforms

# Define transformations to apply to the data
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL image to tensor
    # Add any other transformations you need (e.g., normalization)

# Create a dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

#-----2.create a "DataLoader"-----------------------------
batch_size = 64
num_workers = 2

# Create a DataLoader for the dataset
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
#-----3.Iterate Over Batches-----------------------------
for batch_idx, (data, target) in enumerate(train_loader):
    # data: batch of input data
    # target: batch of corresponding labels
    # Your training code here

由于我参考的是链接: 这篇博文, 这里下载CIFAR10备用。.这个数据集有十个类别,分别是:‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. 图像大小为 3x32x32。
通过torchvision datasets 得到的是 PILImage, 图像数值在 [0, 1]之间. 然后转换为Tensor 并标准化。

import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3. 准备模型

# 这里define your model
class Net(nn.Module):
    def __init__(self):
    def forward(self, x):
        return x

4. 准备Loss Function 和 optimizer

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

5. 训练

#*******************if you need to using the pretrained weight******
pretrained_dict = torch.load('~/weights.pt')

#*******************Define the number of epochs******
num_epochs = 100

# Loop over the dataset for the specified number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    # Enumerate over the training data loader
    for i, data in enumerate(trainloader, 0):
        # Get the inputs and labels from the data loader
        inputs, labels = data

        # Move the inputs and labels to the appropriate device (GPU or CPU)
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the gradients

        # Forward pass: compute predicted outputs by passing inputs to the model
        outputs = net(inputs)

        # Calculate the loss
        loss = criterion(outputs, labels)

        # Backward pass: compute gradient of the loss with respect to model parameters

        # Perform a single optimization step (parameter update)

        # Update statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Print statistics every 2000 mini-batches
        if i % 2000 == 1999:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}/{len(trainloader)}], Loss: {running_loss / 2000:.3f}, Accuracy: {100 * correct / total:.2f}%')
            running_loss = 0.0

    # Print accuracy and loss at the end of each epoch
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(trainloader):.3f}, Accuracy: {100 * correct / total:.2f}%')

    # Save the model weights
    torch.save(net.state_dict(), f'weights_epoch_{epoch + 1}.pt')
print('Finished Training')

在这个代码中,我打印了每个epoch 的loss 和 accuracy, 并保存了每个epoch后的权重

6. 测试分类器(模型+权重)

这个分类器虽然在训练,但是我们并不清除它是否学到了东西,有一个很简单的测试方式,我们将所获得的model 用在测试数据集上,这将获得每个测试样本的预测标签,将其和真实标签比对,就能判断模型的表现力了。

# 之前我们已经获得了testloader
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
#********now chcek the network performs on the whole dataset*********
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')


# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data
        inputs, labels = data[0].to(device), data[1].to(device)
        # or inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
好的,我可以帮你写一个二元分类器。可以用PyTorch框架实现,首先需要定义好模型的结构和训练流程,然后再利用已经标记好的数据进行训练训练好的模型可以进行测试和预测,用来分类新的数据。 以下是一个简单的二元分类器的示例代码(注:以下代码仅供参考,不保证一定能正常运行): ```python import torch import torch.nn as nn import torch.optim as optim # 定义模型的结构 class BinaryClassifier(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(BinaryClassifier, self).__init__() self.hidden = nn.Linear(input_size, hidden_size) self.out = nn.Linear(hidden_size, output_size) self.act = nn.Sigmoid() def forward(self, x): hidden = self.act(self.hidden(x)) output = self.act(self.out(hidden)) return output # 定义训练流程 def train(model, optimizer, criterion, train_data, train_labels, num_epochs): for epoch in range(num_epochs): total_loss = 0 for data, labels in zip(train_data, train_labels): data = torch.tensor(data).float() labels = torch.tensor(labels).float() optimizer.zero_grad() output = model(data) loss = criterion(output, labels) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch + 1}: Loss = {total_loss:.3f}") # 定义输入和标签数据 train_data = [[0, 0], [0, 1], [1, 0], [1, 1]] train_labels = [[0], [1], [1], [0]] # 定义模型和优化器 model = BinaryClassifier(input_size=2, hidden_size=4, output_size=1) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.01) # 训练模型 train(model, optimizer, criterion, train_data, train_labels, num_epochs=1000) # 预测新数据的分类 new_data = [[0.5, 0.5], [0.2, 0.8], [0.9, 0.1]] for data in new_data: x = torch.tensor(data).float() output = model(x) prediction = 0 if output.item() < 0.5 else 1 print(f"Input = {data}, Output = {output.item():.3f}, Prediction = {prediction}") ``` 希望这个示例可以对你有所帮助!
