pytorch mnist 最简单示例代码

导入库

from typing import *

import os
import sys
import random
import pickle
import logging
from datetime import datetime

try:
    from tqdm.notebook import tqdm as _tqdm
except:
    _tqdm = lambda x : x

import numpy as np
from matplotlib import pyplot as plt

import torch as t
import torch.nn as nn
import torch.optim as optim

import torchvision as tv
from torch.utils.data import DataLoader

准备数据集

data_dir = '.'
batch_size = 256
train_dataset = tv.datasets.MNIST(root=data_dir, train=True, transform=tv.transforms.ToTensor(), download=True)
test_dataset = tv.datasets.MNIST(root=data_dir, train=False, transform=tv.transforms.ToTensor(), download=True)

logging 和 seed

logging.basicConfig(
    format='%(asctime)s %(levelname)s %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    stream=sys.stdout,
    force=True,
)

_SEED = 0
random.seed(_SEED)
np.random.seed(_SEED)
t.manual_seed(_SEED)
t.cuda.manual_seed_all(_SEED)
# t.backends.cudnn.benchmark = False # burden for performance
# t.backends.cudnn.deterministic = True # burden for performance
# t.use_deterministic_algorithms(True) # deprecated in cuda 11+

assert random.random() == 0.662743203727557, 'different random generator'
assert np.random.rand() == 0.11717829958543136, 'different random generator'
assert t.rand((1,)).item() == 0.997646152973175, 'different random generator'

if t.cuda.is_available():
    for i in range(t.cuda.device_count()):
        logging.info(t.cuda.get_device_properties(i))
    logging.warning('using cuda:0')
    device = t.
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值