目录
参考资料
李沐《动手学深度学习》
github_deep_learning_500问(图像分割篇)
Voc Pascal 2007(数据集)
还有许多博客,就不一一列出了
代码
VOC2007Dataset.py
import torch
import torchvision
from PIL import Image
import numpy as np
#颜色标签空间转到序号标签空间
def voc_label_indices(colormap, colormap2label):
"""
convert colormap (PIL image) to colormap2label (uint8 tensor).
"""
colormap = np.array(colormap.convert("RGB")).astype('int32')
idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
+ colormap[:, :, 2])
return colormap2label[idx]
#读入数据集
def read_voc_images(root="./data/VOCdevkit/VOC2007",
is_train=True, max_num=None):
txt_fname = '%s/ImageSets/Segmentation/%s' % (
root, 'train.txt' if is_train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
if max_num is not None:
images = images[:min(max_num, len(images))]
features, labels = [None] * len(images), [None] * len(images)
for i, fname in enumerate(images):
features[i] = Image.open('%s/JPEGImages/%s.jpg' % (root, fname)).convert("RGB")
labels[i] = Image.open('%s/SegmentationClass/%s.png' % (root, fname)).convert("RGB")
return features, labels # PIL image
#随机裁剪 从一张图片中 crop 出固定大小的区域,然后在 label 上也做同样方式的 crop。
def voc_rand_crop(feature, label, height, width):
"""
Random crop feature (PIL image) and label (PIL image).
"""
#获取随机坐标
i, j, h, w = torchvision.transforms.RandomCrop.get_params(
feature, output_size=(height, width))
feature = torchvision.transforms.functional.crop(feature, i, j, h, w)
label = torchvision.transforms.functional.crop(label, i, j, h, w)
return feature, label
#构建自己的数据集 继承于Dataset
class VOCSegDataset(torch.utils.data.Dataset):
#继承父类的自定义数据集必须继承 init getitem len 必须重写
def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None):
"""
crop_size: (h, w)
"""
#均值 方差归一化 (用Ptyorch自带预处理模型,方差 标准差 必须如下
self.rgb_mean = np.array([0.485, 0.456, 0.406])
self.rgb_std = np.array([0.229, 0.224, 0.225])
self.tsf = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=self.rgb_mean,
std=self.rgb_std)
])
#裁剪尺寸
self.crop_size = crop_size # (h, w)
features, labels = read_voc_images(root=voc_dir,
is_train=is_train,
max_num=max_num)
self.features = self.filter(features) # PIL image
self.labels = self.filter(labels) # PIL image
self.colormap2label = colormap2label
print('read ' + str(len(self.features)) + ' valid examples')
#过滤不符合大小的图片
def filter(self, imgs):
return [img for img in imgs if (
img.size[1] >= self.crop_size[0] and
img.size[0] >= self.crop_size[1])]
#获得图片和标签
def __getitem__(self, idx):
feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
*self.crop_size)
label = voc_label_indices(label, self.colormap2label).numpy().astype('uint8')
# 统一GT
h, w = label.shape
target = torch.zeros(21, h, w)
for c in range(21):
target[c][label == c] = 1
return (self.tsf(feature), target)
#数据集长度
def __len__(self):
return len(self.features)
'''
class FirstDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. 初始化文件路径或文件名列表。
#也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
pass
def __getitem__(self, index):
# TODO
#1。从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
#2。预处理数据(例如torchvision.Transform)。
#3。返回数据对(例如图像和标签)。
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# 您应该将0更改为数据集的总大小。
'''
#数据迭代器
def VOC2007SegDataIter(batch_size=64, crop_size=(320, 480), num_workers=4, max_num=None):
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]]
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person',
'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']
colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8)
#将颜色标签映射为单通道的数组索引
for i, colormap in enumerate(VOC_COLORMAP):
colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
voc_train = VOCSegDataset(True, crop_size, "./data/VOCdevkit/VOC2007", colormap2label, max_num)
voc_val = VOCSegDataset(False, crop_size, "./data/VOCdevkit/VOC2007", colormap2label, max_num)
train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True, drop_last=True,
num_workers=num_workers)
val_iter = torch.utils.data.DataLoader(voc_val, batch_size, drop_last=True, num_workers=num_workers)
return train_iter, val_iter
train.py
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 20 23:21:02 2020
@author: ZLH
"""
from torch import optim
import torch.nn as nn
from tqdm import tqdm
from VOC2007Dataset import VOC2007SegDataIter
import torch
import numpy as np
from torchvision import models
# 选择设备,有cuda用cuda,没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#20中类别+1哥背景
num_classes = 21
#线性插值算法
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
dtype='float32')
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.tensor(weight)
def train_net(net, device, batch_size=1, epochs=40, lr=0.00001):
#加载数据集
train_iter, val_iter = VOC2007SegDataIter(batch_size, (320, 480), 0, 200)
#预训练的层和新增加的层 不同学习率
output_params = list(map(id, net[-1].parameters())) + list(map(id, net[-2].parameters()))
feature_params = filter(lambda p: id(p) not in output_params, net.parameters())
optimizer = optim.SGD([{'params': feature_params},
{'params': net[-2].parameters(), 'lr': lr * 10},
{'params': net[-1].parameters(), 'lr': lr * 10}],
lr=lr, weight_decay=0.001)
# 定义RMSprop算法
#optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9)
# 定义Loss算法
criterion = nn.BCEWithLogitsLoss()
# best_loss统计,初始化为正无穷
#best_loss = float('inf')
# 训练epochs次
for epoch in range(epochs):
# 训练模式
net.train()
# 按照batch_size开始训练
for image, label in tqdm(train_iter):
optimizer.zero_grad()
# 将数据拷贝到device中
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
# 使用网络参数,输出预测结果
pred = net(image)
# 计算loss
loss = criterion(pred, label)
print('Loss/train', loss.item())
# 保存loss值最小的网络参数
#if loss < best_loss:
#best_loss = loss
#torch.save(net.state_dict(), 'best_model2.pth')
# 更新参数
loss.backward()
optimizer.step()
torch.save(net.state_dict(), 'best_model2.pth')
if __name__ == '__main__':
#用resnet18网络
resnet18 = models.resnet18(pretrained=True)
resnet18_modules = [layer for layer in resnet18.children()]
net = nn.Sequential()
#去掉最后两层的全连接层
for i, layer in enumerate(resnet18_modules[:-2]):
net.add_module(str(i), layer)
#增加线性插值层
net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
#增加反卷积层
net.add_module("ConvTranspose2d",
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))
#设置权重
net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)
# 将网络拷贝到deivce中
net.to(device=device)
# 开始训练
train_net(net, device,batch_size = 4,epochs=5,lr=0.1)
predict.py
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 20 22:00:29 2020
@author: ZLH
"""
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
import numpy as np
from VOC2007Dataset import VOC2007SegDataIter
from torch import nn
import os
# 选择设备,有cuda用cuda,没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 21
#线性插值
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
dtype='float32')
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.tensor(weight)
def decode_segmap(image, nc=21):
#给出每种类别对应的 RGB 值 进行染色
label_colors = np.array([(0, 0, 0), # 0=background
# 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
(128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
# 6=bus, 7=car, 8=cat, 9=chair, 10=cow
(0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
# 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
(192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
# 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
(0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, nc):
idx = image == l
r[idx] = label_colors[l, 0]
g[idx] = label_colors[l, 1]
b[idx] = label_colors[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
if __name__ == '__main__':
batch_size = 4
#搭建网络 采用resnet18
resnet18 = models.resnet18(pretrained=True)
resnet18_modules = [layer for layer in resnet18.children()]
net = nn.Sequential()
for i, layer in enumerate(resnet18_modules[:-2]):
net.add_module(str(i), layer)
#去掉最后两层的全连接层
net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module("ConvTranspose2d",
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))
#设置权重
net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)
PATH = './best_model2.pth'
net.load_state_dict(torch.load(PATH))
img_root = './input'
pred_root = './output'
#批量读入图片
files = os.listdir(img_root)
for file in files:
filesname = img_root + '/' + file
print(filesname)
img = Image.open(filesname)
plt.imshow(img)
plt.show()
# 测试模式
net.eval()
#裁剪 确定中心 转化为Tensor 均值 方程 标准化 (pytorch自带预训练模型方差和标准差必须是如下参数)
trf = T.Compose([T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])])
inp = trf(img).unsqueeze(0)
# 预测
out = net(inp)
# 提取结果
pred = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
# 处理结果
rgb = decode_segmap(pred)
plt.imshow(pred)
plt.show()
# 保存图片
plt.imsave(pred_root+'/'+file, pred)
效果对比
(复现效果属实不太行 O^O