核心参考
【代码】Pretrained Anime StyleGAN2 — convert to pytorch and editing images by encoder
【理论】http://www.seeprettyface.com/research_notes.html#step6
本博客代码地址
0 概要
基于stylegan2-pytorch
动漫预训练模型,
结合标签和分类器,可以对生成动漫头像进行某些特征编辑
编辑图像:
其中,中间图片(第4张图
)为原图
向左生成闭嘴,向右生成张嘴
。
1. 下载项目相关
git clone https://github.com/viuts/stylegan2_pytorch.git
1.1 动漫头像(anime-face)模型下载
2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl
- google网盘
- [百度网盘 ]:https://pan.baidu.com/s/1yvrPQVGOlAgr8PczI-Z_Ww?pwd=0415 提取码:0415
–来自百度网盘超级会员V4的分享 - 其他下载链接
1.2 将tf模型转为pytorch
cd your_path/stylegan2_pytorch
python run_convert_from_tf.py --input=2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl --output checkpoint
1.3 推理生成图片
python run_generator.py generate_images --network=checkpoint/Gs.pth \
--seeds=1-5 --truncation_psi=1.0
-
生成图片保存在
stylegan2_pytorch/results
下
2. 获得动漫属性便签
web端示意效果
2.1 测试本地标签分类器推理
python edit_convert_cntk_2_onnx.py
- 代码见
附录
2.2 打标签
- 生成4*5000张图片,并逐个得到标签,
- 保存
dlatend
,labels
,tags
,- 其中
tags
表示输出labels序号的对应的中文名称
- 其中
- 耗时
- 2小时40分钟
3 训练需要编辑潜码方向
3.0 理论依据
如下图所示,以标签的中位数为分界线,低于该值的标签改为0,高于该值的标签改为1,然后构造w·x+b=y的目标函数,运用逻辑斯蒂回归求解此二分类问题,求出的w就可以近似为我们需要的方向向量
3.1 以编辑头发颜色为例
代码edit_find_latend_direction.py
见附录
- 训练结果
3.1.1 编辑结果1-粉色头发
- (中间是原图)3
- 左边是非粉色,右边是粉色插值
3.1.2 编辑结果2-黑色头发
3.1.3 编辑结果3-棕色色头发
3.2可能遇到问题
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
将tensor全部转到cpu或者gpu
one more thing
对其他预训练模型进行编辑(未开源
)
短发到长发
其他参考
附录
edit_convert_cntk_2_onnx.py
import cntk as C
import onnxruntime
import numpy as np
def convet():
model_path = 'checkpoint/danbooru-resnet_custom_v1-p3/model.cntk'
ctnk_model = C.load_model(model_path)
ctnk_model.save('checkpoint/model.onnx', format=C.ModelFormat.ONNX)
return ctnk_model
def test_cntk(ctnk_model):
ort_session = onnxruntime.InferenceSession("checkpoint/model.onnx")
x = np.random.rand(1, 3, 299, 299).astype(np.float32)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: x}
ort_outs = ort_session.run(None, ort_inputs)
# compute the ctnk output
cntk_out = ctnk_model.eval(x)
np.testing.assert_allclose(np.array(cntk_out), ort_outs[0], rtol=1e-03, atol=1e-05)
if __name__ == '__main__':
ctnk_model=convet()
test_cntk(ctnk_model)
edit_generate-labeled-anime-data.py
import os
import onnxruntime
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
input_path = 'checkpoint'
tags_path = os.path.join(input_path, 'tags.txt')
model_path = os.path.join(input_path, 'model.onnx')
generator_path = os.path.join(input_path, 'Gs.pth')
device = device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 4
seed = 0
# let's run one image to checkout if it works
C = onnxruntime.InferenceSession(model_path)
with open(tags_path, 'r') as tags_stream:
tags = np.array([tag for tag in (tag.strip() for tag in tags_stream) if tag])
import stylegan2
from stylegan2 import utils
G = stylegan2.models.load(generator_path, map_location=device)
G.to(device)
def to_image_tensor(image_tensor, pixel_min=-1, pixel_max=1):
if pixel_min != 0 or pixel_max != 1:
image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min)
return image_tensor.clamp(min=0, max=1)
torch.manual_seed(seed)
qlatents = torch.randn(1, G.latent_size).to(device=device, dtype=torch.float32)
generated = G(qlatents)
images = to_image_tensor(generated)
# 299 is the input size of the model
images = F.interpolate(images, size=(299, 299), mode='bilinear')
ort_inputs = {C.get_inputs()[0].name: images.detach().cpu().numpy()}
[predicted_labels] = C.run(None, ort_inputs)
# print out some tags
plt.imshow(images[0].detach().cpu().permute(1, 2, 0))
labels = [tags[i] for i, score in enumerate(predicted_labels[0]) if score > 0.5]
print(labels)
# reset seed
torch.manual_seed(seed)
iteration = 5000
progress = utils.ProgressWriter(iteration)
progress.write('Generating images...', step=False)
qlatents_data = torch.Tensor(0, G.latent_size).to(device=device, dtype=torch.float32)
dlatents_data = torch.Tensor(0, 16, G.latent_size).to(device=device, dtype=torch.float32)
labels_data = torch.Tensor(0, len(tags)).to(device=device, dtype=torch.float32)
for i in range(iteration):
qlatents = torch.randn(batch_size, G.latent_size).to(device=device, dtype=torch.float32)
with torch.no_grad():
generated, dlatents = G(latents=qlatents, return_dlatents=True)
# inplace to save memory
generated = to_image_tensor(generated)
# 299 is the input size of the model
# resize the image to 299 * 299
images = F.interpolate(generated, size=(299, 299), mode='bilinear')
labels = []
## tagger does not take input as batch, need to feed one by one
for image in images:
ort_inputs = {C.get_inputs()[0].name: image.reshape(1, 3, 299, 299).detach().cpu().numpy()}
[[predicted_labels]] = C.run(None, ort_inputs)
labels.append(predicted_labels)
# store the result
labels_tensor = torch.Tensor(labels).to(device=device, dtype=torch.float32)
qlatents_data = torch.cat((qlatents_data, qlatents))
dlatents_data = torch.cat((dlatents_data, dlatents))
labels_data = torch.cat((labels_data, labels_tensor))
progress.step()
progress.write('Done!', step=False)
progress.close()
torch.save({
'qlatents_data': qlatents_data.cpu(),
'dlatents_data': dlatents_data.cpu(),
'labels_data': labels_data.cpu(),
'tags': tags
}, 'latents.pth')
edit_find_latend_direction.py
import torch
import matplotlib.pylab as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import stylegan2
input_path = 'checkpoint'
latents_path = os.path.join(input_path, 'latents.pth')
generator_path = os.path.join(input_path, 'Gs.pth')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#根据资源进行设置
batch_size = 16
seed = 0
state = torch.load(latents_path, map_location=device)
qlatents_data = state['qlatents_data']
dlatents_data = state['dlatents_data']
labels_data = state['labels_data']
tags = state['tags']
G = stylegan2.models.load(generator_path).to(device)
dlatents_data=dlatents_data.to(device=device, dtype=torch.float32)
labels_data=labels_data.to(device=device, dtype=torch.float32)
print("dlatents_data.size()",dlatents_data.size())
print("labels_data.size()",labels_data.size())
zipped = list(zip(dlatents_data, labels_data))
train_size = int(0.7 * len(zipped))
valid_size = int(len(zipped) * 0.2)
test_size = len(zipped) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(zipped, [train_size, valid_size, test_size])
#参考代码num_workers=4会报错,可根据自己实际情况修改
datasets = dict(
train=torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0),
valid=torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=0),
test=test_dataset,
)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# kernel
self.main = nn.Sequential(
nn.Linear(in_features=16 * 512, out_features=1),
nn.LeakyReLU(0.2),
nn.Sigmoid(),
)
def forward(self, x):
return self.main(x)
def train_coeff(tag, total=5):
model = Net()
model=model.to(device)
# create your optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.BCELoss()
[tag_index], = np.where(tags == tag)
epoch = 0
epoch_val_loss_min = 100
while True:
epoch += 1
training_loss, valid_loss = 0.0, 0.0
for phase in ['train', 'valid']:
dataset = datasets[phase]
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
for (dlatents, labels) in dataset:
# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
with torch.set_grad_enabled(phase == 'train'):
inputs = dlatents.reshape(-1, 16 * 512)
inputs=inputs.to(device)
output = model(inputs)
targets = torch.Tensor(0, 1)
targets = targets.to(device)
for label in labels:
value = label[tag_index]
# value = 1.0 if value > 0.5 else 0.0
new_label = torch.Tensor([[value]])
new_label = new_label.to(device)
targets = torch.cat((targets, new_label))
loss = criterion(output, targets)
if phase == 'train':
loss.backward()
optimizer.step() # Does the update
# statistics
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / (len(dataset) * batch_size)
print(f'Epoch:{epoch}/{total}, {phase} Loss: {epoch_loss:.4f}')
#根据验证loss 保存最佳
if phase=='valid':
if epoch_loss < epoch_val_loss_min:
epoch_val_loss_min=epoch_loss
weight_val_min = model.state_dict()['main.0.weight'] # not bias
direction_path = f'checkpoint/directions_{tag}_val{epoch_loss}.pth'
# torch.save(weight_val_min, direction_path)
print(f" the best val is:{epoch_loss}")
if epoch == total:
break
weight = weight_val_min
return weight.detach().cpu().reshape(16, 512)
def generate_image(dlatents, pixel_min=-1, pixel_max=1):
generated = G(dlatents=dlatents)
if pixel_min != 0 or pixel_max != 1:
generated = (generated - pixel_min) / (pixel_max - pixel_min)
generated.clamp_(min=0, max=1)
return generated.detach().cpu().reshape(3, 512, 512).permute(1, 2, 0)
def move_and_show(latent_vector, direction, coeffs):
img_list = []
for i, coeff in enumerate(coeffs):
new_latent_vector = latent_vector.clone()
# direction=direction.to(direction)
# print((new_latent_vector),(latent_vector),(direction),coeff)
new_latent_vector[:8] = (latent_vector + coeff * direction)[:8]
img=generate_image(new_latent_vector)
img_list.append(img)
# plt.show()
return img_list
"core"
def move_and_show_samples(direction, direction_name,sample=3, coeffs=[-10,-5,2, 0,2, 5,10]):
fig, ax = plt.subplots(sample, 1, figsize=(50, 50), dpi=80)
for i,(latents, labels) in enumerate(list(datasets['test'])[:sample]):
inputs = latents.clone().reshape(1, 16, 512)
direction=direction.to(device)
img_list=move_and_show(inputs, direction, coeffs)
ax[i].imshow(np.hstack(img_list))
plt.suptitle(f'Edit: {direction_name}',size=16)
[x.axis('off') for x in ax] #取消网格
plt.tight_layout() # 使图片自适应填充
save_folder=f'./edit/'
os.makedirs(save_folder,exist_ok=True)
plt.savefig(f"{save_folder}/{direction_name}combine_{sample}.png")
if __name__ == '__main__':
result = {}
'''
训练迭代,一般3次就够了
'''
flag_train=1
direction_save_path = f'checkpoint/direction.pth'
if flag_train:
picked_tags=['black_hair','pink_hair','open_mouth','brown_hair']
# filter out the real tags
picked_tags = [tag for tag in picked_tags if tag in tags]
print(picked_tags)
for tag in picked_tags:
print(f'training {tag}')
result[tag] = train_coeff(tag, 3)
'''
保存所有维度
'''
torch.save(result, direction_save_path)
else:
result = torch.load(direction_save_path, map_location=device)
'''
可视化编辑结果
'''
for name in result.keys():
move_and_show_samples(result[name],name,sample=5)
'''
## Let's pick some tags and train it!
colors = ['aqua', 'black', 'blue', 'brown', 'green', 'grey', 'lavender', 'light_brown', 'multicolored', 'orange',
'pink', 'purple', 'red', 'silver', 'white', 'yellow']
switches = ['open', 'closed', 'covered']
# generate composition of elements
components = ['eyes', 'hair', 'mouth']
picked_tags = []
for component in components:
picked_tags = picked_tags + [f'{color}_{component}' for color in colors]
picked_tags = picked_tags + [f'{switch}_{component}' for switch in switches]
# filter out the real tags
picked_tags = [tag for tag in picked_tags if tag in tags]
print(picked_tags)
## Train all these tags!
for tag in picked_tags:
print(f'training {tag}')
result[tag] = train_coeff(tag, 3)
# try some of them
move_and_show_samples(result['open_mouth'])
# play a bit more, training charater specify encoder?
charas = ['hakurei_reimu', 'kirisame_marisa']
# Let's check out how many samples we got
for chara in charas:
[chara_index], = np.where(tags == chara)
count = [x[chara_index] for x in labels_data if x[chara_index] > 0.5]
print(f'{chara}: {len(count)}, {(len(count) / len(labels_data)) * 100}%')
result[chara] = train_coeff(chara, 3)
# too rare, properly don't work
move_and_show_samples(result['hakurei_reimu'])
move_and_show_samples(result['kirisame_marisa'])
# store the result
torch.save(result, 'checkpoint/directions.pth')
'''