【代码学习】EAT复现+代码分析

论文:Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation

代码:yuangan/EAT_code: Official code for ICCV 2023 paper: "Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation". (github.com)

1. 训练

1.1 A2KP Training

training A2KP transformer with latent and pca loss:pretrain_a2kp.py

if __name__ == "__main__":
    
    if sys.version_info[0] < 3:
        raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")

    # 调用parser.parse_args()来解析命令行参数,并将结果存储在opt变量中。
    parser = ArgumentParser()
    parser.add_argument("--config", default="config/vox-transformer.yaml", help="path to config")
    parser.add_argument("--mode", default="train", choices=["train",])
    parser.add_argument("--gen", default="spade", choices=["original", "spade"])
    parser.add_argument("--log_dir", default='./output/', help="path to log into")
    parser.add_argument("--checkpoint", default='./00000189-checkpoint.pth.tar', help="path to checkpoint to restore")
    #parser.add_argument("--device_ids", default="0, 1, 2, 3, 4, 5, 6, 7", type=lambda x: list(map(int, x.split(','))),
    parser.add_argument("--device_ids", default="0, 1", type=lambda x: list(map(int, x.split(','))),
                        help="Names of the devices comma separated.")
    parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
    parser.set_defaults(verbose=False)

    opt = parser.parse_args()
    
    # 打开配置文件,并使用yaml库加载配置文件中的内容,并将结果存储在config变量中。
    with open(opt.config) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    
    # log dir when checkpoint is set
    # if opt.checkpoint is not None:
    #     log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
    # else:
    # 根据配置文件的路径和当前时间生成一个日志目录。它使用了os.path模块来操作路径,并使用strftime函数生成日期和时间字符串。
    log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
    log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
        
    # 根据选择的opt.gen参数创建不同类型的生成器模型对象。根据配置文件中的参数,调用相应的生成器类进行初始化。
    if opt.gen == 'original':
        generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                            **config['model_params']['common_params'])
    elif opt.gen == 'spade':
        generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
                                                 **config['model_params']['common_params'])
        
    # 检查CUDA是否可用,并将生成器模型移动到指定的设备上(如果可用)。如果设置了verbose标志,则打印生成器模型的结构。
    if torch.cuda.is_available():
        print('cuda is available')
        generator.to(opt.device_ids[0])
    if opt.verbose:
        print(generator)

    discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
                                            **config['model_params']['common_params'])
    if torch.cuda.is_available():
        discriminator.to(opt.device_ids[0])
    if opt.verbose:
        print(discriminator)

    # 创建关键点检测器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的关键点检测器类进行初始化。
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])

    if torch.cuda.is_available():
        kp_detector.to(opt.device_ids[0])

    if opt.verbose:
        print(kp_detector)

    # 创建音频到关键点转换器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的音频到关键点转换器类进行初始化。
    audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'])

    if torch.cuda.is_available():
        audio2kptransformer.to(opt.device_ids[0])

    # 创建数据集对象。根据配置文件中的参数,调用相应的数据集类进行初始化。
    dataset = FramesWavsDatasetMEL25(is_train=(opt.mode == 'train'), **config['dataset_params'])

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
        copy(opt.config, log_dir)

    # 根据opt.mode参数的值决定进行训练或其他操作。如果设置为"train",则调用train函数进行模型训练,传递所需的参数。
    if opt.mode == 'train':
        print("Training...")
        train(config, generator, discriminator, kp_detector, audio2kptransformer, opt.checkpoint, log_dir, dataset, opt.device_ids)

training A2KP transformer with all loss :pretrain_a2kp_img.py

大致同上

frames_dataset_transformer25.py:

有三个主要的类:

1.class FramesWavsDatasetMEL25VoxBoxQG2(Dataset):处理包含视频和音频数据的数据集。

  • __getitem__ 方法:根据索引获取数据集中的一个样本。

    • 如果是训练集且满足概率条件,则获取无情感的样本;否则获取中性样本。
    • 对于无情感样本,根据视频路径和索引获取相关数据,包括音频、Mel频谱、姿势等信息,并返回一个包含这些信息的字典。
    • 对于中性样本,根据视频路径和索引获取相关数据,包括音频、Mel频谱、姿势等信息,并返回一个包含这些信息的字典。
  • 其他辅助函数:

    • get_frame_id:根据帧路径获取帧的索引。
    • get_window:根据起始帧、头部姿势和姿势图像获取窗口帧、头部姿势和姿势图像。
    • crop_audio_window:对音频、姿势和深度特征进行裁剪。
    • _load_tensor:加载音频数据并转换为张量。
    • getitem_neu:获取训练集中的中性样本。
    • getitem_vox_woemo:获取训练集中的无情感样本。

