EDGE 音乐生成舞蹈 diffution 算法-代码学习

1、数据切分、音乐特征生成

(1)输入数据:音乐数据.wav ,及对应 smpl 结构动作数据 .pkl

(2)根据 splits 下 txt 进行数据训练、测试、评估数据切分

(3)对数据进行扩充、切分 以 stride=0.5, length=5 每段间隔0.5s 长度为 5s 切分数据

(4)使用 jukebox 【jukebox 音乐生成算法,使用不同大小的多级vq-vae网络,生成时输入数据:随机数+音乐家、音乐风等信息,由上到下级联结合后decoder结果】读取音乐特征,并保存 npy 数据 (秒数 * 音乐帧数,4800)

(本方法:使用读取音乐特征 -> vq-vae 的encoder 层 -> 增加默认的条件输入 -> 使用由上到下的级联结构 输出特征(decoder 前结果))

如果使用自己的数据注意:(如果动作数据非 60 fps 需要修改)

        1、在slice.py 38 、39 行动作数据的 FPS 默认为60,改成自己的 fps 

        2、音乐数据 jukebox 特征提取会下采样到 30 FPS

        3、dataset/dance_dataset.py self.raw_fps = 60  self.data_fps = 30  需要调整,默认会以步长为2的方式取动作数据,如果自己fps=30 改 self.raw_fps = 30

create_dataset.py

def create_dataset(opt):
    # split the data according to the splits files
    print("Creating train / test split")

    # 更改字典名+移动位置 train、test
    split_data(opt.dataset_folder)
    # slice motions/music into sliding windows to create training dataset

    # 切分训练、测试数据集 stride=0.5, length=5 秒
    print("Slicing train data")
    slice_aistpp(f"train/motions", f"train/wavs")
    print("Slicing test data")
    slice_aistpp(f"test/motions", f"test/wavs")
    # process dataset to extract audio features
    # jukemirlib 特征提取
    if opt.extract_baseline:
        print("Extracting baseline features")
        baseline_extract("train/wavs_sliced", "train/baseline_feats")
        baseline_extract("test/wavs_sliced", "test/baseline_feats")
    if opt.extract_jukebox:
        print("Extracting jukebox features")
        jukebox_extract("train/wavs_sliced", "train/jukebox_feats")
        jukebox_extract("test/wavs_sliced", "test/jukebox_feats")

2、数据预处理

(1) 读取npy 音乐特征(n*4800)、pkl  smpl 平移,旋转数据、并根据 音乐fps、data_fps 裁剪 smpl 平移,旋转数据

(2)对 smpl 平移(n*24*3)、旋转(n*3)数据

               1)从y轴向上旋转到z轴向上

                2)使用 ax_to_6v 函数将关节的轴角平移表示转换为6D表示 (n*24*6)

                3)生成3d点位置,检测脚部关节点是否接近接触。(n*4)

                4)拼接 2) 3)平移 3种数据、然后展平(n*151)、经过 normalizer 标准化

dance_dataset.py

