【DA-CLIP】Windows下使用clip_interrogator/BLIP进行生成图像-文本-退化类型数据集的generate_caption.py代码运行逻辑

 一、背景:

DA-CLIP/open_clip模型创建代码思路:

总体而言代码使用了多层方法的调用

clip_interrogator在open_clip的最外层又定义了一层

Interrogator():
__init__()加载BLIP模型
load_clip_model()加载CLIP模型

最外层:open_clip\factory.py

open_clip.create_model_and_transforms,open_clip.create_model_from_pretrained

次外层:open_clip\factory.py

create_model()

第三层:

这个比较多样。以openai模型为例,open_clip\openai.py。还有CoCa、CustomTextCLIP、CLIP等直接创建类实例,而非第五层才创建,该途径需要根据create_model()输入的custom_text布尔值设置。具体见create_model()

load_openai_model()

第四层:open_clip\model.py

以openai模型为例

build_model_from_openai_state_dict()

第五层:

CLIP类、CustomTextCLIP类 open_clip\model.py

CoCa类 open_clip\coca_model.py

 DA-CLIP接受一个CLIP实例作为参数初始化

def __init__(self, clip_model: CLIP):

【3万字代码解读】DA-CLIP/open_clip模型创建、模型配置读取、预训练权重地址读取icon-default.png?t=N7T8http://t.csdnimg.cn/HTz2m

BLIP:

旨在实现统一的视觉-语言理解和生成。该库提供了预训练和微调后的模型检查点,支持

图像-文本检索、图像标题生成、视觉问答和NLVR2等多种任务。

DA-CLIP关于使用BLIP生成数据集的代码注释icon-default.png?t=N7T8http://t.csdnimg.cn/rW7E1一个BLIP在colab上的运行demo,和对generate_caption.py代码讲解,注意本文才涉及带代码运行和windows运行相关问题。

clip_interrogator

是一个结合了 OpenAI 的 CLIP 和 Salesforce 的 BLIP 技术的 prompt 工程工具,它专门设计用于优化文本提示(prompts),以便与给定图像相匹配。该工具的技术实现基于这两个先进的多模态模型,通过分析图像内容和相关的文本描述,生成高质量的文本提示。

功能方面,CLIP-Interrogator 允许用户通过自然语言与 AI 进行交互,提出关于图像的问题,并获取相应的文本描述。这些生成的文本提示可以用于文本到图像的模型(如 Stable Diffusion)来创造新的艺术作品或进行图像生成的实验。此外,CLIP-Interrogator 支持在不同的 CLIP 模型之间进行选择,并且可以配置以适应不同的硬件条件,如低 VRAM 设备。

CLIP-Interrogator 可以作为一个库来使用,允许开发者在自己的脚本中调用其功能,从而实现图像内容的自动化分析和描述生成。此外,它还提供了与自己定义的术语列表进行排名对比的功能,这使得用户可以根据自己的特定需求定制化模型的输出。

GitHub - pharmapsychotic/clip-interrogator: Image to prompt with BLIP and CLIP

建议提前阅读该仓库说明,,对模型使用和参数有相关介绍

CLIP查询器使用OpenCLIP,它支持许多不同的预训练CLIP模型。对于Stable Diffusion1的最佳提示。X使用viti - l -14/openai为clip_model_name。Stable Diffusion2.0使用viti - h -14/ laon2b_s32b_b79k

Config对象允许您配置CLIP询问者的处理。
clip_model_name:使用哪个OpenCLIP预训练的CLIP模型
Cache_path:保存预先计算的文本嵌入的路径
download_cache:当为True时,将从huggingface下载预先计算的嵌入
chunk_size: CLIP的批处理大小,对于较小的VRAM使用更小的
quiet: True时不显示进度条或文本输出

二、环境安装准备数据集等

在原先环境的基础上需要clip_interrogator。

Create dataset:

To generate clean captions with BLIP, we use the clip-interrogator tool. Install it with pip install clip-interrogator==0.6.0 and run:

