flowers 数据集分类 vgg19 微调网络 保存为pb格式

原文链接: flowers 数据集分类 vgg19 微调网络 保存为pb格式

上一篇: tfjs vue 风格迁移 展示

下一篇: mobilenet tfjs 分类网络 vue包装 保存后本地加载后使用

vgg19 下载

https://github.com/tensorflow/models/tree/master/research/slim/#Pretrained

flowers数据集下载,包含原始record,提取的图片,以及生成的

链接:https://pan.baidu.com/s/1Aa3tk1VCV5vimPLqCXXA_Q
提取码:48u0
d9585d5126053e51905337586ca91ef6c7f.jpg

jpg格式转为record格式

文件夹结构

f97dca42434f3a8a7574df3cabde840f464.jpg

转为record格式

将分类放好的图片按照指定大小统一转化为record格式

注意需要将顺序打乱

import tensorflow as tf
import os
from PIL import Image
import numpy as np

IMAGE_SIZE = 224


# 把传入的value转化为整数型的属性,int64_list对应着 tf.train.Example 的定义
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 把传入的value转化为字符串型的属性,bytes_list对应着 tf.train.Example 的定义
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# image_dir 为图像存放路径,文件夹内含有0->n个类别的图像
# out_path 为输出record路径
def save(image_dir, out_path, n_class=5):
    # 包含两个信息,分类和图片地址
    image_paths = [
        [label, os.path.join(image_dir, str(label), name)]
        for label in range(n_class)
        for name in os.listdir(os.path.join(image_dir, str(label)))
    ]

    #  打乱顺序后写入record文件
    image_paths = np.array(image_paths)
    np.random.shuffle(image_paths)

    with tf.python_io.TFRecordWriter(out_path) as writer:
        for label, path in image_paths:
            # 因为含有字符串,所以会被自动转为字符串,需要再次转为int
            label = int(label)
            # print(label, path)

            # 读取图片矩阵,统一大小,转化为二进制格式
            img = Image.open(path).resize((IMAGE_SIZE, IMAGE_SIZE)).tobytes()
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
            }))  # example对象对label和image数据进行封装
            writer.write(example.SerializeToString())  # 序列化为字符串


def main():
    train_dir = r'D:\data\flowers\jpg\train'
    train_record = r'D:\data\flowers\train.record'
    test_dir = r'D:\data\flowers\jpg\test'
    test_record = r'D:\data\flowers\test.record'
    save(train_dir, train_record)
    save(test_dir, test_record)


if __name__ == '__main__':
    main()

数据集有16600张,大小计算如下,uint占用一个字节

可以看到标签大小对数据集影响较小


print(16600 * 224 * 224 * 3 / 1024 ** 3)
print((16600 * 2 + 16600 * 224 * 224 * 3) / 1024 ** 3)

1a818d635a409de6262cd82824c7dc34891.jpg

构建分类网络,进行训练

然后保存为pb格式供其他程序使用

960最多可可以用批次大小为48,64会OOM

训练一次大概一秒左右,1000次后效果可以达到百分之九十以上

import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
from tensorflow.python.framework import graph_util
import time

TRAIN_STEP = 1000
SHOW_STEP = 10
# 批次太大会OOM
# batch_size = 64
BATCH_SIZE = 32
IMAGE_SIZE = 224
CAPACITY = 256
MIN_AFTER = 128
TEST_SIZE = 32
NUM_THREADS = 4
N_CLASS = 5
LR = .001

TRAIN_RECORD = r"D:\data\flowers\train.record"
TEST_RECORD = r&#
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值