使用mean_teacher算法对MNIST数据集进行测试

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import os
import struct
import numpy as np

class mnistNet(nn.Module):
    def __init__(self):
        super(mnistNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5, 1)
        self.conv2 = nn.Conv2d(30, 60, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 60, 300)
        self.fc2 = nn.Linear(300, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 60)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def load_mnist(path, kind='train'):
    # 读取mnist数据到numpy
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    images = images.reshape((-1, 28, 28))
    return images, labels

class datasetsMnist(Dataset):
    def __init__(self, root, kind, augment=None):
        # mnist 图像及label
        self.images, self.labels = load_mnist(root, kind)
        self.augment = augment
        if kind=="train":
            # 随机设置255,使label(%98)失效
            c=np.linspace(0, 60000-1,60000)
            a2 = np.random.choice(c, size=int(60000*0.98), replace=False ).astype(np.int32)
            self.labels[a2] = 250
    
    def __getitem__(self, index):
        image = self.images[index]
        image = self.augment(image)	  # 这里对图像进行了数据增强
        return image, self.labels[index]
    
    def __len__(self):
        return len(self.images)

def train(args, device):

    # 定义数据
    train_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_datasets = datasetsMnist('./mnist/MNIST/raw', "train", train_transform)
    train_loader = torch.utils.data.DataLoader(train_datasets, num_workers=8, batch_size=args.train_batch_size, shuffle=True)
    
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    test_datasets = datasetsMnist('./mnist/MNIST/raw', "t10k", test_transform)
    test_loader = torch.utils.data.DataLoader(test_datasets, num_workers=4, batch_size=args.test_batch_size, shuffle=True)

    # 定义模型
    student_model = mnistNet().to(device)
    mean_teacher = mnistNet().to(device)
    
    # 回归器
    optimizer = optim.SGD(student_model.parameters(), lr=args.lr, momentum=args.momentum)

    for epoch in range(args.epochs):

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.float().to(device), target.long().to(device)
            idx = torch.where(target<20) # 过滤target
            
            if idx[0].shape==torch.Size([0]): continue
            optimizer.zero_grad()
            # print(target.shape)
            # print(target_HT.shape)

            output = student_model(data)
            # 防止梯度传递到mean_teacher模型
            with torch.no_grad():
                mean_t_output = mean_teacher(data)

            # 以mean_teacher的推理结果为target, 计算student_model的均方损失误差
            const_loss = F.mse_loss(output, mean_t_output)

            # 计算总体误差
            weight = 0.2
            # 有target的样本与target进行损失计算
            loss = F.nll_loss(output[idx], target[idx]) + weight*const_loss
            # loss = F.nll_loss(output, target[idx])
            loss.backward()
            optimizer.step()

            # update mean_teacher的模型参数
            alpha = 0.95
            for mean_param, param in zip(mean_teacher.parameters(), student_model.parameters()):
                mean_param.data.mul_(alpha).add_(1 - alpha, param.data)

            # print('Train Epoch: {}\tLoss: {:.6f}'.format(epoch, loss.item()))
        test(student_model, device, test_loader, "student")
        test(mean_teacher, device, test_loader, "teacher")
        if (args.save_model and False):
            torch.save(student_model.state_dict(), "mnist_cnn.pt")

def test(model, device, test_loader, name):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.float().to(device), target.long().to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # 计算loss
            pred = output.argmax(dim=1, keepdim=True)  # 推理结果
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('{} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( name,
        test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    model.train()


if __name__ == '__main__':
    # 配置
    parser = argparse.ArgumentParser(description='半监督学习pyTorch')
    parser.add_argument('--train_batch_size', type=int, default=30)
    parser.add_argument('--test_batch_size', type=int, default=30)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.5, help='SGD momentum')
    parser.add_argument('--no-cuda', action='store_true', default=False)
    parser.add_argument('--save-model', action='store_true', default=False)
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    
    # device = torch.device("cuda" if use_cuda else "cpu")
    device = torch.device("cuda")
    # 训练
    train(args, device)

在这里插入图片描述

结论:

  • 使用60000个有标签数据进行训练,得到了最好的测试效果;
  • 仅使用2%的有标签数据进行训练,测试有较大波动;
  • 使用2%的有标签数据和98%的无标签数据进行训练,整体效果处于两则之间。
  • 2
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,以下是使用PyTorch对Fashion MNIST数据集进行分类的示例代码: 首先,我们需要导入必要的库和模块: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision.datasets as datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader ``` 然后,我们需要下载并加载数据集。Fashion MNIST数据集可以通过以下方式下载: ```python train_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=transforms.ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=transforms.ToTensor() ) ``` 接下来,我们需要定义一个神经网络模型。在这个例子中,我们使用了一个简单的卷积神经网络: ```python class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.fc = nn.Sequential( nn.Linear(7 * 7 * 64, 128), nn.ReLU(), nn.Linear(128, 10) ) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = out.reshape(out.size(0), -1) out = self.fc(out) return out ``` 然后,我们需要定义损失函数和优化器: ```python model = CNN() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) ``` 最后,我们可以开始训练模型并评估其性能: ```python train_loader = DataLoader(train_data, batch_size=100, shuffle=True) test_loader = DataLoader(test_data, batch_size=100, shuffle=False) for epoch in range(10): for i, (images, labels) in enumerate(train_loader): optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() if (i + 1) % 100 == 0: print(f"Epoch [{epoch + 1}/{10}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}") with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Test Accuracy: {accuracy:.2f}%") ``` 这就是使用PyTorch对Fashion MNIST数据集进行分类的示例代码。希望能对你有所帮助!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值