哈希编码类型数据集制作

哈希编码类型数据集制作

实际模型训练中,我们经常需要创建自己的数据集,其中大多数都是简单的生成路径,然后通过路径将数据传入模型,但是哈希编码的标签与通常的不一样,在这里分享一种创建自己的哈希编码数据集程序,希望可以达到抛砖引玉效果.

数据集格式

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

生成txt文本

我们首先将数据集图片的路径转化成txt文本格式,并划分成总数据集,训练数据集、测试数据集:

# coding=gbk
import os
import random
def generate(dir, label):
    files = os.listdir(dir) # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。

    files.sort()  #对文件或文件夹进行排序
    print(files)

    print('****************')
    print('input :', dir)
    print('start...')
    listText = open('Silk2/all_list.txt', 'a+')  #创建并打开一个txt文件,a+表示打开一个文件并追加内容


    # names = "path" + ' ' + "class"
    # listText.write(names)
    j = 1

    for file in files:  #遍历文件夹中的文件
        fileType = os.path.split(file) #os.path.split()返回文件的路径和文件名,【0】为路径,【1】为文件名
        if fileType[1] == '.txt':  # 若文件名的后缀为txt,则继续遍历循环,否则退出循环
            continue
        name = './data/Silk2/'+ folder+ '/' +file  + ' '+ str(int(label))  +'\n'  #name 为文件路径和文件名+空格+label+换行
        listText.write(name)  #在创建的txt文件中写入name

        j = j + 1


    listText.close() #关闭txt文件

    print('down!')
    print('****************')


l_train = []
l_val = []
l_test = []

# 读取文件中的内容,并将其打乱写入列表FileNameList
def ReadFileDatas(original_filename):
    file=open(original_filename,'r+')
    FileNameList=file.readlines()
    random.shuffle(FileNameList)
    file.close()
    print("数据集总量:", len(FileNameList))
    return FileNameList

#将数据集随机划分
def TrainValTestFile(FileNameList):
    i=0
    j=len(FileNameList)
    for line in FileNameList:
        if i<(j*0.8):
            i+=1
            l_train.append(line)
        elif i<(j*0.8):
            i+=1
            l_val.append(line)
        else:
            i+=1
            l_test.append(line)
    print("总数量:%d,此时创建train,val,test数据集"%i)
    return l_train,l_val,l_test

#将获取到的各个数据集的包含的文件名写入txt中
def WriteDatasToFile(listInfo, new_filename):
    file_handle = open(new_filename,'w')
    for str_Result in listInfo:
        file_handle.write(str_Result)
    file_handle.close()
    print('写入 %s 文件成功.' % new_filename)





outer_path = './Silk2'  # 这里是你的图片路径

if __name__ == '__main__':  #主函数


    i = 1
    folderlist = os.listdir(outer_path)# 列举文件夹
    for folder in folderlist:  #遍历文件夹中的文件夹(若engagement文件夹中存在txt或py文件,则后面会报错)
        print(os.path.join(outer_path, folder))
        generate(os.path.join(outer_path, folder), i) #调用generate函数,函数中的参数为:(图片路径+文件夹名,标签号)
        i += 1

    listFileInfo = ReadFileDatas('Silk2/all_list.txt')  # 读取文件
    l_train,l_val, l_test = TrainValTestFile(listFileInfo)
    WriteDatasToFile(l_train, './recordvideo/all_train.txt')
    WriteDatasToFile(l_val, './recordvideo/all_val.txt')
    WriteDatasToFile(l_test, './recordvideo/all_test.txt')


类型如下:
在这里插入图片描述

生产csv格式

我们将txt文本格式的路径转换成csv类型的格式,以便于后期哈希编码

# coding=utf8
import os
import csv

# ============================说明==================================
# 把要转换成csv文件的txt文件路径复制到rootdir下,运行程序后,在..\\**\\日期\\**\\**\\recordvideo
# 目录下会生成txt文档对应的csv格式文档,且带表头。
# ============================说明==================================

rootdir = './recordvideo'


def findtxt(path, ret):
    """Finding the *.txt file in specify path"""
    filelist = os.listdir(path)
    for filename in filelist:
        de_path = os.path.join(path, filename)
        if os.path.isfile(de_path):
            if de_path.endswith(".txt"):  # Specify to find the txt file.
                ret.append(de_path)
        else:
            findtxt(de_path, ret)


def txt2csv(filepath, data):
    with open(filepath + '.csv', 'w', newline='') as csvfile:
        spamwriter = csv.writer(csvfile)
        # 添加csv文件的表头
        spamwriter.writerow(['path', 'super_class_id'])
        with open(filepath, 'r') as file_object:
            lines = file_object.readlines()
            # 清除lines list中换行符\n, 空格符''
            lines = [x.strip('\n') for x in lines if x.strip() != '']
            for i, line in enumerate(lines):
                data.append([])
                elements = line.split(' ')
                # 清除单行line 中存在的空格符字符串''
                elements = [x for x in elements if x != '']
                for element in elements:
                    data[i].append(element)
        # 将txt文件内容写入csv文件中
        spamwriter.writerows(data)


