python调用随机分层抽样方法_python 多分类任务中按照类别分层采样

#将数据按照类别进行分层划分

defsave_file_stratified(filename, ssdfile_dir, categories):"""将文件分流到3个文件中

filename: 原数据地址,一个csv文件

文件内容格式: 类别\t内容"""f_train= open('../data/usefuldata-711depart/train.txt', 'w', encoding='utf-8')

f_val= open('../data/usefuldata-711depart/val.txt', 'w', encoding='utf-8')

f_test= open('../data/usefuldata-711depart/test.txt', 'w', encoding='utf-8')#f_class = open('../data/usefuldata-37depart/class.txt', 'w', encoding='utf-8')

dict_ssdqw ={}for ssdfile inos.listdir(ssdfile_dir):

ssdfile_name=os.path.join(ssdfile_dir, ssdfile)

f= open(ssdfile_name, 'r', encoding='utf-8')

content_qw= ''content=f.readline()#以下部分,因为统计整个案件基本情况他有换行,所以将多行处理在一行里面

whilecontent:

content_qw+=content

content_qw= content_qw.replace('\n', '')

content=f.readline()

ssdfile_key= str(ssdfile).replace('.txt','')

dict_ssdqw[ssdfile_key]=content_qw#doc_count代表每一类数据总共有多少个

doc_count_0 =0

doc_count_1=0

doc_count_2=0

doc_count_3=0

doc_count_4=0

doc_count_5=0

doc_count_6=0

doc_count_7=0

doc_count_8=0

doc_count_9=0

doc_count_10=0

doc_count_11=0

doc_count_12=0

temp_file= open(filename, 'r', encoding='utf-8')

line=temp_file.readline()whileline:

line_content= line.split(',')

name=line_content[0]if name indict_ssdqw:

label= line_content[1]if label ==categories[0]:

doc_count_0+= 1

elif label == categories[1]:

doc_count_1+= 1

elif label == categories[2]:

doc_count_2+= 1

elif label == categories[3]:

doc_count_3+= 1

elif label == categories[4]:

doc_count_4+= 1

elif label == categories[5]:

doc_count_5+= 1

elif label == categories[6]:

doc_count_6+= 1

elif label == categories[7]:

doc_count_7+= 1

elif label == categories[8]:

doc_count_8+= 1

elif label == categories[9]:

doc_count_9+= 1

elif label == categories[10]:

doc_count_10+= 1

elif label == categories[11]:

doc_count_11+= 1

elif label == categories[12]:

doc_count_12+= 1line=temp_file.readline()

temp_file.close()#总数量

doc_count = doc_count_0 + doc_count_1 + doc_count_2 + doc_count_3 +\

doc_count_4+ doc_count_5 + doc_count_6 + doc_count_7 +\

doc_count_8+ doc_count_9 + doc_count_10 + doc_count_11 +doc_count_12

class_set=set()

tag_train_0= doc_count_0 * 70 / 100tag_train_1= doc_count_1 * 70 / 100tag_train_2= doc_count_2 * 70 / 100tag_train_3= doc_count_3 * 70 / 100tag_train_4= doc_count_4 * 70 / 100tag_train_5= doc_count_5 * 70 / 100tag_train_6= doc_count_6 * 70 / 100tag_train_7= doc_count_7 * 70 / 100tag_train_8= doc_count_8 * 70 / 100tag_train_9= doc_count_9 * 70 / 100tag_train_10= doc_count_10 * 70 / 100tag_train_11= doc_count_11 * 70 / 100tag_train_12= doc_count_12 * 70 / 100tag_val_0= doc_count_0 * 85 / 100tag_val_1= doc_count_1 * 85 / 100tag_val_2= doc_count_2 * 85 / 100tag_val_3= doc_count_3 * 85 / 100tag_val_4= doc_count_4 * 85 / 100tag_val_5= doc_count_5 * 85 / 100tag_val_6= doc_count_6 * 85 / 100tag_val_7= doc_count_7 * 85 / 100tag_val_8= doc_count_8 * 85 / 100tag_val_9= doc_count_9 * 85 / 100tag_val_10= doc_count_10 * 85 / 100tag_val_11= doc_count_11 * 85 / 100tag_val_12= doc_count_12 * 85 / 100

#tag_test = doc_count * 70 / 100

tag_0 =0

tag_1=0

tag_2=0

tag_3=0

tag_4=0

tag_5=0

tag_6=0

tag_7=0

tag_8=0

tag_9=0

tag_10=0

tag_11=0

tag_12=0#有些文书行业标记是空!!我想看看有多少条?

blank_tag =0#标记一下,每个类别有多少个训练集、验证集、测试集?

train_class_tag =[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

val_class_tag=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

test_class_tag=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]#csvfile = open(filename, 'r', encoding='utf-8')

txtfile = open(filename, 'r', encoding='utf-8')

process_line=txtfile.readline()whileprocess_line:

line_content= process_line.split(',')

name=line_content[0]if name indict_ssdqw:

content=dict_ssdqw[name]

