人脸表情系列:代码阅读——GANimation: Anatomically-aware Facial Animation from a Single Image

本文详细解读了GANimation算法的代码实现,包括数据预处理、模型结构(生成器和判别器)、训练过程以及损失函数。通过分析train.py和相关模块,阐述了如何从单一图像生成具有特定表情的面部动画。
摘要由CSDN通过智能技术生成

代码地址:https://github.com/albertpumarola/GANimation

该表情生成算法是基于Action Unit的,因此首先要有images和其AU标签。根据readme文件,要根据data目录下的prepare_au_annotations.py文件生成pkl文件,该文件包含了AU标签。

直接看训练过程,运行train.py文件,命令行输入参数data_dir,name,batch_size,接下来根据train.py进行分析。

if __name__ == "__main__":
    Train()

首先实例化Train()类,该类定义在train.py文件中,构造器代码如下:

class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()

首先解析参数,然后分别实例化CustomDatasetDataLoader类的两个对象:data_loader_traindata_loader_test,类CustomDatasetDataLoader定义在文件custom_dataset_data_loader.py中,代码如下:

class CustomDatasetDataLoader:
    def __init__(self, opt, is_for_train=True):
        self._opt = opt
        self._is_for_train = is_for_train
        self._num_threds = opt.n_threads_train if is_for_train else opt.n_threads_test
        self._create_dataset()

    def _create_dataset(self):
        # default='aus'
        self._dataset = DatasetFactory.get_by_name(self._opt.dataset_mode, self._opt, self._is_for_train)
        # 按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入
        self._dataloader = torch.utils.data.DataLoader(
            self._dataset,
            batch_size=self._opt.batch_size,
            shuffle=not self._opt.serial_batches,
            num_workers=int(self._num_threds),
            drop_last=True)

    def load_data(self):
        return self._dataloader

    def __len__(self):
        return len(self._dataset)

根据传入参数is_for_train分别加载训练集和测试集。构造函数最后调用了方法_create_dataset(),该方法通过get_by_name()获取数据然后封装为Tensor,get_by_name()的返回值是类AusDataset的实例,类AusDataset构造函数如下:

class AusDataset(DatasetBase):
    def __init__(self, opt, is_for_train):
        super(AusDataset, self).__init__(opt, is_for_train)
        self._name = 'AusDataset'

        # read dataset
        self._read_dataset_paths()

最后调用了函数_read_dataset_paths():

def _read_dataset_paths(self):
        self._root = self._opt.data_dir
        self._imgs_dir = os.path.join(self._root, self._opt.images_folder)

        # read ids
        use_ids_filename = self._opt.train_ids_file if self._is_for_train else self._opt.test_ids_file
        use_ids_filepath = os.path.join(self._root, use_ids_filename)
        self._ids = self._read_ids(use_ids_filepath)
        
        # read aus
        conds_filepath = os.path.join(self._root, self._opt.aus_file)
        self._conds = self._read_conds(conds_filepath)
        print('self cond:', set(self._conds))
        # print('self ids:', set(self._ids))


        self._ids = list(set(self._ids).intersection(set(self._conds.keys())))
        # print('self ids:', self._ids)

        # dataset size
        self._dataset_size = len(self._ids)
        # print('dataset size:', self._dataset_size)

该函数分别从train_ids.csv,test_ids.csv和aus_openface.pkl文件中读取图像名称/id和AU标签信息。其中图像id的结果self._ids是一个列表,self._conds是一个dict,每一个图像的id作为key,其AU标签作为value,格式为numpy的array。

回到train.py文件,现在通过如下代码获得Tensor形式的数据集,已经按照batch_size划分好:

self._dataset_train = data_loader_train.load_data()
self._dataset_test = data_loader_test.load_data()
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值