文章目录
git地址:https://github.com/TMElyralab/MuseTalk
一、模型理解
1.1 训练方式
类似于stable diffusion的流程,原始人脸与mask唇部图片先经过训练好的Encoder转化为隐向量concat一起,经过Unet预测每一步生成的噪声,语音信息【./models/whisper/tiny.pt生成】作为一个条件机制加入到训练过程中。
1.2 Loss
L
1
L_1
L1预测的隐向量和真实图片的隐向量的距离
L
2
L_2
L2原图和生成图之间的差异
二、测试代码流程
2.1 main函数
总流程:
- 输入视频/图片、音频
- 对于音频,调用Audio2Feature类进行音频编码、分块(第i帧图片对应音频块[i-2,i+2])
- 如果输入的是视频,分帧、提取提取人脸关键点、提取人脸边界框(手动mask人脸下半部分)
- 模型生成每一帧图片
- 裁剪图片与背景合成每一帧
- FFmpeg 生成视频
def main(args):
global pe
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
#{'task_0': {'video_path': './data/video/yongen.mp4', 'audio_path': './data/audio/yongen.wav'}, 'task_1': {'video_path': './data/video/sun.mp4', 'audio_path': './data/audio/sun.wav', 'bbox_shift': -7}}
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
#./results/yongen_yongen.mp4'
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
#如果输入是一个视频,切帧
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
#提取视频的每一帧,总帧数:帧率*时长s
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
#print(input_img_list)
############################################## extract audio feature ##############################################
#调用Audio2Feature类进行音频编码
whisper_feature = audio_processor.audio2feat(audio_path)
#对音频分块
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
#获取每一帧的关键点,提取人脸关键点的第24到第91个关键点和边界框
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
# 裁剪人脸位置
crop_frame = frame[y1:y2, x1:x2]
#使用 OpenCV 将裁剪后的图像帧 crop_frame 调整为 256x256 像素,使用 INTER_LANCZOS4 进行插值
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
#使用变分自编码器(VAE)的 get_latents_for_unet 方法获取裁剪图像帧的潜在表示 latents
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
#音频特征批量转换为 PyTorch 张量,并发送到指定设备上
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
#得到的潜在表示生成重建帧 pred_latents
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
#解码 帧
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
############################################## pad to full image ##############################################
print("pad talking image to original video")
for i, res_frame in enumerate(tqdm(res_frame_list)):
#获取当前结果帧的边界框 bbox 和对应的原始帧 ori_frame
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
try:
#使用 OpenCV 将生成的结果帧 res_frame 调整大小,确保与原始帧的边界框大小相同
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
# print(bbox)
continue
#调用 get_image 函数将原始帧和生成的结果帧结合起来,以获取最终的合成帧
combine_frame = get_image(ori_frame,res_frame,bbox)
#将合成帧写入临时图片文件夹中,文件名使用八位数进行编号
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
#使用 FFmpeg 将临时图片文件夹中的合成帧转换为临时视频文件 temp.mp4
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video)
#使用 FFmpeg 将原始音频文件和临时视频文件合成为最终的输出视频文件 output_vid_name
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
os.remove("temp.mp4")
shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
三、语音信息的提取Whisper:
3.1 Audio2Feature获取音频信息过程
使用Whisper提取音频信息,总流程如下,后面会介绍每个函数快作用
class Audio2Feature():
def __init__(self,
whisper_model_type="tiny",
model_path="./models/whisper/tiny.pt"):
self.whisper_model_type = whisper_model_type
self.model = load_model(model_path) #
def get_sliced_feature(self,
feature_array,
vid_idx,
audio_feat_length=[2,2],
fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
center_idx = int(vid_idx*50/fps)
left_idx = center_idx-audio_feat_length[0]*2
right_idx = center_idx + (audio_feat_length[1]+1)*2
for idx in range(left_idx,right_idx):
idx = max(0, idx)
idx = min(length-1, idx)
x = feature_array[idx]
selected_feature.append(x)
selected_idx.append(idx)
selected_feature = np.concatenate(selected_feature, axis=0)
selected_feature = selected_feature.reshape(-1, 384)# 50*384
return selected_feature,selected_idx
def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
for dt in range(-audio_feat_length[0],audio_feat_length[1]+1):
left_idx = int((vid_idx+dt)*50/fps)
if left_idx<1 or left_idx>length-1:
left_idx = max(0, left_idx)
left_idx = min(length-1, left_idx)
x = feature_array[left_idx]
x = x[np.newaxis,:,:]
x = np.repeat(x, 2, axis=0)
selected_feature.append(x)
selected_idx.append(left_idx)
selected_idx.append(left_idx)
else:
x = feature_array[left_idx-1:left_idx+1]
selected_feature.append(x)
selected_idx.append(left_idx-1)
selected_idx.append(left_idx)
selected_feature = np.concatenate(selected_feature, axis=0)
selected_feature = selected_feature.reshape(-1, 384)# 50*384
return selected_feature,selected_idx
def feature2chunks(self,feature_array,fps,audio_feat_length = [2,2]):
whisper_chunks = []
whisper_idx_multiplier = 50./fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while 1:
start_idx = int(i * whisper_idx_multiplier)
selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i,audio_feat_length=audio_feat_length,fps=fps)
#print(f"i:{i},selected_idx {selected_idx}")
whisper_chunks.append(selected_feature)
i += 1
if start_idx>len(feature_array):
break
return whisper_chunks
def audio2feat(self,audio_path):
# get the sample rate of the audio
result = self.model.transcribe(audio_path)
embed_list = []
for emb in result['segments']:
encoder_embeddings = emb['encoder_embeddings']
encoder_embeddings = encoder_embeddings.transpose(0,2,1,3)
encoder_embeddings = encoder_embeddings.squeeze(0)
start_idx = int(emb['start'])
end_idx = int(emb['end'])
emb_end_idx = int((end_idx - start_idx)/2)
embed_list.append(encoder_embeddings[:emb_end_idx])
concatenated_array = np.concatenate(embed_list, axis=0)
return concatenated_array
3.2 音频信息首先转化为对数Mel频谱图(log-Mel spectrogram)
3.2.1 使用FFmpeg将音频文件转换为16kHz采样率的单声道音频,然后将其转换为NumPy数组.
def load_audio(file: str, sr: int = SAMPLE_RATE):
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
#"-" 表示输出到标准输出(stdout),format="s16le" 指定输出格式为 16 位小端序 PCM
#acodec="pcm_s16le" 指定音频编解码器为 PCM 16 位小端序,ac=1 指定音频通道数为 1(单声道),ar=sr 指定音频采样率
。
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
#np.int16 类型表示 16 位有符号整数,#PCM 16 位音频使用有符号 16 位整数来表示音频样本值,其取值范围是 [-32768, 32767]
#np.frombuffer(out, np.int16) 将字节缓冲区转换为 16 位整数的 NumPy 数组
#flatten() 将数组展平成一维数组,astype(np.float32) 将数组转换为 float32 类型
#除以 32768.0 将整数值归一化到浮点数范围(-1.0 到 1.0)
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
3.2.2 计算log-Mel频谱图
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
#创建一个Hann窗口,长度为 N_FFT,400
window = torch.hann_window(N_FFT).to(audio.device)
#计算音频的短时傅里叶变换(STFT),N_FFT【400】是FFT窗口的大小,HOP_LENGTH【160】是窗口之间的跳跃长度,window 是Hann窗口,return_complex=True 返回复数结果。
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
#计算STFT结果的幅度平方,忽略最后一个频率成分(对称性)
magnitudes = stft[:, :-1].abs() ** 2
#创建梅尔滤波器,数量为 n_mels【80】
filters = mel_filters(audio.device, n_mels)
#将幅度平方结果乘以梅尔滤波器,得到梅尔频谱图。
mel_spec = filters @ magnitudes
#对梅尔频谱图进行对数变换
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
3.2.3 提取audio特征
audio[128000] -> Mel [80,800] -> pad_or_trim填充[1, 80, 3000] ->audioFeature(encoder_embeddings) [1, 5, 1500, 384] -> whisper_feature[400,5,384]
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
#对音频 进行填充或裁剪,使其在指定的轴上达到指定的长度
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
3.3 音频分块
3.3.1 feature2chunks
- 根据视频帧率 fps 和假设的音频特征帧率(50 FPS)计算每帧视频对应的音频特征索引。
- 使用 get_sliced_feature 方法从特征数组中提取特定长度的特征块。
- 将提取的特征块存储到 whisper_chunks 列表中,直到处理完所有特征数据。
def feature2chunks(self,feature_array,fps,audio_feat_length = [2,2]):
whisper_chunks = []
#将视频帧索引转换为音频特征的索引。这里假设音频特征的帧率是 50 FPS
whisper_idx_multiplier = 50./fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while 1:
start_idx = int(i * whisper_idx_multiplier)
#从 feature_array 中提取一个特征块和对应的索引
selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i,audio_feat_length=audio_feat_length,fps=fps)
#print(f"i:{i},selected_idx {selected_idx}")
whisper_chunks.append(selected_feature)
i += 1
if start_idx>len(feature_array):
break
return whisper_chunks
3.3.2 get_sliced_feature
计算给定视频帧索引对应的音频特征中心索引。
根据指定的音频特征长度,从特征数组中提取一个窗口。
返回提取的特征块及其对应的索引。
def get_sliced_feature(self,
feature_array,
vid_idx,
audio_feat_length=[2,2],
fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
center_idx = int(vid_idx*50/fps)
#计算音频特征窗口的左右边界索引 left_idx 和 right_idx
left_idx = center_idx-audio_feat_length[0]*2
right_idx = center_idx + (audio_feat_length[1]+1)*2
#遍历从 left_idx 到 right_idx 的索引范围,确保索引在有效范围内,并提取对应的特征
for idx in range(left_idx,right_idx):
idx = max(0, idx)###确保不小于0
idx = min(length-1, idx)#确保索引不超过数组的最大索引
x = feature_array[idx]
selected_feature.append(x)
selected_idx.append(idx)
selected_feature = np.concatenate(selected_feature, axis=0)
#将拼接后的特征数组重塑为形状 (-1, 384),其中每个片段的大小为 384
selected_feature = selected_feature.reshape(-1, 384)# 50*384
return selected_feature,selected_idx
4、生成效果
因为是一张一张生成图片,然后合成视频的,可以对比每一帧效果
左真实,右合成效果