label= line_content[1]#if label != '' and label != '其他行业':

if label != '':

class_set.add(label)#对每一类进行分层采样

if label ==categories[0]:

tag_0+= 1

if tag_0

f_train.write(label+ '\t' + content + '\n')

train_class_tag[0]+= 1

elif tag_0

f_val.write(label+ '\t' + content + '\n')

val_class_tag[0]+= 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[0]+= 1

elif label == categories[1]:

tag_1+= 1

if tag_1

f_train.write(label+ '\t' + content + '\n')

train_class_tag[1] += 1

elif tag_1

f_val.write(label+ '\t' + content + '\n')

val_class_tag[1] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[1] += 1

elif label == categories[2]:

tag_2+= 1

if tag_2

f_train.write(label+ '\t' + content + '\n')

train_class_tag[2] += 1

elif tag_2

f_val.write(label+ '\t' + content + '\n')

val_class_tag[2] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[2] += 1

elif label == categories[3]:

tag_3+= 1

if tag_3

f_train.write(label+ '\t' + content + '\n')

train_class_tag[3] += 1

elif tag_3

f_val.write(label+ '\t' + content + '\n')

val_class_tag[3] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[3] += 1

elif label == categories[4]:

tag_4+= 1

if tag_4

f_train.write(label+ '\t' + content + '\n')

train_class_tag[4] += 1

elif tag_4

f_val.write(label+ '\t' + content + '\n')

val_class_tag[4] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[4] += 1

elif label == categories[5]:

tag_5+= 1

if tag_5

f_train.write(label+ '\t' + content + '\n')

train_class_tag[5] += 1

elif tag_5

f_val.write(label+ '\t' + content + '\n')

val_class_tag[5] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[5] += 1

elif label == categories[6]:

tag_6+= 1

if tag_6

f_train.write(label+ '\t' + content + '\n')

train_class_tag[6] += 1

elif tag_6

f_val.write(label+ '\t' + content + '\n')

val_class_tag[6] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[6] += 1

elif label == categories[7]:

tag_7+= 1

if tag_7

f_train.write(label+ '\t' + content + '\n')

train_class_tag[7] += 1

elif tag_7

f_val.write(label+ '\t' + content + '\n')

val_class_tag[7] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[7] += 1

elif label == categories[8]:

tag_8+= 1

if tag_8

f_train.write(label+ '\t' + content + '\n')

train_class_tag[8] += 1

elif tag_8

f_val.write(label+ '\t' + content + '\n')

val_class_tag[8] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[8] += 1

elif label == categories[9]:

tag_9+= 1

if tag_9

f_train.write(label+ '\t' + content + '\n')

train_class_tag[9] += 1

elif tag_9

f_val.write(label+ '\t' + content + '\n')

val_class_tag[9] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[9] += 1

elif label == categories[10]:

tag_10+= 1

if tag_10

f_train.write(label+ '\t' + content + '\n')

train_class_tag[10] += 1

elif tag_10

f_val.write(label+ '\t' + content + '\n')

val_class_tag[10] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[10] += 1

elif label == categories[11]:

tag_11+= 1

if tag_11

f_train.write(label+ '\t' + content + '\n')

train_class_tag[11] += 1

elif tag_11

f_val.write(label+ '\t' + content + '\n')

val_class_tag[11] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[11] += 1

elif label == categories[12]:

tag_12+= 1

if tag_12

f_train.write(label+ '\t' + content + '\n')

train_class_tag[12] += 1

elif tag_12

f_val.write(label+ '\t' + content + '\n')

val_class_tag[12] += 1

else:

f_test.write(label+ '\t' + content + '\n')

test_class_tag[12] += 1

else:

blank_tag+= 1process_line=txtfile.readline()

txtfile.close()print("有" + str(blank_tag) + "个文书的行业标记为空!")print("train:")print(train_class_tag)

train_tag_total=0for i_total intrain_class_tag:

train_tag_total+=i_total

train_class_tag_distribute=[]for i intrain_class_tag:

train_class_tag_distribute.append((i/ train_tag_total) * 100)print("分布:")print(train_class_tag_distribute)print("val:")print(val_class_tag)

val_tag_total=0for i_total inval_class_tag:

val_tag_total+=i_total

val_class_tag_distribute=[]for i inval_class_tag:

val_class_tag_distribute.append((i/ val_tag_total) * 100)print("分布:")print(val_class_tag_distribute)print("test:")print(test_class_tag)

test_tag_total=0for i_total intest_class_tag:

test_tag_total+=i_total

test_class_tag_distribute=[]for i intest_class_tag:

test_class_tag_distribute.append((i/ test_tag_total) * 100)print("分布:")print(test_class_tag_distribute)

f_train.close()

f_test.close()

f_val.close()if __name__ == '__main__':

categories=["class1","class2","class3","class4","class5","class6","class7","class8","class9","class10","class11","class12","class13"]

save_file_stratified('../data/qwdata/shuffle-try3/classified_table_ms.txt', '../data/qwdata/ms-ygscplusssdqw',categories)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值