# -*- coding: utf-8 -*-
"""
Created on Tue Jul 10 13:17:30 2018
@author: zltao2
"""
import argparse
import os
import io
import sys
import time
#sys.path.insert(0,'D:\\graphvis2.38\\bin')
import re
from glob import glob
import random
import numpy as np
from shutil import copy
def context_clean(context):
context.strip('\n')
context = context.replace('\n','')
context = context.replace(' ', '')
return context
def clean_str(string): #改用正则表达式
"""Tokenization/string cleaning for dataset
Every dataset is lower cased except """
string = re.sub(r"\\", "", string)
string = re.sub(r"\'", "", string)
string = re.sub(r"\"", "", string)
string = re.sub(r"^", "", string)
string = re.sub(r"■", "", string)
#string = re.sub(r"#", "", string)
string = re.sub(r"\t", "", string)
#string = re.sub(r" ", "", string)
#re.sub(r"[-\s+\.\!\/_,$%^*(+\"\']+|[+——....!,。?’、~@#¥%……&*()]+", "", string)
string = re.sub(u"[\s+\.\!\/_,$%^*+\"\']+|[+——~@#¥%……&*()]+", "", string)
# useless_pun= frozenset(u"[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。?、~@#¥%……&*()]+ ")
# for ch in string:
# if useless_pun.__contains__(ch):
# string = re.sub(ch,"", string)
return string.strip().lower()
def get_trainfiles(file_dir):
first = []
label_first = []
middle = []
label_middle = []
end1 = []
label_end1 = []
only = []
label_only = []
j=0
#定义存放各类别数据和对应标签的列表,列表名对应你所需要分类的列别名
#A5,A6等是我的数据集中要分类图片的名字
cout=0
for file in os.listdir(file_dir):
for file1 in os.listdir(file_dir+'/'+file):
if('判决书' in file1):
for file2 in os.listdir(file_dir+'/'+file+'/'+file1):
for file3 in os.listdir(file_dir+'/'+file+'/'+file1+'/'+file2):
#print(file3)
dir=file_dir+'/'+file+'/'+file1+'/'+file2 +'/'+ file3
if 'C.txt' in file3:
only.append(dir)
label_only.append(int(0))
cout=cout+1
elif 'CB.txt' in file3:
end1.append(dir)
label_end1.append(int(3))
cout=cout+1
elif 'CA.txt' in file3:
first.append(dir)
label_first.append(int(1))
cout=cout+1
elif('C.txt' not in file3) and ('CA.txt' not in file3)and ('CB.txt' not in file3)and ('.txt' in file3):
middle.append(dir)
label_middle.append(int(2))
cout=cout+1
else:
j=j+1
#print('There are %d first\nThere are %d middle\nThere are %d end1\nThere are %d only' %(len(first),len(middle),len(end1),len(only)))
image_list = np.hstack((first,middle,end1,only))
label_list = np.hstack((label_first,label_middle,label_end1,label_only))
#用来水平合并数组
temp = np.array([image_list,label_list])
temp = temp.transpose()
np.random.shuffle(temp)
image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(i) for i in label_list]
return image_list,label_list
def generate_train_val_list_txt(image_list, label_list, train_scale):
# print('==============')
# print(len(image_list))
num_train_each_class = int(len(image_list) * train_scale)
train_dir=image_list[:num_train_each_class]
train_label=label_list[:num_train_each_class]
test_dir=[]
test_label=[]
# test_dir=image_list[num_train_each_class:]
# test_label=label_list[num_train_each_class:]
return train_dir, train_label, test_dir, test_label
def generate_train_test_file(train_dir):
count = 0
context1=[]
for i in range(len(train_dir)):
txts_list_txt = train_dir[i]
#print(txts_list_txt)
with io.open(txts_list_txt, 'r', encoding='UTF-8') as f:
context = f.read()
context = context_clean(context)
#print('====运行中===>>>>>>>>>处理图片数%d'%(count))
context = clean_str(context)
#label_index = labels_list.index(label)
context1.append(context)
count = count+1
print('====共有文件数===>>>>>>>>>处理图片数%d'%(count))
return context1
#from scipy import misc
file_dir='/home/4T/2018_runpu_data/test_data/0409117ronghe'
classes=['sdasd','我会骄傲是','asdasdasd']
txt_file="/home/wxliang/润普资料整理_20180411/code/20180411_lyyu_clssification/文本类别分类/segementation_ocr_v2_20180611_linux/data/train22.txt"
txt_file1="/home/wxliang/润普资料整理_20180411/code/20180411_lyyu_clssification/文本类别分类/segementation_ocr_v2_20180611_linux/data/test.txt"
def save_classes_index(classes, txt_file,train_label1):
with io.open(txt_file, 'w',encoding='UTF-8') as f:
for i,um in enumerate(classes):
# print(i)
f.write('__label__%s,%s\n' % (train_label1[i],um))
train_scale=1
image_list,label_list=get_trainfiles(file_dir)
train_dir, train_label, test_dir, test_label=generate_train_val_list_txt(image_list, label_list, train_scale)
context1=generate_train_test_file(train_dir)
save_classes_index(context1, txt_file,train_label)
#context11=generate_train_test_file(test_dir)
#
#save_classes_index(context11, txt_file1,test_label)