本文是在
https://www.jianshu.com/p/d8feaddc7bdf文章的基础上用Pytorch实现的
话不多说,直接上代码,具体的可以看代码中的解释
代码实现
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.transform
import numpy as np
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''
使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
argparse是python的一个包,用来解析输入的参数
如:
python mnist.py --outf model
(意思是将训练的模型保存到model文件夹下,当然,你也可以不加参数,那样的话代码最后一行
torch.save()就需要注释掉了)
python mnist.py --net model/net_005.pth
(意思是加载之前训练好的网络模型,前提是训练使用的网络和测试使用的网络是同一个网络模型,保证权重参数矩阵相等)
'''
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') # 模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") # 模型加载路径
opt = parser.parse_args() # 解析得到你在路径中输入的参数,比如 --outf 后的"model"或者 --net 后的"model/net_005.pth",是作为字符串形式保存的
# Load training and testing datasets.
ROOT_PATH = "./traffic"
train_data_dir = os.path.join(ROOT_PATH, "datasets/BelgiumTS/Training")
test_data_dir = os.path.join(ROOT_PATH, "datasets/BelgiumTS/Testing")
'''
定义LeNet神经网络,进一步的理解可查看Pytorch入门,里面很详细,代码本质上是一样的,这里做了一些封装
'''
class LeNet(nn.Module):
'''
该类继承了torch.nn.Modul类
构建LeNet神经网络模型
'''
def __init__(self):
super(LeNet,