def load_aistpp(self):
        # open data path
        split_data_path = os.path.join(
            self.data_path, "train" if self.train else "test"
        )

        # Structure:
        # data
        #   |- train
        #   |    |- motion_sliced
        #   |    |- wav_sliced
        #   |    |- baseline_features
        #   |    |- jukebox_features
        #   |    |- motions
        #   |    |- wavs

        motion_path = os.path.join(split_data_path, "motions_sliced")
        sound_path = os.path.join(split_data_path, f"{self.feature_type}_feats")
        wav_path = os.path.join(split_data_path, f"wavs_sliced")
        # sort motions and sounds
        motions = sorted(glob.glob(os.path.join(motion_path, "*.pkl")))
        features = sorted(glob.glob(os.path.join(sound_path, "*.npy")))
        wavs = sorted(glob.glob(os.path.join(wav_path, "*.wav")))

        # stack the motions and features together
        all_pos = []
        all_q = []
        all_names = []
        all_wavs = []
        assert len(motions) == len(features)
        w = []
        for motion, feature, wav in zip(motions, features, wavs):
            # make sure name is matching
            m_name = os.path.splitext(os.path.basename(motion))[0]
            f_name = os.path.splitext(os.path.basename(feature))[0]
            w_name = os.path.splitext(os.path.basename(wav))[0]
            assert m_name == f_name == w_name, str((motion, feature, wav))
            # load motion
            data = pickle.load(open(motion, "rb"))
            pos = data["pos"]
            q = data["q"]
            local_q = torch.Tensor(q)
            if torch.isnan(local_q).any():
                continue

            all_pos.append(pos)
            all_q.append(q)
            all_names.append(feature)
            all_wavs.append(wav)

        all_pos = np.array(all_pos)  # N x seq x 3
        all_q = np.array(all_q)  # N x seq x (joint * 3)
        # downsample the motions to the data fps
        #
        all_pos = all_pos[:, :: self.data_stride, :]
        all_q = all_q[:, :: self.data_stride, :]
        data = {"pos": all_pos, "q": all_q, "filenames": all_names, "wavs": all_wavs}
        return data

    def process_dataset(self, root_pos, local_q): # 3, 72
        # FK skeleton
        smpl = SMPLSkeleton()
        # to Tensor
        root_pos = torch.Tensor(root_pos)
        local_q = torch.Tensor(local_q)

        assert not torch.isnan(root_pos).any()
        assert not torch.isnan(local_q).any()
        # to ax
        bs, sq, c = local_q.shape
        local_q = local_q.reshape((bs, sq, -1, 3)) # b,l,24,3

        # AISTPP dataset comes y-up - rotate to z-up to standardize against the pretrain dataset
        # 说明AISTPP数据集使用y轴作为向上方向,而预训练的数据集使用z轴作为向上方向,中的根关节点从y轴向上旋转到z轴向上
        root_q = local_q[:, :, :1, :]  # sequence x 1 x 3
        root_q_quat = axis_angle_to_quaternion(root_q)
        rotation = torch.Tensor(
            [0.7071068, 0.7071068, 0, 0]
        )  # 90 degrees about the x axis
        root_q_quat = quaternion_multiply(rotation, root_q_quat)
        root_q = quaternion_to_axis_angle(root_q_quat)
        local_q[:, :, :1, :] = root_q

        #
        # don't forget to rotate the root position too 😩
        # 更改旋转点解
        pos_rotation = RotateAxisAngle(90, axis="X", degrees=True)
        root_pos = pos_rotation.transform_points(
            root_pos
        )  # basically (y, z) -> (-z, y), expressed as a rotation for readability

        # do FK
        # 转为 绝对坐标 人体模型的关节点数量,3是三维坐标(x, y, z)。
        positions = smpl.forward(local_q, root_pos)  # batch x sequence x 24 x 3
        feet = positions[:, :, (7, 8, 10, 11)]
        feetv = torch.zeros(feet.shape[:3])
        feetv[:, :-1] = (feet[:, 1:] - feet[:, :-1]).norm(dim=-1)
        contacts = (feetv < 0.01).to(local_q)  # cast to right dtype

        # to 6d
        # 使用ax_to_6v函数将关节的轴角表示(或其他表示)转换为6D表示。6D表示是一种用于表示3D旋转的常用方法,它避免了四元数的一些潜在问题(如万向锁)。
        local_q = ax_to_6v(local_q)

        # now, flatten everything into: batch x sequence x [...]
        # 4 + 3 + 24*6
        l = [contacts, root_pos, local_q]

        # resize b * l * 151
        global_pose_vec_input = vectorize_many(l).float().detach()

        # normalize the data. Both train and test need the same normalizer.
        if self.train:
            self.normalizer = Normalizer(global_pose_vec_input)
        else:
            assert self.normalizer is not None
        global_pose_vec_input = self.normalizer.normalize(global_pose_vec_input)

        assert not torch.isnan(global_pose_vec_input).any()
        data_name = "Train" if self.train else "Test"

        # cut the dataset
        if self.data_len > 0:
            global_pose_vec_input = global_pose_vec_input[: self.data_len]

        global_pose_vec_input = global_pose_vec_input

        print(f"{data_name} Dataset Motion Features Dim: {global_pose_vec_input.shape}")

        return global_pose_vec_input

