torchvision.datasets下载数据集后,怎样取其中部分数据做训练

本文介绍了如何在PyTorch中使用Mnist数据集进行训练。首先,通过transforms进行数据预处理,然后创建Dataloader加载数据。训练过程中,利用GPU加速计算,计算损失函数,进行反向传播和优化器更新。同时,文章还展示了如何从数据集中选取部分数据进行训练。
摘要由CSDN通过智能技术生成

Mnist数据集为例

一、直接在整个数据集上训练

数据下载和预处理

trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset_train = torchvision.datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = torchvision.datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)

然后即可放入Dataloader中,

trainDataLoader = torch.utils.data.DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True)  # 批量读取并打乱
testDataLoader = torch.utils.data.DataLoader(dataset=testData, batch_size=batch_size)

训练时迭代读取数据

for epoch in range(1, epochs + 1):
    processBar = tqdm(trainDataLoader, unit='step')
    model.train(True)
    train_loss, train_correct = 0, 0
    for step, (train_imgs, labels) in enumerate(processBar):

        if torch.cuda.is_available():  # GPU可用
            train_imgs = train_imgs.cuda()
            labels = labels.cuda()
        model.zero_grad()  # 梯度清零
        outputs = model(train_imgs)  # 输入训练集
        loss = criterion(outputs, labels)  # 计算损失函数
        predictions = torch.argmax(outputs, dim=1)  # 得到预测值
        correct = torch.sum(predictions == labels)
        accuracy = correct / labels.shape[0]  # 计算这一批次的正确率
        loss.backward()  # 反向传播
        optimizer.step()  # 更新优化器参数
        processBar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f" %  # 可视化训练进度条设置
                                   (epoch, epochs, loss.item(), accuracy.item()))
二、取数据集上部分数据训练

以数据集索引提取训练数据
取其中序号为data_idx的数据

dataset_train[data_idx][0] #取图像数据(image)
dataset_train[data_idx][1] #取对应的标签(label)

于是,采样一些数据作为训练集可使用如下代码

sample_index = [i for i in range(500)] #假设取前500个训练数据
X_train = []
y_train = []
for i in sample_index:
    X = dataset_train[i][0]
    X_train.append(X)
    y = dataset_train[i][1]
    y_train.append(y)

sampled_train_data = [(X, y) for X, y in zip(X_train, y_train)] #包装为数据对
trainDataLoader = torch.utils.data.DataLoader(sampled_train_data, batch_size=16, shuffle=True)

将trainDataloader带入训练过程中即可。

参考资料

[1] PyTorch入门——实现MNIST分类

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值