pytorch原版github地址:https://github.com/yunjey/StarGAN
tensorflow版github地址:https://github.com/taki0112/StarGAN-Tensorflow
两个版本实现相差不大,以pytorch版来介绍。
以celebA数据为例,下载后的数据包括label文件,和图像.
文件的第一行为图像的总数,为202599.
第二行为数据处理的类别,包括40种,
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
第三行及之后的每行为,图像名,已经对应的40种类别的label,label值为1或-1
000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
1.数据预处理
1)准备数据
all_attr_names表示全部40种任务类别集合
self.selected_attrs表示我们训练选用的任务类别集合,默认的是[‘Black_Hair’, ‘Blond_Hair’, ‘Brown_Hair’, ‘Male’, ‘Young’]
def preprocess(self):
"""Preprocess the CelebA attribute file."""
lines = [line.rstrip() for line in open(self.attr_path, 'r')]
all_attr_names = lines[1].split()
for i, attr_name in enumerate(all_attr_names):
self.attr2idx[attr_name] = i
self.idx2attr[i] = attr_name
lines = lines[2:]
random.seed(1234)
random.shuffle(lines)#打乱图片
for i, line in enumerate(lines):
split = line.split()
filename = split[0]#图片名
values = split[1:]#图片对应的标签
label = []
for attr_name in self.selected_attrs:#创建训练选用的任务类别和索引的一一对应关系
idx = self.attr2idx[attr_name]
label.append(values[idx] == '1')#label如果是1则还是为1,为-1是换成0
if (i+1) < 2000:#取2000张做测试集数据
self.test_dataset.append([filename, label])
else:
self.train_dataset.append([filename, label])
print('Finished preprocessing the CelebA dataset...')
2) 创建一个data loader
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
batch_size=16, dataset='CelebA', mode='train', num_workers=1):
"""Build and return a data loader."""
transform = []
if mode == 'train':
transform.append(T.RandomHorizontalFlip())#数据随机水平翻转
transform.append(T.CenterCrop(crop_size))#从中间裁剪
transform.append(T.Resize(image_size))#更改图片大小
transform.append(T.ToTensor())
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))#正则化
transform = T.Compose(transform)
if dataset == 'CelebA':
dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
elif dataset == 'RaFD':
dataset = ImageFolder(image_dir, transform)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=(mode=='train'),
num_workers=num_workers)
return data_loader
2.创建网络
def build_model(self):
"""Create a generator and a discriminator."