论文研读:Tune-a-video — 微调学习单个视频内的物体动作(Arxiv: 2023-03-17)

原文:Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation

1.基本介绍

1.1 Introduction

这里主要分析了文生图模型, 即text-to-image model, T2I model

  • cross-attention

    可根据文本<名词>和<动词>生成语义<图像>的能力

  • self-attention

    能聚焦同一目标, 但无时间连续性

  • T2V model

    图像仅具有空间特征(spatial features),因此:

text-to-video模型需要时序特征(spatio-temporal),具体区别如图:

在这里插入图片描述

  • DDIM Inversion

只有spatio-temporal是不够的,时间的先后顺序,即连续性存在问题,为了改进这个问题:

将时序特征通过DDIM加噪(inversion)的方式,嵌入latent 特征,以保证特征去噪后具有更好的时间连续性。

1.2 Related Work

这里主要说了当前T2V模型(例如CogView)需要大量的视频进行训练,消耗较大,

而本方法仅通过一个<文本-视频>对, 即可将T2I模型转换为T2V模型

1.3 方法概述 (High-level Overview)

通过一个video的frams,将 T2I 模型微调成 T2V模型, 该模型记住了这个video的动作,可以生成类似动作

如图所示:

在这里插入图片描述

2. 方法介绍—微调

需要将video的frames 通过 DDIM Inversion (加噪)为噪为 LDM 的 噪声特征,连同 text 送入 SD.

在SD的U-net中,更新self-attention (ST-Attn),cross-attention (Cross-Attn), 以及一个新增的处理视频的 temporal self-attention (T-Attn)

2.1 处理self-attention (AT1):

将第一帧和后续帧concat,作为attention的可学习矩阵 W V W^V WV, W K W^K WK 完成 value 和 key的运算。

另外将后一帧作为可学习矩阵 W Q W^Q WQ, 作为self-attention的 query

具体如图所示:

在这里插入图片描述

这里,为了减少计算量,以第一帧为锚定,计算v和k,q仅包括第i帧, 且仅学习 W Q W^Q WQ