python ../scripts/generate_captions.py

Then you will get daclip_train.csv and daclip_val.csv under the datasets/universal directory.

使用pip自带的channel直接下载网速太慢又容易中断。

我直接修改配置,下次就不用指定下载源了

在命令行输入notepad %APPDATA%\pip\pip.ini

修改pip.ini文件的index-url改为清华源,地址如下

https://pypi.tuna.tsinghua.edu.cn/simple/

光速下完

 

index-url 是默认的 channel。

extra-index-url 是额外的 channels

trusted-host 是你信任的 hosts,

设定数据集地址

三、模型加载问题:网络及地址

3.1:Can't load tokenizer for 'bert-base-uncased'.

不挂梯子连不了hugging face,,挂了代理又运行不了。

先看报错: Can't load tokenizer for 'bert-base-uncased'.点击查看报错点源码

文件路径为C:\Users\86136\anaconda3\envs\DA-CLIP\Lib\site-packages\blip\models\blip.py

def init_tokenizer():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.add_special_tokens({'bos_token':'[DEC]'})
    tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})       
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]  
    return tokenizer

 明确目标先离线下载bert-base-uncased,参考博文

http://t.csdnimg.cn/BDNALicon-default.png?t=N7T8http://如何下载和在本地使用Bert预训练模型推荐谷歌浏览器,下载更快。博主下载了pytorch版本和其他相关配置文件。

地址放在本项目该代码目录下的新建文件夹bert-base-uncased中

修改该方法下代码读取的目录,,修改为你的地址

tokenizer = BertTokenizer.from_pretrained('C:\\Users\\86136\\Desktop\\daclip-uir-main\\scripts\\bert-base-uncased')

3.2 简述generate_caption.py和Interrogator函数思路:

ci = Interrogator(Config(clip_model_name="ViT-L-14/openai"))

该代码创建一个Interrogator类实例ci 。Interrogator类中加载了BLIP模型和CLIP模型,相应代码下面会有。随后使用generate_caption.py中创建的generate_captions()方法

generate_captions(dataroot, ci, 'val')

 该该方法中 创建caption的语句是Interrogator类中的generate_caption方法。

所以我们有必要查看clip_interrogator.py的相关代码

caption = ci.generate_caption(image)

3.3BLIP模型权重下载和地址存放问题

在generate_caption.py的ci = Interrogator(Config(clip_model_name="ViT-L-14/openai"))

 查看Config方法,鼠标双击Config后等待一会出现注解

点击标蓝的文字出现代码文件路径

 打开

C:\Users\86136\anaconda3\envs\DA-CLIP\________这部分是你的环境路径

Lib\site-packages\clip_interrogator\clip_interrogator.py

3.3.1模型选择和下载地址 

 该文件开头定义了BLIP_MODELS的下载地址,,我运行时自动下载很快就没有修改这部分

有需要可以自行下载再修改地址,,不过下载后本地地址要写对不然报错RuntimeError,加载地址相关代码可见3.3.4

"large"1.75Ghttps://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth

BLIP_MODELS = {
    'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
    'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
}

查看Config类 

blip_model_type: str = 'large' 

# choose between 'base' or 'large'
#尽管后面关于blip_model_type的定义全都是base.但在最外层用的是large

因为Interrogator使用的是Config.blip_model_type

3.3.2创建BLIP的代码

Interrogator类的__init__中,如下 

 if config.blip_model is None:
            if not config.quiet:
                print("Loading BLIP model...")
            blip_path = os.path.dirname(inspect.getfile(blip_decoder))
            configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
            med_config = os.path.join(configs_path, 'med_config.json')
            blip_model = blip_decoder(
                pretrained=BLIP_MODELS[config.blip_model_type],
                image_size=config.blip_image_eval_size, 
                vit=config.blip_model_type, 
                med_config=med_config
            ) #创建模型
            blip_model.eval() # 将模型设置为评估模式
            if not self.config.blip_offload:
                blip_model = blip_model.to(config.device)
            self.blip_model = blip_model