(3)训练:加载 dataloader 、训练、计算loss 、反向传播

  EDGE.py
def train_loop(self, opt):
        # load datasets
        # 读取数据集
        train_tensor_dataset_path = os.path.join(
            opt.processed_data_dir, f"train_tensor_dataset.pkl"
        )
        test_tensor_dataset_path = os.path.join(
            opt.processed_data_dir, f"test_tensor_dataset.pkl"
        )
        if (
            not opt.no_cache
            and os.path.isfile(train_tensor_dataset_path)
            and os.path.isfile(test_tensor_dataset_path)
        ):
            train_dataset = pickle.load(open(train_tensor_dataset_path, "rb"))
            test_dataset = pickle.load(open(test_tensor_dataset_path, "rb"))
        else:
            train_dataset = AISTPPDataset(
                data_path=opt.data_path,
                backup_path=opt.processed_data_dir,
                train=True,
                force_reload=opt.force_reload,
            )
            test_dataset = AISTPPDataset(
                data_path=opt.data_path,
                backup_path=opt.processed_data_dir,
                train=False,
                normalizer=train_dataset.normalizer,
                force_reload=opt.force_reload,
            )
            # cache the dataset in case
            if self.accelerator.is_main_process:
                pickle.dump(train_dataset, open(train_tensor_dataset_path, "wb"))
                pickle.dump(test_dataset, open(test_tensor_dataset_path, "wb"))

        # set normalizer
        self.normalizer = test_dataset.normalizer

        # data loaders
        # decide number of workers based on cpu count
        num_cpus = multiprocessing.cpu_count()
        train_data_loader = DataLoader(
            train_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=min(int(num_cpus * 0.75), 32),
            pin_memory=True,
            drop_last=True,
        )
        test_data_loader = DataLoader(
            test_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True,
        )

        train_data_loader = self.accelerator.prepare(train_data_loader)
        # boot up multi-gpu training. test dataloader is only on main process
        # 主进程 显示进度条,其他进程不显示
        load_loop = (
            partial(tqdm, position=1, desc="Batch")
            if self.accelerator.is_main_process
            else lambda x: x
        )
        #
        if self.accelerator.is_main_process:
            save_dir = str(increment_path(Path(opt.project) / opt.exp_name))
            opt.exp_name = save_dir.split("/")[-1]

            # 初始化 wandb
            # wandb.init(project=opt.wandb_pj_name, name=opt.exp_name)
            # 创建权重保存目录
            save_dir = Path(save_dir)
            wdir = save_dir / "weights"
            wdir.mkdir(parents=True, exist_ok=True)

        self.accelerator.wait_for_everyone()
        # 开始循环 epochs
        for epoch in range(1, opt.epochs + 1):
            avg_loss = 0
            avg_vloss = 0
            avg_fkloss = 0
            avg_footloss = 0
            # train
            self.train()
            for step, (x, cond, filename, wavnames) in enumerate(load_loop(train_data_loader)
):
                # 0, x[32, 75, 151] 动作,cond[[32, 150, 4800] 音乐编码,.npy 文件名,wav文件名
                total_loss, (loss, v_loss, fk_loss, foot_loss) = self.diffusion(
                    x, cond, t_override=None
                )
                self.optim.zero_grad()
                self.accelerator.backward(total_loss)

                self.optim.step()

                # ema update and train loss update only on main
                if self.accelerator.is_main_process:
                    avg_loss += loss.detach().cpu().numpy()
                    avg_vloss += v_loss.detach().cpu().numpy()
                    avg_fkloss += fk_loss.detach().cpu().numpy()
                    avg_footloss += foot_loss.detach().cpu().numpy()
                    if step % opt.ema_interval == 0:
                        self.diffusion.ema.update_model_average(
                            self.diffusion.master_model, self.diffusion.model
                        )
            # Save model
            if (epoch % opt.save_interval) == 0:
                # everyone waits here for the val loop to finish ( don't start next train epoch early)
                self.accelerator.wait_for_everyone()
                # save only if on main thread
                if self.accelerator.is_main_process:
                    self.eval()
                    # log
                    avg_loss /= len(train_data_loader)
                    avg_vloss /= len(train_data_loader)
                    avg_fkloss /= len(train_data_loader)
                    avg_footloss /= len(train_data_loader)
                    log_dict = {
                        "Train Loss": avg_loss,
                        "V Loss": avg_vloss,
                        "FK Loss": avg_fkloss,
                        "Foot Loss": avg_footloss,
                    }
                    wandb.log(log_dict)
                    ckpt = {
                        "ema_state_dict": self.diffusion.master_model.state_dict(),
                        "model_state_dict": self.accelerator.unwrap_model(
                            self.model
                        ).state_dict(),
                        "optimizer_state_dict": self.optim.state_dict(),
                        "normalizer": self.normalizer,
                    }
                    torch.save(ckpt, os.path.join(wdir, f"train-{epoch}.pt"))
                    # generate a sample
                    render_count = 2
                    shape = (render_count, self.horizon, self.repr_dim)
                    print("Generating Sample")
                    # draw a music from the test dataset
                    (x, cond, filename, wavnames) = next(iter(test_data_loader))
                    cond = cond.to(self.accelerator.device)
                    self.diffusion.render_sample(
                        shape,
                        cond[:render_count],
                        self.normalizer,
                        epoch,
                        os.path.join(opt.render_dir, "train_" + opt.exp_name),
                        name=wavnames[:render_count],
                        sound=True,
                    )
                    print(f"[MODEL SAVED at Epoch {epoch}]")
        if self.accelerator.is_main_process:
            wandb.run.finish()
 diffusion.py 训练部分
