深度学习-小实例-基于多层感知机进行假钞识别

一、配置文件

class Hyperparameter:
    device = 'cpu'  # cuda
    data_dir = './data/'
    data_path = './data/data_banknote_authentication.txt'
    trainset_path = './data/train.txt'
    devset_path = './data/dev.txt'
    testset_path = './data/test.txt'

    in_features = 4  # input feature dim
    out_dim = 2  # output feature dim (classes number)
    seed = 1234  # random seed

  
    #多层感知机 四个特征 中间64 128 64是隐藏层 
    layer_list = [in_features, 64, 128, 64, out_dim]

    batch_size = 64
    #学习率
    init_lr = 1e-3
    epochs = 100
    #打印步数
    verbose_step = 10
    #保存步数
    save_step = 200

HP = Hyperparameter()

二、模型代码

import torch
from torch import nn
from torch.nn import functional as F
from config import HP

class BanknoteClassificationModel(nn.Module):
    def __init__(self, ):
        super(BanknoteClassificationModel, self).__init__()
        # nn.ModuleList 和普通的list很像,帮助完成layer的连接以及反向求导。输入的维度和输出的维度,来自超参的定义。每一层的输入是上一层的输出
        self.linear_layer = nn.ModuleList([
            nn.Linear(in_features=in_dim, out_features=out_dim)
            for in_dim, out_dim in zip(HP.layer_list[:-1], HP.layer_list[1:])
        ])
        print('输入层是:{},输出层是:{}'.format(HP.layer_list[:-1],HP.layer_list[1:]))

    def forward(self, input_x):
        for layer in self.linear_layer:
            input_x = layer(input_x)
            input_x = F.relu(input_x)
        return input_x


if __name__ == '__main__':
    model = BanknoteClassificationModel()
    x = torch.randn(size=(16, HP.in_features)).to(HP.device)
    y_pred = model(x)
    print(y_pred)
    print(y_pred.size())

这里的

for in_dim, out_dim in zip(HP.layer_list[:-1], HP.layer_list[1:])])
加了一个print('输入层是:{},输出层是:{}'.format(HP.layer_list[:-1],HP.layer_list[1:]))

结果是

输入层是:[4, 64, 128, 64],输出层是:[64, 128, 64, 2]

看来是一个错位传入的作用~ 输入层从第一个数开始,取不到最后一个数,输出从第二个数开始,取到最后~

三、模型训练

import os
from argparse import ArgumentParser
import torch.optim as optim
import torch
import random
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from model import BanknoteClassificationModel
from config import HP
from dataset_banknote import BanknoteDataset

logger = SummaryWriter('./log')

# seed init: Ensure Reproducible Result
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)


def evaluate(model_, devloader, crit):
    #表示进入evaluation过程了
    model_.eval() # set evaluation flag
    #总的loss
    sum_loss = 0.
    with torch.no_grad():
        for batch in devloader:
            x, y = batch
            pred = model_(x)
            loss = crit(pred, y)
            sum_loss += loss.item()

    model_.train() # back to training mode模式返回成training
    return sum_loss / len(devloader)


def save_checkpoint(model_, epoch_, optm, checkpoint_path):
    save_dict = {
        'epoch': epoch_,
        'model_state_dict': model_.state_dict(),
        'optimizer_state_dict': optm.state_dict()
    }
    torch.save(save_dict, checkpoint_path)


def train():
    #定义命令传递
    parser = ArgumentParser(description="Model Training")
    #传入记录点,恢复模型的训练过程,设置成None的时候就是从头开始训练
    parser.add_argument(
        '--c',
        default=None,
        type=str,
        help='train from scratch or resume training'
    )
    args = parser.parse_args()
    # new model instance
    model = BanknoteClassificationModel()
    model = model.to(HP.device)

    # 定义交叉熵损失
    criterion = nn.CrossEntropyLoss()

    # optimizer
    opt = optim.Adam(model.parameters(), lr=HP.init_lr)
    # opt = optim.SGD(model.parameters(), lr=HP.init_lr)

    # train dataloader
    trainset = BanknoteDataset(HP.trainset_path)
    #drop_last=True 多出的数据 丢弃 便于批归一化
    train_loader = DataLoader(trainset, batch_size=HP.batch_size, shuffle=True, drop_last=True)

    # dev datalader(evaluation)不涉及训练
    devset = BanknoteDataset(HP.devset_path)
    dev_loader = DataLoader(devset, batch_size=HP.batch_size, shuffle=True, drop_last=False)

    #设计一个起始轮数 step统计所有的数据取了多少次
    start_epoch, step = 0, 0

    #如果模型从哪里断掉,就需要执行一些恢复的步骤
    if args.c:
        checkpoint = torch.load(args.c)
        model.load_state_dict(checkpoint['model_state_dict'])
        # 学习率到哪里,优化器的相关内容也要更新进来
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        # 到哪一轮,读取,恢复进来
        start_epoch = checkpoint['epoch']
        print('Resume From %s.' % args.c)
    else:
        #如果是空,则重新开始
        print('Training From scratch!')

    model.train()   # set training flag

    # main loop
    for epoch in range(start_epoch, HP.epochs):
        print('Start Epoch: %d, Steps: %d' % (epoch, len(train_loader)/HP.batch_size))
        for batch in train_loader:
            x, y = batch    # load data
            opt.zero_grad() # gradient clean
            pred = model(x) # forward process 前向过程
            loss = criterion(pred, y)   # loss calc
            loss.backward() # backward process
            #更新模型参数
            opt.step() #以上是一次完成

            #下面是计算loss,传入loss和哪一步,把这一步的loss加到tenserboard里来了
            logger.add_scalar('Loss/Train', loss, step)
            #每10步打印一次
            # 过拟合欠拟合可以打印这个evaluate 看一下
            if not step % HP.verbose_step:  # evaluate log print
                eval_loss = evaluate(model, dev_loader, criterion)#传入的是model,dataloader和loss function
                logger.add_scalar('Loss/Dev', eval_loss, step)

            if not step % HP.save_step: # model save
                model_path = 'model_%d_%d.pth' % (epoch, step)
                save_checkpoint(model, epoch, opt, os.path.join('model_save', model_path))

            #记录全局步数
            step += 1
            logger.flush()
            print('Epoch: [%d/%d], step: %d Train Loss: %.5f, Dev Loss: %.5f'
                  % (epoch, HP.epochs, step, loss.item(), eval_loss))
    logger.close()


if __name__ == '__main__':
    train()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值