2.class FramesWavsDatasetMEL25VoxBoxQG2ImgAll(Dataset):

        ImgAll表示数据集中包含所有的图像样本。

3.class FramesWavsDatasetMEL25VoxBoxQG2ImgPrompt(Dataset):"ImgPrompt" 表示数据集中的样本是根据图像提示或问题进行选择的。

  • getitem_vox_emo:获取训练集中具有情感标签的样本。
  • getitem_vox_woemo:获取训练集中无情感标签的样本。

  def getitem_neu(self, idx):  在训练集中获取中性样本数据,并返回一个包含各种特征和路径的字典,以供后续处理和训练使用:

    def getitem_neu(self, idx): # 获取训练集中的中性样本
        while 1:
            idx = idx%len(self.videos)  # 对索引进行取余操作,以循环访问视频列表
            name = self.videos[idx]
            path = os.path.join(self.root_dir, name)   # 构建视频路径

            video_name = os.path.basename(path)  # 获取视频文件名
            vsp = video_name.split('_')  # 将视频文件名按下划线分割成列表

            out = {}
            
            deep_path = f'{mead_path}/deepfeature32/{video_name}.npy'   # 构建深度特征路径
            deeps = np.load(deep_path)                                  # 加载深度特征数据

            wave_path = f'{mead_path}/wav_16000/{video_name}.wav'       # 构建音频路径
            out['wave_path'] = wave_path               # 将音频路径存储到输出字典中
            wave_tensor = self._load_tensor(wave_path) # 加载音频数据并转换为张量
            if len(wave_tensor.shape) > 1:             # 如果音频张量的形状大于1维
                wave_tensor = wave_tensor[:, 0]        # 只保留第一个通道的数据
            mel_tensor = self.to_melspec(wave_tensor)  # 将音频张量转换为梅尔频谱图
            mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mean) / self.std  # 对梅尔频谱图进行归一化处理

            lable_index = self.emo_label.index(vsp[1])   # 获取情绪标签对应的索引值

            # print(out['drivinglmk'].shape)


            # out['y_trg'] = self.emo_label.index(vsp[1])
            # z_trg = torch.randn(self.latent_dim)
            # out['z_trg'] = z_trg

            # select gt frames
            frames = os.listdir(path)     # 获取视频的文件列表。path:视频路径
            num_frames = len(frames)      # 获取视频帧的数量
            num_frames = min(num_frames, len(deeps), mel_tensor.shape[1])  # 取视频帧、深度特征和梅尔频谱图中最小的数量

            # 在可选择的帧数不满足要求时,通过递增索引并循环访问视频列表来获取更多的帧。
            if num_frames - self.syncnet_T + 1 <= 0:  # 如果可选择的帧数不足以满足要求
                # print(num_frames)
                idx += 1
                idx = idx%len(self.videos)         # 对索引取余,以循环访问视频列表
                continue
            
            # 可选择的帧数满足要求时
            frame_idx = np.random.choice(num_frames- self.syncnet_T+1, replace=True, size=1)[0]+1   # 随机选择一个起始帧的索引
            choose = join(path, '{:04}.jpg'.format(frame_idx))     # 构建所选帧的完整路径
            
            # driving latent with pretrained
            driving_latent = np.load(path.replace('images', 'latent')+'.npy', allow_pickle=True)
            he_driving = driving_latent[1]  # 获取driving_latent的第二个元素
            
            ### poseimg after AntiAliasInterpolation2d: num_frames, 1, 64, 64
            fposeimg = gzip.GzipFile(f'{poseimg_path}/poseimg/{video_name}.npy.gz', "r") # 打开poseimg文件
            poseimg = np.load(fposeimg) 


            try:
                window_fnames, he_d, poses = self.get_window(choose, he_driving, poseimg)
            except:
                print(choose, path)
                idx += 1
                idx = idx%len(self.videos)
                continue
            out['he_driving'] = he_d
            
            # neutral frames
            video_name_neu = vsp[0]+'_neu_1_'+'*'   # 构建中性样本视频名称模式
            path_neus = path.replace(video_name, video_name_neu)   # 构建中性样本视频路径模式
            path_neu = random.choice(glob.glob(path_neus))
            source_latent = np.load(path_neu.replace('images', 'latent')+'.npy', allow_pickle=True) # 加载中性样本的潜在特征数据。这里使用中性样本的图像路径生成对应的潜在特征路径,并加载该路径下的潜在特征数据。
            num_frames_source = source_latent[1]['yaw'].shape[0]  # 获取中性样本的帧数,这里以 'yaw' 特征的形状的第一个维度作为帧数。
            source_index=np.random.choice(num_frames_source, replace=True, size=1)[0]+1   # 随机选择一个中性样本的帧索引,从 1 到中性样本的帧数。
            video_array_source = img_as_float32(io.imread(join(path_neu, '{:04}.jpg'.format(source_index)))) #  加载中性样本的图像数据转换为浮点型数组。
            
            # neutral source latent with pretrained
            he_source = {}                      # 存储中性样本的潜在特征。
            for k in source_latent[1].keys():   #  遍历中性样本的潜在特征的键(即特征类型)。
                he_source[k] = torch.from_numpy(source_latent[1][k][source_index-1])    # 将中性样本的潜在特征转换为张量并存储到字典中
            out['he_source'] = he_source        # 将字典 he_source 存储到输出字典 out 的键 'he_source' 中,以保存中性样本的潜在特征。

            out['source'] = video_array_source.transpose((2, 0, 1))    # 调整中性样本的图像数据维度顺序,并存储到字典out中




            mel, poses_f, deep_frames = self.crop_audio_window(mel_tensor, poses, deeps, choose, num_frames)    # 对音频、姿态和深度特征进行裁剪
            out['mel'] = mel.unsqueeze(1)    # 在梅尔频谱图上增加一个维度,并存储到输出字典中
            out['pose'] = poses_f            # 存储裁剪后的姿态数据到输出字典中
            out['name'] = video_name
            out['deep'] = deep_frames

            return out

