# 工作记录
对Openai Clip初步研究
对Coco数据集分类
import os
import shutil
import torch
import clip
from PIL import Image
from tqdm import tqdm
# 一张图对应一个输出还是多个输出选项:sole or multiple
sole_or_multiple = "multiple"
# 读取图片后缀名
endwithjpg = ('.jpg', '.jpeg', '.png')
# 准确率阈值
accuracy = 0.5
# 自选类别输出(会自己创建对应文件夹,输入为列表)
choose_list = ["person", "dog", "cat","bird","station","bus","table"]
########################################################################
# 加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
text = clip.tokenize(choose_list).to(device)
def load_img(path):
img_file=[]
for root,_,files in os.walk(path):
for file in files:
if file.lower().endswith(endwithjpg):
img_file.append(str(root+'/'+file))
return img_file
def clip_image(image):
image = preprocess(Image.open(image)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
return probs
def find_max_index(array):
lst = list(array)
max_value = max(lst)
max_index = lst.index(max_value)
return max_value, max_index
def save_image(image,probs,savepath_list):
max_value, max_index = find_max_index(probs)
if max_value < accuracy:
return 0
else:
if sole_or_multiple == 'sole':
savepath = savepath_list[max_index]
shutil.copy(image,savepath+'/'+(image.split("/")[-1]).split('.')[0]+'_'+str(max_value)+'.jpg')
if sole_or_multiple == 'multiple':
for i in range(len(probs)):
if probs[i] > accuracy:
savepath = savepath_list[i]
shutil.copy(image,savepath+'/'+(image.split("/")[-1]).split('.')[0]+'_'+str(probs[i])+'.jpg')
if __name__ == '__main__':
# images路径
img_path = '/home/yyh/large_model/datasets/cocoimages'
# 存储路径
savepath = '/home/yyh/large_model/datasets/cocoimagesort'
savepath_list = []
for name in choose_list:
temp_path = savepath + '/' + name
if not os.path.exists(temp_path):
os.makedirs(temp_path)
savepath_list.append(savepath+'/'+name)
img_list = load_img(img_path)
for image in tqdm(img_list):
save_image(image,clip_image(image)[0],savepath_list)