paddlepaddle使用笔记——使用自己的数据训练ocr模型

1、使用环境:

ubuntu18.04,4gpu,nvidia410.78,cuda9.0,cudnn7.3,python3.6

2、使用代码:

官方提供的ocr模型代码

https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/ocr_recognition

3、将代码运行起来

为了方便看到运行的效果,我修改了参数,save_model_period,这样可以更快的保存数据,好知道运行是否有效

4、生成自己的数据

import random
import cv2
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import os
from unit import segmentation

path_font='/home/zz/文字'
path_out='/media/zz/testtttt'
if not os.path.exists(path_out):
    os.mkdir(path_out)

dicters='0123456789.'
CHARS='0123456789'
number=100
font_index=0
font_list=os.listdir(path_font)
font_list.sort()
f_ind=0
f_size=0

def get_word(length):
    global font_index, CHARS
    f = ''
    for i in range(length):
        f = f + random.choice(CHARS)
        font_index = font_index + 1
    return f

def get_txt():
    txt=''
    f3=random.randint(0,2) # 2/3的可能会出现.
    len_num=random.randint(1,8)
    num=get_word(len_num)
    txt=txt+num
    if f3>0 and len(num)>=3:
        txt=txt[:-2]+'.'+txt[-2:]
    return txt


def get_txt_test():
    global dicters
    return dicters

def get_bg(color, w,h):
    # w_l=random.randint(1,500)
    # w_r=random.randint(1,500)
    # h_t=random.randint(1,50)
    # h_b=random.randint(1,50)
    w_l=50
    w_r=50
    h_t=20
    h_b=20
    bg=np.zeros((h+h_b+h_t,w+w_l+w_r),dtype='uint8')
    bg=bg+color
    # bg=Image.fromarray(bg)
    return bg,w_l,h_t

def oblique(bg,fc):
    bg=np.array(bg)
    imgh,imgw=bg.shape
    new_bg=np.zeros((imgh,imgw+2*imgh),dtype='uint8')
    new_bg[:,imgh:imgh+imgw]=bg
    step=random.randint(15,25)
    st=random.randint(0,step)

    while st+imgh<=imgw+2*imgh:
        pt1=(st,0)
        pt2=(st+imgh,imgh)
        new_bg=cv2.line(new_bg, pt1, pt2, fc, 1, 4)
        st=st+step
        # cv2.imshow('a',new_bg)
        # cv2.waitKey()
    bg=new_bg[:,imgh:imgh+imgw]
    # cv2.imshow('a',new_bg)
    # cv2.imshow('b',bg)
    # cv2.waitKey()
    bg=Image.fromarray(bg)
    return bg




def interfere(bg,x,y,w,h,bc,fc):
    global f_size
    if f_size>=50 and not random.randint(0,3):
        bg=oblique(bg,fc)
    return bg


def gen_data(co):
    global f_ind,f_size

    # 确定颜色
    f_color=random.randint(0,255)
    bg_color=f_color
    while abs(bg_color-f_color)<30:
        bg_color=random.randint(0,255)

    # 字体
    f_ind = f_ind % len(font_list)
    f_size=random.randint(15,75)

    txt = get_txt()  # 文字内容
    font_text = ImageFont.truetype('{}/{}'.format(path_font, font_list[f_ind]), f_size)
    print('{}---{}'.format(f_ind,font_list[f_ind]))

    background_bg,x1,y1 = get_bg(bg_color, font_text.getsize(txt)[0], font_text.getsize(txt)[1])
    background_bg = Image.fromarray(background_bg, mode="L")
    draw_txt = ImageDraw.Draw(background_bg)  # 确认输出文字的背景图片
    draw_txt.text((x1, y1), txt, fill=(f_color), font=font_text)

    background_bg=interfere(background_bg,x1,y1,font_text.getsize(txt)[0], font_text.getsize(txt)[1],bg_color,f_color)

    txt = txt.replace('.', '+')
    background_bg.save('{}/{:08d}_{}.jpg'.format(path_out, co, txt))
    f_ind = f_ind + 1

    return


