数据处理(过采样)

25 篇文章 2 订阅
# coding: utf-8
import os
import sys
import matplotlib.pyplot as plt
#将全部数据索引至字典中
path = "/media/dell/dell/data/remote_sensing/remote/train_image"
dirs = sorted(os.listdir(path))
print(dirs)
files = {}
for index, dir in enumerate(dirs):
    path_ = path + "/" + dir + "/"
    files[int(dir)] = []
    for file in os.listdir(path_):
        files[int(dir)].append(path_+file)
    sys.stdout.write('\r>> Loading data %d/%d'%(index+1, 9))
    sys.stdout.flush()
sys.stdout.write("\n")    
#print(len(files))
#查看各种数据所占比例
file_num = []
for i in range(1, 10):
    file_num.append(len(files[i]))
print(file_num)
plt.bar(range(len(file_num)),file_num)
plt.xlabel('class_id')
plt.ylabel('amount')
plt.ylim(0, 11000)
for x,y in zip(range(len(file_num)),file_num):
    plt.text(x, y+100, '%d' % y, ha='center', va= 'bottom')
plt.show()
#写入valid数据
import random
all_={}
f1=open('./all.txt','w+')
for i in range(1, 10):
    all_[i] = files[i]
    for item in all_[i]:
        print(item)
        f1.write(item+"\n")
f1.close()
f = open("./valid.txt", "w+")
valid_data = {}
train_data = {}
for i in range(1, 10):
    valid_data[i] = random.sample(files[i], 200)
    train_data[i] = list(set(files[i]) - set(valid_data[i]))
    for item in valid_data[i]:
        print(item)
        f.write(item+"\n")
f.close()
#查看除去valid的数据
file_num_ = []
for i in range(1, 10):
    file_num_.append(len(train_data[i]))
print(file_num_)
plt.bar(range(len(file_num_)), file_num_)
plt.xlabel('class_id')
plt.ylabel('amount')
plt.ylim(0, 11000)
for x,y in zip(range(len(file_num_)), file_num_):
    plt.text(x, y+100, '%d' % y, ha='center', va= 'bottom')
plt.show()
#写入train数据
f = open("./train.txt", "w+")
for i in range(1, 10):
    for item in train_data[i]:
        print(item)
        f.write(item+"\n")
f.close()
#进行过采样
max_amount = max(file_num_)
for i in range(1, 10):
    for j in range(max_amount-len(train_data[i])):
        train_data[i].append(train_data[i][random.randint(0, file_num_[i-1]-1)])
#可视化数据比例
file_num_ = []
for i in range(1, 10):
    file_num_.append(len(train_data[i]))
print(file_num_)
plt.bar(range(len(file_num_)), file_num_)
plt.xlabel('class_id')
plt.ylabel('amount')
plt.ylim(0, 11000)
for x,y in zip(range(len(file_num_)), file_num_):
    plt.text(x, y+100, '%d' % y, ha='center', va= 'bottom')
plt.show()
#写入过采样的数据
f = open("./train_oversampling.txt", "w+")
for i in range(1, 10):
    for item in train_data[i]:
        print(item)
        f.write(item+"\n")
f.close()
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值