def p_losses(self, x_start, cond, t):
        # b,l,151  ,  b,l,4800  , b,(0-1000)
        # 生成 和 x_start 相同维度的正态分布随机数
        noise = torch.randn_like(x_start)

        # x * a + n * (1-a)    数据按比例增加 噪声
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        # reconstruct
        # 通过 噪声,音乐,步长 输出 去噪后的预测数据
        x_recon = self.model(x_noisy, cond, t, cond_drop_prob=self.cond_drop_prob)

        assert noise.shape == x_recon.shape

        model_out = x_recon
        if self.predict_epsilon:
            target = noise
        else:
            target = x_start

        # full reconstruction loss
        # 计算 均方差损失
        # 整合损失取 mean ,然后损失加权
        loss = self.loss_fn(model_out, target, reduction="none")
        loss = reduce(loss, "b ... -> b (...)", "mean")
        loss = loss * extract(self.p2_loss_weight, t, loss.shape)

        # split off contact from the rest
        # 切分预测结果、标签
        model_contact, model_out = torch.split(
            model_out, (4, model_out.shape[2] - 4), dim=2
        )
        target_contact, target = torch.split(target, (4, target.shape[2] - 4), dim=2)

        # velocity loss
        # 计算运动速度损失
        target_v = target[:, 1:] - target[:, :-1]
        model_out_v = model_out[:, 1:] - model_out[:, :-1]
        v_loss = self.loss_fn(model_out_v, target_v, reduction="none")
        v_loss = reduce(v_loss, "b ... -> b (...)", "mean")
        v_loss = v_loss * extract(self.p2_loss_weight, t, v_loss.shape)

        # FK loss
        b, s, c = model_out.shape
        # unnormalize
        # model_out = self.normalizer.unnormalize(model_out)
        # target = self.normalizer.unnormalize(target)
        # X, Q
        model_x = model_out[:, :, :3]
        model_q = ax_from_6v(model_out[:, :, 3:].reshape(b, s, -1, 6))
        target_x = target[:, :, :3]
        target_q = ax_from_6v(target[:, :, 3:].reshape(b, s, -1, 6))

        # perform FK
        # 计算 3d 点数据 均方差损失
        model_xp = self.smpl.forward(model_q, model_x)
        target_xp = self.smpl.forward(target_q, target_x)

        fk_loss = self.loss_fn(model_xp, target_xp, reduction="none")
        fk_loss = reduce(fk_loss, "b ... -> b (...)", "mean")
        fk_loss = fk_loss * extract(self.p2_loss_weight, t, fk_loss.shape)

        # foot skate loss
        # 计算 脚部 损失,防止脚滑步

        foot_idx = [7, 8, 10, 11]
        # find static indices consistent with model's own predictions
        static_idx = model_contact > 0.95  # N x S x 4
        model_feet = model_xp[:, :, foot_idx]  # foot positions (N, S, 4, 3)
        model_foot_v = torch.zeros_like(model_feet)
        model_foot_v[:, :-1] = (
            model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :]
        )  # (N, S-1, 4, 3)
        model_foot_v[~static_idx] = 0
        # 由于我们想要惩罚任何非零速度(即滑动),因此目标速度设置为零。
        foot_loss = self.loss_fn(
            model_foot_v, torch.zeros_like(model_foot_v), reduction="none"
        )
        foot_loss = reduce(foot_loss, "b ... -> b (...)", "mean")

        # 损失集合
        losses = (
            0.636 * loss.mean(),
            2.964 * v_loss.mean(),
            0.646 * fk_loss.mean(),
            10.942 * foot_loss.mean(),
        )
        return sum(losses), losses



