对13章风格迁移任务进行讲解,并对其中的部分操作如pil与tensor呼唤进行具体介绍,方便后续调用!
目录
1.读取原图片:
d2l.set_figsize()
content_img = d2l.Image.open('img/rainier.jpg')
d2l.plt.imshow(content_img);
style_img = d2l.Image.open('img/autumn-oak.jpg')
d2l.plt.imshow(style_img);
2.转换函数
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
return transforms(img).unsqueeze(0)
def postprocess(img):
img = img[0].to(rgb_std.device) # 表示移动到rgb_std的device上
img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 此时img为hwc
return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1)) # 此时per后为chw
PIL图像用可以用img.size返回其(weigh,height),与tensor正好相反!!进行逆std与mean运算时应为(h,w,c),tensor中为(bs,c,h,w),进行ToPILImage操作时应为(c,h,w)
应用,注:使用d2l.plt.imshow()传入一个pil图像可以直接显示:
# 加载图像并进行预处理
img_tensor = transforms.ToTensor()(Image.open('img/rainier.jpg'))
# 将张量转换为 PIL 图像对象
img_pil = transforms.ToPILImage()(img_tensor)
# 在 Jupyter Notebook 中显示图像
d2l.plt.imshow(img_pil)
img_pil.size
3.抽取图像特征与net构件
pretrained_net = torchvision.models.vgg19(pretrained=True)
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
net = nn.Sequential(*[pretrained_net.features[i] for i in
range(max(content_layers + style_layers) + 1)])
保留了0-28层,后面的不要了,上一节children嵌套list讲过了。
逐层计算,并保留内容层和风格层的输出fmap:
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
def get_contents(image_shape, device):
content_X = preprocess(content_img, image_shape).to(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
def get_styles(image_shape, device):
style_X = preprocess(style_img, image_shape).to(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
4.定义损失
内容损失直接用生成图片内容与真正图片内容进行均方误差。
样式一样怎么理解?通过RGB直方图来理解,每个像素值没必要一样,匹配的是通道之间相关性。这里使用格拉姆矩阵,在gram函数中实现。传入X为(b,c,h,w),首先拉成(c,hw)尺寸,其中每一行可看作x1,x2...xc为长度为hw的向量,表示在该通道上的风格特征,进行矩阵相乘格拉姆矩阵中,i行和j列的元素Xij表示上述i行和j行的内积,表示通道i\j之间的风格相关性。
为了避免风格损失受其中某较大误差值的影响,return的格拉姆矩阵再除矩阵中元素的个数chw(设bs=1)
def content_loss(Y_hat, Y):
# 我们从动态计算梯度的树中分离⽬标:
# 这是⼀个规定的值,⽽不是⼀个变量。
return torch.square(Y_hat - Y.detach()).mean()
def gram(X):
num_channels, n = X.shape[1], X.numel() // X.shape[1]
X = X.reshape((num_channels, n))
return torch.matmul(X, X.T) / (num_channels * n)
def style_loss(Y_hat, gram_Y):
return torch.square(gram(Y_hat) - gram_Y.detach()).mean()
def tv_loss(Y_hat):
return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
最后tv_loss是降噪操作,尽可能使临近的像素值相似,避免特别明亮或过暗的像素噪点。
赋予权重并写总loss计算函数:
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# 分别计算内容损失、⻛格损失和全变分损失
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# 对所有损失求和
l = sum(10 * styles_l + contents_l + [tv_l])
return contents_l, styles_l, tv_l, l
赋予权重,再相加。
sum中传入3个列表则将列表里面的所有元素都相加起来,返回一个值。
5.训练
明确一点,训练的使X,使图片,不是参数。
所以写一个SynthesizedImage的类,将合成图像视为参数,这样就可以利用自动计算梯度更新合成图像,注意这里合成图像一开始用rand随机生成。
class SynthesizedImage(nn.Module):
def __init__(self, img_shape, **kwargs):
super(SynthesizedImage, self).__init__(**kwargs)
self.weight = nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
定义inits函数,创建合成图像,在此可以将内容或风格图像当作初始权重,这样就不用随机rand生成了,使用的使data.copy_(X.data)的命令。
然后,再将原始风格图像的gram提前计算出。
def get_inits(X, device, lr, styles_Y):
gen_img = SynthesizedImage(X.shape).to(device)
gen_img.weight.data.copy_(X.data)
trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, trainer
训练函数,真正backward的只有l,其他的三个分loss画图print用的:
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs],
legend=['content', 'style', 'TV'],
ncols=2, figsize=(7, 2.5))
for epoch in range(num_epochs):
trainer.zero_grad()
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(
X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
l.backward()
trainer.step()
scheduler.step()
if (epoch + 1) % 10 == 0:
animator.axes[1].imshow(postprocess(X))
animator.add(epoch + 1, [float(sum(contents_l)),
float(sum(styles_l)), float(tv_l)])
return X
device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)