人民币识别
split_dataset
import os
import random
import shutil
def makedir ( new_dir) :
if not os. path. exists( new_dir) :
os. makedirs( new_dir)
if __name__ == '__main__' :
random. seed( 1 )
dataset_dir = os. path. join( "data" , "RMB_data" )
split_dir = os. path. join( "data" , "rmb_split" )
train_dir = os. path. join( split_dir, "train" )
valid_dir = os. path. join( split_dir, "valid" )
test_dir = os. path. join( split_dir, "test" )
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
for root, dirs, files in os. walk( dataset_dir) :
for sub_dir in dirs:
imgs = os. listdir( os. path. join( root, sub_dir) )
imgs = list ( filter ( lambda x: x. endswith( '.jpg' ) , imgs) )
random. shuffle( imgs)
img_count = len ( imgs)
train_point = int ( img_count * train_pct)
valid_point = int ( img_count * ( train_pct + valid_pct) )
for i in range ( img_count) :
if i < train_point:
out_dir = os. path. join( train_dir, sub_dir)
elif i < valid_point:
out_dir = os. path. join( valid_dir, sub_dir)
else :
out_dir = os. path. join( test_dir, sub_dir)
makedir( out_dir)
target_path = os. path. join( out_dir, imgs[ i] )
src_path = os. path. join( dataset_dir, sub_dir, imgs[ i] )
shutil. copy( src_path, target_path)
print ( 'Class:{}, train:{}, valid:{}, test:{}' . format ( sub_dir, train_point, valid_point- train_point,
train_lenet
import os
import random
import numpy as np
import torch
import torch. nn as nn
from torch. utils. data import DataLoader
import torchvision. transforms as transforms
import torch. optim as optim
from matplotlib import pyplot as plt
from model. lenet import LeNet
from tools. my_dataset import RMBDataset
def set_seed (</