model.py self.model()方法 训练部分
def forward(
        self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
    ):
        batch_size, device = x.shape[0], x.device

        # project to latent space
        # Linear 特征编码 151 - 512
        x = self.input_projection(x)
        # add the positional embeddings of the input sequence to provide temporal information
        # x + pr 位置编码 [8, 150, 512]
        # 没使用
        x = self.abs_pos_encoding(x)

        # create music conditional embedding with conditional dropout
        # (batch_size,) 的 0 矩阵
        keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
        # (batch_size,1,1)
        keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
        keep_mask_hidden = rearrange(keep_mask, "b -> b 1") # (batch_size,1)

        # Linear 音乐编码 4800 - 512
        cond_tokens = self.cond_projection(cond_embed)
        # encode tokens # x + pr 位置编码 没使用
        cond_tokens = self.abs_pos_encoding(cond_tokens)
        # 使用标准 transformer 编码 2层
        cond_tokens = self.cond_encoder(cond_tokens) # [8, 150, 512]

        # 可学习参数
        null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
        # 如果 keep_mask_embed 在某个位置上是 True(或非零值),则保留 cond_tokens 在该位置的值;否则,将 cond_tokens 在该位置的值替换为 null_cond_embed。
        # 这样,你就可以基于条件丢弃概率来选择性地保留或“丢弃”某些条件信息。
        cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed) # [8, 150, 512]

        mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) # 平均池化 [8, 512]
        cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) #  [8, 512]

        # create the diffusion timestep embedding, add the extra music projection
        # 将时间步长times转换为隐藏表示。
        t_hidden = self.time_mlp(times)

        # project to attention and FiLM conditioning
        # 将时间步长隐藏表示分别投影到时间条件和时间令牌。
        t = self.to_time_cond(t_hidden)
        t_tokens = self.to_time_tokens(t_hidden)

        # FiLM conditioning
        # 将条件隐藏表示添加到时间条件中,进行FiLM(Feature-wise Linear Modulation)条件化。
        null_cond_hidden = self.null_cond_hidden.to(t.dtype)
        cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
        t += cond_hidden # 时间 + 音乐特征

        # cross-attention conditioning
        # 交叉注意力机制 音乐特征+时间特征
        c = torch.cat((cond_tokens, t_tokens), dim=-2)
        cond_tokens = self.norm_cond(c) # 归一化

        # Pass through the transformer decoder
        # attending to the conditional embedding
        # 输入x,同时关注音乐条件cond_tokens 和 时间条件t。
        # 这是一个自定义的Transformer解码器层的实现,其中包含了自注意力(self-attention)、跨注意力(cross-attention,也常被称为编码器-解码器注意力encoder-decoder attention)和前馈神经网络(feedforward network)三个主要部分。
        # 此外,这个实现还引入了Feature-wise Linear Modulation (FiLM) 层来对这些主要部分进行条件化的调制,并使用了Layer Normalization(层归一化)和Dropout。
        output = self.seqTransDecoder(x, cond_tokens, t)

        # 512 - out size
        output = self.final_layer(output)

        return output

