Openai Clip进行图像分类

# 工作记录

对Openai Clip初步研究

对Coco数据集分类

import os
import shutil
import torch
import clip
from PIL import Image
from tqdm import tqdm

# 一张图对应一个输出还是多个输出选项:sole or multiple
sole_or_multiple = "multiple"
# 读取图片后缀名
endwithjpg = ('.jpg', '.jpeg', '.png')
# 准确率阈值
accuracy = 0.5
# 自选类别输出(会自己创建对应文件夹,输入为列表)
choose_list = ["person", "dog", "cat","bird","station","bus","table"]
########################################################################

# 加载模型
device = "cuda"  if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
text = clip.tokenize(choose_list).to(device)


def load_img(path):
    img_file=[]
    for root,_,files in os.walk(path):
        for file in files:
            if file.lower().endswith(endwithjpg):
                img_file.append(str(root+'/'+file))
    return img_file


def clip_image(image):
    image = preprocess(Image.open(image)).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    return probs


def find_max_index(array):
    lst = list(array)
    max_value = max(lst)
    max_index = lst.index(max_value)
    return max_value, max_index

def save_image(image,probs,savepath_list):
    max_value, max_index = find_max_index(probs)
    if max_value < accuracy:
       return 0     
    else:
        if sole_or_multiple == 'sole':
            savepath = savepath_list[max_index]
            shutil.copy(image,savepath+'/'+(image.split("/")[-1]).split('.')[0]+'_'+str(max_value)+'.jpg')
        if sole_or_multiple == 'multiple':
            for i in range(len(probs)):
                if probs[i] > accuracy:
                    savepath = savepath_list[i]
                    shutil.copy(image,savepath+'/'+(image.split("/")[-1]).split('.')[0]+'_'+str(probs[i])+'.jpg')

if __name__ == '__main__':
    # images路径
    img_path = '/home/yyh/large_model/datasets/cocoimages'
    # 存储路径
    savepath = '/home/yyh/large_model/datasets/cocoimagesort'
    savepath_list = []
    for name in choose_list:
        temp_path = savepath + '/' + name
        if not os.path.exists(temp_path):
            os.makedirs(temp_path)
        savepath_list.append(savepath+'/'+name)
        
    img_list = load_img(img_path)
    for image in tqdm(img_list):
        save_image(image,clip_image(image)[0],savepath_list)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值