if __name__ == "__main__":
    ret = []
    findtxt(rootdir, ret)
    for ret_ in ret:
        data = []
        txt2csv(ret_, data)

在这里插入图片描述

将CSV格式路径通过哈希编码转化成我们需要的类型

# coding=gbk
import os
import numpy as np
import pandas as pd



df = pd.read_csv('./Silk2/all_list.txt.csv',engine='python')
# df.columns =['path','class','id']
print(df)
n = len(df)
print(n)
nclass = 11
all_label_data = np.zeros((n, nclass), dtype=np.int8)
all_label_path = list()

for i, data in df.iterrows():
    all_label_data[i][data.super_class_id-1] = 1
    all_label_path.append(data.path)
print(all_label_data)
print(all_label_path)

train_num = 4031
test_num = 1000

data_id = np.array(range(0, n))

test_data_index = []
train_data_index = []
database_data_index = []

for i in range(nclass):
    class_mask = df['super_class_id'] == i + 1
    # print(class_mask)
    index_of_class = data_id[class_mask].copy()  # index of the class [2, 10, 656,...]
    print("index_of_class:",index_of_class)
    print("个数",len(index_of_class))
    # np.random.shuffle(index_of_class)

    # query_n = int(test_num / nclass)
    query_n = int(len(index_of_class) * 0.2)
    print("query_n:",query_n)
    # train_n = int(train_num / nclass)
    train_n = int(len(index_of_class)* 0.8)
    print("train_n:",train_n)

    index_for_query = index_of_class[:query_n].tolist()
    print("index_for_query:",index_for_query)
    index_for_db = index_of_class[query_n:].tolist()
    print("index_for_db:",index_for_db)
    index_for_train = index_for_db[:train_n]
    print("index_for_train:",index_for_train)

    train_data_index.extend(index_for_train)
    test_data_index.extend(index_for_query)
    database_data_index.extend(index_for_db)

with open("database.txt", "w") as f:
    for index in database_data_index:
        line =   all_label_path[index] + " " \
        + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
        f.write(line)

with open("train.txt", "w") as f:
    for index in train_data_index:
        line =  all_label_path[index] + " " \
        + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
        f.write(line)

with open("test.txt", "w") as f:
    for index in test_data_index:
        line =  all_label_path[index] + " " \
        + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
        f.write(line)


在这里插入图片描述

附录–mirflickr25k转化成哈希代码程序

# download data from here: https://press.liacs.nl/mirflickr/mirdownload.html

# import hashlib
# with open("mirflickr25k.zip","rb") as f:
#     md5_obj = hashlib.md5()
#     md5_obj.update(f.read())
#     hash_code = md5_obj.hexdigest()
#     print(str(hash_code).upper() == "A23D0A8564EE84CDA5622A6C2F947785")

import os
import numpy as np

all_label_data = np.zeros((25000, 24), dtype=np.int8)
label_index = -1
label_dir_name = "mirflickr25k_annotations_v080"
for label_file in os.listdir(label_dir_name):
    if "README.txt" != label_file and "_r1" not in label_file:
        label_index += 1
        with open(os.path.join(label_dir_name, label_file), "r") as f:
            print(label_file)
            for line in f.readlines():
                all_label_data[int(line.strip()) - 1][label_index] = 1
        
        if os.path.exists(label_file[:-4]+"_r1.txt"):
            with open(os.path.join(label_dir_name, label_file[:-4]+"_r1.txt"), "r") as f:
                print(label_file+"_r1")
                for line in f.readlines():
                    all_label_data[int(line.strip()) - 1][label_index] = 1

train_num = 5000
test_num = 2000
# perm_index = np.random.permutation(all_label_data.shape[0])
non_exist_index = np.where((all_label_data.sum(1)==0) == True)[0]
perm_index = np.array(list(range(all_label_data.shape[0])))
perm_index = np.delete(perm_index, non_exist_index)
np.random.shuffle(perm_index)
# train_data_index = perm_index[:train_num]
# test_data_index = perm_index[train_num:train_num + test_num]
# database_data_index = perm_index[train_num + test_num:]

test_data_index = perm_index[:test_num]
train_data_index = perm_index[test_num:test_num + train_num]
database_data_index = perm_index[test_num:]

with open("database.txt", "w") as f:
    for index in database_data_index:
        line = "data/mirflickr/im" + str(index + 1) + ".jpg " + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
        f.write(line)
with open("train.txt", "w") as f:
    for index in train_data_index:
        line = "data/mirflickr/im" + str(index + 1) + ".jpg " + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
        f.write(line)
with open("test.txt", "w") as f:
    for index in test_data_index:
        line = "data/mirflickr/im" + str(index + 1) + ".jpg " + str(all_label_data[index].tolist())[1:-1].replace(", ", " ") + "\n"
        f.write(line)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Peihj2021

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值