3、推理

(1)定义输出长度 sample_size 

(2)检索音乐、并使用jukebox 读取特征

(3)初始化 EDGE 模型,循环处理每个音乐特征进行推理

import glob
import os
from functools import cmp_to_key
from pathlib import Path
from tempfile import TemporaryDirectory
import random

import jukemirlib
import numpy as np
import torch
from tqdm import tqdm

from args import parse_test_opt
from data.slice import slice_audio
from EDGE import EDGE
from data.audio_extraction.baseline_features import extract as baseline_extract
from data.audio_extraction.jukebox_features import extract as juke_extract

# sort filenames that look like songname_slice{number}.ext
key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])


def stringintcmp_(a, b):
    aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
    ka, kb = key_func(a), key_func(b)
    if aa < bb:
        return -1
    if aa > bb:
        return 1
    if ka < kb:
        return -1
    if ka > kb:
        return 1
    return 0


stringintkey = cmp_to_key(stringintcmp_)

def test(opt):
    # 读取数据类型 jukebox
    feature_func = juke_extract if opt.feature_type == "jukebox" else baseline_extract
    # 定义输出长度
    sample_length = opt.out_length
    sample_size = int(sample_length / 2.5) - 1

    temp_dir_list = []
    all_cond = []
    all_filenames = []
    if opt.use_cached_features:
        print("Using precomputed features")
        # all subdirectories
        dir_list = glob.glob(os.path.join(opt.feature_cache_dir, "*/"))
        for dir in dir_list:
            file_list = sorted(glob.glob(f"{dir}/*.wav"), key=stringintkey)
            juke_file_list = sorted(glob.glob(f"{dir}/*.npy"), key=stringintkey)
            assert len(file_list) == len(juke_file_list)
            # random chunk after sanity check
            rand_idx = random.randint(0, len(file_list) - sample_size)
            file_list = file_list[rand_idx : rand_idx + sample_size]
            juke_file_list = juke_file_list[rand_idx : rand_idx + sample_size]
            cond_list = [np.load(x) for x in juke_file_list]
            all_filenames.append(file_list)
            all_cond.append(torch.from_numpy(np.array(cond_list)))
    else:
        print("Computing features for input music")
        # 循环输入地址下 所有。wav 音乐
        for wav_file in glob.glob(os.path.join(opt.music_dir, "*.wav")):
            # create temp folder (or use the cache folder if specified)
            if opt.cache_features:
                songname = os.path.splitext(os.path.basename(wav_file))[0]
                save_dir = os.path.join(opt.feature_cache_dir, songname)
                Path(save_dir).mkdir(parents=True, exist_ok=True)
                dirname = save_dir
            else:
                # 创建临时文件夹
                temp_dir = TemporaryDirectory()
                temp_dir_list.append(temp_dir)
                dirname = temp_dir.name
            # slice the audio file
            print(f"Slicing {wav_file}")
            # 切分音乐保存临时文件下 间隔2.5s  长度 5s
            slice_audio(wav_file, 2.5, 5.0, dirname)
            file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=stringintkey)
            # randomly sample a chunk of length at most sample_size
            # #随机采样长度最多为sample_size的块
            rand_idx = random.randint(0, len(file_list) - sample_size)
            cond_list = []
            # generate juke representations
            print(f"Computing features for {wav_file}")
            for idx, file in enumerate(tqdm(file_list)):
                # if not caching then only calculate for the interested range
                if (not opt.cache_features) and (not (rand_idx <= idx < rand_idx + sample_size)):
                    continue
                # audio = jukemirlib.load_audio(file)
                # reps = jukemirlib.extract(
                #     audio, layers=[66], downsample_target_rate=30
                # )[66]
                # 提取音乐特则
                reps, _ = feature_func(file)
                # save reps
                if opt.cache_features:
                    featurename = os.path.splitext(file)[0] + ".npy"
                    np.save(featurename, reps)
                # if in the random range, put it into the list of reps we want
                # to actually use for generation
                if rand_idx <= idx < rand_idx + sample_size:
                    cond_list.append(reps)
            cond_list = torch.from_numpy(np.array(cond_list))
            all_cond.append(cond_list)
            all_filenames.append(file_list[rand_idx : rand_idx + sample_size])
    # 初始化 edge 模型
    model = EDGE(opt.feature_type, opt.checkpoint)
    model.eval()

    # directory for optionally saving the dances for eval
    fk_out = None
    if opt.save_motions:
        fk_out = opt.motion_save_dir

    print("Generating dances")
    for i in range(len(all_cond)):
        data_tuple = None, all_cond[i], all_filenames[i]
        # 开始推理
        model.render_sample(
            data_tuple, "test", opt.render_dir, render_count=-1, fk_out=fk_out, render=not opt.no_render
        )
    print("Done")
    
    # 删除临时文件
    torch.cuda.empty_cache()
    for temp_dir in temp_dir_list:
        temp_dir.cleanup()


