【TensorFlow2.0】训练本地图片进行手写体数字识别

33 篇文章 0 订阅

配置tensorflow-gpu环境

首先,配置tensorflow-gpu环境。(不使用GPU环境运行出错)
我的环境配置:
python3.7.9
cuda_10.0.130_411.31_win10
cudnn-10.0-windows10-x64-v7.5.0.56
配置步骤

#coding=utf-8

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import datetime
import os
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
#root=os.path.abspath(os.path.dirname(__file__))
#os.chdir(root)


image_train_path="train_pic"
#label_train_path="train.txt"
tfRecord_train="data/train_pic.tfrecords"
image_test_path="test_pic"
#label_test_path="test.txt"
tfRecord_test="data/test_pic.tfrecords"
resize_height=32
resize_width=32
data_path="data"

def image_example(image_string, label,image_tostring):
    image_shape = tf.image.decode_bmp(image_string).shape
    feature = {
      'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[image_shape[0]])),
      'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[image_shape[1]])),
      'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
      'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_tostring]))
    }

    return tf.train.Example(features=tf.train.Features(feature=feature))
def write_record(tf_record_name,image_path,label_data):
    print ("center".center(10,"-"))
    writer = tf.io.TFRecordWriter(tf_record_name)
    pic_id=0
    for con in label_data:
        filename=image_path+"/"+con

        image_string = open(filename, 'rb').read()
        img_str= Image.open(filename)
        image_tostring=img_str.tobytes()
        labels=int(con[0])
        tf_example = image_example(image_string, labels,image_tostring)
        writer.write(tf_example.SerializeToString())
        pic_id+=1
        if pic_id%100 == 0:
            print ("the pic is:",pic_id)
    writer.close()
    print ("The data is writed successful!")
def get_label_data(img_path):
    if not os.path.exists(img_path):
        print ("input data error!")
        return
    root =  os.listdir(img_path)
    list_data=[]
    for r in root :
        list_data.append(r)
    return list_data
    
def read_and_decode(example_string):
    feature_description = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([], tf.string)
        }

    feature_dict = tf.io.parse_single_example(example_string, feature_description)
    image = tf.io.decode_raw(feature_dict['image'],out_type=tf.uint8)
    image = tf.cast(tf.reshape(image, (28, 28)), dtype=tf.float32)
    label= tf.cast(feature_dict['label'], tf.int64)
    return {"image":image,"label":label}


def read_record(tfRecord_train):
    dataset = tf.data.TFRecordDataset(tfRecord_train)
    dataset = dataset.repeat() # 重复数据集
    dataset = dataset.map(read_and_decode) # 解析数据
    dataset = dataset.shuffle(100).map(preprocess).batch(32).repeat() # 在缓冲区中随机打乱数据
    return dataset
def preprocess(x):
    target_size = 32
    num_classes = 7356
    x['image'] = tf.expand_dims(x['image'], axis=-1)
    x['image'] = tf.image.resize(x['image'], [target_size, target_size])
    return x['image'], x['label']

def generate_tfRecord():

    if not os.path.exists(data_path):
        os.makedirs(data_path)
        print ("root create successfully!")
    else:
        print ("root already exists!")
    write_data = 0
    if write_data == True:
        train_label_data=get_label_data(image_train_path)
        print ("pic_num:",len(train_label_data))
        write_record(tfRecord_train,image_train_path,train_label_data)
        test_label_data=get_label_data(image_test_path)
        print ("pic_num:",len(test_label_data))
        write_record(tfRecord_test,image_test_path,test_label_data)
    else:
        batch  = read_record(tfRecord_train)
        model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(filters=6,kernel_size=(5,5),padding='valid',activation=tf.nn.relu,input_shape=(32,32,1)),
        tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=(2,2),padding='same'),
        tf.keras.layers.Conv2D(filters=16,kernel_size=(5,5),padding='valid',activation=tf.nn.relu,input_shape=(32,32,1)),
        tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=(2,2),padding='same'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=120,activation=tf.nn.relu),
        tf.keras.layers.Dense(units=84,activation=tf.nn.relu),
        tf.keras.layers.Dense(units=10,activation=tf.nn.softmax),
        ])
        
        model.summary()

        num_epochs=1#训练次数
        batch_size=100#每个批次喂多少张图片
        lr=0.0001#学习率

        adam_optimizer=tf.keras.optimizers.Adam(lr)
         
        model.compile(
            optimizer=adam_optimizer,
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            metrics=['accuracy']
        )

        start_time=datetime.datetime.now()
        model.fit(batch,epochs=100,steps_per_epoch=100)
        end_time=datetime.datetime.now()
        time_cost=end_time-start_time
        print('time_cost: ',time_cost)
        model.save('leNet_model.h5')

def main():
    generate_tfRecord()


if __name__ == "__main__":
    main()

1,write_data == True时执行将本地图片,标签制成数据集。

在这里插入图片描述

2,write_data == False时执行读取本地数据集,并训练。

在这里插入图片描述

数据集28*28,10000张。image,二进制图片,label,0-9;
在这里插入图片描述

#coding=utf-8
import tensorflow as tf
mnist=tf.keras.datasets.mnist
import matplotlib.pyplot as plt
import matplotlib as m
import numpy as np
import cv2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
assert len(gpu) == 1
tf.config.experimental.set_memory_growth(gpu[0], True)


root=os.path.abspath(os.path.dirname(__file__))
os.chdir(root)
#加载模型

def digit_predict():
    model=tf.keras.models.load_model('leNet_model_.h5')
    img=cv2.imread('train_pic/9_11.bmp')
    img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    img=cv2.resize(img,(32,32))
    img=img/255
    pred=model.predict(img.reshape(1,32,32,1))
    print('prediction Number: ',pred.argmax())


digit_predict()
     
     
     

测试图片:
在这里插入图片描述
pridict result:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

佐倉

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值