python基于sklearn使用MLP手写数字识别模型训练及应用流程

参考书籍:掌控python.人工智能之机器视觉 / 程晨编著
图片处理相关工具参考了csdn大佬的相关文章,如有侵权,请联系删除。

1. 使用自制数据集进行复现及模型训练

# train_model.py

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
import joblib
from produce_train_data import *

digits = load_digits()
# 以下为自制数据集加入,使用mnist数据集时注释掉下面四行
path = './train_img/train/'
data,target = image_datasets(path)
digits.data = data
digits.target = target
x_train, x_test, y_train, y_test = train_test_split(digits['data'],digits['target'],test_size=0.3,random_state=0)
mplc = MLPClassifier(max_iter=1000)
mplc.fit(x_train,y_train)
pred = mplc.predict(x_test)
print(pred==y_test)
joblib.dump(mplc,'./model/mplc.pkl')

2.模型应用

# infer.py

def infer(model_path,img):
    classfier = joblib.load(model_path)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    # (1,8930)根据实际训练图片大小进行修改,将图片转换为一维数组
    img = numpy.reshape(img,(1,8930))
    x = str(classfier.predict(img)[0])
    print(x)
    return x

3.自制数据集

# produce_train_data.py

# -- coding: utf-8 --
from sklearn import datasets
import  numpy as np
from PIL import Image
import os
import csv
import skimage

#读取文件夹中的图像信息,生成列表
def generate_dataset(path):
    filelist = os.listdir(path)
    csvfile = open("imgfile.txt", 'w')
    for files in filelist:
        filename = os.path.splitext(files)[0]
        str1 = path + files + ' ' + filename[0] + '\n'
        csvfile.writelines(str1)
    csvfile.close()
    return csvfile.name

#生成sklearn训练数据集
def load_imgesets(filename):
    file = open(filename,'r')
    data = []
    target = []
    data = np.array(data,dtype=float)
    flag = 1
    for line in file:
        # 分割图像路径与类别
        str = line.split(' ',1)
        #读取图片转换为灰度图
        #将灰度矩阵转换为一维数据
        if flag == 1:
            flag = 0
            data = np.array(Image.open(str[0]).convert('L')).reshape(1,-1)
        else:
            row = np.array(Image.open(str[0]).convert('L')).reshape(1, -1)
            data = np.row_stack((data, row))
        target.append(str[1])
    file.close()
    target =np.asarray(target,dtype=int)
    return data,target

def image_datasets(path):
    filename = generate_dataset(path)
    data,target = load_imgesets(filename)
    return data,target


if __name__ == '__main__':
    path = './img_new1/'
    data,target = image_datasets(path)

4.图片阈值寻找工具

from __future__ import division
import cv2
import numpy as np
def nothing(*arg):
        pass
icol = (0, 0, 0, 255, 255, 255)
cv2.namedWindow('colorTest')
#阈值低点
cv2.createTrackbar('lowHue', 'colorTest', icol[0], 255, nothing)
cv2.createTrackbar('lowSat', 'colorTest', icol[1], 255, nothing)
cv2.createTrackbar('lowVal', 'colorTest', icol[2], 255, nothing)
#阈值高点
cv2.createTrackbar('highHue', 'colorTest', icol[3], 255, nothing)
cv2.createTrackbar('highSat', 'colorTest', icol[4], 255, nothing)
cv2.createTrackbar('highVal', 'colorTest', icol[5], 255, nothing)
#读取图片,图片命名test1.jpg
frame = cv2.imread('test1.jpg')
while True:
    lowHue = cv2.getTrackbarPos('lowHue', 'colorTest')
    lowSat = cv2.getTrackbarPos('lowSat', 'colorTest')
    lowVal = cv2.getTrackbarPos('lowVal', 'colorTest')
    highHue = cv2.getTrackbarPos('highHue', 'colorTest')
    highSat = cv2.getTrackbarPos('highSat', 'colorTest')
    highVal = cv2.getTrackbarPos('highVal', 'colorTest')
    #色域转换
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    #设定阈值
    colorLow = np.array([lowHue,lowSat,lowVal])
    colorHigh = np.array([highHue,highSat,highVal])
    mask = cv2.inRange(hsv, colorLow, colorHigh)
    result = cv2.bitwise_and(frame, frame, mask = mask)
    #图片拼接
    imgs = np.hstack([frame,result])
    cv2.imshow('colorTest', imgs)
    k = cv2.waitKey(5) & 0xFF
    if k == 27:
        break
cv2.destroyAllWindows()

5.图片变换工具

# 红色字体时使用,红色转变为白色
def red2white(img):
    for x in range(img.shape[0]):   # 图片的高
        for y in range(img.shape[1]):   # 图片的宽
            px = img[x,y]
            #print(px)    # 这样就能得到每个点的bgr值
            if img[x,y,0] < 10 and img[x,y,1] < 189 and img[x,y,2] < 209 :
                img[x,y,0] = img[x,y,1] = img[x,y,2] = 0
            else: 
                img[x,y,0] = img[x,y,1] = img[x,y,2] = 255
    return img

# 将图片背景转换为黑色
def background2black(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    #字体为红色时使用
    #lower_color = np.array([0,115, 181])
    #字体为白色时使用
    lower_color = np.array([0,0, 230])
    upper_color = np.array([255, 255, 255])
    mask = cv2.inRange(hsv, lower_color, upper_color)
    res = cv2.bitwise_and(img, img, mask=mask)
    return res
  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值