本次记录是在上一次的基础上进行的,上一次实验:【实验记录】在本地运行基于stable diffusion的图像压缩案例Stable Diffusion based lossy image compression-CSDN博客
在上一次的文末已经提到过一些优化方案,此次优化目标:使得案例可以适配所有分辨率的图像
理论知识
在 VAE 编码器中,经过编码后图像分辨率的计算主要取决于卷积层的配置,尤其是步幅 (stride) 和 padding(填充),这些参数决定了每层卷积后图像分辨率的变化。
通常情况下,VAE 编码器通过若干层的卷积和下采样将输入图像压缩为潜在空间表示。假设编码器每层卷积的步幅和填充都是一致的(如 stride=2),你可以通过递归计算每层卷积后的输出分辨率来确定最终编码后的分辨率。
计算方法:
如果编码器通过 N 层卷积来下采样,且每一层卷积的步幅为 2,那么图像在每经过一层卷积后,宽度和高度都会缩小一半。公式如下:
其中:
- 和 分别是输入图像的高度和宽度,
- 是编码器的卷积层数(每层都进行下采样),
- 和是编码后图像的高度和宽度。
本次模型使用的是 Stable Diffusion v1-4 VAE 的标准配置,通常经过编码器处理后,输入的分辨率会缩小到原始图像的 1/8。
实践:直接修改源代码
再次提醒,是在前文的基础上进行修改,只展示了在前文基础上进行修改的部分cell
def resize_to_8like(input_file, output_file):
# 作用是把图片修改为宽和高均为8的倍数
img = Image.open(input_file).convert('RGB')
#resize
tmp_w = (img.width // 8)*8
tmp_h = (img.height // 8)*8
print(f"target width: {tmp_w}")
print(f"target height: {tmp_h}")
img = img.resize((tmp_w,tmp_h), Image.LANCZOS)
img.save(output_file, lossless = True, quality = 100)
print(f"img was set to size: {img.size}")
def quantize(latents):
quantized_latents = (latents / (255 * 0.18215) + 0.5).clamp(0,1)
quantized = quantized_latents.cpu().permute(0, 2, 3, 1).detach().numpy()[0]
quantized = (quantized * 255.0 + 0.5).astype(np.uint8)
return quantized
def unquantize(quantized):
unquantized = quantized.astype(np.float32) / 255.0
unquantized = unquantized[None].transpose(0, 3, 1, 2)
unquantized_latents = (unquantized - 0.5) * (255 * 0.18215)
unquantized_latents = torch.from_numpy(unquantized_latents)
return unquantized_latents.to(torch_device)
@torch.no_grad()
def denoise(latents):
latents = latents * 0.18215
step_size = 15
num_inference_steps = scheduler.config.get("num_train_timesteps", 1000) // step_size
strength = 0.04
scheduler.set_timesteps(num_inference_steps)
offset = scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps], dtype=torch.long, device=torch_device)
extra_step_kwargs = {}
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs["eta"] = 0.9
latents = latents.to(unet.dtype).to(torch_device)
t_start = max(num_inference_steps - init_timestep + offset, 0)
with autocast():
for i, t in enumerate(scheduler.timesteps[t_start:]):
noise_pred = unet(latents, t, encoder_hidden_states=uncond_embeddings).sample
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
#reset scheduler to free cached noise predictions
scheduler.set_timesteps(1)
return latents / 0.18215
def compress_input(input_file, output_path):
gt_img = Image.open(input_file)
target_size = gt_img.size # 获取原尺寸
display(gt_img)
print(f'Ground Truth, size:{target_size}')
print("max PSNR and SSIM:")
print_metrics(gt_img, gt_img)
# Display VAE roundtrip image
latents = to_latents(gt_img)
img_from_latents = to_img(latents,target_size=target_size)
display(img_from_latents)
print(f'VAE roundtrip, size:{img_from_latents.size}')
print_metrics(gt_img, img_from_latents)
# img_size = img_from_latents.tell()
# print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))
# print(f"size: {img_size}")
# Quantize latent representation and save as lossless webp image
quantized = quantize(latents)
del latents
quantized_img = Image.fromarray(quantized)
quantized_img.save(output_path + "_sd_quantized_latents.webp", lossless=True, quality=100)
# img_size = quantized_img.tell()
# print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))
# print(f"size: {img_size}")
# Display VAE decoded image from 8-bit quantized latents
unquantized_latents = unquantize(quantized)
unquantized_img = to_img(unquantized_latents,target_size=target_size)
display(unquantized_img)
del unquantized_latents
print('VAE decoded from 8-bit quantized latents')
print_metrics(gt_img, unquantized_img)
# img_size = unquantized_img.tell()
# print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))
# print(f"size: {img_size}")
# further quantize to palette. Use libimagequant for Dithering
attr = liq.Attr()
attr.speed = 1
attr.max_colors = 256
input_image = attr.create_rgba(quantized.flatten('C').tobytes(),
quantized_img.width,
quantized_img.height,
0)
quantization_result = input_image.quantize(attr)
quantization_result.dithering_level = 1.0
# Get the quantization result
out_pixels = quantization_result.remap_image(input_image)
out_palette = quantization_result.get_palette()
np_indices = np.frombuffer(out_pixels, np.uint8)
np_palette = np.array([c for color in out_palette for c in color], dtype=np.uint8)
# 打印np_indices形状
# print(f"np_indices size: {np_indices}")
w, h = quantized_img.size
print(f"quantized img size: width:{w}, height:{h}")
sd_palettized_bytes = io.BytesIO()
np.savez_compressed(sd_palettized_bytes, w=w, h=h, i=np_indices.flatten(), p=np_palette)
with open(output_path + ".npz", "wb") as f:
f.write(sd_palettized_bytes.getbuffer())
# Compress the dithered 8-bit latents using zlib and save them to disk
compressed_bytes = zlib.compress(
np.concatenate((np_palette, np_indices), dtype=np.uint8).tobytes(),
level=9
)
with open(output_path + ".bin", "wb") as f:
f.write(compressed_bytes)
sd_bytes = len(compressed_bytes)
# Display VAE decoding of dithered 8-bit latents
np_indices = np_indices.reshape((h,w))
palettized_latent_img = Image.fromarray(np_indices, mode='P')
palettized_latent_img.putpalette(np_palette, rawmode='RGBA')
latents = np.array(palettized_latent_img.convert('RGBA'))
latents = unquantize(latents)
palettized_img = to_img(latents)
display(palettized_img)
print('VAE decoding of palettized and dithered 8-bit latents')
print(f"decoded img size: {palettized_img.size}")
print(f"origin img size: {target_size}")
print_metrics(gt_img, palettized_img)
# Use Stable Diffusion U-Net to de-noise the dithered latents
latents = denoise(latents)
denoised_img = to_img(latents)
display(denoised_img)
del latents
print('VAE decoding of de-noised dithered 8-bit latents')
print(f"decoded img size: {denoised_img.size}")
print(f"origin img size: {target_size}")
print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))
print_metrics(gt_img, denoised_img)
# denoised_img.save('denoised_image.png') # 保存为PNG格式
# print('Denoised image saved as denoised_image.png')
# denoised_img.save('denoised_image.jpg') # 保存为jpg格式
# print('Denoised image saved as denoised_image.jpg')
# denoised_img.save('denoised_image.webp') # 保存为webp格式
# print('Denoised image saved as denoised_image.webp')
# denoised_img.save('denoised_image.jpeg')
# print('Denoised image saved as denoised_image.jpeg')
# denoised_img.save('denoised_image.svg')
# print('Denoised image saved as denoised_image.svg')
# denoised_img.save('denoised_image.jpeg')
# print('Denoised image saved as denoised_image.jpeg')
# denoised_img.save('denoised_image.jpeg')
# print('Denoised image saved as denoised_image.jpeg')
# 以下内容可以全部进行注释,只是起到一个对比的作用。
# 分别输出的时jpg压缩和webp压缩的结果
# Find JPG compression settings that result in closest data size that is larger than SD compressed data
jpg_bytes = io.BytesIO()
q = 0
while jpg_bytes.getbuffer().nbytes < sd_bytes:
jpg_bytes = io.BytesIO()
gt_img.save(jpg_bytes, format="JPEG", quality=q, optimize=True, subsampling=1)
jpg_bytes.flush()
jpg_bytes.seek(0)
jpg_bytes = io.BytesIO(mozjpeg_lossless_optimization.optimize(jpg_bytes.read()))
jpg_bytes.flush()
q += 1
with open(output_path + ".jpg", "wb") as f:
f.write(jpg_bytes.getbuffer())
jpg = Image.open(jpg_bytes)
try:
display(jpg)
print('JPG compressed with quality setting: {}'.format(q))
print('size: {}b = {}kB'.format(jpg_bytes.getbuffer().nbytes, jpg_bytes.getbuffer().nbytes / 1024.0))
print_metrics(gt_img, jpg)
except:
print('something went wrong compressing {}.jpg'.format(output_path))
webp_bytes = io.BytesIO()
q = 0
while webp_bytes.getbuffer().nbytes < sd_bytes:
webp_bytes = io.BytesIO()
gt_img.save(webp_bytes, format="WEBP", quality=q, method=6)
webp_bytes.flush()
q += 1
with open(output_path + ".webp", "wb") as f:
f.write(webp_bytes.getbuffer())
try:
webp = Image.open(webp_bytes)
display(webp)
print('WebP compressed with quality setting: {}'.format(q))
print('size: {}b = {}kB'.format(webp_bytes.getbuffer().nbytes, webp_bytes.getbuffer().nbytes / 1024.0))
print_metrics(gt_img, webp)
except:
print('something went wrong compressing {}.webp'.format(output_path))
以上是优化的第一部分,剩余部分有待继续学习实践。
敬请期待。
PS:从结果上来看,对源代码的修改部分很少,结果论的话感觉我并没有投入多少精力,以及这个工作量看起来很少。但是实际上花了我挺多的时间去翻阅文档、查看修改源码和测试,最终才得到了本文。