通过缺陷图像相似度来划分验证集和训练集

通过缺陷图像相似度来划分验证集和训练集

import torch
import torchvision
from utils import *
import json
model = torchvision.models.resnet50(pretrained=True)
model.eval()
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import datetime

# 前处理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 构建特征抽取器
resnet50_feature_extractor = model
resnet50_feature_extractor.fc = torch.nn.Linear(2048, 2048)  # 512,512 2048,1024 ...
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False


all_imgs = get_files_path(r"C:\Users\29939\Desktop\defect\defect\images")

similarity_score_threold = 0.90

match_dict = {}

feat_list = []

for index,item in enumerate(tqdm(all_imgs)):
    img = Image.open(item).convert('RGB')
    img_tensor = preprocess(img)
    x = img_tensor.unsqueeze_(dim=0)
    x = x.to('cpu')
    # 抽取图像特征
    feat_list.append(resnet50_feature_extractor(x))


for i,index_i in enumerate(tqdm(feat_list)):
    # 抽取图像特征
    raw_img_feat = feat_list[i]
    for j in range(i+1,len(feat_list)):
        output = feat_list[j]
        similarity = torch.cosine_similarity(raw_img_feat, output, dim=1)
        if similarity > similarity_score_threold:
            match_dict[all_imgs[i]] = all_imgs[j]
            break

    print("match_dict:", match_dict)
    # 存json
with open("./dict.json", 'w') as f:
    json.dump(match_dict, f, indent=2)


import torch
import torchvision
from utils import *
model = torchvision.models.resnet50(pretrained=True)
model.eval()
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from flask import Flask,request
import json
app = Flask(__name__)

# 存放文件夹索引
dir_index = 0
# 输出目录
output_path = r"./test"