1.2 Emotional Adaptation Training

prompt_st_dp_eam3d.py:

大致同上

2. 数据处理

根据第一帧中检测到的人脸,排除了一些模糊或人脸太小的视频。

视频预处理:preprocess_video.py

from glob import glob
import os

# 使用glob模块获取指定目录下所有的.mp4文件路径
allmp4s = glob('./video/*.mp4')

# 设置目标文件夹路径,并确保该文件夹存在
path_fps25='./video_fps25'
os.makedirs(path_fps25, exist_ok=True)

# 遍历每个.mp4文件
for mp4 in allmp4s:
    # 获取文件名(不带路径)
    name = os.path.basename(mp4)
    
    # 使用ffmpeg命令将视频转换为25帧每秒的视频,并设置音频参数
    os.system(f'ffmpeg -y -i {mp4} -filter:v fps=25 -ac 1 -ar 16000 -crf 10 {path_fps25}/{name}')
    
    # 使用ffmpeg命令将上一步生成的视频转换为.wav格式的音频文件
    os.system(f'ffmpeg -y -i {path_fps25}/{name} {path_fps25}/{name[:-4]}.wav')

#============== extract lmk for crop =================
# 提取关键点信息用于裁剪
print('============== extract lmk for crop =================')
os.system(f'python extract_lmks_eat.py {path_fps25}')

#======= extract speech in deepspeech_features =======
# 提取语音特征
print('======= extract speech in deepspeech_features =======')
os.chdir('./deepspeech_features/')
os.system(f'python extract_ds_features.py --input=../{path_fps25}')
os.chdir('../')
os.system('python deepfeature32.py')

#=================== crop videos =====================
# 裁剪视频
print('=================== crop videos =====================')
os.chdir('./vid2vid/')
os.system('python data_preprocess.py --dataset_mode preprocess_eat')
os.chdir('../')

#========== extract latent from cropped videos =======
#从裁剪的视频中提取潜在特征
print('========== extract latent from cropped videos =======')
os.system('python videos2img.py')
os.system('python latent_extractor.py')

#=========== extract poseimg from latent =============
# 从潜在特征中提取姿势图像
print('=========== extract poseimg from latent =============')
os.system('python generate_poseimg.py')

之后,Extract the bbox for training:preprocess/extract_bbox.py

fa = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda')
#初始化了一个人脸对齐(Face Alignment)模型对象,CUDA加速。

def detect_bbox(img_names):
    bboxs = []
    for img_name in img_names:
        img = img_as_float32(io.imread(img_name)).transpose((2, 0, 1))
        img = np.transpose(img[np.newaxis], (0,2,3,1))[...,::-1]
        bbox = fa.get_detections_for_batch(img*255)
        if bbox is not None:
            bboxs.append(bbox[0])
        else:
            bboxs.append(None)
    assert(len(bboxs)==len(img_names))
    return bboxs

