哈希编码类型数据集制作
实际模型训练中,我们经常需要创建自己的数据集,其中大多数都是简单的生成路径,然后通过路径将数据传入模型,但是哈希编码的标签与通常的不一样,在这里分享一种创建自己的哈希编码数据集程序,希望可以达到抛砖引玉效果.
数据集格式
生成txt文本
我们首先将数据集图片的路径转化成txt文本格式,并划分成总数据集,训练数据集、测试数据集:
# coding=gbk
import os
import random
def generate(dir, label):
files = os.listdir(dir) # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
files.sort() #对文件或文件夹进行排序
print(files)
print('****************')
print('input :', dir)
print('start...')
listText = open('Silk2/all_list.txt', 'a+') #创建并打开一个txt文件,a+表示打开一个文件并追加内容
# names = "path" + ' ' + "class"
# listText.write(names)
j = 1
for file in files: #遍历文件夹中的文件
fileType = os.path.split(file) #os.path.split()返回文件的路径和文件名,【0】为路径,【1】为文件名
if fileType[1] == '.txt': # 若文件名的后缀为txt,则继续遍历循环,否则退出循环
continue
name = './data/Silk2/'+ folder+ '/' +file + ' '+ str(int(label)) +'\n' #name 为文件路径和文件名+空格+label+换行
listText.write(name) #在创建的txt文件中写入name
j = j + 1
listText.close() #关闭txt文件
print('down!')
print('****************')
l_train = []
l_val = []
l_test = []
# 读取文件中的内容,并将其打乱写入列表FileNameList
def ReadFileDatas(original_filename):
file=open(original_filename,'r+')
FileNameList=file.readlines()
random.shuffle(FileNameList)
file.close()
print("数据集总量:", len(FileNameList))
return FileNameList
#将数据集随机划分
def TrainValTestFile(FileNameList):
i=0
j=len(FileNameList)
for line in FileNameList:
if i<(j*0.8):
i+=1
l_train.append(line)
elif i<(j*0.8):
i+=1
l_val.append(line)
else:
i+=1
l_test.append(line)
print("总数量:%d,此时创建train,val,test数据集"%i)
return l_train,l_val,l_test
#将获取到的各个数据集的包含的文件名写入txt中
def WriteDatasToFile(listInfo, new_filename):
file_handle = open(new_filename,'w')
for str_Result in listInfo:
file_handle.write(str_Result)
file_handle.close()
print('写入 %s 文件成功.' % new_filename)
outer_path = './Silk2' # 这里是你的图片路径
if __name__ == '__main__': #主函数
i = 1
folderlist = os.listdir(outer_path)# 列举文件夹
for folder in folderlist: #遍历文件夹中的文件夹(若engagement文件夹中存在txt或py文件,则后面会报错)
print(os.path.join(outer_path, folder))
generate(os.path.join(outer_path, folder), i) #调用generate函数,函数中的参数为:(图片路径+文件夹名,标签号)
i += 1
listFileInfo = ReadFileDatas('Silk2/all_list.txt') # 读取文件
l_train,l_val, l_test = TrainValTestFile(listFileInfo)
WriteDatasToFile(l_train, './recordvideo/all_train.txt')
WriteDatasToFile(l_val, './recordvideo/all_val.txt')
WriteDatasToFile(l_test, './recordvideo/all_test.txt')
类型如下:
生产csv格式
我们将txt文本格式的路径转换成csv类型的格式,以便于后期哈希编码
# coding=utf8
import os
import csv
# ============================说明==================================
# 把要转换成csv文件的txt文件路径复制到rootdir下,运行程序后,在..\\**\\日期\\**\\**\\recordvideo
# 目录下会生成txt文档对应的csv格式文档,且带表头。
# ============================说明==================================
rootdir = './recordvideo'
def findtxt(path, ret):
"""Finding the *.txt file in specify path"""
filelist = os.listdir(path)
for filename in filelist:
de_path = os.path.join(path, filename)
if os.path.isfile(de_path):
if de_path.endswith(".txt"): # Specify to find the txt file.
ret.append(de_path)
else:
findtxt(de_path, ret)
def txt2csv(filepath, data):
with open(filepath + '.csv', 'w', newline='') as csvfile:
spamwriter = csv.writer(csvfile)
# 添加csv文件的表头
spamwriter.writerow(['path', 'super_class_id'])
with open(filepath, 'r') as file_object:
lines = file_object.readlines()
# 清除lines list中换行符\n, 空格符''
lines = [x.strip('\n') for x in lines if x.strip() != '']
for i, line in enumerate(lines):
data.append([])
elements = line.split(' ')
# 清除单行line 中存在的空格符字符串''
elements = [x for x in elements if x != '']
for element in elements:
data[i].append(element)
# 将txt文件内容写入csv文件中
spamwriter.writerows(data)
if __name__ == "__main__":
ret = []
findtxt(rootdir, ret)
for ret_ in ret:
data = []
txt2csv(ret_, data)
将CSV格式路径通过哈希编码转化成我们需要的类型
# coding=gbk
import os
import numpy as np
import pandas as pd
df = pd.read_csv('./Silk2/all_list.txt.csv',engine='python')
# df.columns =['path','class','id']
print(df)
n = len(df)
print(n)
nclass = 11
all_label_data = np.zeros((n, nclass), dtype=np.int8)
all_label_path = list()
for i, data in df.iterrows():
all_label_data[i][data.super_class_id-1] = 1
all_label_path.append(data.path)
print(all_label_data)
print(all_label_path)
train_num = 4031
test_num = 1000
data_id = np.array(range(0, n))
test_data_index = []
train_data_index = []
database_data_index = []
for i in range(nclass):
class_mask = df['super_class_id'] == i + 1
# print(class_mask)
index_of_class = data_id[class_mask].copy() # index of the class [2, 10, 656,...]
print("index_of_class:",index_of_class)
print("个数",len(index_of_class))
# np.random.shuffle(index_of_class)
# query_n = int(test_num / nclass)
query_n = int(len(index_of_class) * 0.2)
print("query_n:",query_n)
# train_n = int(train_num / nclass)
train_n = int(len(index_of_class)* 0.8)
print("train_n:",train_n)
index_for_query = index_of_class[:query_n].tolist()
print("index_for_query:",index_for_query)
index_for_db = index_of_class[query_n:].tolist()
print("index_for_db:",index_for_db)
index_for_train = index_for_db[:train_n]
print("index_for_train:",index_for_train)
train_data_index.extend(index_for_train)
test_data_index.extend(index_for_query)
database_data_index.extend(index_for_db)
with open("database.txt", "w") as f:
for index in database_data_index:
line = all_label_path[index] + " " \
+ str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
f.write(line)
with open("train.txt", "w") as f:
for index in train_data_index:
line = all_label_path[index] + " " \
+ str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
f.write(line)
with open("test.txt", "w") as f:
for index in test_data_index:
line = all_label_path[index] + " " \
+ str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
f.write(line)
附录–mirflickr25k转化成哈希代码程序
# download data from here: https://press.liacs.nl/mirflickr/mirdownload.html
# import hashlib
# with open("mirflickr25k.zip","rb") as f:
# md5_obj = hashlib.md5()
# md5_obj.update(f.read())
# hash_code = md5_obj.hexdigest()
# print(str(hash_code).upper() == "A23D0A8564EE84CDA5622A6C2F947785")
import os
import numpy as np
all_label_data = np.zeros((25000, 24), dtype=np.int8)
label_index = -1
label_dir_name = "mirflickr25k_annotations_v080"
for label_file in os.listdir(label_dir_name):
if "README.txt" != label_file and "_r1" not in label_file:
label_index += 1
with open(os.path.join(label_dir_name, label_file), "r") as f:
print(label_file)
for line in f.readlines():
all_label_data[int(line.strip()) - 1][label_index] = 1
if os.path.exists(label_file[:-4]+"_r1.txt"):
with open(os.path.join(label_dir_name, label_file[:-4]+"_r1.txt"), "r") as f:
print(label_file+"_r1")
for line in f.readlines():
all_label_data[int(line.strip()) - 1][label_index] = 1
train_num = 5000
test_num = 2000
# perm_index = np.random.permutation(all_label_data.shape[0])
non_exist_index = np.where((all_label_data.sum(1)==0) == True)[0]
perm_index = np.array(list(range(all_label_data.shape[0])))
perm_index = np.delete(perm_index, non_exist_index)
np.random.shuffle(perm_index)
# train_data_index = perm_index[:train_num]
# test_data_index = perm_index[train_num:train_num + test_num]
# database_data_index = perm_index[train_num + test_num:]
test_data_index = perm_index[:test_num]
train_data_index = perm_index[test_num:test_num + train_num]
database_data_index = perm_index[test_num:]
with open("database.txt", "w") as f:
for index in database_data_index:
line = "data/mirflickr/im" + str(index + 1) + ".jpg " + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
f.write(line)
with open("train.txt", "w") as f:
for index in train_data_index:
line = "data/mirflickr/im" + str(index + 1) + ".jpg " + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
f.write(line)
with open("test.txt", "w") as f:
for index in test_data_index:
line = "data/mirflickr/im" + str(index + 1) + ".jpg " + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
f.write(line)