Torch上分之路——Mnist

Torch上分之路——Mnist

文档目录(其中common用于写模型,utils用于写算法过程中要用到函数工具,所以只有datalist,model,train,config会被用到
在这里插入图片描述

config.py

模型配置文件

import argparse

# Training settings
#metavar参数,用来控制部分命令行参数的显示
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--use_cuda',type=bool,default=False,help='whether to use cuda to accerlate')
datalist.py

`

from torchvision.datasets import mnist
from torchvision import datasets,transforms
from minist_torch.config import parser

#train_set=mnist.MNIST('./data',train=True,download=True)
#test_set=mnist.MNIST('./data',train=False,download=True)


train_set=datasets.MNIST('./data', train=True, download=True,
               transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ]))

test_set=datasets.MNIST('./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))

#自己写Dataset至少需要有这样的格式
# class Dataset(Dataset):
#     def __init__(self):
#         super(YoloDataset, self).__init__()
#     def __len__(self):
#         return 
# def __getitem__(self, index):
#         return 
# # DataLoader中collate_fn使用
# def dataset_collate(batch):
#     images = []
#     bboxes = []
#     for img, box in batch:
#         images.append(img)
#         bboxes.append(box)
#     images = np.array(images)
#     bboxes = np.array(bboxes)
#     return images, bboxes

model.py
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
train.py
from __future__ import print_function
import torch
import torch.nn.functional as F
import torch.optim as optim
from minist_torch.model import Net
from minist_torch.config import parser
from minist_torch.datalist import test_set,train_set
from tqdm import tqdm

class train(object):
    def __init__(self):
        self.args = parser.parse_args()

        use_cuda = not self.args.no_cuda and torch.cuda.is_available()
        torch.manual_seed(self.args.seed)
        self.device = torch.device("cuda" if use_cuda else "cpu")

        kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}
        self.train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=self.args.batch_size, shuffle=True, **kwargs)
        self.test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=self.args.test_batch_size, shuffle=True, **kwargs)

        self.model = Net().to(self.device)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        for epoch in range(1, self.args.epochs + 1):
            self.train(epoch)
            self.test()

    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                test_loss += F.nll_loss(output, target, size_average=False).item()  # sum up batch loss
                pred = torch.max(output, 1)[1]
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.test_loader.dataset)
        print('\n Test_set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(self.test_loader.dataset),
                      100. * correct / len(self.test_loader.dataset)))

    def train(self, epoch):
        self.model.train()
        with tqdm(total=len(train_set)//self.args.batch_size) as pbar:
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                self.optimizer.step()
                pbar.set_postfix_str(f'epoch:{epoch} total_loss:{round(loss.item(),2)}')
                pbar.update(1)
        torch.save({
            'epoch':epoch,
            'model_state_dict':self.model.state_dict(),
            'optimizer_state_dict':self.optimizer.state_dict(),
            'loss':loss
        },f'./weight/weight-{epoch}.pth')

if __name__=="__main__":
    train()
模型训练结果

在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值