这个函数用于检测一组图像中的人脸边界框。它接受一个图像文件名列表作为输入,并返回相应的人脸边界框列表。该函数首先加载图像文件,然后将其转换为指定的格式,并调用人脸对齐模型的get_detections_for_batch方法来获取人脸边界框。如果检测到了人脸边界框,则将其添加到bboxs列表中;否则,将None添加到列表中。

def main(args):
    file_images = glob('/data2/gy/lrw/lrw_images/*')
    file_images.sort()
    p = args.part
    t = len(file_images)
    for fi in tqdm(file_images[t*p:t*(p+1)]):
        out = basename(fi)
        outpath =f'/data2/gy/lrw/lrw_bbox/{out}.npy'
        if exists(outpath):
            continue
        images = glob(fi+'/*.jpg')
        images.sort()
        bboxs = detect_bbox(images)
        np.save(outpath, bboxs)

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--files", default="*", help="filenames")
    parser.add_argument("--part", default="0", type=int, help="part")
    args = parser.parse_args()
    main(args)

利用人脸对齐模型进行人脸边界框检测,并将结果保存到.npy文件中。

3. 模型结构

augmentation.py

crop_clip 函数用于裁剪视频片段中的帧。根据输入参数 min_hmin_whw,分别表示起始高度、起始宽度、裁剪后的高度和宽度。根据 clip[0] 的类型进行判断,如果是 np.ndarray 类型,则使用切片操作裁剪每一帧;如果是 PIL.Image.Image 类型,则使用 crop() 方法裁剪每一帧;否则抛出类型错误异常。

def crop_clip(clip, min_h, min_w, h, w):
    if isinstance(clip[0], np.ndarray):
        cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]

    elif isinstance(clip[0], PIL.Image.Image):
        cropped = [
            img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
            ]
    else:
        raise TypeError('Expected numpy.ndarray or PIL.Image' +
                        'but got list of {0}'.format(type(clip[0])))
    return cropped

pad_clip 函数用于填充视频片段中的帧。根据输入参数 hw,分别表示期望的高度和宽度。通过获取第一帧的形状信息 im_him_w,根据目标尺寸与原始尺寸的比较,计算需要填充的上下左右边界大小,并使用 pad() 方法进行填充。

def pad_clip(clip, h, w):
    im_h, im_w = clip[0].shape[:2]
    pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
    pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)

    return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')

resize_clip 函数用于调整视频片段中的帧大小。根据输入参数 sizeinterpolation,分别表示期望的尺寸和插值方法。根据 clip[0] 的类型进行判断,如果是 np.ndarray 类型,则使用 resize() 方法对每一帧进行调整;如果是 PIL.Image.Image 类型,则使用 resize() 方法对每一帧进行调整;否则抛出类型错误异常。 

def resize_clip(clip, size, interpolation='bilinear'):
    if isinstance(clip[0], np.ndarray):
        if isinstance(size, numbers.Number):
            im_h, im_w, im_c = clip[0].shape
            # Min spatial dim already matches minimal size
            if (im_w <= im_h and im_w == size) or (im_h <= im_w
                                                   and im_h == size):
                return clip
            new_h, new_w = get_resize_sizes(im_h, im_w, size)
            size = (new_w, new_h)
        else:
            size = size[1], size[0]

        scaled = [
            resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
                   mode='constant', anti_aliasing=True) for img in clip
            ]
    elif isinstance(clip[0], PIL.Image.Image):
        if isinstance(size, numbers.Number):
            im_w, im_h = clip[0].size
            # Min spatial dim already matches minimal size
            if (im_w <= im_h and im_w == size) or (im_h <= im_w
                                                   and im_h == size):
                return clip
            new_h, new_w = get_resize_sizes(im_h, im_w, size)
            size = (new_w, new_h)
        else:
            size = size[1], size[0]
        if interpolation == 'bilinear':
            pil_inter = PIL.Image.NEAREST
        else:
            pil_inter = PIL.Image.BILINEAR
        scaled = [img.resize(size, pil_inter) for img in clip]
    else:
        raise TypeError('Expected numpy.ndarray or PIL.Image' +
                        'but got list of {0}'.format(type(clip[0])))
    return scaled

