论文:Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation
1. 训练
1.1 A2KP Training
training A2KP transformer with latent and pca loss:pretrain_a2kp.py
if __name__ == "__main__":
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
# 调用parser.parse_args()来解析命令行参数,并将结果存储在opt变量中。
parser = ArgumentParser()
parser.add_argument("--config", default="config/vox-transformer.yaml", help="path to config")
parser.add_argument("--mode", default="train", choices=["train",])
parser.add_argument("--gen", default="spade", choices=["original", "spade"])
parser.add_argument("--log_dir", default='./output/', help="path to log into")
parser.add_argument("--checkpoint", default='./00000189-checkpoint.pth.tar', help="path to checkpoint to restore")
#parser.add_argument("--device_ids", default="0, 1, 2, 3, 4, 5, 6, 7", type=lambda x: list(map(int, x.split(','))),
parser.add_argument("--device_ids", default="0, 1", type=lambda x: list(map(int, x.split(','))),
help="Names of the devices comma separated.")
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
parser.set_defaults(verbose=False)
opt = parser.parse_args()
# 打开配置文件,并使用yaml库加载配置文件中的内容,并将结果存储在config变量中。
with open(opt.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# log dir when checkpoint is set
# if opt.checkpoint is not None:
# log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
# else:
# 根据配置文件的路径和当前时间生成一个日志目录。它使用了os.path模块来操作路径,并使用strftime函数生成日期和时间字符串。
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
# 根据选择的opt.gen参数创建不同类型的生成器模型对象。根据配置文件中的参数,调用相应的生成器类进行初始化。
if opt.gen == 'original':
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
elif opt.gen == 'spade':
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
# 检查CUDA是否可用,并将生成器模型移动到指定的设备上(如果可用)。如果设置了verbose标志,则打印生成器模型的结构。
if torch.cuda.is_available():
print('cuda is available')
generator.to(opt.device_ids[0])
if opt.verbose:
print(generator)
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
discriminator.to(opt.device_ids[0])
if opt.verbose:
print(discriminator)
# 创建关键点检测器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的关键点检测器类进行初始化。
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
kp_detector.to(opt.device_ids[0])
if opt.verbose:
print(kp_detector)
# 创建音频到关键点转换器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的音频到关键点转换器类进行初始化。
audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'])
if torch.cuda.is_available():
audio2kptransformer.to(opt.device_ids[0])
# 创建数据集对象。根据配置文件中的参数,调用相应的数据集类进行初始化。
dataset = FramesWavsDatasetMEL25(is_train=(opt.mode == 'train'), **config['dataset_params'])
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
copy(opt.config, log_dir)
# 根据opt.mode参数的值决定进行训练或其他操作。如果设置为"train",则调用train函数进行模型训练,传递所需的参数。
if opt.mode == 'train':
print("Training...")
train(config, generator, discriminator, kp_detector, audio2kptransformer, opt.checkpoint, log_dir, dataset, opt.device_ids)
training A2KP transformer with all loss :pretrain_a2kp_img.py
大致同上
frames_dataset_transformer25.py:
有三个主要的类:
1.class FramesWavsDatasetMEL25VoxBoxQG2(Dataset):处理包含视频和音频数据的数据集。
-
__getitem__
方法:根据索引获取数据集中的一个样本。- 如果是训练集且满足概率条件,则获取无情感的样本;否则获取中性样本。
- 对于无情感样本,根据视频路径和索引获取相关数据,包括音频、Mel频谱、姿势等信息,并返回一个包含这些信息的字典。
- 对于中性样本,根据视频路径和索引获取相关数据,包括音频、Mel频谱、姿势等信息,并返回一个包含这些信息的字典。
-
其他辅助函数:
get_frame_id
:根据帧路径获取帧的索引。get_window
:根据起始帧、头部姿势和姿势图像获取窗口帧、头部姿势和姿势图像。crop_audio_window
:对音频、姿势和深度特征进行裁剪。_load_tensor
:加载音频数据并转换为张量。getitem_neu
:获取训练集中的中性样本。getitem_vox_woemo
:获取训练集中的无情感样本。
2.class FramesWavsDatasetMEL25VoxBoxQG2ImgAll(Dataset):
ImgAll表示数据集中包含所有的图像样本。
3.class FramesWavsDatasetMEL25VoxBoxQG2ImgPrompt(Dataset):"ImgPrompt" 表示数据集中的样本是根据图像提示或问题进行选择的。
getitem_vox_emo
:获取训练集中具有情感标签的样本。getitem_vox_woemo
:获取训练集中无情感标签的样本。
def getitem_neu(self, idx): 在训练集中获取中性样本数据,并返回一个包含各种特征和路径的字典,以供后续处理和训练使用:
def getitem_neu(self, idx): # 获取训练集中的中性样本
while 1:
idx = idx%len(self.videos) # 对索引进行取余操作,以循环访问视频列表
name = self.videos[idx]
path = os.path.join(self.root_dir, name) # 构建视频路径
video_name = os.path.basename(path) # 获取视频文件名
vsp = video_name.split('_') # 将视频文件名按下划线分割成列表
out = {}
deep_path = f'{mead_path}/deepfeature32/{video_name}.npy' # 构建深度特征路径
deeps = np.load(deep_path) # 加载深度特征数据
wave_path = f'{mead_path}/wav_16000/{video_name}.wav' # 构建音频路径
out['wave_path'] = wave_path # 将音频路径存储到输出字典中
wave_tensor = self._load_tensor(wave_path) # 加载音频数据并转换为张量
if len(wave_tensor.shape) > 1: # 如果音频张量的形状大于1维
wave_tensor = wave_tensor[:, 0] # 只保留第一个通道的数据
mel_tensor = self.to_melspec(wave_tensor) # 将音频张量转换为梅尔频谱图
mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mean) / self.std # 对梅尔频谱图进行归一化处理
lable_index = self.emo_label.index(vsp[1]) # 获取情绪标签对应的索引值
# print(out['drivinglmk'].shape)
# out['y_trg'] = self.emo_label.index(vsp[1])
# z_trg = torch.randn(self.latent_dim)
# out['z_trg'] = z_trg
# select gt frames
frames = os.listdir(path) # 获取视频的文件列表。path:视频路径
num_frames = len(frames) # 获取视频帧的数量
num_frames = min(num_frames, len(deeps), mel_tensor.shape[1]) # 取视频帧、深度特征和梅尔频谱图中最小的数量
# 在可选择的帧数不满足要求时,通过递增索引并循环访问视频列表来获取更多的帧。
if num_frames - self.syncnet_T + 1 <= 0: # 如果可选择的帧数不足以满足要求
# print(num_frames)
idx += 1
idx = idx%len(self.videos) # 对索引取余,以循环访问视频列表
continue
# 可选择的帧数满足要求时
frame_idx = np.random.choice(num_frames- self.syncnet_T+1, replace=True, size=1)[0]+1 # 随机选择一个起始帧的索引
choose = join(path, '{:04}.jpg'.format(frame_idx)) # 构建所选帧的完整路径
# driving latent with pretrained
driving_latent = np.load(path.replace('images', 'latent')+'.npy', allow_pickle=True)
he_driving = driving_latent[1] # 获取driving_latent的第二个元素
### poseimg after AntiAliasInterpolation2d: num_frames, 1, 64, 64
fposeimg = gzip.GzipFile(f'{poseimg_path}/poseimg/{video_name}.npy.gz', "r") # 打开poseimg文件
poseimg = np.load(fposeimg)
try:
window_fnames, he_d, poses = self.get_window(choose, he_driving, poseimg)
except:
print(choose, path)
idx += 1
idx = idx%len(self.videos)
continue
out['he_driving'] = he_d
# neutral frames
video_name_neu = vsp[0]+'_neu_1_'+'*' # 构建中性样本视频名称模式
path_neus = path.replace(video_name, video_name_neu) # 构建中性样本视频路径模式
path_neu = random.choice(glob.glob(path_neus))
source_latent = np.load(path_neu.replace('images', 'latent')+'.npy', allow_pickle=True) # 加载中性样本的潜在特征数据。这里使用中性样本的图像路径生成对应的潜在特征路径,并加载该路径下的潜在特征数据。
num_frames_source = source_latent[1]['yaw'].shape[0] # 获取中性样本的帧数,这里以 'yaw' 特征的形状的第一个维度作为帧数。
source_index=np.random.choice(num_frames_source, replace=True, size=1)[0]+1 # 随机选择一个中性样本的帧索引,从 1 到中性样本的帧数。
video_array_source = img_as_float32(io.imread(join(path_neu, '{:04}.jpg'.format(source_index)))) # 加载中性样本的图像数据转换为浮点型数组。
# neutral source latent with pretrained
he_source = {} # 存储中性样本的潜在特征。
for k in source_latent[1].keys(): # 遍历中性样本的潜在特征的键(即特征类型)。
he_source[k] = torch.from_numpy(source_latent[1][k][source_index-1]) # 将中性样本的潜在特征转换为张量并存储到字典中
out['he_source'] = he_source # 将字典 he_source 存储到输出字典 out 的键 'he_source' 中,以保存中性样本的潜在特征。
out['source'] = video_array_source.transpose((2, 0, 1)) # 调整中性样本的图像数据维度顺序,并存储到字典out中
mel, poses_f, deep_frames = self.crop_audio_window(mel_tensor, poses, deeps, choose, num_frames) # 对音频、姿态和深度特征进行裁剪
out['mel'] = mel.unsqueeze(1) # 在梅尔频谱图上增加一个维度,并存储到输出字典中
out['pose'] = poses_f # 存储裁剪后的姿态数据到输出字典中
out['name'] = video_name
out['deep'] = deep_frames
return out
1.2 Emotional Adaptation Training
prompt_st_dp_eam3d.py:
大致同上
2. 数据处理
根据第一帧中检测到的人脸,排除了一些模糊或人脸太小的视频。
视频预处理:preprocess_video.py
from glob import glob
import os
# 使用glob模块获取指定目录下所有的.mp4文件路径
allmp4s = glob('./video/*.mp4')
# 设置目标文件夹路径,并确保该文件夹存在
path_fps25='./video_fps25'
os.makedirs(path_fps25, exist_ok=True)
# 遍历每个.mp4文件
for mp4 in allmp4s:
# 获取文件名(不带路径)
name = os.path.basename(mp4)
# 使用ffmpeg命令将视频转换为25帧每秒的视频,并设置音频参数
os.system(f'ffmpeg -y -i {mp4} -filter:v fps=25 -ac 1 -ar 16000 -crf 10 {path_fps25}/{name}')
# 使用ffmpeg命令将上一步生成的视频转换为.wav格式的音频文件
os.system(f'ffmpeg -y -i {path_fps25}/{name} {path_fps25}/{name[:-4]}.wav')
#============== extract lmk for crop =================
# 提取关键点信息用于裁剪
print('============== extract lmk for crop =================')
os.system(f'python extract_lmks_eat.py {path_fps25}')
#======= extract speech in deepspeech_features =======
# 提取语音特征
print('======= extract speech in deepspeech_features =======')
os.chdir('./deepspeech_features/')
os.system(f'python extract_ds_features.py --input=../{path_fps25}')
os.chdir('../')
os.system('python deepfeature32.py')
#=================== crop videos =====================
# 裁剪视频
print('=================== crop videos =====================')
os.chdir('./vid2vid/')
os.system('python data_preprocess.py --dataset_mode preprocess_eat')
os.chdir('../')
#========== extract latent from cropped videos =======
#从裁剪的视频中提取潜在特征
print('========== extract latent from cropped videos =======')
os.system('python videos2img.py')
os.system('python latent_extractor.py')
#=========== extract poseimg from latent =============
# 从潜在特征中提取姿势图像
print('=========== extract poseimg from latent =============')
os.system('python generate_poseimg.py')
之后,Extract the bbox for training:preprocess/extract_bbox.py
fa = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda')
#初始化了一个人脸对齐(Face Alignment)模型对象,CUDA加速。
def detect_bbox(img_names):
bboxs = []
for img_name in img_names:
img = img_as_float32(io.imread(img_name)).transpose((2, 0, 1))
img = np.transpose(img[np.newaxis], (0,2,3,1))[...,::-1]
bbox = fa.get_detections_for_batch(img*255)
if bbox is not None:
bboxs.append(bbox[0])
else:
bboxs.append(None)
assert(len(bboxs)==len(img_names))
return bboxs
这个函数用于检测一组图像中的人脸边界框。它接受一个图像文件名列表作为输入,并返回相应的人脸边界框列表。该函数首先加载图像文件,然后将其转换为指定的格式,并调用人脸对齐模型的get_detections_for_batch
方法来获取人脸边界框。如果检测到了人脸边界框,则将其添加到bboxs列表中;否则,将None添加到列表中。
def main(args):
file_images = glob('/data2/gy/lrw/lrw_images/*')
file_images.sort()
p = args.part
t = len(file_images)
for fi in tqdm(file_images[t*p:t*(p+1)]):
out = basename(fi)
outpath =f'/data2/gy/lrw/lrw_bbox/{out}.npy'
if exists(outpath):
continue
images = glob(fi+'/*.jpg')
images.sort()
bboxs = detect_bbox(images)
np.save(outpath, bboxs)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--files", default="*", help="filenames")
parser.add_argument("--part", default="0", type=int, help="part")
args = parser.parse_args()
main(args)
利用人脸对齐模型进行人脸边界框检测,并将结果保存到.npy文件中。
3. 模型结构
augmentation.py
crop_clip
函数用于裁剪视频片段中的帧。根据输入参数 min_h
、min_w
、h
和 w
,分别表示起始高度、起始宽度、裁剪后的高度和宽度。根据 clip[0]
的类型进行判断,如果是 np.ndarray
类型,则使用切片操作裁剪每一帧;如果是 PIL.Image.Image
类型,则使用 crop()
方法裁剪每一帧;否则抛出类型错误异常。
def crop_clip(clip, min_h, min_w, h, w):
if isinstance(clip[0], np.ndarray):
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
cropped = [
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return cropped
pad_clip
函数用于填充视频片段中的帧。根据输入参数 h
和 w
,分别表示期望的高度和宽度。通过获取第一帧的形状信息 im_h
和 im_w
,根据目标尺寸与原始尺寸的比较,计算需要填充的上下左右边界大小,并使用 pad()
方法进行填充。
def pad_clip(clip, h, w):
im_h, im_w = clip[0].shape[:2]
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
resize_clip
函数用于调整视频片段中的帧大小。根据输入参数 size
和 interpolation
,分别表示期望的尺寸和插值方法。根据 clip[0]
的类型进行判断,如果是 np.ndarray
类型,则使用 resize()
方法对每一帧进行调整;如果是 PIL.Image.Image
类型,则使用 resize()
方法对每一帧进行调整;否则抛出类型错误异常。
def resize_clip(clip, size, interpolation='bilinear'):
if isinstance(clip[0], np.ndarray):
if isinstance(size, numbers.Number):
im_h, im_w, im_c = clip[0].shape
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
scaled = [
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
mode='constant', anti_aliasing=True) for img in clip
]
elif isinstance(clip[0], PIL.Image.Image):
if isinstance(size, numbers.Number):
im_w, im_h = clip[0].size
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
if interpolation == 'bilinear':
pil_inter = PIL.Image.NEAREST
else:
pil_inter = PIL.Image.BILINEAR
scaled = [img.resize(size, pil_inter) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return scaled
get_resize_sizes
函数用于根据原始帧的高度和宽度以及期望的尺寸,计算调整后的新尺寸。根据原始帧的长宽比例,通过比较大小来确定应该调整的维度,并计算调整后的新尺寸。
def get_resize_sizes(im_h, im_w, size):
if im_w < im_h:
ow = size
oh = int(size * im_h / im_w)
else:
oh = size
ow = int(size * im_w / im_h)
return oh, ow
class RandomFlip(object):
def __init__(self, time_flip=False, horizontal_flip=False):
self.time_flip = time_flip
self.horizontal_flip = horizontal_flip
def __call__(self, clip):
"""
随机根据给定的参数水平翻转或在时间上翻转输入的视频片段。
参数:
clip (list):表示视频片段的图像或numpy数组的列表。
返回值:
list:翻转后的视频片段。
"""
if random.random() < 0.5 and self.time_flip:
return clip[::-1]
if random.random() < 0.5 and self.horizontal_flip:
return [np.fliplr(img) for img in clip]
return clip
class RandomResize(object):
"""将一组(H x W x C)大小的numpy.ndarray调整为最终大小
原始图像越大,插值的次数越多
参数:
interpolation (str):可以是'nearest'、'bilinear'中的一个,默认为nearest
size (tuple):(宽度, 高度)
"""
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
self.ratio = ratio
self.interpolation = interpolation
def __call__(self, clip):
"""
将输入的视频片段调整为新的大小。
参数:
clip (list):表示视频片段的图像或numpy数组的列表。
返回值:
list:调整大小后的视频片段。
"""
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
new_w = int(im_w * scaling_factor)
new_h = int(im_h * scaling_factor)
new_size = (new_w, new_h)
resized = resize_clip(
clip, new_size, interpolation=self.interpolation)
return resized
class RandomCrop(object):
"""从一组视频中提取相同位置的随机裁剪。
参数:
size (sequence or int):期望的裁剪输出尺寸,格式为(h, w)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
size = (size, size)
self.size = size
def __call__(self, clip):
"""
从输入的视频片段中提取随机裁剪。
参数:
clip (list):表示视频片段的图像或numpy数组的列表。
返回值:
list:裁剪后的视频片段。
"""
h, w = self.size
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
clip = pad_clip(clip, h, w)
im_h, im_w = clip.shape[1:3]
x1 = 0 if h == im_h else random.randint(0, im_w - w)
y1 = 0 if w == im_w else random.randint(0, im_h - h)
cropped = crop_clip(clip, y1, x1, h, w)
return cropped
class RandomRotation(object):
"""随机旋转整个视频片段,角度在给定范围内。
参数:
degrees (sequence or int):选择角度的范围
如果degrees是一个数字而不是(min, max)形式的序列,
角度的范围将为(-degrees, +degrees)。
"""
def __init__(self, degrees):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError('If degrees is a single number,'
'must be positive')
degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError('If degrees is a sequence,'
'it must be of len 2.')
self.degrees = degrees
def __call__(self, clip):
"""
将输入的视频片段按随机角度旋转。
参数:
clip (list):表示视频片段的图像或numpy数组的列表。
返回值:
list:旋转后的视频片段。
"""
angle = random.uniform(self.degrees[0], self.degrees[1])
if isinstance(clip[0], np.ndarray):
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
rotated = [img.rotate(angle) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return rotated