一. 场景说明
经常遇到一种情况,手机或者电脑里面的图片太多并且存在重复的图片。这些重复的图片浪费设备的内存,同时也提高了处理这些数据的成本。
博主是学AI的,因此基于神经网络开发了一个图片去重算法。
二. 基本思路
- 先用视觉模型提取图片的特征
- 轮流对比图片的特征,将相似度很好的图片过滤掉
代码实现:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import glob
from tqdm import tqdm
class FeatureExtract(object):
def __init__(self):
# 加载预训练的ResNet18模型
self.resnet = models.resnet18(pretrained=True)
# 移除最后一层全连接层
self.resnet = torch.nn.Sequential(*list(self.resnet.children())[:-1])
# 设置模型为评估模式
self.resnet.eval()
self.preprocess = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def feature_extract(self, image_path):
# 加载和预处理图像
image = Image.open(image_path)
input_tensor = self.preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# 使用模型提取特征
with torch.no_grad():
features = self.resnet(input_batch)
# 输出特征向量
return features.squeeze()
def is_duplicate(features, feature, thres=0.99):
if len(features) == 0:
return False
for feat in features:
similarity = F.cosine_similarity(feat, feature, dim=0).item()
if similarity > thres:
return True
return False
if __name__ == "__main__":
extract = FeatureExtract()
features = []
images = glob.glob("/home/gp/workspace/images/*.jpg")
num = 0
for img_path in tqdm(images):
feature = extract.feature_extract(img_path)
flag = is_duplicate(features, feature)
if flag:
num += 1
print("copy")
else:
features.append(feature)
print("copy num:%s" % num)
三. 优化思路
上面的代码可能比较慢,特别是图片比较多的时候。
- 批处理计算:可以将向量列表分成小批次进行计算,而不是逐个遍历每个向量。这样可以利用矩阵运算的并行性,提高计算效率。可以使用PyTorch的torch.stack()函数将向量列表转换为一个大张量,然后使用矩阵乘法或批量计算余弦相似度;
- 利用生产者消费者模式,生产者读取图片并提取特征,放入队列;消费者从队列中取图片,并计算相似度和后处理。
代码实现:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import glob
from tqdm import tqdm
import threading
import queue
import shutil
import uuid
import concurrent.futures
def is_duplicate_old(features, feature, thres=0.99):
if len(features) == 0:
return False
for feat in features:
similarity = F.cosine_similarity(feat, feature, dim=0).item()
if similarity > thres:
return True
return False
def is_duplicate(features, feature, thres=0.99):
if len(features) < 1000:
return is_duplicate_old(features, feature, thres=0.99)
num_vectors = len(features)
batch_size = 1000 # 每批处理的向量数量
max_similarity = -1
with concurrent.futures.ThreadPoolExecutor() as executor:
for i in range(0, num_vectors, batch_size):
batch_vector = torch.stack(features[i:i+batch_size])
similarity = F.cosine_similarity(feature.unsqueeze(0), batch_vector, dim=1)
batch_max_similarity, max_batch_index = torch.max(similarity, dim=0)
batch_max_similarity = batch_max_similarity.item()
if batch_max_similarity > thres:
return True
return False
class ProducerConsumer:
def __init__(self):
self.queue = queue.Queue(maxsize=1000) # 设置队列长度为1000
self.extract = FeatureExtract()
self.features = []
def produce(self):
images = glob.glob("imgs/*.jpg")
for img_path in tqdm(images):
feature = self.extract.feature_extract(img_path)
self.queue.put([img_path, feature])
self.queue.put(None)
def consume(self):
while True:
img_info = self.queue.get()
if img_info is None:
break
img_path, feature = img_info
flag = is_duplicate(self.features, feature)
if flag:
continue
else:
self.features.append(feature)
img_save_path = "images/%s.jpg" % uuid.uuid4()
shutil.copy(img_path, img_save_path)
def run(self):
producer_thread = threading.Thread(target=self.produce)
consumer_thread = threading.Thread(target=self.consume)
producer_thread.start()
consumer_thread.start()
producer_thread.join()
consumer_thread.join()
class FeatureExtract(object):
def __init__(self):
# 加载预训练的ResNet18模型
self.resnet = models.resnet18(pretrained=True)
# 移除最后一层全连接层
self.resnet = torch.nn.Sequential(*list(self.resnet.children())[:-1])
# 设置模型为评估模式
self.resnet.eval()
self.preprocess = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def feature_extract(self, image_path):
# 加载和预处理图像
image = Image.open(image_path)
input_tensor = self.preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# 使用模型提取特征
with torch.no_grad():
features = self.resnet(input_batch)
# 输出特征向量
return features.squeeze()
if __name__ == "__main__":
import time
start = time.time()
# 创建对象并运行
pc = ProducerConsumer()
pc.run()
end = time.time()
print(end-start)
实验测试,速度提升50%,改善效果明显。