if __name__ == '__main__':


    for i in range(number):
        print(i)
        gen_data(i)

生成灰度图片,规则:

1、背景和文字的颜色差大于30

2、字的个数在1-8个

3、左右上下有一个随机的扩大范围

4、如果出现小数点,保留两位小数

5、干扰,斜线

 

5、预处理

包括,包括裁剪,二值化,统计max size,这里由于之后需要resize,所以统计的是max ration,就是w/h的最大值

二值化使用opencv提供的otsu方法

在处理过程中缩减周围可缩小范围


import numpy as np
import cv2
import os
import time
from unit import segmentation

t1 = 0
t2 = 0
name=''



def linkseg(segx,dis,img):
    nst=0
    nen=0
    nseg=[]

    h=min(img.shape[0]//10,5)
    for i in range(len(segx)):
        st,en=segx[i]
        img_p=img[:,st:en]
        y=np.sum(img_p,axis=1)
        y=y-np.average(y)
        pt=len(np.where(y!=0)[0])
        if pt<h:
            segx[i]=segx[i+1]
        else:
            break
    for i in range(len(segx)):
        ind=len(segx)-1-i
        st,en=segx[ind]
        img_p=img[:,st:en]
        y=np.sum(img_p,axis=1)
        y=y-np.average(y)
        pt=len(np.where(y!=0)[0])
        if pt<h:
            segx[ind]=segx[ind-1]
        else:
            break
    nseg.append([segx[0][0],segx[-1][1]])
    return nseg

def binar(img):
    global t1,t2,name
    count=0

    # cv2imshow('origin', img)
    # cv2waitKey()

    if len(img.shape)==3:
        # 灰度
        img=cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)

    # cv2.imwrite('{}/{}.jpg'.format(path_out,name),img)
    # cv2.imshow('binar', img)
    # cv2.waitKey()


    # 二值化

    ret3, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # 确保黑字0
    a=np.sum(img)/(img.shape[0]*img.shape[1]*255)
    if a>0.5:
        img=255-img

    # cv2.imshow('binar', img)
    # cv2.waitKey()

    # 裁剪文字,将之裁剪成一个一个字
    h,w=img.shape
    y=np.sum(img,axis=1)
    y=y-min(y)
    segy=segmentation(y,0)

    # 根据横向的空白,将裁剪下来的数据分成多条,认为每一条的纵向上面没有干扰
    # 生成数据,上下没有干扰,所以删除裁剪不正确的数据
    if len(segy)>1:
        cv2.imwrite('{}/{}'.format(path_err,filename),img)
        return

    for st,en in segy:
        img_p=img[st:en,:]
        x=np.sum(img_p,axis=0)
        x=x-min(x)
        segx=segmentation(x,0)
        # link seg
        segx=linkseg(segx,5,img_p)


        # 根据纵向的空白,将裁剪下来的数据分成多块,认为每一块上一个数字
        for st,en in segx:

            img_n=img_p[:,st:en]
            if count>1:
                cv2.imshow('a',img_n)
                cv2.waitKey()
            img_n = 255 - img_n
            cv2.imwrite('{}/{}'.format(path_out, filename), img_n)
            count = count + 1
            # cv2.imshow('c1', img_n)
            # cv2.waitKey()

            


if __name__ == '__main__':

    # img=cv2.imread('/home/zz/图片/桌面/test/lALPDgQ9qxS5d1jNAwDNBVY_1366_768(第 5 个复件).png')
    # binar(img)

    path='/media/zz/AE1AD9D91AD99F21/book-zz/digit'
    path_out='/media/zz/AE1AD9D91AD99F21/book-zz/digit-cut'
    path_err='/media/zz/AE1AD9D91AD99F21/book-zz/digit-err'
    if not os.path.exists(path_out):
        os.mkdir(path_out)
    if not os.path.exists(path_err):
        os.mkdir(path_err)
    img_list=os.listdir(path)
    img_list.sort()
    for i in range(544,len(img_list)):
        filename=img_list[i]
        name = '{:03d}'.format(i)
        print(filename)
        img=cv2.imread('{}/{}'.format(path,filename))
        binar(img)



    print(t1)
    print(t2)


 

 

