使用cnn提取特征,图像相似度对比。pytorch 推理的时候报内存不足的问题

with torch.no_grad()
https://blog.csdn.net/CRDarwin/article/details/119943128


# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os

def cos_sim(a, b):
    """
    计算两个向量之间的余弦相似度
    """
    a = np.mat(a)
    b = np.mat(b)
    return float(a * b.T) / (np.linalg.norm(a) * np.linalg.norm(b))

class MyDataset(Dataset):
    def __init__(self, file_path,transform = None):
        file_name=os.listdir(file_path)
        imgs = []
        img_names = []
        for file in file_name:
            image_path=os.path.join(file_path,file)
            imgs.append(image_path)
            img_names.append(file)
        self.imgs = imgs
        self.img_names = img_names
        self.transform = transform
    def __getitem__(self, index):
        fn= self.imgs[index]
        img_name=self.img_names[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img,img_name
    def __len__(self):
        return len(self.imgs)

file_path="F:/my_code/2021/sf/data/ICR_EXT/"
save_path="F:/my_code/2021/sf/data/s/"
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,640)),
])
mydata=MyDataset(file_path,transform=preprocess)
l=mydata.__len__()
mydata_loader=DataLoader(mydata,batch_size=32,shuffle=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = torchvision.models.shufflenet_v2_x0_5(pretrained=True)
model.fc = nn.Sequential()
model.to(device)
model.eval()
img_names=[]
img_features = []
for image,img_name in mydata_loader:
    image = image.to(device, torch.float)
    img_names.append(img_name)
    with torch.no_grad():
        predictions = model(image)
    predictions.unsqueeze_(1)
    for t in predictions:
        img_features.append(t)
    print(len(img_features))
b = str(img_names)
b = b.replace('(', '')
b = b.replace(')', '')
img_names = list(eval(b))
torch.cat(img_features,dim=0)

dictionary = dict(zip( img_features,img_names))
result=[]
while len(dictionary)>0:
    imgf=list(dictionary.keys())[0]
    result.append(dictionary.get(imgf))
    dictionary.pop(imgf)
    for img_feature in list(dictionary.keys()):
        if cos_sim(imgf.tolist(),img_feature.tolist())>0.93:
            dictionary.pop(img_feature)

for image_name in result:
    image_path = os.path.join(file_path,image_name)
    image = Image.open(image_path)
    save_img_path= os.path.join(save_path,image_name)
    image.save(save_img_path)
print(result.__len__())








  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值