制作视频 - Pytorch(wip)
Make-A-Video的实现,新的 SOTA 文本到来自 Meta AI 的视频生成器,在 Pytorch 中。它们结合了伪 3d 卷积(轴向卷积)和时间注意力,并显示出更好的时间融合。
伪 3d 卷积并不是一个新概念。之前在其他情况下已经对其进行了探索,例如将蛋白质接触预测称为“维度混合残差网络”。
这篇论文的要点归结为,采用 SOTA 文本到图像模型(这里他们使用 DALL-E2,但相同的学习点很容易适用于 Imagen),跨时间和其他方式对注意力进行一些小的修改节省计算成本,正确进行帧插值,获得出色的视频模型。
欣赏
-
Stability.ai慷慨赞助尖端人工智能研究
-
Jonathan Ho通过他的开创性论文带来了生成人工智能的革命
安装
$ pip install make-a-video-pytorch
用法
传入视频功能
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)
传入图像(如果先对图像进行预训练),时间卷积和注意力都将被自动跳过。换句话说,您可以直接在 2d Unet 中使用它,然后在该训练阶段完成后将其移植到 3d Unet。时间模块被初始化为输出身份,就像论文所做的那样。
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)
conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)
您还可以控制这两个模块,以便在输入 3 维特征时,它只进行空间训练
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
# below it will not train across time
conv_out = conv(video, enable_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, enable_time = False) # (1, 256, 8, 16, 16)
fullSpaceTimeUnet
与图像或视频训练无关,即使传入视频,也可以忽略时间
import torch
from make_a_video_pytorch import SpaceTimeUnet
unet = SpaceTimeUnet(
dim = 64,
channels = 3,
dim_mult = (1, 2, 4, 8),
temporal_compression = (False, False, False, True),
self_attns = (False, False, False, True),
condition_on_timestep = False
).cuda()
# train on images
images = torch.randn(1, 3, 128, 128).cuda()
images_out = unet(images)
assert images.shape == images_out.shape
# then train on videos
video = torch.randn(1, 3, 16, 128, 128).cuda()
video_out = unet(video)
assert video_out.shape == video.shape
# or even treat your videos as images
video_as_images_out = unet(video, enable_time = False)
去做
-
注意最好的位置嵌入研究必须提供的
-
提高注意力
-
确保 dalle2-pytorch 可以接受
SpaceTimeUnet
训练