代码从 https://tensorflow.google.cn/tutorials/keras/classification 搬移
改了一下输入的构造 (仅是一个 demo 输入构造简化了很多)
代码中使用 图像数据下载链接 https://download.csdn.net/download/zslngu/16165364
最终的结果
1. 输入准备
首先拿到 0~9 数字的图片,清理成下面这种格式
数字 0 示例 (文件暂时命名为 0.model)
01111110
01111111
11100111
11100011
11100011
11000011
11100011
11100011
11100111
01111110
01111110
因为是 demo 所以简单构造一下输入 将所有数字的图形信息都复制一下
def copy_data():
l = [str(i) for i in range(0, 10)]
for ll in l:
path = "./tf/" + str(ll) + ".model"
with open(path, "r") as f:
cur = f.read()
for j in range(0, 15):
# s_path ex: 0_1.model 生成15个一样的输入
s_path = "./tf/" + str(ll) + "_" + str(j) + ".model"
with open(s_path, "w") as wf:
wf.write(cur)
这样一共生成 10 * 15 个图像数据
2. 格式化输入
代码引用的库 & 环境配置
import os
model_path = "代码根目录"
os.chdir(model_path)
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
将原始内容可视化
trains = []
labels=[]
def get_train(path):
with open(path,