代码地址:https://github.com/albertpumarola/GANimation
该表情生成算法是基于Action Unit的,因此首先要有images和其AU标签。根据readme文件,要根据data目录下的prepare_au_annotations.py文件生成pkl文件,该文件包含了AU标签。
直接看训练过程,运行train.py文件,命令行输入参数data_dir,name,batch_size,接下来根据train.py进行分析。
if __name__ == "__main__":
Train()
首先实例化Train()类,该类定义在train.py文件中,构造器代码如下:
class Train:
def __init__(self):
self._opt = TrainOptions().parse()
data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)
self._dataset_train = data_loader_train.load_data()
self._dataset_test = data_loader_test.load_data()
self._dataset_train_size = len(data_loader_train)
self._dataset_test_size = len(data_loader_test)
print('#train images = %d' % self._dataset_train_size)
print('#test images = %d' % self._dataset_test_size)
self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
self._tb_visualizer = TBVisualizer(self._opt)
self._train()
首先解析参数,然后分别实例化CustomDatasetDataLoader类的两个对象:data_loader_train和data_loader_test,类CustomDatasetDataLoader定义在文件custom_dataset_data_loader.py中,代码如下:
class CustomDatasetDataLoader:
def __init__(self, opt, is_for_train=True):
self._opt = opt
self._is_for_train = is_for_train
self._num_threds = opt.n_threads_train if is_for_train else opt.n_threads_test
self._create_dataset()
def _create_dataset(self):
# default='aus'
self._dataset = DatasetFactory.get_by_name(self._opt.dataset_mode, self._opt, self._is_for_train)
# 按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入
self._dataloader = torch.utils.data.DataLoader(
self._dataset,
batch_size=self._opt.batch_size,
shuffle=not self._opt.serial_batches,
num_workers=int(self._num_threds),
drop_last=True)
def load_data(self):
return self._dataloader
def __len__(self):
return len(self._dataset)
根据传入参数is_for_train分别加载训练集和测试集。构造函数最后调用了方法_create_dataset(),该方法通过get_by_name()获取数据然后封装为Tensor,get_by_name()的返回值是类AusDataset的实例,类AusDataset构造函数如下:
class AusDataset(DatasetBase):
def __init__(self, opt, is_for_train):
super(AusDataset, self).__init__(opt, is_for_train)
self._name = 'AusDataset'
# read dataset
self._read_dataset_paths()
最后调用了函数_read_dataset_paths():
def _read_dataset_paths(self):
self._root = self._opt.data_dir
self._imgs_dir = os.path.join(self._root, self._opt.images_folder)
# read ids
use_ids_filename = self._opt.train_ids_file if self._is_for_train else self._opt.test_ids_file
use_ids_filepath = os.path.join(self._root, use_ids_filename)
self._ids = self._read_ids(use_ids_filepath)
# read aus
conds_filepath = os.path.join(self._root, self._opt.aus_file)
self._conds = self._read_conds(conds_filepath)
print('self cond:', set(self._conds))
# print('self ids:', set(self._ids))
self._ids = list(set(self._ids).intersection(set(self._conds.keys())))
# print('self ids:', self._ids)
# dataset size
self._dataset_size = len(self._ids)
# print('dataset size:', self._dataset_size)
该函数分别从train_ids.csv,test_ids.csv和aus_openface.pkl文件中读取图像名称/id和AU标签信息。其中图像id的结果self._ids是一个列表,self._conds是一个dict,每一个图像的id作为key,其AU标签作为value,格式为numpy的array。
回到train.py文件,现在通过如下代码获得Tensor形式的数据集,已经按照batch_size划分好:
self._dataset_train = data_loader_train.load_data()
self._dataset_test = data_loader_test.load_data()