KaTeX parse error: Unexpected character: '' at position 26: …v_i} , K = W^K ̲[z_{v_1} , z_{v…

在源代码中,该模块改自diffusers.models.attention的CrossAttention

2.2 处理cross-attention (AT2)

这是文生图部分,即将文本通过clip找到图像对应的特征,嵌入 W Q W^Q WQ 并更新该矩阵

在源代码中,该模块直接使用diffusers.models.attention的CrossAttention

2.3 处理 temporal self-attention (AT3)

这一步学习全局时序特征,将AT1学到的时序特征输入到AT2,学习文本到video特征,再将AT2 的输出送到AT3,

AT3,本质是一个self-attention,同时更新Q,K,V三个特征矩阵( W Q , W K , W V W^Q, W^K,W^V WQ,WK,WV)。

在源代码中,该模块直接使用diffusers.models.attention的CrossAttention

最后的方法整体结构如下(Pipeline):
在这里插入图片描述

原文部分摘录:

It is com posed of stacked 2D convolutional residual blocks and transformer blocks.

Each transformer block consists of a spatial self-attention layer, a cross-attention layer, and a feed-forward network (FFN).

The spatio temporal attention (ST-Attn) is designed to model temporal consistency by querying relevant positions in previous
frames. Therefore, we propose to fix parameters W^K and W^V , and only update W^Q in ST-Attn layers.

Moreover, we propose to refine the text-video alignment by updating the query projection in cross-attention (Cross-Attn).

In contrast, we finetune the entire temporal self-attention (T-Attn) layers as they are newly added.

3.实验

3.1 微调设置

  • 设备 A100 GPU(这款GPU有40g和80g两个显存版本,估计是80g)
  • 训练输入: 32 frames (512x512),
  • 训练时间:500 次迭代, 约10min
  • 学习率:3e-5-
  • Batch_size: 1
  • 推理方法: DDIM, classifier-free guidance, 约1min

3.2 推理设置(评估)

从 DAVIS 数据集 中选择了 42 个视频,涵盖了动物、车辆和人类等一系列类别。选定的视频项目列在表 2 中。

为了获取视频片段,使用 BLIP-2 进行自动字幕 生成(off-the-shelf)。

为每个视频手动设计三个编辑提示,总共得到 140 个编辑提示。这些编辑提示包括对象编辑、背景更改和风格转换。

在这里插入图片描述

4. 实验结果

4.1 消融实验

验证 AT1 与 DDIM Inversion的效果:

在这里插入图片描述

4.2 方法局限

图像中存在多个目标主体时,数量会出错,这个也是SD常见问题:
在这里插入图片描述

4.3 baseline比较

在这里插入图片描述

4.4 方法整体效果

  • 改变text:

在这里插入图片描述

  • 与conditional T2I model可以无缝结合(T2I-Adapter 与 ControlNet, 最后一行) :

在这里插入图片描述

Ref:

  • Jonathan Ho, Tim Salimans, Alexey Gritsenko, William Chan, Mohammad Norouzi, and David J Fleet. Video diffusion models. arXiv:2204.03458, 2022
  • https://github.com/showlab/Tune-A-Video
  • https://github.com/bryandlee/Tune-A-Video
  • https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/models/attention.py#L272
以下是用ffmpeg编码h264并用srs-librtmp推流的C++例子: ```cpp #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <fcntl.h> #include <errno.h> #include <sys/types.h> #include <sys/stat.h> #include <sys/time.h> #include <time.h> #include <stdint.h> #include <math.h> #include <signal.h> extern "C" { #include <libavcodec/avcodec.h> #include <libavformat/avformat.h> #include <libavutil/opt.h> #include <libavutil/imgutils.h> #include <libavutil/mathematics.h> #include <libavutil/time.h> #include <libswscale/swscale.h> #include <libswresample/swresample.h> #include <librtmp/rtmp.h> #include <librtmp/log.h> } #define W 640 #define H 480 #define FPS 30 #define BITRATE 500000 AVFormatContext *pFormatCtx; AVOutputFormat *pOutputFmt; AVStream *pVideoStream; AVCodecContext *pVideoCodecCtx; AVCodec *pVideoCodec; AVFrame *pVideoFrame; AVPacket videoPkt; uint8_t *videoBuf; SwsContext *pImgConvertCtx; RTMP *pRtmp; RTMPPacket rtmpPkt; int64_t pts = 0; void signal_handler(int signo) { if (signo == SIGINT) { printf("Got SIGINT signal, exiting...\n"); exit(1); } } void init_video_codec() { pVideoCodec = avcodec_find_encoder(AV_CODEC_ID_H264); if (!pVideoCodec) { fprintf(stderr, "Failed to find H.264 codec\n"); exit(1); } pVideoCodecCtx = avcodec_alloc_context3(pVideoCodec); if (!pVideoCodecCtx) { fprintf(stderr, "Failed to allocate H.264 codec context\n"); exit(1); } pVideoCodecCtx->bit_rate = BITRATE; pVideoCodecCtx->width = W; pVideoCodecCtx->height = H; pVideoCodecCtx->time_base.num = 1; pVideoCodecCtx->time_base.den = FPS; pVideoCodecCtx->gop_size = FPS * 2; pVideoCodecCtx->max_b_frames = 1; pVideoCodecCtx->pix_fmt = AV_PIX_FMT_YUV420P; if (pOutputFmt->flags & AVFMT_GLOBALHEADER) pVideoCodecCtx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER; av_opt_set(pVideoCodecCtx->priv_data, "preset", "ultrafast", 0); av_opt_set(pVideoCodecCtx->priv_data, "tune", "zerolatency", 0); if (avcodec_open2(pVideoCodecCtx, pVideoCodec, NULL) < 0) { fprintf(stderr, "Failed to open H.264 codec\n"); exit(1); } pVideoFrame = av_frame_alloc(); if (!pVideoFrame) { fprintf(stderr, "Failed to allocate video frame\n"); exit(1); } pVideoFrame->format = pVideoCodecCtx->pix_fmt; pVideoFrame->width = pVideoCodecCtx->width; pVideoFrame->height = pVideoCodecCtx->height; if (av_frame_get_buffer(pVideoFrame, 32) < 0) { fprintf(stderr, "Failed to allocate video frame buffer\n"); exit(1); } videoBuf = (uint8_t*)av_malloc(av_image_get_buffer_size(pVideoCodecCtx->pix_fmt, pVideoCodecCtx->width, pVideoCodecCtx->height, 1)); if (!videoBuf) { fprintf(stderr, "Failed to allocate video buffer\n"); exit(1); } av_image_fill_arrays(pVideoFrame->data, pVideoFrame->linesize, videoBuf, pVideoCodecCtx->pix_fmt, pVideoCodecCtx->width, pVideoCodecCtx->height, 1); } void init_sws_context() { pImgConvertCtx = sws_getContext(pVideoCodecCtx->width, pVideoCodecCtx->height, AV_PIX_FMT_BGR24, pVideoCodecCtx->width, pVideoCodecCtx->height, pVideoCodecCtx->pix_fmt, SWS_BICUBIC, NULL, NULL, NULL); if (!pImgConvertCtx) { fprintf(stderr, "Failed to create SwsContext\n"); exit(1); } } void init_output() { avformat_alloc_output_context2(&pFormatCtx, NULL, "flv", NULL); if (!pFormatCtx) { fprintf(stderr, "Failed to allocate output context\n"); exit(1); } pOutputFmt = pFormatCtx->oformat; if (pOutputFmt->video_codec == AV_CODEC_ID_NONE) { fprintf(stderr, "Failed to find suitable video codec\n"); exit(1); } init_video_codec(); init_sws_context(); pVideoStream = avformat_new_stream(pFormatCtx, pVideoCodec); if (!pVideoStream) { fprintf(stderr, "Failed to allocate video stream\n"); exit(1); } pVideoStream->id = 0; pVideoStream->time_base.num = 1; pVideoStream->time_base.den = FPS; pVideoStream->codecpar->codec_id = pVideoCodec->id; pVideoStream->codecpar->codec_type = AVMEDIA_TYPE_VIDEO; pVideoStream->codecpar->width = W; pVideoStream->codecpar->height = H; pVideoStream->codecpar->format = pVideoCodecCtx->pix_fmt; if (avformat_write_header(pFormatCtx, NULL) < 0) { fprintf(stderr, "Failed to write header\n"); exit(1); } } void init_rtmp(const char *url) { RTMP_LogSetLevel(RTMP_LOGDEBUG); pRtmp = RTMP_Alloc(); if (!pRtmp) { fprintf(stderr, "Failed to allocate RTMP object\n"); exit(1); } if (!RTMP_Init(pRtmp)) { fprintf(stderr, "Failed to initialize RTMP object\n"); exit(1); } if (!RTMP_SetupURL(pRtmp, (char*)url)) { fprintf(stderr, "Failed to set RTMP URL\n"); exit(1); } RTMP_EnableWrite(pRtmp); if (!RTMP_Connect(pRtmp, NULL)) { fprintf(stderr, "Failed to connect to RTMP server\n"); exit(1); } if (!RTMP_ConnectStream(pRtmp, 0)) { fprintf(stderr, "Failed to connect to RTMP stream\n"); exit(1); } } void cleanup() { av_write_trailer(pFormatCtx); avcodec_free_context(&pVideoCodecCtx); av_frame_free(&pVideoFrame); av_free(videoBuf); sws_freeContext(pImgConvertCtx); RTMP_Close(pRtmp); RTMP_Free(pRtmp); avformat_free_context(pFormatCtx); } void encode_frame() { AVFrame *pFrameBGR24 = av_frame_alloc(); if (!pFrameBGR24) { fprintf(stderr, "Failed to allocate BGR24 frame\n"); exit(1); } pFrameBGR24->format = AV_PIX_FMT_BGR24; pFrameBGR24->width = W; pFrameBGR24->height = H; if (av_frame_get_buffer(pFrameBGR24, 32) < 0) { fprintf(stderr, "Failed to allocate BGR24 frame buffer\n"); exit(1); } // generate test pattern for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) { uint8_t r, g, b; if (x < W / 2) { r = 255 * x * 2 / W; g = 255 * y / H; b = 255 - 255 * x * 2 / W; } else { r = 255 - 255 * (x - W / 2) * 2 / W; g = 255 - 255 * y / H; b = 255 * (x - W / 2) * 2 / W; } pFrameBGR24->data[0][y * pFrameBGR24->linesize[0] + x * 3] = b; pFrameBGR24->data[0][y * pFrameBGR24->linesize[0] + x * 3 + 1] = g; pFrameBGR24->data[0][y * pFrameBGR24->linesize[0] + x * 3 + 2] = r; } } sws_scale(pImgConvertCtx, pFrameBGR24->data, pFrameBGR24->linesize, 0, H, pVideoFrame->data, pVideoFrame->linesize); av_init_packet(&videoPkt); videoPkt.data = NULL; videoPkt.size = 0; pVideoFrame->pts = pts++; int ret = avcodec_send_frame(pVideoCodecCtx, pVideoFrame); if (ret < 0) { fprintf(stderr, "Failed to send video frame: %s\n", av_err2str(ret)); exit(1); } while (ret >= 0) { ret = avcodec_receive_packet(pVideoCodecCtx, &videoPkt); if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) break; else if (ret < 0) { fprintf(stderr, "Failed to receive video packet: %s\n", av_err2str(ret)); exit(1); } rtmpPkt.m_nChannel = 0x04; rtmpPkt.m_headerType = RTMP_PACKET_TYPE_VIDEO; rtmpPkt.m_nTimeStamp = videoPkt.pts * 1000 / FPS; rtmpPkt.m_nBodySize = videoPkt.size; rtmpPkt.m_nInfoField2 = pRtmp->m_stream_id; rtmpPkt.m_body = videoPkt.data; if (!RTMP_IsConnected(pRtmp)) break; if (!RTMP_SendPacket(pRtmp, &rtmpPkt, TRUE)) { fprintf(stderr, "Failed to send RTMP packet\n"); exit(1); } av_packet_unref(&videoPkt); } av_frame_free(&pFrameBGR24); } int main(int argc, char **argv) { if (argc < 2) { printf("Usage: %s <rtmp url>\n", argv[0]); exit(1); } init_output(); init_rtmp(argv[1]); signal(SIGINT, signal_handler); while (1) { encode_frame(); usleep(1000000 / FPS); } cleanup(); return 0; } ``` 这个例子会生成一个测试图案并用ffmpeg编码成H.264,然后用srs-librtmp推流到指定的RTMP服务器上。你需要将代码中的推流地址替换为你自己的RTMP服务器地址。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值