【飞桨黑客松】AIGC - DreamBooth LoRA 文生图模型微调_en

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

【PaddlePaddle Hackathon 4】基于 PPDiffusers 训练 AIGC 趣味模型官方教程

🔥🔥🔥 前往官网了解题目详情、报名、冲大奖、玩创意!👉👉👉 https://github.com/PaddlePaddle/Paddle/issues/50631#task105

本教程将从以下两个方面带领大家熟悉整个流程。

  • 1. 准备工作
    • 1.1 环境安装
    • 1.2 Hugging Face Space 注册和登录
  • 2. 如何训练
    • 2.1 上传图片
    • 2.2 训练参数调整
    • 2.3 可视化训练过程
    • 2.4 挑选满意的权重上传至Huggingface

1. 准备工作

1.1 环境安装


在开始之前,我们需要准备我们所需的环境,运行下面的命令安装依赖。为了确保安装成功,安装完毕请重启内核!(注意:这里只需要运行一次!)

pip install "paddlenlp>=2.5.2" "ppdiffusers>=0.11.1" safetensors --user
# 请运行这里安装所需要的依赖环境!!
!pip install "paddlenlp>=2.5.2" safetensors "ppdiffusers>=0.11.1" --user
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting paddlenlp>=2.5.2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e4/67/f97788181e3a49afcdf8ffa162a335c21d18a55387bae85be24b01383165/paddlenlp-2.5.2-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting safetensors
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7c/34/54c2207f5b4eaf9b455ab679d7aa9a1c0c13e159fc09a491507933b02ad1/safetensors-0.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting ppdiffusers>=0.11.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ab/fe/d305869384636e0c3e272c41d78915fabf0776064bbe7282bbeb4a377300/ppdiffusers-0.14.2-py3-none-any.whl (977 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m977.2/977.2 kB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hRequirement already satisfied: multiprocess<=0.70.12.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.70.11.1)
Requirement already satisfied: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.4.4)
Collecting typer
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/bf/0e/c68adf10adda05f28a6ed7b9f4cd7b8e07f641b44af88ba72d9c89e4de7a/typer-0.9.0-py3-none-any.whl (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.9/45.9 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hRequirement already satisfied: huggingface-hub>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.12.1)
Requirement already satisfied: uvicorn in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.21.1)
Requirement already satisfied: dill<0.3.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.3.3)
Requirement already satisfied: paddlefsl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (1.1.0)
Requirement already satisfied: Flask-Babel<3.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (1.0.0)
Requirement already satisfied: rich in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (13.3.2)
Requirement already satisfied: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (4.1.0)
Requirement already satisfied: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.42.1)
Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (2.4.0)
Requirement already satisfied: fastapi in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.95.0)
Requirement already satisfied: datasets>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (2.10.1)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (4.64.1)
Requirement already satisfied: paddle2onnx in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (1.0.0)
Requirement already satisfied: seqeval in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (1.2.2)
Requirement already satisfied: sentencepiece in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.5.2) (0.1.96)
Collecting ftfy
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e1/1e/bf736f9576a8979752b826b75cbd83663ff86634ea3055a766e2d8ad3ee5/ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting regex
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9d/1e/8eb13233ac58edecdd58aa7de0d5b68fc04f7141891c1934036b0b34890a/regex-2023.6.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (755 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m755.7/755.7 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hRequirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from ppdiffusers>=0.11.1) (8.2.0)
Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (2.24.0)
Requirement already satisfied: packaging in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (21.3)
Requirement already satisfied: aiohttp in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (3.8.4)
Requirement already satisfied: pyarrow>=6.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (11.0.0)
Requirement already satisfied: numpy>=1.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (1.19.5)
Requirement already satisfied: responses<0.19 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (0.18.0)
Requirement already satisfied: fsspec[http]>=2021.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (2023.1.0)
Requirement already satisfied: xxhash in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (3.2.0)
Requirement already satisfied: pyyaml>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (5.1.2)
Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (4.2.0)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from datasets>=2.0.0->paddlenlp>=2.5.2) (1.1.5)
Requirement already satisfied: Flask in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel<3.0.0->paddlenlp>=2.5.2) (1.1.1)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel<3.0.0->paddlenlp>=2.5.2) (2019.3)
Requirement already satisfied: Jinja2>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel<3.0.0->paddlenlp>=2.5.2) (3.0.0)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel<3.0.0->paddlenlp>=2.5.2) (2.8.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from huggingface-hub>=0.11.1->paddlenlp>=2.5.2) (4.3.0)
Requirement already satisfied: filelock in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from huggingface-hub>=0.11.1->paddlenlp>=2.5.2) (3.0.12)
Requirement already satisfied: starlette<0.27.0,>=0.26.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fastapi->paddlenlp>=2.5.2) (0.26.1)
Requirement already satisfied: pydantic!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0,>=1.6.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fastapi->paddlenlp>=2.5.2) (1.10.6)
Collecting wcwidth>=0.2.5
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/20/f4/c0584a25144ce20bfcf1aecd041768b8c762c1eb0aa77502a3f0baa83f11/wcwidth-0.2.6-py2.py3-none-any.whl (29 kB)
Requirement already satisfied: markdown-it-py<3.0.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from rich->paddlenlp>=2.5.2) (2.2.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from rich->paddlenlp>=2.5.2) (2.13.0)
Requirement already satisfied: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp>=2.5.2) (0.24.2)
Collecting click<9.0.0,>=7.1.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/f1/df59e28c642d583f7dacffb1e0965d0e00b218e0186d7858ac5233dce840/click-8.1.3-py3-none-any.whl (96 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m96.6/96.6 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hRequirement already satisfied: h11>=0.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from uvicorn->paddlenlp>=2.5.2) (0.14.0)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp>=2.5.2) (0.8.53)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp>=2.5.2) (3.20.0)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp>=2.5.2) (2.2.3)
Requirement already satisfied: six>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp>=2.5.2) (1.16.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask->Flask-Babel<3.0.0->paddlenlp>=2.5.2) (0.16.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask->Flask-Babel<3.0.0->paddlenlp>=2.5.2) (1.1.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (1.8.2)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (22.1.0)
Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (3.0.1)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (1.3.3)
Requirement already satisfied: asynctest==0.13.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (0.13.0)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (1.3.1)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (4.0.2)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp>=2.5.2) (6.0.4)
Requirement already satisfied: MarkupSafe>=2.0.0rc2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.5->Flask-Babel<3.0.0->paddlenlp>=2.5.2) (2.0.1)
Requirement already satisfied: mdurl~=0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from markdown-it-py<3.0.0,>=2.2.0->rich->paddlenlp>=2.5.2) (0.1.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging->datasets>=2.0.0->paddlenlp>=2.5.2) (3.0.9)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->datasets>=2.0.0->paddlenlp>=2.5.2) (1.25.11)
Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->datasets>=2.0.0->paddlenlp>=2.5.2) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->datasets>=2.0.0->paddlenlp>=2.5.2) (2.8)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.19.0->datasets>=2.0.0->paddlenlp>=2.5.2) (2019.9.11)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.5.2) (2.1.0)
Requirement already satisfied: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.5.2) (0.14.1)
Requirement already satisfied: scipy>=0.19.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.5.2) (1.6.3)
Requirement already satisfied: anyio<5,>=3.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from starlette<0.27.0,>=0.26.1->fastapi->paddlenlp>=2.5.2) (3.6.1)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp>=2.5.2) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp>=2.5.2) (0.18.0)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->datasets>=2.0.0->paddlenlp>=2.5.2) (3.8.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl->paddlenlp>=2.5.2) (1.1.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl->paddlenlp>=2.5.2) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl->paddlenlp>=2.5.2) (0.10.0)
Requirement already satisfied: sniffio>=1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from anyio<5,>=3.4.0->starlette<0.27.0,>=0.26.1->fastapi->paddlenlp>=2.5.2) (1.3.0)
Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib->visualdl->paddlenlp>=2.5.2) (56.2.0)
Installing collected packages: wcwidth, safetensors, regex, ftfy, click, typer, paddlenlp, ppdiffusers
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
parl 1.4.1 requires pyzmq==18.1.1, but you have pyzmq 23.2.1 which is incompatible.[0m[31m
[0mSuccessfully installed click-8.1.3 ftfy-6.1.1 paddlenlp-2.5.2 ppdiffusers-0.14.2 regex-2023.6.3 safetensors-0.3.1 typer-0.9.0 wcwidth-0.2.6

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.1.2[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

1.2 Hugging Face Space 注册和登录

题目要求将模型上传到 Hugging Face,需要先注册、登录。

  • 注册和登录:https://huggingface.co/join

  • 获取登录 Token

  • Aistudio 登录 Huggingface Hub

Tips:为了方便我们之后上传权重,我们需要登录 Huggingface Hub,想要了解更多的信息我们可以查阅 官方文档

!git config --global credential.helper store
from huggingface_hub import login
login()
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
  • tips:如何检测是否登录成功?

打开日志控制控制台,查看日志。

登录成功时,日志如下:

2. 如何训练模型,并上传到HF

2.1 上传图片

  • 首先,我们需要将所需训练的图片上传到aistudio上的文件夹, 我们可以通过👉拖拽上传 的方式,将我们所需的图片上传至指定的文件夹。
  • 在这里,我们已经在👉dogs文件夹准备好了如下所示的5张图片。

2.2 训练参数调整

在训练过程中,我们可以尝试修改训练的默认参数,下面将从三个方面介绍部分参数。

👉主要修改的参数:

  • pretrained_model_name_or_path :想要训练的模型名称或者本地路径的模型,例如:"runwayml/stable-diffusion-v1-5",更多模型可参考 PaddleNLP 文档
  • instance_data_dir:训练图片所在的文件夹目录,我们可以将图片上传至aistudio项目。
  • instance_prompt:训练所使用的 Prompt 文本。
  • resolution:训练时图像的分辨率,建议为 512
  • output_dir:训练过程中,模型保存的目录。
  • checkpointing_steps:每隔多少步保存模型,默认为100步。
  • learning_rate:训练使用的学习率,当我使用 LoRA 训练模型的时候,我们需要使用更大的学习率,因此我们这里使用 1e-4 而不是 2e-6
  • max_train_steps:最大训练的步数,默认为500步。

👉可选修改的参数:

  • train_batch_size:训练时候使用的 batch_size,当我们的GPU显存比较大的时候可以加大这个值,默认值为4
  • gradient_accumulation_steps:梯度累积的步数,当我们GPU显存比较小的时候还想模拟大的训练批次,我们可以适当增加梯度累积的步数,默认值为1
  • seed:随机种子,设置后可以复现训练结果。
  • lora_rankLoRA 层的 rank 值,默认值为4,最终我们会得到 3.5MB 的模型,我们可以适当修改这个值,如:32、64、128、256 等。
  • lr_scheduler:学习率衰减策略,可以是"linear", "constant", "cosine"等。
  • lr_warmup_steps:学习率衰减前,warmup 到最大学习率所需要的步数。

👉训练过程中评估使用的参数:

  • num_validation_images:训练的过程中,我们希望返回多少张图片,默认值为4张图片。
  • validation_prompt:训练的过程中我们会评估训练的怎么样,因此我们需要设置评估使用的 prompt 文本。
  • validation_steps:每隔多少个 steps 评估模型,我们可以查看训练的进度条,知道当前到了第几个 steps

🔥Tips:
训练过程中会每隔 validation_steps 将生成的图片保存到 {你指定的输出路径}/validation_images/{步数}.jpg

👉权重上传的参数:

  • push_to_hub: 是否将模型上传到 huggingface hub,默认值为 False
  • hub_token: 上传到 huggingface hub 所需要使用的 token,如果我们已经登录了,那么我们就无需填写。
  • hub_model_id: 上传到 huggingface hub 的模型库名称, 如果为 None 的话表示我们将使用 output_dir 的名称作为模型库名称。

在下面的例子中,由于我们前面已经登录了,因此我们可以开启 push_to_hub 按钮,将最终训练好的模型同步上传到 huggingface.co

当我们开启push_to_hub后,等待程序运行完毕后会自动将权重上传到这个路径 https://huggingface.co/{你的用户名}/{你指定的输出路径} ,例如: https://huggingface.co/junnyu/lora_outputs

这里使用Linaqruf/anything-v3.0模型

该模型可通过输入几个文本提示词就能生成高质量、高度详细的动漫风格图片,该模型支持使用 danbooru 标签文本 生成图像。

!python train_dreambooth_lora.py \
  --pretrained_model_name_or_path="hakurei/waifu-diffusion"  \
  --instance_data_dir="./pika" \
  --output_dir="pikachu_outputs" \
  --instance_prompt="a portrait of pikachu with a black cap growing on its head. intricate. lifelike. soft light. sony a 7 r iv 5 5 mm. cinematic post - processing" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --seed=0 \
  --lora_rank=16 \
  --push_to_hub=False \
  --validation_prompt="a portrait of pikachu with a black cap growing on its head. intricate. lifelike. soft light. sony a 7 r iv 5 5 mm. cinematic post - processing" \
  --validation_steps=100 \
  --num_validation_images=4
W0615 11:06:20.931605   710 gpu_resources.cc:85] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0615 11:06:20.936533   710 gpu_resources.cc:115] device: 0, cuDNN Version: 8.2.
正在下载模型权重,请耐心等待。。。。。。。。。。
100%|███████████████████████████████████████████| 291/291 [00:00<00:00, 227kB/s]
100%|████████████████████████████████████████| 842k/842k [00:00<00:00, 2.79MB/s]
100%|████████████████████████████████████████| 512k/512k [00:00<00:00, 12.0MB/s]
100%|████████████████████████████████████████| 2.00/2.00 [00:00<00:00, 1.60kB/s]
100%|███████████████████████████████████████████| 478/478 [00:00<00:00, 226kB/s]
100%|███████████████████████████████████████████| 487/487 [00:00<00:00, 368kB/s]
Downloading (…)cheduler_config.json: 100%|█████| 342/342 [00:00<00:00, 30.7kB/s]
100%|███████████████████████████████████████████| 487/487 [00:00<00:00, 448kB/s]
Downloading (…)model_state.pdparams:  21%|▋  | 105M/492M [00:12<00:22, 17.2MB/s]

2.3 可视化训练过程

VisualDL使用参考:官方教程

我们可以参照如图所示的步骤,开启visualdl,然后查看训练过程中的指标变化。

2.4 挑选满意的权重上传至Huggingface

参数解释:

  • upload_dir:我们需要上传的文件夹目录。
  • repo_name:我们需要上传的repo名称,最终我们会上传到 https://huggingface.co/{你的用户名}/{你指定的repo名称}, 例如: https://huggingface.co/junnyu/lora_sks_dogs.
  • pretrained_model_name_or_path:训练该模型所使用的基础模型。
  • prompt:搭配该权重需要使用的Prompt文本。
from utils import upload_lora_folder
upload_dir                    = "pikachu_outputs"                   # 我们需要上传的文件夹目录
repo_name                     = "pika_comic"                  # 我们需要上传的repo名称
pretrained_model_name_or_path = "hakurei/waifu-diffusion" # 训练该模型所使用的基础模型
prompt                        = "a portrait of pikachu with a black cap growing on its head. intricate. lifelike. soft light. sony a 7 r iv 5 5 mm. cinematic post - processing"  # 搭配该权重需要使用的Prompt文本

upload_lora_folder(
    upload_dir=upload_dir,
    repo_name=repo_name,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    prompt=prompt, 
)
Pushing to enkilee/pika_comic

ained_model_name_or_path=pretrained_model_name_or_path,
prompt=prompt,
)
``

Pushing to enkilee/pika_comic



Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]



paddle_lora_weights.pdparams:   0%|          | 0.00/12.8M [00:00<?, ?B/s]

欢迎加入NLP技术交流群,一起相互讨论交流~

此文章为搬运
原项目链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值