基础实战——FashionMNIST时装分类

基础实战——FashionMNIST时装分类

首先导入必要的包

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

配置训练环境和超参数

Gpu_device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
## 配置其他超参数,如batch_size, num_workers, learning rate, 以及总的epochs
batch_size = 256
num_workers = 4   # 对于Windows用户,这里应设置为0,否则会出现多线程错误
lr = 1e-4
epochs = 20

数据读入和加载
这里同时展示两种方式:

  • 下载并使用PyTorch提供的内置数据集
  • 从网站下载以csv格式存储的数据,读入并转成预期的格式
    第一种数据读入方式只适用于常见的数据集,如MNIST,CIFAR10等,PyTorch官方提供了数据下载。这种方式往往适用于快速测试方法(比如测试下某个idea在MNIST数据集上是否有效)
    第二种数据读入方式需要自己构建Dataset,这对于PyTorch应用于自己的工作中十分重要

同时,还需要对数据进行必要的变换,比如说需要将图片统一为一致的大小,以便后续能够输入网络训练;需要将数据格式转为Tensor类,等等。

这些变换可以很方便地借助torchvision包来完成,这是PyTorch官方用于图像处理的工具库,上面提到的使用内置数据集的方式也要用到。PyTorch的一大方便之处就在于它是一整套“生态”,有着官方和第三方各个领域的支持。这些内容我们会在后续课程中详细介绍。

from torchvision import transforms

image_size = 28
data_transform = transforms.Compose([
    # transforms.ToPILImage(),# 是转换数据格式
    transforms.Resize(image_size),
    transforms.ToTensor()# 对于PILImage转化的Tensor,其数据类型是torch.FloatTensor
])
## 读取方式一:使用torchvision自带数据集,下载可能需要一段时间
from torchvision import datasets

train_data = datasets.FashionMNIST(root='./', train=True, download=True, transform=data_transform)
test_data = datasets.FashionMNIST(root='./', train=False, download=True, transform=data_transform)

在构建训练和测试数据集完成后,需要定义DataLoader类,以便在训练和测试时加载数据

train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = 
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值