if __name__ == "__main__":
    opt = parse_test_opt()
    test(opt)
 diffusion.py

def render_sample(
        self,
        shape,
        cond,
        normalizer,
        epoch,
        render_out,
        fk_out=None,
        name=None,
        sound=True,
        mode="normal",
        noise=None,
        constraint=None,
        sound_folder="ood_sliced",
        start_point=None,
        render=True
    ):
        if isinstance(shape, tuple):
            if mode == "inpaint":
                func_class = self.inpaint_loop
            elif mode == "normal":
                func_class = self.ddim_sample
            elif mode == "long":
                func_class = self.long_ddim_sample
            else:
                assert False, "Unrecognized inference mode"
            samples = (
                func_class(
                    shape,
                    cond,
                    noise=noise,
                    constraint=constraint,
                    start_point=start_point,
                )
                .detach()
                .cpu()
            )#  根据音乐生成噪声
        else:
            samples = shape

        # normalizer
        samples = normalizer.unnormalize(samples)

        # 切分结果切分 4, 147 = 脚部数据4,24*6,3
        if samples.shape[2] == 151:
            sample_contact, samples = torch.split(
                samples, (4, samples.shape[2] - 4), dim=2
            ) 
        else:
            sample_contact = None

        # do the FK all at once
        b, s, c = samples.shape
        pos = samples[:, :, :3].to(cond.device)  # np.zeros((sample.shape[0], 3)) # (3)
        q = samples[:, :, 3:].reshape(b, s, 24, 6) # (144)
        # go 6d to ax 144 - 72
        q = ax_from_6v(q).to(cond.device)


        if mode == "long":
            b, s, c1, c2 = q.shape
            assert s % 2 == 0
            half = s // 2
            if b > 1:
                # if long mode, stitch position using linear interp

                fade_out = torch.ones((1, s, 1)).to(pos.device)
                fade_in = torch.ones((1, s, 1)).to(pos.device)
                fade_out[:, half:, :] = torch.linspace(1, 0, half)[None, :, None].to(
                    pos.device
                )
                fade_in[:, :half, :] = torch.linspace(0, 1, half)[None, :, None].to(
                    pos.device
                )

                pos[:-1] *= fade_out
                pos[1:] *= fade_in

                full_pos = torch.zeros((s + half * (b - 1), 3)).to(pos.device)
                idx = 0
                for pos_slice in pos:
                    full_pos[idx : idx + s] += pos_slice
                    idx += half

                # stitch joint angles with slerp
                slerp_weight = torch.linspace(0, 1, half)[None, :, None].to(pos.device)

                left, right = q[:-1, half:], q[1:, :half]
                # convert to quat
                left, right = (
                    axis_angle_to_quaternion(left),
                    axis_angle_to_quaternion(right),
                )
                merged = quat_slerp(left, right, slerp_weight)  # (b-1) x half x ...
                # convert back
                merged = quaternion_to_axis_angle(merged)

                full_q = torch.zeros((s + half * (b - 1), c1, c2)).to(pos.device)
                full_q[:half] += q[0, :half]
                idx = half
                for q_slice in merged:
                    full_q[idx : idx + half] += q_slice
                    idx += half
                full_q[idx : idx + half] += q[-1, half:]

                # unsqueeze for fk
                full_pos = full_pos.unsqueeze(0)
                full_q = full_q.unsqueeze(0)
            else:
                full_pos = pos
                full_q = q

            # 保存pkl ,
            full_pose = (
                self.smpl.forward(full_q, full_pos).detach().cpu().numpy()
            )  # b, s, 24, 3
            if fk_out is not None:
                outname = f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.pkl'
                Path(fk_out).mkdir(parents=True, exist_ok=True)
                pickle.dump(
                    {
                        "smpl_poses": full_q.squeeze(0).reshape((-1, 72)).cpu().numpy(),
                        "smpl_trans": full_pos.squeeze(0).cpu().numpy(),
                        "full_pose": full_pose[0],
                    },
                    open(os.path.join(fk_out, outname), "wb"),
                )
            return

        poses = self.smpl.forward(q, pos).detach().cpu().numpy()

        # 保存pkl
        if fk_out is not None and mode != "long":
            Path(fk_out).mkdir(parents=True, exist_ok=True)
            for num, (qq, pos_, filename, pose) in enumerate(zip(q, pos, name, poses)):
                path = os.path.normpath(filename)
                pathparts = path.split(os.sep)
                pathparts[-1] = pathparts[-1].replace("npy", "wav")
                # path is like "data/train/features/name"
                pathparts[2] = "wav_sliced"
                audioname = os.path.join(*pathparts)
                outname = f"{epoch}_{num}_{pathparts[-1][:-4]}.pkl"
                pickle.dump(
                    {
                        "smpl_poses": qq.reshape((-1, 72)).cpu().numpy(),
                        "smpl_trans": pos_.cpu().numpy(),
                        "full_pose": pose,
                    },
                    open(f"{fk_out}/{outname}", "wb"),
                )

