1. 简介
TensorBoard主要模块:
- GRAPHS: 保存了模型的结构图,可以比较清晰的显示模型搭建的每一个模块
- SCALARS: 保存了训练过程中的:Accuracy、tran_loss、Learning_Rate
- HISTOGRAMS: 保存了每一个层结构权重数值的分布
- IMAGE:保存每个epoch图片的预测结果
2.花卉识别
利用ResNet网络实现花卉的识别,花卉下载,总共有5中花,分别存在不同文件夹中:
1. daisy
2. dandelion
3. roses
4. sunflowers
5. tulips
2.1 ResNet网络及代码
ResNet网络是在2015年由微软实验室提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。下图是ResNet34层模型的结构简图。
- ResNet 模型代码
model.py
import torch.nn as nn
import torch
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=1, stride=1, bias=False) # squeeze channels
self.bn1 = nn.BatchNorm2d(out_channel)
# -----------------------------------------
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=stride, bias=False, padding=1)
self.bn2 = nn.BatchNorm2d(out_channel)
# -----------------------------------------
self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, bias=False) # unsqueeze channels
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, blocks_num[0])
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
self.in_channel = channel * block.expansion
for _ in range(1, block_num):
layers.append(block(self.in_channel, channel))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet34(num_classes=1000, include_top=True):
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet101(num_classes=1000, include_top=True):
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
2.2 数据处理代码
data_utils.py
- 划分验证集训练集
import os
import json
import pickle
import random
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt
def read_split_data(root: str, val_rate: float = 0.2):
"""划分训练集、验证集"""
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images_path.append(img_path)
val_images_label.append(image_class)
else: # 否则存入训练集
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
plot_image = False
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
plt.xticks(range(len(flower_class)), flower_class)
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num = min(batch_size, 4)
json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)
for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) # 去掉y轴的刻度
plt.imshow(img.astype('uint8'))
plt.show()
def write_pickle(list_info: list, file_name: str):
with open(file_name, 'wb') as f:
pickle.dump(list_info, f)
def read_pickle(file_name: str) -> list:
with open(file_name, 'rb') as f:
info_list = pickle.load(f)
return info_list
def plot_class_preds(net,
images_dir: str,
transform,
num_plot: int = 5,
device="cpu"):
if not os.path.exists(images_dir):
print("not found {} path, ignore add figure.".format(images_dir))
return None
label_path = os.path.join(images_dir, "label.txt")
if not os.path.exists(label_path):
print("not found {} file, ignore add figure".format(label_path))
return None
# read class_indict
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "not found {}".format(json_label_path)
json_file = open(json_label_path, 'r')
# {"0": "daisy"}
flower_class = json.load(json_file)
# {"daisy": "0"}
class_indices = dict((v, k) for k, v in flower_class.items())
# reading label.txt file
label_info = []
with open(label_path, "r") as rd:
for line in rd.readlines():
line = line.strip()
if len(line) > 0:
split_info = [i for i in line.split(" ") if len(i) > 0]
assert len(split_info) == 2, "label format error, expect file_name and class_name"
image_name, class_name = split_info
image_path = os.path.join(images_dir, image_name)
# 如果文件不存在,则跳过
if not os.path.exists(image_path):
print("not found {}, skip.".format(image_path))
continue
# 如果读取的类别不在给定的类别内,则跳过
if class_name not in class_indices.keys():
print("unrecognized category {}, skip".format(class_name))
continue
label_info.append([image_path, class_name])
if len(label_info) == 0:
return None
# get first num_plot info
if len(label_info) > num_plot:
label_info = label_info[:num_plot]
num_imgs = len(label_info)
images = []
labels = []
for img_path, class_name in label_info:
# read img
img = Image.open(img_path).convert("RGB")
label_index = int(class_indices[class_name])
# preprocessing
img = transform(img)
images.append(img)
labels.append(label_index)
# batching images
images = torch.stack(images, dim=0).to(device)
# inference
with torch.no_grad():
output = net(images)
probs, preds = torch.max(torch.softmax(output, dim=1), dim=1)
probs = probs.cpu().numpy()
preds = preds.cpu().numpy()
# width, height
fig = plt.figure(figsize=(num_imgs * 2.5, 3), dpi=100)
for i in range(num_imgs):
# 1:子图共1行,num_imgs:子图共num_imgs列,当前绘制第i+1个子图
ax = fig.add_subplot(1, num_imgs, i+1, xticks=[], yticks=[])
# CHW -> HWC
npimg = images[i].cpu().numpy().transpose(1, 2, 0)
# 将图像还原至标准化之前
# mean:[0.485, 0.456, 0.406], std:[0.229, 0.224, 0.225]
npimg = (npimg * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
plt.imshow(npimg.astype('uint8'))
title = "{}, {:.2f}%\n(label: {})".format(
flower_class[str(preds[i])], # predict class
probs[i] * 100, # predict probability
flower_class[str(labels[i])] # true class
)
ax.set_title(title, color=("green" if preds[i] == labels[i] else "red"))
return fig
- 数据转为pytorch模型加载的DataSet格式,自定义数据集
my_dataset.py
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
delete_img = []
for index, img_path in tqdm(enumerate(images_path)):
img = Image.open(img_path)
w, h = img.size
ratio = w / h
if ratio > 10 or ratio < 0.1:
delete_img.append(index)
# print(img_path, ratio)
for index in delete_img[::-1]:
self.images_path.pop(index)
self.images_class.pop(index)
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
2.3 训练模型,并保存tensorboard
将训练工程中的train_loss、Accuracy、Learning_Rate、以及相关权重保存在tensorboard中进行显示。
- 1 实例化SummaryWriter对象
from torch.utils.tensorboard import SummaryWriter
# 实例化SummaryWriter对象
tb_writer = SummaryWriter(log_dir="runs/flower_experiment") #将tensorboard文件保存在runs/flower_experiment文件中,运行改代码会自动创建runs/flower_experiment目录
- 2 添加ResNet网络的结构图
# 将模型写入tensorboard
init_img=torch.zeros((1,3,224,224),device=device) #创建与模型输入大小一致的batch,chanel,width,height的空矩阵
#添加网络graph需要将init_img传入网络中正向传播,会根据输入的数据在模型中正向传播的流程,来创建我们的网络结构图
tb_writer.add_graph(model,init_img)
- 3 添加mean_loss、accuracy、learning_rate
在模型训练的每个epoch中,验证完模型后,保存每个epoch的mean_loss(平均损失)、accuracy、learning_rate保存到tensorboard中。
for epoch in range(args.epochs):
# train
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
# update learning rate
scheduler.step()
# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)
# add loss, acc and lr into tensorboard
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["train_loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch) #mean_loss
tb_writer.add_scalar(tags[1], acc, epoch) #acc
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)
add_scalar(tag,value,step): tag(string)标签 value(float):浮点型小数 step(int):一般为epoch
- 4 添加预测的图片结果到tensorboard中
# add figure into tensorboard
# 获得图片对象
fig = plot_class_preds(net=model,
images_dir="./plot_img",
transform=data_transform["val"],
num_plot=5,
device=device)
if fig is not None:
tb_writer.add_figure("predictions vs. actuals",
figure=fig,
global_step=epoch)
- 5 添加训练过程中的权重数据到tensorboard中
权重数据,一般以条形图进行统计展示
# add conv1 weights into tensorboard
tb_writer.add_histogram(tag="conv1",
values=model.conv1.weight,
global_step=epoch)
tb_writer.add_histogram(tag="layer1/block0/conv1",
values=model.layer1[0].conv1.weight,
global_step=epoch)
完整的train.py代码如下:
import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
from model import resnet34
from my_dataset import MyDataSet
from data_utils import read_split_data, plot_class_preds
from train_eval_utils import train_one_epoch, evaluate
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(args)
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
# 实例化SummaryWriter对象
tb_writer = SummaryWriter(log_dir="runs/flower_experiment")
if os.path.exists("./weights") is False:
os.makedirs("./weights")
# 划分数据为训练集和验证集
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
# 定义训练以及预测时的预处理方法
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# 实例化训练数据集
train_data_set = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
# 实例化验证数据集
val_data_set = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])
batch_size = args.batch_size
# 计算使用num_workers的数量
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_data_set,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_data_set.collate_fn)
val_loader = torch.utils.data.DataLoader(val_data_set,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_data_set.collate_fn)
# 实例化模型
model = resnet34(num_classes=args.num_classes).to(device)
# 将模型写入tensorboard
init_img = torch.zeros((1, 3, 224, 224), device=device)
tb_writer.add_graph(model, init_img)
# 如果存在预训练权重则载入
if os.path.exists(args.weights):
weights_dict = torch.load(args.weights, map_location=device)
load_weights_dict = {k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(load_weights_dict, strict=False)
else:
print("not using pretrain-weights.")
# 是否冻结权重
if args.freeze_layers:
print("freeze layers except fc layer.")
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if "fc" not in name:
para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(args.epochs):
# train
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
# update learning rate
scheduler.step()
# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)
# add loss, acc and lr into tensorboard
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["train_loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], acc, epoch)
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)
# add figure into tensorboard
fig = plot_class_preds(net=model,
images_dir="./plot_img",
transform=data_transform["val"],
num_plot=5,
device=device)
if fig is not None:
tb_writer.add_figure("predictions vs. actuals",
figure=fig,
global_step=epoch)
# add conv1 weights into tensorboard
tb_writer.add_histogram(tag="conv1",
values=model.conv1.weight,
global_step=epoch)
tb_writer.add_histogram(tag="layer1/block0/conv1",
values=model.layer1[0].conv1.weight,
global_step=epoch)
# save weights
torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lrf', type=float, default=0.1)
# 数据集所在根目录
# http://download.tensorflow.org/example_images/flower_photos.tgz
img_root = "/home/wz/my_project/my_github/data_set/flower_data/flower_photos"
parser.add_argument('--data-path', type=str, default=img_root)
# resnet34 官方权重下载地址
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
parser.add_argument('--weights', type=str, default='resNet34.pth',
help='initial weights path')
parser.add_argument('--freeze-layers', type=bool, default=False)
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
opt = parser.parse_args()
main(opt)
- 6执行训练代码
python train.py
- 训练结束后,cmd到runs目录,执行如下代码:
tensorboard --logdir=./ --samples_per_plugin=images=50 #samples_per_plugin=images设置展示的图片的图片
执行命令后,出现tensorboard的访问网址:
在浏览器中访问:http://localhost:6006,就可以看到训练工程中保存的一系列数据了。
Scalar模块展示训练过程中,每个epoch的train_loss、Accuracy、Learn_Rating的数值变化
Image模块显示图片的预测结果
GRAPH模块展示的是模型的网络结果
HISTOGRAMS模块展示添加到tensorboard中各层的权重分布情况
3 参考
源代码deep-learning-for-image-processing 中的tensorboard_test项目
Resnet:https://www.cnblogs.com/yanshw/p/10576354.html