1、下载数据
import sys
import os
import json
if __name__ =="__main__":
jsonfile = "data/pokeman.json" #sys.argv[1]
savedir = "data/pokemon-blip-captions/" # sys.argv[2]
os.system("curl -X GET \"https://datasets-server.huggingface.co/first-rows?dataset=lambdalabs%2Fpokemon-blip-captions&config=lambdalabs--pokemon-blip-captions&split=train\" > data/pokeman.json")
txtfile = savedir.rstrip("/")+"_text.txt"
with open(jsonfile) as fb:
lines = fb.readlines()
line = lines[0].rstrip()
data_dict = eval(line)
#data_dict = json.load(fb)
print(data_dict.keys())
txt_list = []
for ii in data_dict["rows"]:
idx = ii["row_idx"]
url = ii["row"]["image"]["src"]
os.system(f"wget \"{url}\" -O {savedir}/{idx}.jpg")
txt = ii["row"]["text"]
txt_list.append(txt)
with open(f"{savedir}/{idx}.txt", 'w') as fb:
fb.writelines(txt)
with open(txtfile, 'w') as fb:
fb.writelines("\n".join(txt_list))
##
2、diffusers仓库
1)安装环境,需要额外安装dataset, pip install datasets
2)升级accelerate==0.16.0
3、训练脚本
Using LoRA for Efficient Stable Diffusion Fine-Tuning
采用 diffuers训练脚本
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="models/mymodel/lora_pokeman"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" examples/text_to_image/train_text_to_image_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--dataloader_num_workers=8 \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=15000 \
--learning_rate=1e-04 \
--max_grad_norm=1 \
--lr_scheduler="cosine" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR} \
--checkpointing_steps=500 \
--validation_prompt="" \
--seed=1337
4、输出(sd1.5)
a drawing of a gray and yellow pokemon <lora:lora_pokeman-15000:1>