Wan2.1 图生视频 多卡推理批量生成视频
flyfish
Phantom 视频生成的实践
Phantom 视频生成的流程
Phantom 视频生成的命令
Wan2.1 图生视频 支持批量生成
Wan2.1 文生视频 支持批量生成、参数化配置和多语言提示词管理
Wan2.1 加速推理方法
Wan2.1 通过首尾帧生成视频
AnyText2 在图片里玩文字而且还是所想即所得
Python 实现从 MP4 视频文件中平均提取指定数量的帧
config.json
{
"task": "i2v-14B",
"size": "832*480",
"frame_num": null,
"ckpt_dir": "/media/models/Wan-AI/Wan2___1-I2V-14B-480P/",
"offload_model": null,
"ulysses_size": 2,
"ring_size": 1,
"t5_fsdp": false,
"t5_cpu": true,
"dit_fsdp": true,
"save_file": null,
"prompt": null,
"use_prompt_extend": false,
"prompt_extend_method": "local_qwen",
"prompt_extend_model": null,
"prompt_extend_target_lang": "zh",
"base_seed": -1,
"image": null,
"first_frame": null,
"last_frame": null,
"sample_solver": "unipc",
"sample_steps": null,
"sample_shift": null,
"sample_guide_scale": 5.0
}
prompt.json
[
{
"prompt": "Dragon Playing with Pearl: A warrior wields a red-tasseled spear, summoning seven dragon-like phantom spear tips amid swirling ink shadows that twist air into a shredding vortex; visuals include ink-black shadows, molten fire-red tassel, and a violent air vortex. ",
"image_paths": ["images/1.png"]
},
{
"prompt": "Slicing the Sky, Chopping the Moon: The warrior leaps, slashing the spear diagonally like lightning to create a glowing vacuum rift with azure electricity, then traces a lunar arc that solidifies space to trap enemies; visuals feature a billowing black cape, crackling rift, and frozen lunar arc. ",
"image_paths": ["images/1.png"]
}
]
模型加载时序图(单例模式)
┌──────────────────────────────────────────────────────────┐
│ WanI2VApp.run() │
│ ┌─────────────────┐ ┌─────────────────┐ ┌────────────┐ │
│ │ 加载配置/验证参数 │ │ 初始化分布式环境 │ │ 首次调用 │ │
│ └─────────────────┘ └─────────────────┘ └──────┬─────┘ │
│ │ │
│ ┌───────────────────────────────────────────────┐ │ │
│ │ VideoGenerator.get_instance() │ │ │
│ ├───────────────────────────────────────────────┤ │ │
│ │ ┌───────────────────────┐ ┌────────────────┐ │ │
│ │ │ 检查_instance是否为None │ │ 创建新实例 │ │ │
│ │ └───────────────────────┘ └────────┬────────┘ │ │
│ │ │ │ │
│ │ ┌──────────────────────────────────┐ │ │
│ │ │ VideoGenerator.__init__() │ │ │
│ │ │ ┌─────────────────────┐ ┌──────────┐ │ │ │
│ │ │ │ 加载WanI2V模型 │ │ 初始化参数 │ │ │ │
│ │ │ └─────────────────────┘ └──────────┘ │ │ │
│ │ └──────────────────────────────────┘ │ │
│ └───────────────────────────────────────────────┘ │ │
│ │ │
│ ┌───────────────────────────────────────────────┐ │ │
│ │ 后续调用 │ │ │
│ │ VideoGenerator.get_instance() 直接返回已创建实例 │ │ │
│ └───────────────────────────────────────────────┘ │ │
│ │ │
│ ┌───────────────────────────────────────────────┐ │ │
│ │ 推理阶段 │ │ │
│ │ generator.generate() 使用已加载模型 │ │ │
│ └───────────────────────────────────────────────┘ │ │
│ │ │
│ ┌───────────────────────────────────────────────┐ │ │
│ │ 清理阶段 │ │ │
│ │ VideoGenerator.cleanup() 释放模型资源 │ │ │
│ └───────────────────────────────────────────────┘ │ │
└──────────────────────────────────────────────────────────┘
↑ ↑ ↑
│ │ │
首次触发模型加载 复用模型实例 释放模型资源
流程图
WanI2VApp.run()
├─ 加载配置(ConfigLoader)
│ ├─ 读取config.json
│ └─ 设置参数默认值
├─ 验证参数(ArgsValidator)
│ ├─ 检查ckpt_dir/task/size等参数
│ └─ 设置sample_steps/frame_num等默认值
├─ 初始化日志(LoggerInitializer)
│ ├─ 主进程(rank=0)输出INFO日志
│ └─ 从进程输出ERROR日志
├─ 初始化分布式环境(DistributedEnv)
│ ├─ 多GPU时启动NCCL进程组
│ ├─ 验证ulysses_size/ring_size配置
│ ├─ 同步所有进程的随机种子
│ └─ 初始化模型并行环境(如需要)
├─ 加载模型(VideoGenerator单例)
│ ├─ 首次调用get_instance时创建模型
│ ├─ 加载checkpoint到指定GPU
│ └─ 后续调用直接复用模型实例
├─ 处理prompt.json循环(N个prompt)
│ ├─ 读取prompt和image_paths
│ ├─ 遍历每个image_path
│ │ ├─ 打开并转换图像为RGB
│ │ ├─ 生成唯一标识符(时间戳+UUID)
│ │ ├─ 调用模型生成视频(generator.generate)
│ │ └─ 主进程保存视频(VideoSaver)
│ │ ├─ 处理提示词格式(替换非法字符)
│ │ └─ 生成唯一文件名并保存
├─ 分布式同步(dist.barrier)
│ └─ 等待所有进程完成推理
└─ 清理资源
├─ 释放模型实例(del model)
├─ 清空GPU缓存(torch.cuda.empty_cache)
└─ 主进程输出完成日志
说明
-
模型加载流程:
- 单例模式确保模型仅加载一次
- 首次调用
VideoGenerator.get_instance()
时创建模型 - 后续调用直接返回已创建的实例,避免重复加载
-
分布式执行流程:
- 多GPU环境下每个进程独立加载模型副本
- 主进程负责文件IO操作(保存视频)
dist.barrier()
确保所有进程完成后再清理资源
-
文件名生成逻辑:
任务_分辨率_并行度_图片名_提示词_时间戳_UUID.mp4 └─ 示例: i2v-14B_832x480_2_1_img_Dragon_20250526_193349_6e7f8a9b.mp4
-
错误处理机制:
- 图像加载/模型生成/视频保存均有异常捕获
- 错误日志包含进程ID和详细上下文
- 失败任务跳过不影响整体流程
代码
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
import json
import uuid
warnings.filterwarnings('ignore')
import torch
import random
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.utils import cache_video, cache_image, str2bool
class ArgsValidator:
@staticmethod
def validate(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
# The default sampling steps are 40 for image-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40
if args.sample_shift is None:
args.sample_shift = 3.0 if args.size in ["832*480", "480*832"] else 5.0
# The default number of frames are 81.
if args.frame_num is None:
args.frame_num = 81
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
# 验证ulysses_size和ring_size
if args.ulysses_size < 1:
args.ulysses_size = 1
if args.ring_size < 1:
args.ring_size = 1
return args
class ConfigLoader:
@staticmethod
def load_config():
# 从配置文件读取参数
with open('config.json', 'r') as f:
config = json.load(f)
# 创建一个命名空间来存储参数
class ArgsNamespace:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
args = ArgsNamespace(**config)
# 设置默认值
if not hasattr(args, 'ulysses_size'):
args.ulysses_size = 1
if not hasattr(args, 'ring_size'):
args.ring_size = 1
if not hasattr(args, 't5_fsdp'):
args.t5_fsdp = False
if not hasattr(args, 'dit_fsdp'):
args.dit_fsdp = False
if not hasattr(args, 't5_cpu'):
args.t5_cpu = False
if not hasattr(args, 'offload_model'):
args.offload_model = False
return args
class LoggerInitializer:
def __init__(self, rank):
self.rank = rank
def initialize(self):
# logging
if self.rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
class DistributedEnv:
def __init__(self, args):
self.args = args
self.rank = int(os.getenv("RANK", 0))
self.world_size = int(os.getenv("WORLD_SIZE", 1))
self.local_rank = int(os.getenv("LOCAL_RANK", 0))
self.device = self.local_rank
def initialize(self):
if self.args.offload_model is None:
self.args.offload_model = False if self.world_size > 1 else True
logging.info(
f"Process {self.rank}: offload_model is not specified, set to {self.args.offload_model}.")
if self.world_size > 1:
# 设置CUDA设备
torch.cuda.set_device(self.local_rank)
# 初始化NCCL后端分布式环境
logging.info(f"Process {self.rank}: Initializing distributed environment with NCCL backend...")
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=self.rank,
world_size=self.world_size)
logging.info(f"Process {self.rank}: Distributed environment initialized. "
f"Rank: {self.rank}, World size: {self.world_size}, Local rank: {self.local_rank}")
else:
assert not (
self.args.t5_fsdp or self.args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
self.args.ulysses_size > 1 or self.args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
if self.args.ulysses_size > 1 or self.args.ring_size > 1:
assert self.args.ulysses_size * self.args.ring_size == self.world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
logging.info(f"Process {self.rank}: Initializing model parallel environment...")
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=self.args.ring_size,
ulysses_degree=self.args.ulysses_size,
)
logging.info(f"Process {self.rank}: Model parallel environment initialized. "
f"Ulysses size: {self.args.ulysses_size}, Ring size: {self.args.ring_size}")
# 同步所有进程的随机种子
if dist.is_initialized():
base_seed = [self.args.base_seed] if self.rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
self.args.base_seed = base_seed[0]
logging.info(f"Process {self.rank}: Using synchronized seed: {self.args.base_seed}")
return self.args, self.rank, self.device
class VideoGenerator:
_instance = None # 单例实例
@classmethod
def get_instance(cls, args, rank, device):
# 如果实例不存在,创建新实例
if cls._instance is None:
cls._instance = cls(args, rank, device)
return cls._instance
def __init__(self, args, rank, device):
# 初始化只执行一次
self.args = args
self.rank = rank
self.device = device
self.cfg = WAN_CONFIGS[args.task]
# 设置随机种子
torch.manual_seed(args.base_seed + rank)
random.seed(args.base_seed + rank)
# 加载模型
logging.info(f"Process {self.rank}: Creating WanI2V pipeline (first time)...")
self.model = wan.WanI2V(
config=self.cfg,
checkpoint_dir=self.args.ckpt_dir,
device_id=self.device,
rank=self.rank,
t5_fsdp=self.args.t5_fsdp,
dit_fsdp=self.args.dit_fsdp,
use_usp=(self.args.ulysses_size > 1 or self.args.ring_size > 1),
t5_cpu=self.args.t5_cpu,
)
def generate(self, prompt, img):
# 复用已加载的模型进行推理
logging.info(f"Process {self.rank}: Generating video with existing model...")
video = self.model.generate(
prompt,
img,
max_area=MAX_AREA_CONFIGS[self.args.size],
frame_num=self.args.frame_num,
shift=self.args.sample_shift,
sample_solver=self.args.sample_solver,
sampling_steps=self.args.sample_steps,
guide_scale=self.args.sample_guide_scale,
seed=self.args.base_seed,
offload_model=self.args.offload_model)
return video
@classmethod
def cleanup(cls):
# 清理模型资源
if cls._instance and hasattr(cls._instance, 'model'):
logging.info(f"Process {cls._instance.rank}: Releasing model...")
del cls._instance.model
torch.cuda.empty_cache()
logging.info(f"Process {cls._instance.rank}: Model resources cleaned up.")
cls._instance = None
class VideoSaver:
def __init__(self, args, rank):
self.args = args
self.rank = rank
self.cfg = WAN_CONFIGS[args.task]
def save(self, video, current_prompt, image_path, unique_id):
if self.rank != 0:
return
# 使用传入的唯一ID生成文件名
formatted_id = unique_id
# 从image_path提取文件名前缀(保留20个字符)
image_basename = os.path.basename(image_path).split('.')[0][:20] if image_path else ""
# 处理提示词:截取前30个字符,替换非法字符
formatted_prompt = current_prompt.replace(" ", "_").replace("/", "_").replace(":", "_")[:30]
# 替换所有非法文件名字符
illegal_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*', '’', '“', '”', '!', '#', '$', '%', '&', '(', ')', '[', ']', '{', '}', ';', '@', '+', '=', ',', '.']
for char in illegal_chars:
formatted_prompt = formatted_prompt.replace(char, '_')
# 确保文件名不超过255个字符(NTFS和ext4限制)
max_length = 255 - len(f"{self.args.task}_{self.args.size}_{self.args.ulysses_size}_{self.args.ring_size}_{image_basename}_{formatted_id}.mp4")
formatted_prompt = formatted_prompt[:max_length]
# 构建完整文件名
safe_size = self.args.size.replace('*', 'x') if sys.platform == 'win32' else self.args.size
save_file = f"{self.args.task}_{safe_size}_{self.args.ulysses_size}_{self.args.ring_size}_{image_basename}_{formatted_prompt}_{formatted_id}.mp4"
logging.info(f"Process {self.rank}: Saving generated video to: {save_file}")
# 创建输出目录(如果不存在)
output_dir = os.path.dirname(save_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
cache_video(
tensor=video[None],
save_file=save_file,
fps=self.cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
class WanI2VApp:
def __init__(self):
self.args = None
self.rank = 0
self.device = 0
def run(self):
# 加载配置
config_loader = ConfigLoader()
self.args = config_loader.load_config()
# 验证参数
validator = ArgsValidator()
self.args = validator.validate(self.args)
# 初始化日志
logger_initializer = LoggerInitializer(self.rank)
logger_initializer.initialize()
# 初始化分布式环境
dist_env = DistributedEnv(self.args)
self.args, self.rank, self.device = dist_env.initialize()
logging.info(f"Process {self.rank}: Generation job args: {self.args}")
logging.info(f"Process {self.rank}: Generation model config: {WAN_CONFIGS[self.args.task]}")
# 获取单例模型生成器
generator = VideoGenerator.get_instance(self.args, self.rank, self.device)
# 从prompt.json文件读取prompt和image_paths
with open('prompt.json', 'r') as f:
prompt_list = json.load(f)
for prompt_info in prompt_list:
original_prompt = prompt_info["prompt"]
image_paths = prompt_info["image_paths"]
for image_path in image_paths:
logging.info(f"Process {self.rank}: Input prompt: {original_prompt}")
logging.info(f"Process {self.rank}: Input image: {image_path}")
# 生成全局唯一标识符(时间戳+UUID)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3] # 毫秒级时间戳
unique_id = f"{timestamp}_{uuid.uuid4().hex[:8]}" # 组合时间戳和8位UUID
# 打开并处理图像
try:
img = Image.open(image_path).convert("RGB")
except Exception as e:
logging.error(f"Process {self.rank}: Failed to open image {image_path}: {e}")
continue
# 生成视频
try:
video = generator.generate(original_prompt, img)
except Exception as e:
logging.error(f"Process {self.rank}: Failed to generate video for prompt '{original_prompt}': {e}")
continue
# 保存视频
try:
saver = VideoSaver(self.args, self.rank)
saver.save(video, original_prompt, image_path, unique_id)
except Exception as e:
logging.error(f"Process {self.rank}: Failed to save video for prompt '{original_prompt}': {e}")
continue
# 确保所有进程完成后再清理
if dist.is_initialized():
logging.info(f"Process {self.rank}: Waiting for all processes to synchronize...")
dist.barrier()
logging.info(f"Process {self.rank}: All processes synchronized.")
# 每个进程独立清理自身的模型资源
VideoGenerator.cleanup()
if self.rank == 0:
logging.info("All processes completed.")
if __name__ == "__main__":
app = WanI2VApp()
app.run()
WanI2V程序执行流程
┌──────────────────────────────────────────────────────────┐
│ 开始 │
└────────┬─────────────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ WanI2VApp.run() │
└────────┬─────────────────────────────────────────────────┘
▼
┌───────────────────────┬──────────────────────────────────┐
│ 加载配置文件 │ 验证参数合法性 │
└──────────────┬────────┘ └───────────────┬───────────┘
▼ ▼
┌──────────────────────────────────────────────────────────┐
│ 初始化日志系统 │
└────────┬─────────────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ 初始化分布式环境 │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ 设置设备/进程组│ │ 验证并行参数 │ │ 同步随机种子 │ │
│ └───────────────┘ └───────────────┘ └───────────────┘ │
└────────┬─────────────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ 加载WanI2V模型(单例) │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ 检查单例实例 │ ├─ 是 ──→ 直接返回 │ └────┬────────┘ │
│ └──────┬────────┘ │ │ ▼ │
│ ▼ │ │ 创建模型实例 │
│ 否 ────────────────┘ │ 加载checkpoint │
│ └───────────────┘ │
└────────┬─────────────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ 读取prompt.json │
└────────┬─────────────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ 遍历处理prompt列表 │
│ ┌───────────────┐ │
│ │ 取出prompt和 │ │
│ │ image_paths │ │
│ └──────┬────────┘ │
│ ▼ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ 遍历处理图片 │
│ ├───────────────┬───────────────────────────────────┤ │
│ │ 打开图片 │ 生成唯一标识符 │
│ └──────┬────────┘ └────────────┬───────────┘
│ ▼ ▼ │
│ ┌──────────────────┐ ┌──────────────────────────┐ │
│ │ 调用模型生成 │ │ 主进程保存视频文件 │ │
│ │ 视频(复用模型)│ ├────────────┬─────────────┤ │
│ └──────────────────┘ │ 处理提示词格式 │ 生成唯一文件名 │
│ └────────────┬─────────────┘ │
│ ▼ │
│ 保存视频到文件 │
│ └──────────────────────────────────────────────────┘ │
└──────────────┬─────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ 分布式进程同步 │
│ ┌───────────────┐ │
│ │ 检查是否为分布式│ │
│ └──────┬────────┘ │
│ ▼ │
│ 是 ───→ 调用dist.barrier() 等待所有进程 │
│ └──────────────────────────────────────────────────┘ │
└────────┬─────────────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ 清理模型资源 │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ 删除模型实例 │ │ 清空GPU缓存 │ │ 单例置为None │ │
│ └───────────────┘ └───────────────┘ └───────────────┘ │
└────────┬─────────────────────────────────────────────────┘
▼
┌──────────────────────────────────────────────────────────┐
│ 结束 │
└──────────────────────────────────────────────────────────┘
核心模块交互关系
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ ConfigLoader│─────→│ArgsValidator│─────→│LoggerInitializer│
└──────────────┘ └──────────────┘ └──────────────┘
│ │
▼ ▼
┌──────────────────────┐ ┌──────────────────────┐
│ DistributedEnv │─────→│ VideoGenerator │
└──────────────────────┘ └──────────────────────┘
│ │
▼ ▼
┌──────────────────────┐ ┌──────────────────────┐
│ WanI2VApp │←─────│ VideoSaver │
└──────────────────────┘ └──────────────────────┘