5、保存label,并resize图片,分别保存到train和test

import os
import cv2
import re

# 文件名称
path_f='/media/cj1/data/digit_pic_lab'
dir_img='digit-cut'
dir_train='train_images'
list_train='train_list'
dir_test='test_images'
list_test='test_list'

# 计数器
count=0

# 字典
dicters='0123456789.-¥'

# 生成数据集
img_list=os.listdir('{}/{}'.format(path_f,dir_img))
img_list.sort()

def get_lab_num(lab):
    s=[]
    for l in lab:
        if l=='+':
            ind=dicters.index('.')
        else:
            ind=dicters.index(l)
        s.append(str(ind))
    return s

fs_train=open('{}/{}'.format(path_f,list_train),'w')
fs_test=open('{}/{}'.format(path_f,list_test),'w')

for filename in img_list:
    print(filename)
    lab=filename.split('.')[0].split('_')[-1] # 取出标签内容
    try:
        num_list=get_lab_num(lab) # 转为数字标签
    except:
        # 转化错误的话,就直接下一个
        continue
    try:
        img=cv2.imread('{}/{}/{}'.format(path_f,dir_img,filename))
        h,w,c=img.shape
        nw=int(w*IMG_H/h)
        img=cv2.resize(img,(nw,IMG_H))
        img=cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
        img_n[:,:nw]=img
    except:
        # 读取错误,就直接下一个
        continue
    name=re.sub('\D','',lab) # 取一个只有数字的名字
    newfilename='{}_{}.jpg'.format(filename.split('_')[0],name)
    if count%100==0:
        # 每100张训练,保存一张测试
        if not os.path.exists('{}/{}'.format(path_f,dir_test)):
            os.mkdir('{}/{}'.format(path_f,dir_test))
        fs_test.write('{} {} {} {}\n'.format(w,h,newfilename,','.join(num_list)))
        cv2.imwrite('{}/{}/{}'.format(path_f,dir_test,newfilename),img)
    else:
        if not os.path.exists('{}/{}'.format(path_f,dir_train)):
            os.mkdir('{}/{}'.format(path_f,dir_train))
        fs_train.write('{} {} {} {}\n'.format(w, h, newfilename, ','.join(num_list)))
        cv2.imwrite('{}/{}/{}'.format(path_f, dir_train, newfilename),img)
    count=count+1

在同一个目录下生成如上四个文件,然后文件夹中保存的是图片,test_list保存的是标签

6、修改模型代码

因为我们把数据格式根模型读取的数据格式生成的一样,所以大部分不用修改,只需要data_reader里面一些内容就可以了

1、分类数和图片大小,根据自己实际需要修改

我的这里:

NUM_CLASSES =10
DATA_SHAPE = [1, 32,300]

2、文件读取路径

这个直接写成自己的

data_dir='zz/data'

 

7、错误汇总

--------------------------------------

遇到了一个问题,模型是上面的我的数据的新模型,数据是从生成数据中截留的一点测试数据

在官方的infer中执行结果:

但是我train的时候准确率是高达0.99的,所以我用train修改了一个可以输出测试结果的代码

发现,问题出现在这一步indexes = prune(np.array(result[0]).flatten(), 0, 1)

由于我的输出当中全部都是数字,所以生成的结果通过0和1缩短一下以后就不成数据了。

---------------------------------------

Enforce failed. Expected x_mat_dims[1] == y_mat_dims[0], but received x_mat_dims[1]:768 != y_mat_dims[0]:512.
First matrix's width must be equal with second matrix's height. 768, 512 at [/paddle/paddle/fluid/operators/mul_op.cc:61]

错误原因:SHAPE的大小不对

