基于pytorch的FCN网络简单实现

参考知乎专栏实现FCN网络https://zhuanlan.zhihu.com/p/32506912

import torch
from torch import nn
import torch.nn.functional as f
import torchvision
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.models as models
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime

使用的数据集是VOC数据集,我们先读取数据

voc_root = "./data/VOC2012"

"""
读取图片
图片的名称在/ImageSets/Segmentation/train.txt ans val.txt里
如果传入参数train为True,则读取train.txt的内容,否则读取val.txt的内容
图片都在./data/VOC2012/JPEGImages文件夹下面,需要在train.txt读取的每一行后面加上.jpg
标签都在./data/VOC2012/SegmentationClass文件夹下面,需要在读取的每一行后面加上.png
最后返回记录图片路径的集合data和记录标签路径集合的label

"""
def read_images(root=voc_root, train=True):
    txt_fname = root + '/ImageSets/Segmentation/' + ('train.txt' if train else 'val.txt')
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    data = [os.path.join(root, 'JPEGImages', i+'.jpg') for i in images]
    label = [os.path.join(root, 'SegmentationClass', i+'.png') for i in images]
    return data, label

先来看一下数据长什么样子

data, label = read_images(voc_root)
im = Image.open(data[0])
plt.subplot(2,2,1)
plt.imshow(im)
im = Image.open(label[0])
plt.subplot(2,2,2)
plt.imshow(im)
im = Image.open(data[1])
plt.subplot(2,2,3)
plt.imshow(im)
im = Image.open(label[1])
plt.subplot(2,2,4)
plt.imshow(im)
plt.show()

在这里插入图片描述

可以发现,图片的尺寸不固定,但是我们输入网络的尺寸必须是固定的,而且必须保证data和label相对应的位置相同,所以我们需要写一个函数随机剪裁图片以适应网络输入的大小,并且data和label剪裁的位置要相同。

"""
切割函数,默认都是从图片的左上角开始切割
切割后的图片宽为width,长为height
"""
def crop(data, label, height, width):
    """
    data和lable都是Image对象
    """
    box = (0, 0, width, height)
    data = data.crop(box)
    label = label.crop(box)
    return data, label
im = Image.open(data[0])
la = Image.open(label[0])
plt.subplot(2,2,1), plt.imshow(im)
plt.subplot(2,2,2), plt.imshow(la)
im, la = crop(im, la, 224, 224)
plt.subplot(2,2,3), plt.imshow(im)
plt.subplot(2,2,4), plt.imshow(la)
plt.show()

在这里插入图片描述

下面我们需要将标签和像素点颜色之间建立映射关系

# VOC数据集中对应的标签
classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']

# 各种标签所对应的颜色
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

因为图片是三通道的,并且每一个通道都有0-255一共256中取值,所以我们初始化一个256^3大小的数组就可以做映射了

cm2lbl = np.zeros(256**3)

# 枚举的时候i是下标,cm是一个三元组,分别标记了RGB值
for i, cm in enumerate(colormap):
    cm2lbl[(cm[0]*256 + cm[1])*256 + cm[2]] = i

# 将标签按照RGB值填入对应类别的下标信息
def image2label(im):
    data = np.array(im, dtype="int32")
    idx = (data[:,:,0]*256 + data[:,:,1])*256 + data[:,:,2]
    return np.array(cm2lbl[idx], dtype="int64")
im = Image.open(label[20]).convert("RGB")
label_im = image2label(im)
plt.imshow(im)
plt.show()
label_im[100:110, 200:210]

在这里插入图片描述

array([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]], dtype=int64)

我们可以看到在截取出来的小区域内都是标记为3的像素点,通过标签列表,我们发现下标为3指示的是bird类
下面定义数据和标签的预处理函数

def image_transforms(data, label, height, width):
    data, label = crop(data, label, height, width)
    # 将数据转换成tensor,并且做标准化处理
    im_tfs = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    data = im_tfs(data)
    label = image2label(label)
    label = torch
评论 38
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值