1、数据切分、音乐特征生成
(1)输入数据:音乐数据.wav ,及对应 smpl 结构动作数据 .pkl
(2)根据 splits 下 txt 进行数据训练、测试、评估数据切分
(3)对数据进行扩充、切分 以 stride=0.5, length=5 每段间隔0.5s 长度为 5s 切分数据
(4)使用 jukebox 【jukebox 音乐生成算法,使用不同大小的多级vq-vae网络,生成时输入数据:随机数+音乐家、音乐风等信息,由上到下级联结合后decoder结果】读取音乐特征,并保存 npy 数据 (秒数 * 音乐帧数,4800)
(本方法:使用读取音乐特征 -> vq-vae 的encoder 层 -> 增加默认的条件输入 -> 使用由上到下的级联结构 输出特征(decoder 前结果))
如果使用自己的数据注意:(如果动作数据非 60 fps 需要修改)
1、在slice.py 38 、39 行动作数据的 FPS 默认为60,改成自己的 fps
2、音乐数据 jukebox 特征提取会下采样到 30 FPS
3、dataset/dance_dataset.py self.raw_fps = 60 self.data_fps = 30 需要调整,默认会以步长为2的方式取动作数据,如果自己fps=30 改 self.raw_fps = 30
create_dataset.py
def create_dataset(opt):
# split the data according to the splits files
print("Creating train / test split")
# 更改字典名+移动位置 train、test
split_data(opt.dataset_folder)
# slice motions/music into sliding windows to create training dataset
# 切分训练、测试数据集 stride=0.5, length=5 秒
print("Slicing train data")
slice_aistpp(f"train/motions", f"train/wavs")
print("Slicing test data")
slice_aistpp(f"test/motions", f"test/wavs")
# process dataset to extract audio features
# jukemirlib 特征提取
if opt.extract_baseline:
print("Extracting baseline features")
baseline_extract("train/wavs_sliced", "train/baseline_feats")
baseline_extract("test/wavs_sliced", "test/baseline_feats")
if opt.extract_jukebox:
print("Extracting jukebox features")
jukebox_extract("train/wavs_sliced", "train/jukebox_feats")
jukebox_extract("test/wavs_sliced", "test/jukebox_feats")
2、数据预处理
(1) 读取npy 音乐特征(n*4800)、pkl smpl 平移,旋转数据、并根据 音乐fps、data_fps 裁剪 smpl 平移,旋转数据
(2)对 smpl 平移(n*24*3)、旋转(n*3)数据
1)从y轴向上旋转到z轴向上
2)使用 ax_to_6v 函数将关节的轴角平移表示转换为6D表示 (n*24*6)
3)生成3d点位置,检测脚部关节点是否接近接触。(n*4)
4)拼接 2) 3)平移 3种数据、然后展平(n*151)、经过 normalizer 标准化
dance_dataset.py
def load_aistpp(self):
# open data path
split_data_path = os.path.join(
self.data_path, "train" if self.train else "test"
)
# Structure:
# data
# |- train
# | |- motion_sliced
# | |- wav_sliced
# | |- baseline_features
# | |- jukebox_features
# | |- motions
# | |- wavs
motion_path = os.path.join(split_data_path, "motions_sliced")
sound_path = os.path.join(split_data_path, f"{self.feature_type}_feats")
wav_path = os.path.join(split_data_path, f"wavs_sliced")
# sort motions and sounds
motions = sorted(glob.glob(os.path.join(motion_path, "*.pkl")))
features = sorted(glob.glob(os.path.join(sound_path, "*.npy")))
wavs = sorted(glob.glob(os.path.join(wav_path, "*.wav")))
# stack the motions and features together
all_pos = []
all_q = []
all_names = []
all_wavs = []
assert len(motions) == len(features)
w = []
for motion, feature, wav in zip(motions, features, wavs):
# make sure name is matching
m_name = os.path.splitext(os.path.basename(motion))[0]
f_name = os.path.splitext(os.path.basename(feature))[0]
w_name = os.path.splitext(os.path.basename(wav))[0]
assert m_name == f_name == w_name, str((motion, feature, wav))
# load motion
data = pickle.load(open(motion, "rb"))
pos = data["pos"]
q = data["q"]
local_q = torch.Tensor(q)
if torch.isnan(local_q).any():
continue
all_pos.append(pos)
all_q.append(q)
all_names.append(feature)
all_wavs.append(wav)
all_pos = np.array(all_pos) # N x seq x 3
all_q = np.array(all_q) # N x seq x (joint * 3)
# downsample the motions to the data fps
#
all_pos = all_pos[:, :: self.data_stride, :]
all_q = all_q[:, :: self.data_stride, :]
data = {"pos": all_pos, "q": all_q, "filenames": all_names, "wavs": all_wavs}
return data
def process_dataset(self, root_pos, local_q): # 3, 72
# FK skeleton
smpl = SMPLSkeleton()
# to Tensor
root_pos = torch.Tensor(root_pos)
local_q = torch.Tensor(local_q)
assert not torch.isnan(root_pos).any()
assert not torch.isnan(local_q).any()
# to ax
bs, sq, c = local_q.shape
local_q = local_q.reshape((bs, sq, -1, 3)) # b,l,24,3
# AISTPP dataset comes y-up - rotate to z-up to standardize against the pretrain dataset
# 说明AISTPP数据集使用y轴作为向上方向,而预训练的数据集使用z轴作为向上方向,中的根关节点从y轴向上旋转到z轴向上
root_q = local_q[:, :, :1, :] # sequence x 1 x 3
root_q_quat = axis_angle_to_quaternion(root_q)
rotation = torch.Tensor(
[0.7071068, 0.7071068, 0, 0]
) # 90 degrees about the x axis
root_q_quat = quaternion_multiply(rotation, root_q_quat)
root_q = quaternion_to_axis_angle(root_q_quat)
local_q[:, :, :1, :] = root_q
#
# don't forget to rotate the root position too 😩
# 更改旋转点解
pos_rotation = RotateAxisAngle(90, axis="X", degrees=True)
root_pos = pos_rotation.transform_points(
root_pos
) # basically (y, z) -> (-z, y), expressed as a rotation for readability
# do FK
# 转为 绝对坐标 人体模型的关节点数量,3是三维坐标(x, y, z)。
positions = smpl.forward(local_q, root_pos) # batch x sequence x 24 x 3
feet = positions[:, :, (7, 8, 10, 11)]
feetv = torch.zeros(feet.shape[:3])
feetv[:, :-1] = (feet[:, 1:] - feet[:, :-1]).norm(dim=-1)
contacts = (feetv < 0.01).to(local_q) # cast to right dtype
# to 6d
# 使用ax_to_6v函数将关节的轴角表示(或其他表示)转换为6D表示。6D表示是一种用于表示3D旋转的常用方法,它避免了四元数的一些潜在问题(如万向锁)。
local_q = ax_to_6v(local_q)
# now, flatten everything into: batch x sequence x [...]
# 4 + 3 + 24*6
l = [contacts, root_pos, local_q]
# resize b * l * 151
global_pose_vec_input = vectorize_many(l).float().detach()
# normalize the data. Both train and test need the same normalizer.
if self.train:
self.normalizer = Normalizer(global_pose_vec_input)
else:
assert self.normalizer is not None
global_pose_vec_input = self.normalizer.normalize(global_pose_vec_input)
assert not torch.isnan(global_pose_vec_input).any()
data_name = "Train" if self.train else "Test"
# cut the dataset
if self.data_len > 0:
global_pose_vec_input = global_pose_vec_input[: self.data_len]
global_pose_vec_input = global_pose_vec_input
print(f"{data_name} Dataset Motion Features Dim: {global_pose_vec_input.shape}")
return global_pose_vec_input
(3)训练:加载 dataloader 、训练、计算loss 、反向传播
EDGE.py
def train_loop(self, opt):
# load datasets
# 读取数据集
train_tensor_dataset_path = os.path.join(
opt.processed_data_dir, f"train_tensor_dataset.pkl"
)
test_tensor_dataset_path = os.path.join(
opt.processed_data_dir, f"test_tensor_dataset.pkl"
)
if (
not opt.no_cache
and os.path.isfile(train_tensor_dataset_path)
and os.path.isfile(test_tensor_dataset_path)
):
train_dataset = pickle.load(open(train_tensor_dataset_path, "rb"))
test_dataset = pickle.load(open(test_tensor_dataset_path, "rb"))
else:
train_dataset = AISTPPDataset(
data_path=opt.data_path,
backup_path=opt.processed_data_dir,
train=True,
force_reload=opt.force_reload,
)
test_dataset = AISTPPDataset(
data_path=opt.data_path,
backup_path=opt.processed_data_dir,
train=False,
normalizer=train_dataset.normalizer,
force_reload=opt.force_reload,
)
# cache the dataset in case
if self.accelerator.is_main_process:
pickle.dump(train_dataset, open(train_tensor_dataset_path, "wb"))
pickle.dump(test_dataset, open(test_tensor_dataset_path, "wb"))
# set normalizer
self.normalizer = test_dataset.normalizer
# data loaders
# decide number of workers based on cpu count
num_cpus = multiprocessing.cpu_count()
train_data_loader = DataLoader(
train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=min(int(num_cpus * 0.75), 32),
pin_memory=True,
drop_last=True,
)
test_data_loader = DataLoader(
test_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True,
)
train_data_loader = self.accelerator.prepare(train_data_loader)
# boot up multi-gpu training. test dataloader is only on main process
# 主进程 显示进度条,其他进程不显示
load_loop = (
partial(tqdm, position=1, desc="Batch")
if self.accelerator.is_main_process
else lambda x: x
)
#
if self.accelerator.is_main_process:
save_dir = str(increment_path(Path(opt.project) / opt.exp_name))
opt.exp_name = save_dir.split("/")[-1]
# 初始化 wandb
# wandb.init(project=opt.wandb_pj_name, name=opt.exp_name)
# 创建权重保存目录
save_dir = Path(save_dir)
wdir = save_dir / "weights"
wdir.mkdir(parents=True, exist_ok=True)
self.accelerator.wait_for_everyone()
# 开始循环 epochs
for epoch in range(1, opt.epochs + 1):
avg_loss = 0
avg_vloss = 0
avg_fkloss = 0
avg_footloss = 0
# train
self.train()
for step, (x, cond, filename, wavnames) in enumerate(load_loop(train_data_loader)
):
# 0, x[32, 75, 151] 动作,cond[[32, 150, 4800] 音乐编码,.npy 文件名,wav文件名
total_loss, (loss, v_loss, fk_loss, foot_loss) = self.diffusion(
x, cond, t_override=None
)
self.optim.zero_grad()
self.accelerator.backward(total_loss)
self.optim.step()
# ema update and train loss update only on main
if self.accelerator.is_main_process:
avg_loss += loss.detach().cpu().numpy()
avg_vloss += v_loss.detach().cpu().numpy()
avg_fkloss += fk_loss.detach().cpu().numpy()
avg_footloss += foot_loss.detach().cpu().numpy()
if step % opt.ema_interval == 0:
self.diffusion.ema.update_model_average(
self.diffusion.master_model, self.diffusion.model
)
# Save model
if (epoch % opt.save_interval) == 0:
# everyone waits here for the val loop to finish ( don't start next train epoch early)
self.accelerator.wait_for_everyone()
# save only if on main thread
if self.accelerator.is_main_process:
self.eval()
# log
avg_loss /= len(train_data_loader)
avg_vloss /= len(train_data_loader)
avg_fkloss /= len(train_data_loader)
avg_footloss /= len(train_data_loader)
log_dict = {
"Train Loss": avg_loss,
"V Loss": avg_vloss,
"FK Loss": avg_fkloss,
"Foot Loss": avg_footloss,
}
wandb.log(log_dict)
ckpt = {
"ema_state_dict": self.diffusion.master_model.state_dict(),
"model_state_dict": self.accelerator.unwrap_model(
self.model
).state_dict(),
"optimizer_state_dict": self.optim.state_dict(),
"normalizer": self.normalizer,
}
torch.save(ckpt, os.path.join(wdir, f"train-{epoch}.pt"))
# generate a sample
render_count = 2
shape = (render_count, self.horizon, self.repr_dim)
print("Generating Sample")
# draw a music from the test dataset
(x, cond, filename, wavnames) = next(iter(test_data_loader))
cond = cond.to(self.accelerator.device)
self.diffusion.render_sample(
shape,
cond[:render_count],
self.normalizer,
epoch,
os.path.join(opt.render_dir, "train_" + opt.exp_name),
name=wavnames[:render_count],
sound=True,
)
print(f"[MODEL SAVED at Epoch {epoch}]")
if self.accelerator.is_main_process:
wandb.run.finish()
diffusion.py 训练部分
def p_losses(self, x_start, cond, t):
# b,l,151 , b,l,4800 , b,(0-1000)
# 生成 和 x_start 相同维度的正态分布随机数
noise = torch.randn_like(x_start)
# x * a + n * (1-a) 数据按比例增加 噪声
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
# reconstruct
# 通过 噪声,音乐,步长 输出 去噪后的预测数据
x_recon = self.model(x_noisy, cond, t, cond_drop_prob=self.cond_drop_prob)
assert noise.shape == x_recon.shape
model_out = x_recon
if self.predict_epsilon:
target = noise
else:
target = x_start
# full reconstruction loss
# 计算 均方差损失
# 整合损失取 mean ,然后损失加权
loss = self.loss_fn(model_out, target, reduction="none")
loss = reduce(loss, "b ... -> b (...)", "mean")
loss = loss * extract(self.p2_loss_weight, t, loss.shape)
# split off contact from the rest
# 切分预测结果、标签
model_contact, model_out = torch.split(
model_out, (4, model_out.shape[2] - 4), dim=2
)
target_contact, target = torch.split(target, (4, target.shape[2] - 4), dim=2)
# velocity loss
# 计算运动速度损失
target_v = target[:, 1:] - target[:, :-1]
model_out_v = model_out[:, 1:] - model_out[:, :-1]
v_loss = self.loss_fn(model_out_v, target_v, reduction="none")
v_loss = reduce(v_loss, "b ... -> b (...)", "mean")
v_loss = v_loss * extract(self.p2_loss_weight, t, v_loss.shape)
# FK loss
b, s, c = model_out.shape
# unnormalize
# model_out = self.normalizer.unnormalize(model_out)
# target = self.normalizer.unnormalize(target)
# X, Q
model_x = model_out[:, :, :3]
model_q = ax_from_6v(model_out[:, :, 3:].reshape(b, s, -1, 6))
target_x = target[:, :, :3]
target_q = ax_from_6v(target[:, :, 3:].reshape(b, s, -1, 6))
# perform FK
# 计算 3d 点数据 均方差损失
model_xp = self.smpl.forward(model_q, model_x)
target_xp = self.smpl.forward(target_q, target_x)
fk_loss = self.loss_fn(model_xp, target_xp, reduction="none")
fk_loss = reduce(fk_loss, "b ... -> b (...)", "mean")
fk_loss = fk_loss * extract(self.p2_loss_weight, t, fk_loss.shape)
# foot skate loss
# 计算 脚部 损失,防止脚滑步
foot_idx = [7, 8, 10, 11]
# find static indices consistent with model's own predictions
static_idx = model_contact > 0.95 # N x S x 4
model_feet = model_xp[:, :, foot_idx] # foot positions (N, S, 4, 3)
model_foot_v = torch.zeros_like(model_feet)
model_foot_v[:, :-1] = (
model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :]
) # (N, S-1, 4, 3)
model_foot_v[~static_idx] = 0
# 由于我们想要惩罚任何非零速度(即滑动),因此目标速度设置为零。
foot_loss = self.loss_fn(
model_foot_v, torch.zeros_like(model_foot_v), reduction="none"
)
foot_loss = reduce(foot_loss, "b ... -> b (...)", "mean")
# 损失集合
losses = (
0.636 * loss.mean(),
2.964 * v_loss.mean(),
0.646 * fk_loss.mean(),
10.942 * foot_loss.mean(),
)
return sum(losses), losses
model.py self.model()方法 训练部分
def forward(
self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
):
batch_size, device = x.shape[0], x.device
# project to latent space
# Linear 特征编码 151 - 512
x = self.input_projection(x)
# add the positional embeddings of the input sequence to provide temporal information
# x + pr 位置编码 [8, 150, 512]
# 没使用
x = self.abs_pos_encoding(x)
# create music conditional embedding with conditional dropout
# (batch_size,) 的 0 矩阵
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
# (batch_size,1,1)
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
keep_mask_hidden = rearrange(keep_mask, "b -> b 1") # (batch_size,1)
# Linear 音乐编码 4800 - 512
cond_tokens = self.cond_projection(cond_embed)
# encode tokens # x + pr 位置编码 没使用
cond_tokens = self.abs_pos_encoding(cond_tokens)
# 使用标准 transformer 编码 2层
cond_tokens = self.cond_encoder(cond_tokens) # [8, 150, 512]
# 可学习参数
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
# 如果 keep_mask_embed 在某个位置上是 True(或非零值),则保留 cond_tokens 在该位置的值;否则,将 cond_tokens 在该位置的值替换为 null_cond_embed。
# 这样,你就可以基于条件丢弃概率来选择性地保留或“丢弃”某些条件信息。
cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed) # [8, 150, 512]
mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) # 平均池化 [8, 512]
cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) # [8, 512]
# create the diffusion timestep embedding, add the extra music projection
# 将时间步长times转换为隐藏表示。
t_hidden = self.time_mlp(times)
# project to attention and FiLM conditioning
# 将时间步长隐藏表示分别投影到时间条件和时间令牌。
t = self.to_time_cond(t_hidden)
t_tokens = self.to_time_tokens(t_hidden)
# FiLM conditioning
# 将条件隐藏表示添加到时间条件中,进行FiLM(Feature-wise Linear Modulation)条件化。
null_cond_hidden = self.null_cond_hidden.to(t.dtype)
cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
t += cond_hidden # 时间 + 音乐特征
# cross-attention conditioning
# 交叉注意力机制 音乐特征+时间特征
c = torch.cat((cond_tokens, t_tokens), dim=-2)
cond_tokens = self.norm_cond(c) # 归一化
# Pass through the transformer decoder
# attending to the conditional embedding
# 输入x,同时关注音乐条件cond_tokens 和 时间条件t。
# 这是一个自定义的Transformer解码器层的实现,其中包含了自注意力(self-attention)、跨注意力(cross-attention,也常被称为编码器-解码器注意力encoder-decoder attention)和前馈神经网络(feedforward network)三个主要部分。
# 此外,这个实现还引入了Feature-wise Linear Modulation (FiLM) 层来对这些主要部分进行条件化的调制,并使用了Layer Normalization(层归一化)和Dropout。
output = self.seqTransDecoder(x, cond_tokens, t)
# 512 - out size
output = self.final_layer(output)
return output
3、推理
(1)定义输出长度 sample_size
(2)检索音乐、并使用jukebox 读取特征
(3)初始化 EDGE 模型,循环处理每个音乐特征进行推理
import glob
import os
from functools import cmp_to_key
from pathlib import Path
from tempfile import TemporaryDirectory
import random
import jukemirlib
import numpy as np
import torch
from tqdm import tqdm
from args import parse_test_opt
from data.slice import slice_audio
from EDGE import EDGE
from data.audio_extraction.baseline_features import extract as baseline_extract
from data.audio_extraction.jukebox_features import extract as juke_extract
# sort filenames that look like songname_slice{number}.ext
key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
def stringintcmp_(a, b):
aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
ka, kb = key_func(a), key_func(b)
if aa < bb:
return -1
if aa > bb:
return 1
if ka < kb:
return -1
if ka > kb:
return 1
return 0
stringintkey = cmp_to_key(stringintcmp_)
def test(opt):
# 读取数据类型 jukebox
feature_func = juke_extract if opt.feature_type == "jukebox" else baseline_extract
# 定义输出长度
sample_length = opt.out_length
sample_size = int(sample_length / 2.5) - 1
temp_dir_list = []
all_cond = []
all_filenames = []
if opt.use_cached_features:
print("Using precomputed features")
# all subdirectories
dir_list = glob.glob(os.path.join(opt.feature_cache_dir, "*/"))
for dir in dir_list:
file_list = sorted(glob.glob(f"{dir}/*.wav"), key=stringintkey)
juke_file_list = sorted(glob.glob(f"{dir}/*.npy"), key=stringintkey)
assert len(file_list) == len(juke_file_list)
# random chunk after sanity check
rand_idx = random.randint(0, len(file_list) - sample_size)
file_list = file_list[rand_idx : rand_idx + sample_size]
juke_file_list = juke_file_list[rand_idx : rand_idx + sample_size]
cond_list = [np.load(x) for x in juke_file_list]
all_filenames.append(file_list)
all_cond.append(torch.from_numpy(np.array(cond_list)))
else:
print("Computing features for input music")
# 循环输入地址下 所有。wav 音乐
for wav_file in glob.glob(os.path.join(opt.music_dir, "*.wav")):
# create temp folder (or use the cache folder if specified)
if opt.cache_features:
songname = os.path.splitext(os.path.basename(wav_file))[0]
save_dir = os.path.join(opt.feature_cache_dir, songname)
Path(save_dir).mkdir(parents=True, exist_ok=True)
dirname = save_dir
else:
# 创建临时文件夹
temp_dir = TemporaryDirectory()
temp_dir_list.append(temp_dir)
dirname = temp_dir.name
# slice the audio file
print(f"Slicing {wav_file}")
# 切分音乐保存临时文件下 间隔2.5s 长度 5s
slice_audio(wav_file, 2.5, 5.0, dirname)
file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=stringintkey)
# randomly sample a chunk of length at most sample_size
# #随机采样长度最多为sample_size的块
rand_idx = random.randint(0, len(file_list) - sample_size)
cond_list = []
# generate juke representations
print(f"Computing features for {wav_file}")
for idx, file in enumerate(tqdm(file_list)):
# if not caching then only calculate for the interested range
if (not opt.cache_features) and (not (rand_idx <= idx < rand_idx + sample_size)):
continue
# audio = jukemirlib.load_audio(file)
# reps = jukemirlib.extract(
# audio, layers=[66], downsample_target_rate=30
# )[66]
# 提取音乐特则
reps, _ = feature_func(file)
# save reps
if opt.cache_features:
featurename = os.path.splitext(file)[0] + ".npy"
np.save(featurename, reps)
# if in the random range, put it into the list of reps we want
# to actually use for generation
if rand_idx <= idx < rand_idx + sample_size:
cond_list.append(reps)
cond_list = torch.from_numpy(np.array(cond_list))
all_cond.append(cond_list)
all_filenames.append(file_list[rand_idx : rand_idx + sample_size])
# 初始化 edge 模型
model = EDGE(opt.feature_type, opt.checkpoint)
model.eval()
# directory for optionally saving the dances for eval
fk_out = None
if opt.save_motions:
fk_out = opt.motion_save_dir
print("Generating dances")
for i in range(len(all_cond)):
data_tuple = None, all_cond[i], all_filenames[i]
# 开始推理
model.render_sample(
data_tuple, "test", opt.render_dir, render_count=-1, fk_out=fk_out, render=not opt.no_render
)
print("Done")
# 删除临时文件
torch.cuda.empty_cache()
for temp_dir in temp_dir_list:
temp_dir.cleanup()
if __name__ == "__main__":
opt = parse_test_opt()
test(opt)
diffusion.py
def render_sample(
self,
shape,
cond,
normalizer,
epoch,
render_out,
fk_out=None,
name=None,
sound=True,
mode="normal",
noise=None,
constraint=None,
sound_folder="ood_sliced",
start_point=None,
render=True
):
if isinstance(shape, tuple):
if mode == "inpaint":
func_class = self.inpaint_loop
elif mode == "normal":
func_class = self.ddim_sample
elif mode == "long":
func_class = self.long_ddim_sample
else:
assert False, "Unrecognized inference mode"
samples = (
func_class(
shape,
cond,
noise=noise,
constraint=constraint,
start_point=start_point,
)
.detach()
.cpu()
)# 根据音乐生成噪声
else:
samples = shape
# normalizer
samples = normalizer.unnormalize(samples)
# 切分结果切分 4, 147 = 脚部数据4,24*6,3
if samples.shape[2] == 151:
sample_contact, samples = torch.split(
samples, (4, samples.shape[2] - 4), dim=2
)
else:
sample_contact = None
# do the FK all at once
b, s, c = samples.shape
pos = samples[:, :, :3].to(cond.device) # np.zeros((sample.shape[0], 3)) # (3)
q = samples[:, :, 3:].reshape(b, s, 24, 6) # (144)
# go 6d to ax 144 - 72
q = ax_from_6v(q).to(cond.device)
if mode == "long":
b, s, c1, c2 = q.shape
assert s % 2 == 0
half = s // 2
if b > 1:
# if long mode, stitch position using linear interp
fade_out = torch.ones((1, s, 1)).to(pos.device)
fade_in = torch.ones((1, s, 1)).to(pos.device)
fade_out[:, half:, :] = torch.linspace(1, 0, half)[None, :, None].to(
pos.device
)
fade_in[:, :half, :] = torch.linspace(0, 1, half)[None, :, None].to(
pos.device
)
pos[:-1] *= fade_out
pos[1:] *= fade_in
full_pos = torch.zeros((s + half * (b - 1), 3)).to(pos.device)
idx = 0
for pos_slice in pos:
full_pos[idx : idx + s] += pos_slice
idx += half
# stitch joint angles with slerp
slerp_weight = torch.linspace(0, 1, half)[None, :, None].to(pos.device)
left, right = q[:-1, half:], q[1:, :half]
# convert to quat
left, right = (
axis_angle_to_quaternion(left),
axis_angle_to_quaternion(right),
)
merged = quat_slerp(left, right, slerp_weight) # (b-1) x half x ...
# convert back
merged = quaternion_to_axis_angle(merged)
full_q = torch.zeros((s + half * (b - 1), c1, c2)).to(pos.device)
full_q[:half] += q[0, :half]
idx = half
for q_slice in merged:
full_q[idx : idx + half] += q_slice
idx += half
full_q[idx : idx + half] += q[-1, half:]
# unsqueeze for fk
full_pos = full_pos.unsqueeze(0)
full_q = full_q.unsqueeze(0)
else:
full_pos = pos
full_q = q
# 保存pkl ,
full_pose = (
self.smpl.forward(full_q, full_pos).detach().cpu().numpy()
) # b, s, 24, 3
if fk_out is not None:
outname = f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.pkl'
Path(fk_out).mkdir(parents=True, exist_ok=True)
pickle.dump(
{
"smpl_poses": full_q.squeeze(0).reshape((-1, 72)).cpu().numpy(),
"smpl_trans": full_pos.squeeze(0).cpu().numpy(),
"full_pose": full_pose[0],
},
open(os.path.join(fk_out, outname), "wb"),
)
return
poses = self.smpl.forward(q, pos).detach().cpu().numpy()
# 保存pkl
if fk_out is not None and mode != "long":
Path(fk_out).mkdir(parents=True, exist_ok=True)
for num, (qq, pos_, filename, pose) in enumerate(zip(q, pos, name, poses)):
path = os.path.normpath(filename)
pathparts = path.split(os.sep)
pathparts[-1] = pathparts[-1].replace("npy", "wav")
# path is like "data/train/features/name"
pathparts[2] = "wav_sliced"
audioname = os.path.join(*pathparts)
outname = f"{epoch}_{num}_{pathparts[-1][:-4]}.pkl"
pickle.dump(
{
"smpl_poses": qq.reshape((-1, 72)).cpu().numpy(),
"smpl_trans": pos_.cpu().numpy(),
"full_pose": pose,
},
open(f"{fk_out}/{outname}", "wb"),
)
# 这段代码是在执行一种称为"逆向扩散"的步骤,通常与扩散模型(如DDPM, DDIM等)相关。这些模型首先通过逐步添加噪声来"扩散"数据,然后通过逆向过程来恢复原始数据
@torch.no_grad()
def ddim_sample(self, shape, cond, **kwargs):
# batch device 100 50 1
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1
# 生成时间步 [-1,23,424,total_timesteps - 1]
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist())) # 反转
time_pairs = list(zip(times[:-1], times[1:])) # [(T[-1], T[-2]), (T[-2], T[-3]), ..., (1, 0), (0, -1)]
# 使用正态分布初始化随机张量x。
x = torch.randn(shape, device = device)
cond = cond.to(device)
x_start = None
# 逆向扩散
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
# (batch,) 值为 time 当前时间步
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
# 噪音,音乐,当前时间步,true
pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start = self.clip_denoised)
# 如果最后一步 推出循环
if time_next < 0:
x = x_start
continue
# 计算与当前时间步和下一个时间步相关的alpha值(alpha和alpha_next)。这些alpha值用于控制噪声的缩放
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
# 这两个值用于根据预测的噪声和随机噪声来更新
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(x)
# 使用预测噪声 和随机噪声 一起更新 x
x = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise
return x
# 前向传播结果
def model_predictions(self, x, cond, t, weight=None, clip_x_start = False):
weight = weight if weight is not None else self.guidance_weight # 2
# 前向传播 forward 输出结果
model_output = self.model.guided_forward(x, cond, t, weight)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
x_start = model_output
x_start = maybe_clip(x_start) # 将输入张量替换到 [-1., 1.] 范围内的函数
# x - x_start
pred_noise = self.predict_noise_from_start(x, t, x_start)
return pred_noise, x_start