通过缺陷图像相似度来划分验证集和训练集
import torch
import torchvision
from utils import *
import json
model = torchvision.models.resnet50(pretrained=True)
model.eval()
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import datetime
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
resnet50_feature_extractor = model
resnet50_feature_extractor.fc = torch.nn.Linear(2048, 2048)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
all_imgs = get_files_path(r"C:\Users\29939\Desktop\defect\defect\images")
similarity_score_threold = 0.90
match_dict = {}
feat_list = []
for index,item in enumerate(tqdm(all_imgs)):
img = Image.open(item).convert('RGB')
img_tensor = preprocess(img)
x = img_tensor.unsqueeze_(dim=0)
x = x.to('cpu')
feat_list.append(resnet50_feature_extractor(x))
for i,index_i in enumerate(tqdm(feat_list)):
raw_img_feat = feat_list[i]
for j in range(i+1,len(feat_list)):
output = feat_list[j]
similarity = torch.cosine_similarity(raw_img_feat, output, dim=1)
if similarity > similarity_score_threold:
match_dict[all_imgs[i]] = all_imgs[j]
break
print("match_dict:", match_dict)
with open("./dict.json", 'w') as f:
json.dump(match_dict, f, indent=2)
import torch
import torchvision
from utils import *
model = torchvision.models.resnet50(pretrained=True)
model.eval()
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from flask import Flask,request
import json
app = Flask(__name__)
# 存放文件夹索引
dir_index = 0
# 输出目录
output_path = r"./test"
# 前处理
preprocess = transforms.Compose([
# transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 构建特征抽取器
resnet50_feature_extractor = model
resnet50_feature_extractor.fc = torch.nn.Linear(2048, 2048) # 512,512 2048,1024 ...
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
@app.route("/test",methods=["Get"])
def test():
'''
遍历所有缺陷图像,获取缺陷图像的特征向量
'''
all_args = request.args.get("source")
print(all_args)
# GuaiJiao:F:\PUCPDataSet\Seg\RT_GuaiJiao_DataSet\RawSimiTest\blend\defect\images
# SideArray:F:\PUCPDataSet\Seg\Side_Array_DataSet_Raw_Bleed\RawTest\blend\defect\images
all_imgs = get_files_path(all_args)
all_imgs.sort()
# 相似度阈值
similarity_score_threold = 0.7
match_dict = {}
feat_list = []
# 验证集
val_list = []
# 训练集
train_list = []
for index,item in enumerate(tqdm(all_imgs)):
img = Image.open(item).convert('RGB')
img_tensor = preprocess(img)
x = img_tensor.unsqueeze_(dim=0)
x = x.to('cpu')
# 抽取图像特征
feat_list.append(resnet50_feature_extractor(x))
# 断言:判定是否每张图像都抽取特征向量
assert len(all_imgs) == len(feat_list)
print("all_imgs",len(all_imgs))
print("feat_list",len(feat_list))
'''
找到第一个缺陷,然后找到满足相似度的所有缺陷,然后去除这些,继续上诉操作
'''
# 所有索引
all_index = [i for i in range(len(all_imgs))]
for index_item in tqdm(all_index):
# 当前相似度高于阈值的特征
current_similarities_imgs_list = []
current_similarities_fests_list = []
current_index_list = []
# 获取图像特征
# 第一个缺陷特征向量
raw_img_feat = feat_list[index_item]
current_similarities_imgs_list.append(all_imgs[index_item])
current_similarities_fests_list.append(feat_list[index_item])
# 去除当前索引
current_index_list.append(index_item)
all_index.remove(index_item)
'''
归类所有相似的缺陷,分别存储
'''
# 遍历剩余的缺陷特征向量(找到与当前特征向量相近的所有向量)
for j in all_index:
# 缺陷特征向量
output = feat_list[j]
# 计算两个特征向量之间的相似度
similarity = torch.cosine_similarity(raw_img_feat, output, dim=1)
# 相似度大于阈值则存入匹配字典中
if similarity > similarity_score_threold:
current_similarities_imgs_list.append(all_imgs[j])
current_similarities_fests_list.append(feat_list[j])
current_index_list.append(j)
all_index.remove(j)
print("current_index_list",current_index_list)
print("current_similarities_imgs_list",len(current_similarities_imgs_list))
# makedir(os.path.join(output_path,str(dir_index)))
# for iii in current_index_list:
# src_path = all_imgs[iii]
# dst_path = os.path.join(output_path,str(dir_index),os.path.basename(src_path))
# copyfile(src_path=src_path,dst_path=dst_path)
# dir_index = dir_index +1
# 遍历相似缺陷图像,选择缺陷图像最少的作为验证集
if len(current_similarities_imgs_list) > 1:
val_list.append(current_similarities_imgs_list[0])
print(val_list)
train_list = [i for i in all_imgs if i not in val_list]
print(val_list)
# print(len(all_imgs))
# print(len(val_list))
# print(len(train_list))
val_list = [os.path.basename(i) for i in val_list ]
train_list = [os.path.basename(i) for i in train_list ]
print(val_list)
print(train_list)
res = {"val_dataset":val_list,"train_dataset":train_list}
return str(res)
if __name__ == '__main__':
app.run()
final
import torch
import torchvision
from utils import *
model = torchvision.models.resnet50(pretrained=True)
model.eval()
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from flask import Flask,request
import json
app = Flask(__name__)
# 存放文件夹索引
dir_index = 0
# 输出目录
output_path = r"./test"
# 前处理
preprocess = transforms.Compose([
# transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 构建特征抽取器
resnet50_feature_extractor = model
resnet50_feature_extractor.fc = torch.nn.Linear(2048, 2048) # 512,512 2048,1024 ...
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
@app.route("/test",methods=["Get"])
def test():
'''
遍历所有缺陷图像,获取缺陷图像的特征向量
'''
all_args = request.args.get("source")
print(all_args)
# GuaiJiao:F:\PUCPDataSet\Seg\RT_GuaiJiao_DataSet\RawSimiTest\blend\defect\images
# SideArray:F:\PUCPDataSet\Seg\Side_Array_DataSet_Raw_Bleed\RawTest\blend\defect\images
all_imgs = get_files_path(all_args)
all_imgs.sort()
# 相似度阈值
similarity_score_threold = 0.7
match_dict = {}
feat_list = []
# 验证集
val_list = []
# 训练集
train_list = []
for index,item in enumerate(tqdm(all_imgs)):
img = Image.open(item).convert('RGB')
img_tensor = preprocess(img)
x = img_tensor.unsqueeze_(dim=0)
x = x.to('cpu')
# 抽取图像特征
feat_list.append(resnet50_feature_extractor(x))
# 断言:判定是否每张图像都抽取特征向量
assert len(all_imgs) == len(feat_list)
print("all_imgs",len(all_imgs))
print("feat_list",len(feat_list))
'''
找到第一个缺陷,然后找到满足相似度的所有缺陷,然后去除这些,继续上诉操作
'''
# 所有缺陷索引
all_index = [i for i in range(len(all_imgs))]
# 所有图像图像名称
all_images_name_list = []
for item_imgName in all_imgs:
name_img = os.path.basename(item_imgName)
# print(name_img)
pin_name = name_img.split('_')[1:-1]
pin_name = '_'.join(pin_name)
all_images_name_list.append(pin_name)
# 去重
all_images_name_list = set(all_images_name_list)
for index_item in tqdm(all_index):
# 当前相似度高于阈值的特征
current_similarities_imgs_list = []
current_similarities_fests_list = []
current_index_list = []
# 获取图像特征
# 第一个缺陷特征向量
raw_img_feat = feat_list[index_item]
current_similarities_imgs_list.append(all_imgs[index_item])
current_similarities_fests_list.append(feat_list[index_item])
# 去除当前索引
current_index_list.append(index_item)
all_index.remove(index_item)
'''
归类所有相似的缺陷,分别存储
'''
# 遍历剩余的缺陷特征向量(找到与当前特征向量相近的所有向量)
for j in all_index:
# 缺陷特征向量
output = feat_list[j]
# 计算两个特征向量之间的相似度
similarity = torch.cosine_similarity(raw_img_feat, output, dim=1)
# 相似度大于阈值则存入匹配字典中
if similarity > similarity_score_threold:
current_similarities_imgs_list.append(all_imgs[j])
current_similarities_fests_list.append(feat_list[j])
current_index_list.append(j)
all_index.remove(j)
print("current_index_list",current_index_list)
print("current_similarities_imgs_list",len(current_similarities_imgs_list))
# makedir(os.path.join(output_path,str(dir_index)))
# for iii in current_index_list:
# src_path = all_imgs[iii]
# dst_path = os.path.join(output_path,str(dir_index),os.path.basename(src_path))
# copyfile(src_path=src_path,dst_path=dst_path)
# dir_index = dir_index +1
# 遍历相似缺陷图像,选择缺陷图像最少的作为验证集
if len(current_similarities_imgs_list) > 1:
val_list.append(current_similarities_imgs_list[0])
# print(val_list)
# train_list = [i for i in all_imgs if i not in val_list]
# print(val_list)
# print(len(all_imgs))
# print(len(val_list))
# print(len(train_list))
val_list = [os.path.basename(i) for i in val_list ]
all_val_name_list = []
for item_imgName in val_list:
name_img = os.path.basename(item_imgName)
# print(name_img)
pin_name = name_img.split('_')[1:-1]
pin_name = '_'.join(pin_name)
all_val_name_list.append(pin_name)
all_val_name_list = set(all_val_name_list)
train_list_name_list = [i for i in all_images_name_list if i not in all_val_name_list]
print(val_list)
print(train_list)
res = {"val_dataset":all_val_name_list,"train_dataset":train_list_name_list}
print(len(all_images_name_list))
print(len(all_val_name_list))
print(len(train_list_name_list))
print(res)
return str(res)
if __name__ == '__main__':
app.run()