3.3.3配置文件本地地址

 可以看到该blip模型的配置文件路径为:你的虚拟环境包下的blip\configs\bert_config.json

C:\Users\86136\anaconda3\envs\DA-CLIP\Lib\site-packages\blip\configs\bert_config.json

3.3.4模型权重文件.pth下载后的保存地址

该blip_decoder仍是嵌套了一层模型加载 

def blip_decoder(pretrained='',**kwargs):
    model = BLIP_Decoder(**kwargs)
    if pretrained:
        model,msg = load_checkpoint(model,pretrained)
        assert(len(msg.missing_keys)==0)
    return model  

 在BLIP_Decoder中读取了配置文件

med_config = BertConfig.from_json_file(med_config)

关于模型权重文件更多细节需要查看loac_checkpoint() 

def load_checkpoint(model,url_or_filename):
    # print("url_or_filename",url_or_filename)
    # https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
    if is_url(url_or_filename):
        cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
        checkpoint = torch.load(cached_file, map_location='cpu') 
    elif os.path.isfile(url_or_filename):        
        checkpoint = torch.load(url_or_filename, map_location='cpu') 
    else:
        raise RuntimeError('checkpoint url or path is invalid')
        
    state_dict = checkpoint['model']
    
    state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 
    if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
        state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
                                                                         model.visual_encoder_m)    
    for key in model.state_dict().keys():
        if key in state_dict.keys():
            if state_dict[key].shape!=model.state_dict()[key].shape:
                del state_dict[key]
    
    msg = model.load_state_dict(state_dict,strict=False)
    print('load checkpoint from %s'%url_or_filename)  
    return model,msg

这段Python代码定义了一个名为 load_checkpoint 的函数,它的目的是加载一个预训练的模型检查点(checkpoint)到一个给定的模型中。这个函数接受两个参数:model(要加载检查点的模型实例)和 url_or_filename(包含检查点文件的URL或文件路径)。

由于download_cached_file()代码包含存放地址检测代码,已经下载过的不会再次下载

函数的主要步骤如下:

  1. 首先,使用 is_url 函数检查提供的 url_or_filename 是否是一个有效的URL。如果是,那么 download_cached_file 函数将被用来下载并缓存该文件。torch.load 函数随后用于加载下载的检查点文件,map_location='cpu' 参数确保加载过程在CPU上执行,这在没有GPU可用的情况下很有用。

  2. 如果 url_or_filename 不是一个URL,那么代码将检查它是否指向一个存在的文件。如果文件存在,同样使用 torch.load 来加载检查点。

  3. 如果 url_or_filename 既不是有效的URL也不是存在的文件,函数将引发一个 RuntimeError,指出检查点的URL或路径无效。

  4. 加载检查点后,函数将处理检查点中的 visual_encoder.pos_embed 键,使用 interpolate_pos_embed 函数对其进行插值。这可能是为了匹配模型的当前位置嵌入(position embeddings)的尺寸。

  5. 接下来,函数检查 visual_encoder_m.pos_embed 键是否存在于模型的状态字典中,如果存在,也对其进行插值处理。

  6. 然后,函数遍历模型的状态字典和检查点的状态字典中的所有键,如果键在两个字典中都存在但形状不匹配,那么该键将从检查点的状态字典中删除。

  7. 最后,使用 model.load_state_dict 方法将处理后的检查点加载到模型中,strict=False 参数允许忽略不匹配的键。函数打印一条消息,指示已从何处加载检查点,并返回模型和加载操作的结果 msg

 需要使用本地BLIP模型权重地址的同学可以考虑这一行。

    elif os.path.isfile(url_or_filename):        
        checkpoint = torch.load(url_or_filename, map_location='cpu') 

想知道自动下载的保存地址查看download_cached_file()函数

from timm.models.hub import download_cached_file

 博主最终查到如下

C:\Users\86136\.cache\torch\hub\checkpoints\model_large_caption.pth

红色为你的缓存路径

