fasttext前的数据准备

# -*- coding: utf-8 -*-
"""
Created on Tue Jul 10 13:17:30 2018

@author: zltao2
"""

import argparse

import os
import io
import sys
import time
#sys.path.insert(0,'D:\\graphvis2.38\\bin')
import re
from glob import glob
import random
import numpy as np
from shutil import copy
def context_clean(context):
    context.strip('\n')
    context = context.replace('\n','')
    context = context.replace(' ', '')
    return context

def clean_str(string): #改用正则表达式
    """Tokenization/string cleaning for dataset
    Every dataset is lower cased except """
    string = re.sub(r"\\", "", string)
    string = re.sub(r"\'", "", string)
    string = re.sub(r"\"", "", string)
    string = re.sub(r"^", "", string)
    string = re.sub(r"■", "", string)
    #string = re.sub(r"#", "", string)
    string = re.sub(r"\t", "", string)
    #string = re.sub(r" ", "", string)
    #re.sub(r"[-\s+\.\!\/_,$%^*(+\"\']+|[+——....!,。?’、~@#¥%……&*()]+", "", string)
    string = re.sub(u"[\s+\.\!\/_,$%^*+\"\']+|[+——~@#¥%……&*()]+", "", string)
    # useless_pun= frozenset(u"[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。?、~@#¥%……&*()]+ ")
    # for ch in string:
    #     if useless_pun.__contains__(ch):
    #         string = re.sub(ch,"", string)
    return string.strip().lower()
def get_trainfiles(file_dir):
    first = []
    label_first = []
    middle = []
    label_middle = []
    end1 = []
    label_end1 = []
    only = []
    label_only = []
    j=0
    #定义存放各类别数据和对应标签的列表,列表名对应你所需要分类的列别名
    #A5,A6等是我的数据集中要分类图片的名字

    cout=0
    for file in os.listdir(file_dir):
        for file1 in os.listdir(file_dir+'/'+file):
          if('判决书' in file1): 
              
            for file2 in os.listdir(file_dir+'/'+file+'/'+file1): 
                for file3 in os.listdir(file_dir+'/'+file+'/'+file1+'/'+file2): 
                    #print(file3)
                    dir=file_dir+'/'+file+'/'+file1+'/'+file2 +'/'+ file3
                    if 'C.txt' in file3:
                        only.append(dir)
                        label_only.append(int(0))
                        cout=cout+1
                       
                    elif 'CB.txt' in file3:
                        end1.append(dir)
                        label_end1.append(int(3))
                        cout=cout+1
                    elif 'CA.txt' in file3:
                        first.append(dir)
                        label_first.append(int(1)) 
                        cout=cout+1
                    elif('C.txt' not in file3) and ('CA.txt' not in file3)and ('CB.txt' not in file3)and ('.txt' in file3):
                        middle.append(dir)
                        label_middle.append(int(2))
                        cout=cout+1
                    else:
                        j=j+1 
                              
    #print('There are %d first\nThere are %d middle\nThere are %d end1\nThere are %d only' %(len(first),len(middle),len(end1),len(only)))                
                        
    image_list = np.hstack((first,middle,end1,only)) 
    label_list = np.hstack((label_first,label_middle,label_end1,label_only))
    #用来水平合并数组

    temp = np.array([image_list,label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)

    image_list = list(temp[:,0])
    label_list = list(temp[:,1])
    label_list = [int(i) for i in label_list]  
    return  image_list,label_list

def generate_train_val_list_txt(image_list, label_list, train_scale):
#            print('==============')
#            print(len(image_list))
            num_train_each_class = int(len(image_list) * train_scale)
            train_dir=image_list[:num_train_each_class]
            train_label=label_list[:num_train_each_class]
            test_dir=[]
            test_label=[]
#            test_dir=image_list[num_train_each_class:]
#            test_label=label_list[num_train_each_class:]
            return train_dir, train_label, test_dir, test_label
        
def generate_train_test_file(train_dir):
           count = 0
           context1=[]          
           for i in range(len(train_dir)):
                txts_list_txt = train_dir[i]  
                #print(txts_list_txt)
                with io.open(txts_list_txt, 'r', encoding='UTF-8') as f:
                    context = f.read()
                    context = context_clean(context)
                    #print('====运行中===>>>>>>>>>处理图片数%d'%(count))
                    
                    context = clean_str(context)  
                    
                    #label_index = labels_list.index(label)
                    context1.append(context)
                    count = count+1
           print('====共有文件数===>>>>>>>>>处理图片数%d'%(count))
           return context1

#from scipy import misc
file_dir='/home/4T/2018_runpu_data/test_data/0409117ronghe'

classes=['sdasd','我会骄傲是','asdasdasd']

txt_file="/home/wxliang/润普资料整理_20180411/code/20180411_lyyu_clssification/文本类别分类/segementation_ocr_v2_20180611_linux/data/train22.txt"
txt_file1="/home/wxliang/润普资料整理_20180411/code/20180411_lyyu_clssification/文本类别分类/segementation_ocr_v2_20180611_linux/data/test.txt"
def save_classes_index(classes, txt_file,train_label1):
    with io.open(txt_file, 'w',encoding='UTF-8') as f:
        for i,um in enumerate(classes):
#            print(i)
            f.write('__label__%s,%s\n' % (train_label1[i],um))
            


train_scale=1    
image_list,label_list=get_trainfiles(file_dir)
train_dir, train_label, test_dir, test_label=generate_train_val_list_txt(image_list, label_list, train_scale)

context1=generate_train_test_file(train_dir)

save_classes_index(context1, txt_file,train_label)

#context11=generate_train_test_file(test_dir)
#
#save_classes_index(context11, txt_file1,test_label)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值