pytorch mnist 最简单示例代码

这篇博客介绍了使用PyTorch实现MNIST数据集的最简单示例,从导入必要的库,准备MNIST数据集,设置logging和随机数种子,到构建神经网络模型以及进行训练的详细步骤。
摘要由CSDN通过智能技术生成

导入库

from typing import *

import os
import sys
import random
import pickle
import logging
from datetime import datetime

try:
    from tqdm.notebook import tqdm as _tqdm
except:
    _tqdm = lambda x : x

import numpy as np
from matplotlib import pyplot as plt

import torch as t
import torch.nn as nn
import torch.optim as optim

import torchvision as tv
from torch.utils.data import DataLoader

准备数据集

data_dir = '.'
batch_size = 256
train_dataset = tv.datasets.MNIST(root=data_dir, train=True, transform=tv.transforms.ToTensor(), download=True)
test_dataset = tv.datasets.MNIST(root=data_dir, train=False, transform=tv.transforms.ToTensor(), download=True)

logging 和 seed

logging.basicConfig(
    format='%(asctime)s %(levelname)s %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    stream=sys.stdout,
    force=True,
)

_SEED = 0
random.seed(_SEED)
np.random.seed(_SEED)
t.manual_seed(_SEED)
t.cuda.manual_seed_all(_SEED)
# t.backends.cudnn.benchmark = False # burden for performance
# t.backends.cudnn.deterministic = True # burden for performance
# t.use_deterministic_algorithms(True) # deprecated in cuda 11+

assert random.random() == 0.662743203727557, 'different random generator'
assert np.random.rand() == 0.11717829958543136, 'different random generator'
assert t.rand((1,)).item() == 0.997646152973175, 'different random generator'

if t.cuda.is_available():
    for i in range(t.cuda.device_count()):
        logging.info(t.cuda.get_device_properties(i))
    logging.warning('using cuda:0')
    device = t.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值