欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/141605718
免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。
Joy Caption 模型 (由 Fancy Feast 研发) 在 SigLIP 和 Llama3.1 的基础之上,使用 Adapter 模式,训练出更好的描述图像的模型,需要与 SigLIP 和 Llama3.1 混合使用,输入图像,输出一段语义丰富的图像描述。
- Google 的 SigLIP (Sigmoid Loss for Language Image Pre-Training) 是一种改进的多模态模型,类似于 CLIP,但是采用了更优的损失函数。
- Meta-Llama-3.1-8B-bnb-4bit 是优化的多语言大语言模型,基于 Meta 的 Llama 3.1 架构,使用 BitsAndBytes 库进行 4-bit 量化,大幅减少内存使用,同时保持模型性能。
1. 环境配置
相关工程:
- fancyfeast/joy-caption-pre-alpha
google/siglip-so400m-patch14-384
unsloth/Meta-Llama-3.1-8B-bnb-4bit
下载 HuggingFace 库:
export HF_ENDPOINT="https://hf-mirror.com"
pip install -U huggingface_hub hf-transfer
huggingface-cli download --token [your hf token] Wi-zz/joy-caption-pre-alpha --local-dir joy-caption-pre-alpha
huggingface-cli download --token [your hf token] google/siglip-so400m-patch14-384 --local-dir siglip-so400m-patch14-384
huggingface-cli download --token [your hf token] unsloth/Meta-Llama-3.1-8B-bnb-4bit --local-dir Meta-Llama-3.1-8B-bnb-4bit
在镜像中,使用
Wi-zz/joy-caption-pre-alpha
代替fancyfeast/joy-caption-pre-alpha
存储目录:
fancyfeast/joy-caption-pre-alpha: joy-caption-pre-alpha/
google/siglip-so400m-patch14-384: siglip-so400m-patch14-384/
unsloth/Meta-Llama-3.1-8B-bnb-4bit: Meta-Llama-3.1-8B-bnb-4bit/
修改脚本的环境路径:joy-caption-pre-alpha/app-multi-alpha.py
CLIP_PATH = "siglip-so400m-patch14-384"
VLM_PROMPT = "A descriptive caption for this image:\n"
MODEL_PATH = "Meta-Llama-3.1-8B-bnb-4bit"
CHECKPOINT_PATH = Path("joy-caption-pre-alpha/wpkklhc6")
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
2. 运行脚本
运行脚本:
CUDA_VISIBLE_DEVICES=0,1,2,3 python app-multi-alpha.py \
joy-caption-pre-alpha/input_dir/shoes_dataset/1.jpg \
--bs 1
日志:
Loading CLIP 📎 on GPU 1
Loading CLIP 📎 on GPU 2
Loading CLIP 📎 on GPU 0
Loading CLIP 📎 on GPU 3
Loading tokenizer 🪙 on GPU 1
Loading tokenizer 🪙 on GPU 2
Loading tokenizer 🪙 on GPU 0
Loading tokenizer 🪙 on GPU 3
Loading LLM 🤖 on GPU 1
Loading LLM 🤖 on GPU 0
Loading LLM 🤖 on GPU 2
Loading LLM 🤖 on GPU 3
Loading image adapter 🖼️ on GPU 3
Loading image adapter 🖼️ on GPU 0
Loading image adapter 🖼️ on GPU 1
Loading image adapter 🖼️ on GPU 2
Processing single image 🎞️: 1.jpg
Processing image: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.60s/image]
安装包:
pip install accelerate -i https://pypi.tuna.tsinghua.edu.cn/simple
描述文件是图像的同名文件,将 .jpg
转换成 .txt
,即:
The image is a high-resolution photograph featuring a close-up view of a person’s feet and part of their legs. The individual is wearing a pair of stylish, high-heeled shoes. The shoes are predominantly navy blue, with a textured, woven pattern that gives them a sophisticated and elegant appearance. The upper part of the shoes is made of leather, featuring a cut-out design on the front that reveals the toes, and a small, round, gold-toned buckle on the side. The shoes have a chunky, black leather strap that wraps around the ankle, adding a touch of luxury. The wearer is also seen wearing black leggings that extend up to the knee. The shoes have a sturdy, black leather sole and a high, chunky heel. The background of the image is a light beige, plush sofa, which adds a touch of comfort and elegance to the setting. The overall scene is well-lit, highlighting the details of the shoes and the person’s neatly manicured toenails, painted a bright red. The image focuses on the aesthetic and fashion aspects, providing a clear and detailed view of the footwear.
这张图片是一张高分辨率的照片,展示了一个人的脚和部分腿部的特写。这个人穿着一双时尚的高跟鞋。鞋子主要是深蓝色,带有纹理的编织图案,使它们看起来既复杂又优雅。鞋子的上部由皮革制成,前部有一个镂空设计,露出脚趾,侧面有一个小巧的圆形金色扣环。鞋子有一条粗厚的黑色皮革带子缠绕在脚踝上,增添了一丝奢华感。穿着者还穿着黑色的打底裤,延伸至膝盖。鞋子有一个坚固的黑色皮革鞋底和一个高而粗的鞋跟。图片的背景是一张浅米色的豪华沙发,为场景增添了舒适和优雅的气息。整个场景光线充足,突出了鞋子的细节和人的整齐修剪的脚趾甲,涂着鲜红色。图片聚焦于审美和时尚方面,提供了鞋子的清晰详细视图。
左侧图像是原图,右侧图像是根据 描述(Caption) 由 Flux 输出的图像:
![Img](https://i-blog.csdnimg.cn/direct/f04a72f02ca140f897bfce2e38bb136a.png)
## 3. 源码解析
ImageAdapter 是两层线性结构,使用 `GELU` 激活函数
- 输入维度:`clip_model.config.hidden_size`,维度是 1152
- 输出维度:`text_model.config.hidden_size`,维度是 4096
- 将 `image_features: torch.Size([1, 729, 1152])` 转换成 `embedded_images: torch.Size([1, 729, 4096])`
源码:
```python
class ImageAdapter(nn.Module):
def __init__(self, input_features: int, output_features: int):
super().__init__()
self.linear1 = nn.Linear(input_features, output_features)
self.activation = nn.GELU()
self.linear2 = nn.Linear(output_features, output_features)
def forward(self, vision_outputs: torch.Tensor):
return self.linear2(self.activation(self.linear1(vision_outputs)))
# ...
def load_models(rank):
print(f"Loading CLIP 📎 on GPU {rank}")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model.eval().requires_grad_(False).to(rank)
print(f"Loading tokenizer 🪙 on GPU {rank}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
print(f"Loading LLM 🤖 on GPU {rank}")
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map={"": rank}, torch_dtype=torch.bfloat16).eval()
print(f"Loading image adapter 🖼️ on GPU {rank}")
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location=f"cuda:{rank}", weights_only=True))
image_adapter.eval().to(rank)
return clip_processor, clip_model, tokenizer, text_model, image_adapter
# ...
with torch.amp.autocast_mode.autocast(f"cuda:{rank}", enabled=True):
vision_outputs = clip_model(pixel_values=images, output_hidden_states=True)
image_features = vision_outputs.hidden_states[-2]
print(f"[CL] image_features: {image_features.shape}")
embedded_images = image_adapter(image_features).to(dtype=torch.bfloat16)
print(f"[CL] embedded_images: {embedded_images.shape}")
再将 embedded_images
添加终止符与文本描述,即 起始符(1) + embedded_images
(8) + prompt_embeds
(729) = 738:
prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt')
prompt_embeds = text_model.model.embed_tokens(prompt.to(rank)).to(dtype=torch.bfloat16)
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=rank, dtype=torch.int64)).to(dtype=torch.bfloat16)
inputs_embeds = torch.cat([
embedded_bos.expand(embedded_images.shape[0], -1, -1),
embedded_images,
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
], dim=1).to(dtype=torch.bfloat16)
print(f"[CL] inputs_embeds.shape: {inputs_embeds.shape}")
input_ids = torch.cat([
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1),
torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long),
prompt.expand(embedded_images.shape[0], -1),
], dim=1).to(rank)
print(f"[CL] input_ids.shape: {input_ids.shape}")
输出:
[CL] inputs_embeds.shape: torch.Size([1, 738, 4096])
[CL] input_ids.shape: torch.Size([1, 738])
将 inputs_embeds
、input_ids
、attention_mask
输入至文本模型 (Llama3) 进行推理并且解码,即:
generate_ids = text_model.generate(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=300,
do_sample=True,
top_k=10,
temperature=0.5,
)
generate_ids = generate_ids[:, input_ids.shape[1]:]
for ids in generate_ids:
caption = tokenizer.decode(ids[:-1] if ids[-1] == tokenizer.eos_token_id else ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
all_captions.append(caption)
完整请参考源码:app-multi-alpha.py
错误:
File "joy-caption-pre-alpha/app-multi-alpha.py", line 81, in stream_chat
with torch.amp.autocast_mode.autocast(rank, enabled=True):
File "miniconda3/envs/ai-toolkit/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 225, in __init__
raise ValueError(
ValueError: Expected `device_type` of type `str`, got: `<class 'int'>`
BugFix:
# with torch.amp.autocast_mode.autocast(rank, enabled=True)
with torch.amp.autocast_mode.autocast(f"cuda:{rank}", enabled=True)