caffe2 学习笔记03-从图片如何到mdb数据集

caffe2 学习笔记03-从图片如何到mdb数据集

1. 前言

本文以caffe2训练识别汉字模型为例;

2. import库文件

输出为Required modules imported.")即导入成功,若提示缺少某个库文件,请谷歌一下;

# -*- coding: UTF-8 -*-
%matplotlib inline
import os
import skimage
import skimage.io as io
import skimage.transform
import sys 
import numpy as np
import math
from matplotlib import pyplot
import matplotlib.image as mpimg

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import numpy as np

import lmdb
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, model_helper
print("Required modules imported.")     

3. 准备

设置路径,设置标签对应表,限制最大输出文件大小

path = "/home/hw/H/00_dataOfPlate/15_hanzi/01_new_chn/train" #数据路径
sep = os.path.sep #当前系统(linux)路径分隔符

chn = ["beijing", "tianjin", "hebei", "shanxi", "neimenggu", "liaoning", "jilin", "heilongjiang", "shanghai", "jiangsu", "zhejiang", "anhui", "fujian", "jiangxi", "shandong", "henan", "hubei", "hunan", "guangdong", "guangxi", "hainan", "sichuan", "guizhou", "yunnan", "chongqin", "xizang", "shengxi", "gansu", "qinghai", "ningxia", "xinjiang"] # No. 31 # 标签对应表

LMDB_MAP_SIZE =  1099511627776  #max output file < 1TB
print("prepared") 最大输出文件大小

4. write函数:读入图片文件与标签,并转换为mdb文件

文件结构,以train文件夹为例,train下包含26个字母,标签label以图片所在文件夹为准;
- train
- A
- 0001.bmp
- 0002.bmp
- …
- 4000.bmp
- B
- 0001.bmp
- …
- 4100.bmp
- …
- …

一级目录二级目录图片
trainA1022.bmp
trainA
trainA4032.bmp
trainB1022.bmp
def write_db_with_caffe2(output_file):
    print(">>> Write database ...")
    LMDB_MAP_SIZE = 1099511627776
    env = lmdb.open(output_file, map_size = LMDB_MAP_SIZE)
    checksum = 0
    checksumm = 0
    j = 0

    with env.begin(write = True) as txn:
        for dirs in os.listdir(path):
        #     print dirs
            new_path = path + sep + dirs
            label = chn.index(dirs)
            for pics in os.listdir(new_path):
                #print pics
    #             print(len(os.listdir(new_path)))
                pic_path = new_path + sep + pics
                #print pic_path
                img_data = skimage.img_as_float(skimage.io.imread(pic_path)).astype(np.float)
                print("before: {}".format(img_data.shape))

                img_data = img_data[:,:,:1] #3通道转换为1通道
                img_data = img_data.swapaxes(1, 2).swapaxes(0, 1) #HWC 转换为 CHW
                print("after: {}".format(img_data.shape))
        #         print np.prod(img_data.shape)
                tensor_protos = caffe2_pb2.TensorProtos()
                img_tensor = tensor_protos.protos.add()
                img_tensor.dims.extend(img_data.shape)
                img_tensor.data_type = 1

                flatten_img = img_data.reshape(np.prod(img_data.shape))
                print("after: {}".format(flatten_img.shape))
                img_tensor.float_data.extend(flatten_img.flat)

                label_tensor = tensor_protos.protos.add()
                label_tensor.data_type = 2
                label_tensor.int32_data.append(label)
                txn.put('{}'.format(j).encode('ascii'),tensor_protos.SerializeToString())

    #             print(np.sum(img_data))
    #             print(label)
                checksum += np.sum(img_data) * label
                checksumm += np.sum(img_data)
                if(j % 5 == 0):
                    pass
    #                 print("Inserted {} rows".format(j))
                j+=1
    #     print(j)
        print("Checksum/write: {}".format(int(checksum)))
        print("Checksumm/write: {}".format(int(checksumm)))

5. read函数:读取mdb文件,并校验(此步不是必须的)

输入数据所在文件夹:read_db_with_caffe2(db_file, expected_checksum)
db_file: 数据文件所在路径
expected_checksum:期望的输出校验值,应该与write_db_with_caffe2中的值对应

def read_db_with_caffe2(db_file, expected_checksum):

    print(">>> Read database...")
    model = model_helper.ModelHelper(name="lmdbtest")
    batch_size = 744000 #共计多少个文件,一定要写正确,否则会造成校验失败("Read/write checksums dont match")
    data, label = model.TensorProtosDBInput(
        [], ["data", "label"], batch_size=batch_size,
        db=db_file, db_type="lmdb")

    checksum = 0 
    workspace.RunNetOnce(model.param_init_net)
    workspace.CreateNet(model.net)

    for _ in range(0, 1): 
        workspace.RunNet(model.net.Proto().name)

        img_datas = workspace.FetchBlob("data")
        labels = workspace.FetchBlob("label")
#         print("batch_size: {}".format(batch_size))
#         print(img_data.shape)
        for j in range(batch_size):
#             print(img_datas[j, 2])
            checksum += np.sum(img_datas[j, :]) * labels[j]
            checksumm += np.sum(img_datas[j, :])
#             print(np.sum(img_datas[j,:]))
#             print(labels[j])
    print("Checksum/read: {}".format(int(checksum)))
    print("minus of read and write: {}".format(np.abs(expected_checksum - checksum )))
    assert np.abs(expected_checksum - checksum < 0.1), \
        "Read/write checksums dont match"

6. 执行

执行时间较长,请耐心等待,读取744000个大小为20*20的灰度图像时,时间约为二十分钟,读取db数据进行测试,电脑卡死了,o(╯□╰)o;

write_db_with_caffe2("./chn_db") 
read_db_with_caffe2("./chn_db", 640020532) #640020532为校验值,应该等于write中输出的checksum大小

7. 可能遇见的报错

1. CHW和HWC的问题:

input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1
原因:默认读取的图片为shape为HWC(height/width/channels),而caffe2默认图片数据格式为CHW,所以需要进行转换,不转换则报错如下:

RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1 Error from operator: 
input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"

解决方式:在将图片转换为mdb文件时,加入img_data = img_data.swapaxes(1, 2).swapaxes(0, 1)(上面的程序中已经加入了)

2. channels不匹配问题:

input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1
原因:默认读取的图片,不论是否为灰度图,都会以三通道的形式读取,经过上面1. 中的HWC–>> CHW的转换后,通道为3,与MNIST示例LENET中的单通道不匹配,所以报错如下:

RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1 Error from operator:
input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"

解决方式:在将图片转换为mdb文件时,加入:img_data = img_data[:,:,:1](上面程序已经加入了)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值