# 前处理
preprocess = transforms.Compose([
    # transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 构建特征抽取器
resnet50_feature_extractor = model
resnet50_feature_extractor.fc = torch.nn.Linear(2048, 2048)  # 512,512 2048,1024 ...
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False



@app.route("/test",methods=["Get"])
def test():
    '''
    遍历所有缺陷图像,获取缺陷图像的特征向量
    '''

    all_args = request.args.get("source")

    print(all_args)
    # GuaiJiao:F:\PUCPDataSet\Seg\RT_GuaiJiao_DataSet\RawSimiTest\blend\defect\images
    # SideArray:F:\PUCPDataSet\Seg\Side_Array_DataSet_Raw_Bleed\RawTest\blend\defect\images
    all_imgs = get_files_path(all_args)
    all_imgs.sort()
    # 相似度阈值
    similarity_score_threold = 0.7

    match_dict = {}

    feat_list = []

    # 验证集
    val_list = []
    # 训练集
    train_list = []

    for index,item in enumerate(tqdm(all_imgs)):
        img = Image.open(item).convert('RGB')
        img_tensor = preprocess(img)
        x = img_tensor.unsqueeze_(dim=0)
        x = x.to('cpu')
        # 抽取图像特征
        feat_list.append(resnet50_feature_extractor(x))

    # 断言:判定是否每张图像都抽取特征向量
    assert len(all_imgs) == len(feat_list)
    print("all_imgs",len(all_imgs))
    print("feat_list",len(feat_list))

    '''
    找到第一个缺陷,然后找到满足相似度的所有缺陷,然后去除这些,继续上诉操作
    '''

    # 所有索引
    all_index = [i for i in range(len(all_imgs))]


    for index_item in tqdm(all_index):
        # 当前相似度高于阈值的特征
        current_similarities_imgs_list = []
        current_similarities_fests_list = []
        current_index_list = []
        # 获取图像特征
        # 第一个缺陷特征向量
        raw_img_feat = feat_list[index_item]

        current_similarities_imgs_list.append(all_imgs[index_item])
        current_similarities_fests_list.append(feat_list[index_item])

        # 去除当前索引
        current_index_list.append(index_item)
        all_index.remove(index_item)

        '''
        归类所有相似的缺陷,分别存储
        '''
        # 遍历剩余的缺陷特征向量(找到与当前特征向量相近的所有向量)
        for j in all_index:
            # 缺陷特征向量
            output = feat_list[j]
            # 计算两个特征向量之间的相似度
            similarity = torch.cosine_similarity(raw_img_feat, output, dim=1)
            # 相似度大于阈值则存入匹配字典中
            if similarity > similarity_score_threold:
                current_similarities_imgs_list.append(all_imgs[j])
                current_similarities_fests_list.append(feat_list[j])
                current_index_list.append(j)
                all_index.remove(j)



        print("current_index_list",current_index_list)
        print("current_similarities_imgs_list",len(current_similarities_imgs_list))

        # makedir(os.path.join(output_path,str(dir_index)))
        # for iii in current_index_list:
        #     src_path = all_imgs[iii]
        #     dst_path = os.path.join(output_path,str(dir_index),os.path.basename(src_path))
        #     copyfile(src_path=src_path,dst_path=dst_path)
        # dir_index = dir_index +1

        # 遍历相似缺陷图像,选择缺陷图像最少的作为验证集
        if len(current_similarities_imgs_list) > 1:
            val_list.append(current_similarities_imgs_list[0])



    print(val_list)
    train_list = [i for i in all_imgs if i not in val_list]
    print(val_list)
    # print(len(all_imgs))
    # print(len(val_list))
    # print(len(train_list))
    val_list = [os.path.basename(i) for i in val_list ]

    train_list = [os.path.basename(i) for i in train_list ]

    print(val_list)
    print(train_list)
    res = {"val_dataset":val_list,"train_dataset":train_list}

    return str(res)



if __name__ == '__main__':
    app.run()

final

import torch
import torchvision
from utils import *
model = torchvision.models.resnet50(pretrained=True)
model.eval()
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from flask import Flask,request
import json
app = Flask(__name__)

# 存放文件夹索引
dir_index = 0
# 输出目录
output_path = r"./test"



# 前处理
preprocess = transforms.Compose([
    # transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 构建特征抽取器
resnet50_feature_extractor = model
resnet50_feature_extractor.fc = torch.nn.Linear(2048, 2048)  # 512,512 2048,1024 ...
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False



@app.route("/test",methods=["Get"])
def test():
    '''
    遍历所有缺陷图像,获取缺陷图像的特征向量
    '''

    all_args = request.args.get("source")

    print(all_args)
    # GuaiJiao:F:\PUCPDataSet\Seg\RT_GuaiJiao_DataSet\RawSimiTest\blend\defect\images
    # SideArray:F:\PUCPDataSet\Seg\Side_Array_DataSet_Raw_Bleed\RawTest\blend\defect\images
    all_imgs = get_files_path(all_args)
    all_imgs.sort()
    # 相似度阈值
    similarity_score_threold = 0.7

    match_dict = {}

    feat_list = []

    # 验证集
    val_list = []
    # 训练集
    train_list = []

    for index,item in enumerate(tqdm(all_imgs)):
        img = Image.open(item).convert('RGB')
        img_tensor = preprocess(img)
        x = img_tensor.unsqueeze_(dim=0)
        x = x.to('cpu')
        # 抽取图像特征
        feat_list.append(resnet50_feature_extractor(x))

    # 断言:判定是否每张图像都抽取特征向量
    assert len(all_imgs) == len(feat_list)
    print("all_imgs",len(all_imgs))
    print("feat_list",len(feat_list))

    '''
    找到第一个缺陷,然后找到满足相似度的所有缺陷,然后去除这些,继续上诉操作
    '''

    # 所有缺陷索引
    all_index = [i for i in range(len(all_imgs))]

    # 所有图像图像名称
    all_images_name_list = []
    for item_imgName in all_imgs:
        name_img = os.path.basename(item_imgName)
        # print(name_img)
        pin_name = name_img.split('_')[1:-1]
        pin_name = '_'.join(pin_name)
        all_images_name_list.append(pin_name)
    # 去重
    all_images_name_list = set(all_images_name_list)
    for index_item in tqdm(all_index):
        # 当前相似度高于阈值的特征
        current_similarities_imgs_list = []
        current_similarities_fests_list = []
        current_index_list = []
        # 获取图像特征
        # 第一个缺陷特征向量
        raw_img_feat = feat_list[index_item]

        current_similarities_imgs_list.append(all_imgs[index_item])
        current_similarities_fests_list.append(feat_list[index_item])

        # 去除当前索引
        current_index_list.append(index_item)
        all_index.remove(index_item)

        '''
        归类所有相似的缺陷,分别存储
        '''
        # 遍历剩余的缺陷特征向量(找到与当前特征向量相近的所有向量)
        for j in all_index:
            # 缺陷特征向量
            output = feat_list[j]
            # 计算两个特征向量之间的相似度
            similarity = torch.cosine_similarity(raw_img_feat, output, dim=1)
            # 相似度大于阈值则存入匹配字典中
            if similarity > similarity_score_threold:
                current_similarities_imgs_list.append(all_imgs[j])
                current_similarities_fests_list.append(feat_list[j])
                current_index_list.append(j)
                all_index.remove(j)



        print("current_index_list",current_index_list)
        print("current_similarities_imgs_list",len(current_similarities_imgs_list))

        # makedir(os.path.join(output_path,str(dir_index)))
        # for iii in current_index_list:
        #     src_path = all_imgs[iii]
        #     dst_path = os.path.join(output_path,str(dir_index),os.path.basename(src_path))
        #     copyfile(src_path=src_path,dst_path=dst_path)
        # dir_index = dir_index +1

        # 遍历相似缺陷图像,选择缺陷图像最少的作为验证集
        if len(current_similarities_imgs_list) > 1:
            val_list.append(current_similarities_imgs_list[0])



    # print(val_list)
    # train_list = [i for i in all_imgs if i not in val_list]
    # print(val_list)
    # print(len(all_imgs))


    # print(len(val_list))
    # print(len(train_list))
    val_list = [os.path.basename(i) for i in val_list ]


    all_val_name_list = []
    for item_imgName in val_list:
        name_img = os.path.basename(item_imgName)
        # print(name_img)
        pin_name = name_img.split('_')[1:-1]
        pin_name = '_'.join(pin_name)
        all_val_name_list.append(pin_name)

    all_val_name_list = set(all_val_name_list)


    train_list_name_list = [i for i in all_images_name_list if i not in all_val_name_list]




    print(val_list)
    print(train_list)
    res = {"val_dataset":all_val_name_list,"train_dataset":train_list_name_list}

    print(len(all_images_name_list))
    print(len(all_val_name_list))
    print(len(train_list_name_list))
    print(res)
    return str(res)



if __name__ == '__main__':
    app.run()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值