# 这段代码是在执行一种称为"逆向扩散"的步骤,通常与扩散模型(如DDPM, DDIM等)相关。这些模型首先通过逐步添加噪声来"扩散"数据,然后通过逆向过程来恢复原始数据
    @torch.no_grad()
    def ddim_sample(self, shape, cond, **kwargs):

        # batch device  100 50 1
        batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1

        # 生成时间步 [-1,23,424,total_timesteps - 1]
        times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist())) # 反转
        time_pairs = list(zip(times[:-1], times[1:])) # [(T[-1], T[-2]), (T[-2], T[-3]), ..., (1, 0), (0, -1)]

        # 使用正态分布初始化随机张量x。
        x = torch.randn(shape, device = device)
        cond = cond.to(device)

        x_start = None
        # 逆向扩散
        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            # (batch,) 值为 time 当前时间步
            time_cond = torch.full((batch,), time, device=device, dtype=torch.long)

            # 噪音,音乐,当前时间步,true
            pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start = self.clip_denoised)

            # 如果最后一步 推出循环
            if time_next < 0:
                x = x_start
                continue

            # 计算与当前时间步和下一个时间步相关的alpha值(alpha和alpha_next)。这些alpha值用于控制噪声的缩放
            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]
            # 这两个值用于根据预测的噪声和随机噪声来更新
            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(x)
            # 使用预测噪声 和随机噪声 一起更新 x
            x = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise

        return x

# 前向传播结果
    def model_predictions(self, x, cond, t, weight=None, clip_x_start = False):
        weight = weight if weight is not None else self.guidance_weight # 2
        # 前向传播 forward 输出结果
        model_output = self.model.guided_forward(x, cond, t, weight)
        maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
        
        x_start = model_output
        x_start = maybe_clip(x_start) # 将输入张量替换到 [-1., 1.] 范围内的函数
        # x - x_start
        pred_noise = self.predict_noise_from_start(x, t, x_start)

        return pred_noise, x_start

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值