3.4CLIP模型权重加载和标签下载

   def load_clip_model(self):
        start_time = time.time()
        config = self.config

        if config.clip_model is None:#没有传入模型就
            if not config.quiet:
                print("Loading CLIP model...")

            clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
            self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
                clip_model_name, 
                pretrained=clip_model_pretrained_name, 
                precision='fp16' if config.device == 'cuda' else 'fp32',
                device=config.device,
                jit=False,
                cache_dir=config.clip_model_path
            )
            print("create clip_model over")
            self.clip_model.eval()
        else:
            self.clip_model = config.clip_model
            self.clip_preprocess = config.clip_preprocess
        self.tokenize = open_clip.get_tokenizer(clip_model_name)

该代码主要关注点open_clip.create_model_and_transforms,不记得创建模型函数调用过程可以回到开头。注意该函数参数我们可以根据config.clip_model_path设置本地预训练权重地址

cache_dir=config.clip_model_path

 权重加载问题描述:HTTPSConnectionPool(host='huggingface.co', port=443)

ViT-L-4/openai模型权重下载地址在HF,连不上。

根据上篇文章的考证,可以在open_clip\pretrained.py找到所有模型相关下载地址和函数

https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt

挂梯子下载到本地

添加地址参数。注意不包含权重文件名,否则报错。暂时还没找相关代码

ci = Interrogator(Config(
        clip_model_name="ViT-L-14/openai",
        clip_model_path="E:\\download\\ViT-L-14"  # 你的模型权重文件的本地路径
    ))

  标签下载HTTPSConnectionPool(host='huggingface.co', port=443)


        sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 
                 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 
                 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
        trending_list = [site for site in sites]
        trending_list.extend(["trending on "+site for site in sites])
        trending_list.extend(["featured on "+site for site in sites])
        trending_list.extend([site+" contest winner" for site in sites])

        raw_artists = _load_list(config.data_path, 'artists.txt')
        artists = [f"by {a}" for a in raw_artists]
        artists.extend([f"inspired by {a}" for a in raw_artists])
        # print(config.data_path)
        # C:\Users\86136\anaconda3\envs\DA-CLIP\lib\site-packages\clip_interrogator\data
        self._prepare_clip()
        self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
        self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
        self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
        self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
        self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
        self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)

这段代码的主要目的是构建和准备一系列的标签(labels)和搜索提示(prompts),这些标签和提示将用于与CLIP模型交互,以便生成或搜索特定类型的艺术作品。代码中创建了几个不同的标签列表,包括基于流行网站和艺术家的标签,以及一些特定的艺术风格、媒介、运动和负面标签。

具体步骤如下:

  1. 定义了一个名为 sites 的列表,包含了一系列艺术和设计相关的网站名称,如 'Artstation', 'Behance' 等。

  2. 使用列表推导式创建了 trending_list,这个列表包含了基于 sites 中的网站名称生成的各种搜索提示,如 "trending on Artstation", "featured on Behance" 等。

  3. 通过调用 _load_list 函数,从配置路径中加载了一个名为 'artists.txt' 的文件,该文件包含了艺术家的名字。

  4. 基于加载的艺术家名单 raw_artists,创建了 artists 列表,其中包含了由艺术家名字构成的标签,如 "by {artist_name}" 和 "inspired by {artist_name}"。

  5. 接下来,代码创建了五个 LabelTable 对象,每个对象都与特定的标签列表和CLIP模型相关联。这些对象分别是:

  • self.artists:与艺术家相关的标签。
  • self.flavors:与艺术风格相关的标签。
  • self.mediums:与艺术媒介相关的标签。
  • self.movements:与艺术运动相关的标签。
  • self.trendings:与流行趋势相关的标签。
  • self.negative:与负面标签相关的列表

 

 提供对应标签下载地址

https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors

https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors

 

https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors

https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors

https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors

https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_negative.safetensors

放到你的项目该文件夹下,cache是运行时会生成的,如果没有,新建即可 

 

 over

现在只剩数据集问题了

  • 30
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值