Alexnet复现学习问题解决
https://blog.csdn.net/weixin_44023658/article/details/105798326?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_baidulandingword-9&spm=1001.2101.3001.4242
文中代码参考此篇博客,又对代码做了详细补充注释,且对文中所涉及函数有所介绍
数据下载
DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
- 处理数据的操作
实现了对文件的创建以及图片的分类
# spile_data.py
# 注意:多次运行该程序会反复生成数据
import os
from shutil import copy
import random
def mkfile(file):
if not os.path.exists(file): # 判断路径是否存在
os.makedirs(file) # 创建指定目录
file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla] # 若cla中不包含".txt",返回True,复制目录文件名至列表
mkfile('flower_data/train')
for cla in flower_class:
mkfile('flower_data/train/'+cla)
mkfile('flower_data/val')
for cla in flower_class:
mkfile('flower_data/val/'+cla)
split_rate = 0.1
for cla in flower_class: # 该层for循环为第一层文件夹中的内容,即图片种类
cla_path = file + '/' + cla + '/'
images = os.listdir(cla_path) # 返回指定目录下的文件名
num = len(images)
eval_index = random.sample(images, k=int(num*split_rate)) # # 用于随机截取指定长度的列表
for index, image in enumerate(images): # 该层for循环为第二层文件夹中的内容,即图片
if image in eval_index:
image_path = cla_path + image
new_path = 'flower_data/val/' + cla
copy(image_path, new_path) # 复制date, mode bit由image_path至new_path
else:
image_path = cla_path + image
new_path = 'flower_data/train/' + cla
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar \r为回车,end=""取消print默认换行,二者配合实现输出为一行数字的变化
print() # 为第一次for循环的输出实现换行操作(print()函数默认输出换行)
print("processing done!")
- 训练的实现
train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
#device : GPU or CPU
device = torch.device("cuda:0"