get_resize_sizes 函数用于根据原始帧的高度和宽度以及期望的尺寸,计算调整后的新尺寸。根据原始帧的长宽比例,通过比较大小来确定应该调整的维度,并计算调整后的新尺寸。 

def get_resize_sizes(im_h, im_w, size):
    if im_w < im_h:
        ow = size
        oh = int(size * im_h / im_w)
    else:
        oh = size
        ow = int(size * im_w / im_h)
    return oh, ow

class RandomFlip(object):
    def __init__(self, time_flip=False, horizontal_flip=False):
        self.time_flip = time_flip
        self.horizontal_flip = horizontal_flip

    def __call__(self, clip):
        """
        随机根据给定的参数水平翻转或在时间上翻转输入的视频片段。

        参数:
            clip (list):表示视频片段的图像或numpy数组的列表。

        返回值:
            list:翻转后的视频片段。
        """

        if random.random() < 0.5 and self.time_flip:
            return clip[::-1]
        if random.random() < 0.5 and self.horizontal_flip:
            return [np.fliplr(img) for img in clip]

        return clip


class RandomResize(object):
    """将一组(H x W x C)大小的numpy.ndarray调整为最终大小
    原始图像越大,插值的次数越多
    参数:
        interpolation (str):可以是'nearest'、'bilinear'中的一个,默认为nearest
        size (tuple):(宽度, 高度)
    """

    def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
        self.ratio = ratio
        self.interpolation = interpolation

    def __call__(self, clip):
        """
        将输入的视频片段调整为新的大小。

        参数:
            clip (list):表示视频片段的图像或numpy数组的列表。

        返回值:
            list:调整大小后的视频片段。
        """
        scaling_factor = random.uniform(self.ratio[0], self.ratio[1])

        if isinstance(clip[0], np.ndarray):
            im_h, im_w, im_c = clip[0].shape
        elif isinstance(clip[0], PIL.Image.Image):
            im_w, im_h = clip[0].size

        new_w = int(im_w * scaling_factor)
        new_h = int(im_h * scaling_factor)
        new_size = (new_w, new_h)
        resized = resize_clip(
            clip, new_size, interpolation=self.interpolation)

        return resized


class RandomCrop(object):
    """从一组视频中提取相同位置的随机裁剪。
    参数:
        size (sequence or int):期望的裁剪输出尺寸,格式为(h, w)
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            size = (size, size)

        self.size = size

    def __call__(self, clip):
        """
        从输入的视频片段中提取随机裁剪。

        参数:
            clip (list):表示视频片段的图像或numpy数组的列表。

        返回值:
            list:裁剪后的视频片段。
        """
        h, w = self.size
        if isinstance(clip[0], np.ndarray):
            im_h, im_w, im_c = clip[0].shape
        elif isinstance(clip[0], PIL.Image.Image):
            im_w, im_h = clip[0].size
        else:
            raise TypeError('Expected numpy.ndarray or PIL.Image' +
                            'but got list of {0}'.format(type(clip[0])))

        clip = pad_clip(clip, h, w)
        im_h, im_w = clip.shape[1:3]
        x1 = 0 if h == im_h else random.randint(0, im_w - w)
        y1 = 0 if w == im_w else random.randint(0, im_h - h)
        cropped = crop_clip(clip, y1, x1, h, w)

        return cropped


class RandomRotation(object):
    """随机旋转整个视频片段,角度在给定范围内。
    参数:
        degrees (sequence or int):选择角度的范围
        如果degrees是一个数字而不是(min, max)形式的序列,
        角度的范围将为(-degrees, +degrees)。
    """

    def __init__(self, degrees):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError('If degrees is a single number,'
                                 'must be positive')
            degrees = (-degrees, degrees)
        else:
            if len(degrees) != 2:
                raise ValueError('If degrees is a sequence,'
                                 'it must be of len 2.')

        self.degrees = degrees

    def __call__(self, clip):
        """
        将输入的视频片段按随机角度旋转。

        参数:
            clip (list):表示视频片段的图像或numpy数组的列表。

        返回值:
            list:旋转后的视频片段。
        """
        angle = random.uniform(self.degrees[0], self.degrees[1])
        if isinstance(clip[0], np.ndarray):
            rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
        elif isinstance(clip[0], PIL.Image.Image):
            rotated = [img.rotate(angle) for img in clip]
        else:
            raise TypeError('Expected numpy.ndarray or PIL.Image' +
                            'but got list of {0}'.format(type(clip[0])))

        return rotated

  • 6
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值