我在的图片是32*300的,但是shape的大小设置成了48*500,然后就会报这个错误

 

 

 

 

 

 

 

 

 

  • 5
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
### 回答1: wrf-chem数据下载的相关网址链接: 1. NCEP/NCAR Reanalysis I: ftp://ftp.cdc.noaa.gov/Datasets/ncep.reanalysis.dailyavgs/surface/ 2. Chemical Transport Model (CTM) data from the GEOS-Chem group: https://acmg.seas.harvard.edu/geos/ 3. Emissions data from the Emissions Database for Global Atmospheric Research (EDGAR): https://edgar.jrc.ec.europa.eu/ 4. The Community Multi-scale Air Quality (CMAQ) modeling system data: https://www.epa.gov/air-research/community-multiscale-air-quality-cmaq-modeling-system 请注意,不同的数据来源可能需要不同的许可证才能访问,请确保您具有访问所需数据的合法资格。 ### 回答2: WRF-Chem是一种大气化学模型,它用于模拟大气中化学物种的输运和转化过程。在建立WRF-Chem模型之前,我们需要收集和处理一些数据,以确保模型的准确性和可靠性。这些数据包括地理信息、排放数据、气象数据和化学初始和边界条件等。 首先,地理信息数据是建立WRF-Chem模型的基础。这些数据包括经纬度、高程和土地覆盖类型等信息,可以用于生成地形和表面辐射强度图。我们可以在https://www.ngdc.noaa.gov/上下载世界各地的地理数据。 其次,排放数据是描述大气中污染物来源和排放速率的关键数据。这些数据包括人工排放和自然排放两种来源。人工排放包括工业、交通和农业等活动产生的污染物,自然排放包括植被的插值和火山喷发等自然事件。各个国家和地区的排放数据可在Emission Database for Global Atmospheric Research (EDGAR) (https://www.sciencedirect.com/science/article/pii/S1352231009003904 )上下载。 第三,气象数据是WRF-Chem模型的必需数据。气象数据包括气温、风速、风向和湿度等逐小时或逐分钟的数据。我们可以在National Centers for Environmental Prediction (NCEP) (https://www.ncdc.noaa.gov/data-access/model-data/model-datasets)或European Center for Medium-Range Weather Forecasts (ECMWF) (https://www.ecmwf.int/en/forecasts/datasets)上下载气象数据。 最后,化学初始和边界条件数据是指大气中化学物种的浓度和化学反应速率等信息。这些数据通常由现场观测或其他化学模型得出,可以在全球化学输送模型 (GEOS-Chem) (http://acmg.seas.harvard.edu/geos/)上获取。 总之,WRF-Chem模型的建立需要以上四个基本数据。这些数据可以在相关数据下载网址上获取。但是,这些数据的质量和格式都需要我们认真审查和处理,以确保WRF-Chem模型的准确性和可靠性。 ### 回答3: wrf-chem是一种用于模拟大气物质输运和化学反应的数值模型。在进行wrf-chem模拟时,需要使用许多与气体和颗粒物浓度、化学反应等相关的数据。这些数据可以通过官方网站和其他一些数据平台进行下载。 其中,官方网站是wrf-chem模型最全面的数据源,开发者提供了许多与模型运行相关的数据和工具。这些数据包括了不同时间尺度上的气象模型、气体和颗粒物浓度模型、化学反应模型、辐射强度模型等。此外,网站中还提供了许多工具,例如反求模块、统计模块等,可以用于模型调试和后处理。下载方式为直接点击网站上的下载链接,选择相应的数据和工具即可。 另外,还有一些数据平台也可以提供相关数据的下载,例如NCAR Data Portal、Earth System Grid、国家气象信息中心等。这些平台通常提供了一些免费的数据下载服务,但需要用户进行注册和认证。同时,有些数据需要进行特定的格式转换,才能够被wrf-chem模型使用。 总体来说,wrf-chem模型所需的数据比较丰富,但是通过官方网站和其他数据平台的配合,用户可以方便地获取这些数据,并进行相应的分析和后处理。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值