提高训练样本质量,去除相似图片

提高训练样本质量,去除相似图片

使用了fiftyone,数据集使用yolo,多个数据集挑选质量最高的一批图片,然后合并,剩下的图片作为测试

import fiftyone as fo
import fiftyone.zoo as foz
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import os
import shutil
from pathlib import Path

def fiftyone_image_deduplication_yolo(in_dir,out_dir,limit=50,thresh = 0.92):    
    path_img = os.path.join(in_dir, "images")
    path_label = os.path.join(in_dir, "labels")
    path_label_class = os.path.join(path_label, "classes.txt")
    
    path_out_img = os.path.join(out_dir, "images")
    path_out_label = os.path.join(out_dir, "labels")
    path_out_label_class = os.path.join(path_out_label, "classes.txt")
    path_out_test = os.path.join(out_dir, "test")
    
    if os.path.exists(out_dir):        
        shutil.rmtree(out_dir)
    os.makedirs(path_out_img)
    os.makedirs(path_out_label)
    os.makedirs(path_out_test)
    
    dataset = fo.Dataset.from_images_dir(path_img)
    
    #dataset = fo.Dataset.from_dir(
    #    dataset_dir=in_dir,
    #    dataset_type=fo.types.ImageDirectory,
    #    name="my-image-deduplication",
    #)
    
    model = foz.load_zoo_model("mobilenet-v2-imagenet-torch")
    embeddings = dataset.compute_embeddings(model)
    #print(embeddings.shape)

    similarity_matrix = cosine_similarity(embeddings)
    #print(similarity_matrix.shape)
    #print(similarity_matrix)

    n = len(similarity_matrix)
    similarity_matrix = similarity_matrix - np.identity(n)
    #print(similarity_matrix)

    id_map = [s.id for s in dataset.select_fields(["id"])]
    for idx, sample in enumerate(dataset):
        sample["max_similarity"] = similarity_matrix[idx].max()
        sample.save()

    samples_to_remove = set()
    samples_to_keep = set()
    for idx, sample in enumerate(dataset):
        if sample.id not in samples_to_remove:
            # Keep the first instance of two duplicates
            samples_to_keep.add(sample.id)

            dup_idxs = np.where(similarity_matrix[idx] > thresh)[0]
            for dup in dup_idxs:
                # We kept the first instance so remove all other duplicates
                samples_to_remove.add(id_map[dup])

            if len(dup_idxs) > 0:
                sample.tags.append("has_duplicates")
                sample.save()

        else:
            sample.tags.append("duplicate")
            sample.save()

    print("samples total:",len(dataset))
    print("remove:",len(samples_to_remove))
    print("keep:",len(samples_to_keep))
    
    dataset.delete_samples(list(samples_to_remove))
    #print(len(dataset))

    view = dataset.sort_by("max_similarity", reverse=False).limit(limit)
    select_img = []
    for idx, sample in enumerate(view):
        print(idx,sample.filepath,sample.max_similarity)
        bname = os.path.basename(sample.filepath)
        select_img.append(bname)
        sptext = os.path.splitext(bname)
        org_file_img = sample.filepath
        org_file_label = os.path.join(path_label, sptext[0]+'.txt')
        out_file_img = os.path.join(path_out_img, bname)
        out_file_label = os.path.join(path_out_label, sptext[0]+'.txt')
        shutil.copy(org_file_img, out_file_img)
        shutil.copy(org_file_label, out_file_label)
    
    #挑选剩下的作为测试
    all_img = []
    for file in os.listdir(path_img):        
        all_img.append(file)

    test_img = list(set(all_img)-set(select_img))
    print("test img:",len(test_img))
    for file in test_img:
        file_img = os.path.join(path_img, file)
        shutil.copy(file_img, path_out_test)
    
def dir_cut(src,dst,rm=True):
    if os.path.exists(dst) and rm:   
        shutil.rmtree(dst)
    if not os.path.exists(dst):
        os.makedirs(dst)
    for file in os.listdir(src):
        file_src = os.path.join(src, file)
        shutil.copy(file_src, dst)
    shutil.rmtree(src)
    
def deduplication_copy(src,dst):
    path_img = os.path.join(src, "images")
    path_label = os.path.join(src, "labels")
    path_test = os.path.join(src, "test")
    path_out_img = os.path.join(dst, "images")
    path_out_label = os.path.join(dst, "labels")
    path_out_test = os.path.join(dst, "test")
    dir_cut(path_img, path_out_img, rm=False)
    dir_cut(path_label, path_out_label, rm=False)
    dir_cut(path_test, path_out_test, rm=False)

def merge_deduplication():
    # 输入输出是 yolo 格式,images和labels目录
    fiftyone_image_deduplication_yolo(
        in_dir="/workspace/data/small_flame_fog_smoke/org/smoke",
        out_dir="/workspace/data/small_flame_fog_smoke/tmp/smoke",
        limit=300)

    fiftyone_image_deduplication_yolo(
        in_dir="/workspace/data/small_flame_fog_smoke/org/fog",
        out_dir="/workspace/data/small_flame_fog_smoke/tmp/fog",
        limit=50)

    fiftyone_image_deduplication_yolo(
        in_dir="/workspace/data/small_flame_fog_smoke/org/small_flame",
        out_dir="/workspace/data/small_flame_fog_smoke/tmp/small_flame",
        limit=70)

    deduplication_dir = Path("/workspace/data/small_flame_fog_smoke/deduplication/yolo")
    if os.path.exists(deduplication_dir):        
        shutil.rmtree(deduplication_dir)
    os.makedirs(deduplication_dir)
    deduplication_copy(Path("/workspace/data/small_flame_fog_smoke/tmp/smoke"),deduplication_dir)
    deduplication_copy(Path("/workspace/data/small_flame_fog_smoke/tmp/fog"),deduplication_dir)
    deduplication_copy(Path("/workspace/data/small_flame_fog_smoke/tmp/small_flame"),deduplication_dir)
    shutil.rmtree(Path("/workspace/data/small_flame_fog_smoke/tmp"))
    shutil.copy(Path("/workspace/data/small_flame_fog_smoke/org/classes.txt"), os.path.join(deduplication_dir, "labels", "classes.txt"))

    path_test = os.path.join(deduplication_dir, "test")
    path_out_test = Path("/workspace/data/small_flame_fog_smoke/test")
    dir_cut(path_test,path_out_test)

merge_deduplication()
print("done")
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

绯虹剑心

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值