【实验记录】基于stable diffusion的图像压缩案例的优化

本次记录是在上一次的基础上进行的,上一次实验:【实验记录】在本地运行基于stable diffusion的图像压缩案例Stable Diffusion based lossy image compression-CSDN博客

在上一次的文末已经提到过一些优化方案,此次优化目标:使得案例可以适配所有分辨率的图像

理论知识

在 VAE 编码器中,经过编码后图像分辨率的计算主要取决于卷积层的配置,尤其是步幅 (stride) padding(填充),这些参数决定了每层卷积后图像分辨率的变化。

通常情况下,VAE 编码器通过若干层的卷积和下采样将输入图像压缩为潜在空间表示。假设编码器每层卷积的步幅和填充都是一致的(如 stride=2),你可以通过递归计算每层卷积后的输出分辨率来确定最终编码后的分辨率。

计算方法:

如果编码器通过 N 层卷积来下采样,且每一层卷积的步幅为 2,那么图像在每经过一层卷积后,宽度和高度都会缩小一半。公式如下:

H_{out} = \frac{H_{in}}{2^{N}}

W_{out} = \frac{W_{in}}{2^{N}}

其中:
- H_{out}W_{in}分别是输入图像的高度和宽度,
- N是编码器的卷积层数(每层都进行下采样),
- H_{out}W_{out}是编码后图像的高度和宽度。

本次模型使用的是 Stable Diffusion v1-4 VAE 的标准配置,通常经过编码器处理后,输入的分辨率会缩小到原始图像的 1/8。

实践:直接修改源代码

再次提醒,是在前文的基础上进行修改,只展示了在前文基础上进行修改的部分cell

def resize_to_8like(input_file, output_file):
  # 作用是把图片修改为宽和高均为8的倍数
  img = Image.open(input_file).convert('RGB')
  #resize
  tmp_w = (img.width // 8)*8
  tmp_h = (img.height // 8)*8
  print(f"target width: {tmp_w}")
  print(f"target height: {tmp_h}")
  img = img.resize((tmp_w,tmp_h), Image.LANCZOS)
  img.save(output_file, lossless = True, quality = 100)
  print(f"img was set to size: {img.size}")
def quantize(latents):
  quantized_latents = (latents / (255 * 0.18215) + 0.5).clamp(0,1)
  quantized = quantized_latents.cpu().permute(0, 2, 3, 1).detach().numpy()[0]
  quantized = (quantized * 255.0 + 0.5).astype(np.uint8)
  return quantized

def unquantize(quantized):
  unquantized = quantized.astype(np.float32) / 255.0
  unquantized = unquantized[None].transpose(0, 3, 1, 2)
  unquantized_latents = (unquantized - 0.5) * (255 * 0.18215)
  unquantized_latents = torch.from_numpy(unquantized_latents)
  return unquantized_latents.to(torch_device)

@torch.no_grad()
def denoise(latents):
  latents = latents * 0.18215
  step_size = 15
  num_inference_steps = scheduler.config.get("num_train_timesteps", 1000) // step_size
  strength = 0.04
  scheduler.set_timesteps(num_inference_steps)
  offset = scheduler.config.get("steps_offset", 0)
  init_timestep = int(num_inference_steps * strength) + offset
  init_timestep = min(init_timestep, num_inference_steps)
  timesteps = scheduler.timesteps[-init_timestep]
  timesteps = torch.tensor([timesteps], dtype=torch.long, device=torch_device)
  extra_step_kwargs = {}
  if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
    extra_step_kwargs["eta"] = 0.9
  latents = latents.to(unet.dtype).to(torch_device)
  t_start = max(num_inference_steps - init_timestep + offset, 0)
  with autocast():
    for i, t in enumerate(scheduler.timesteps[t_start:]):
      noise_pred = unet(latents, t, encoder_hidden_states=uncond_embeddings).sample
      latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
  #reset scheduler to free cached noise predictions
  scheduler.set_timesteps(1)
  return latents / 0.18215

def compress_input(input_file, output_path):
  gt_img = Image.open(input_file)
  target_size = gt_img.size # 获取原尺寸

  display(gt_img)
  print(f'Ground Truth, size:{target_size}')
  print("max PSNR and SSIM:")
  print_metrics(gt_img, gt_img)

  # Display VAE roundtrip image
  latents = to_latents(gt_img)
  img_from_latents = to_img(latents,target_size=target_size)
  display(img_from_latents)
  print(f'VAE roundtrip, size:{img_from_latents.size}')
  print_metrics(gt_img, img_from_latents)
  # img_size = img_from_latents.tell()
  # print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))
  # print(f"size: {img_size}")

  # Quantize latent representation and save as lossless webp image
  quantized = quantize(latents)
  del latents
  quantized_img = Image.fromarray(quantized)
  quantized_img.save(output_path + "_sd_quantized_latents.webp", lossless=True, quality=100)
  # img_size = quantized_img.tell()
  # print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))
  # print(f"size: {img_size}")

  # Display VAE decoded image from 8-bit quantized latents
  unquantized_latents = unquantize(quantized)
  unquantized_img = to_img(unquantized_latents,target_size=target_size)
  display(unquantized_img)
  del unquantized_latents
  print('VAE decoded from 8-bit quantized latents')
  print_metrics(gt_img, unquantized_img)
  # img_size = unquantized_img.tell()
  # print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))
  # print(f"size: {img_size}")

  # further quantize to palette. Use libimagequant for Dithering
  attr = liq.Attr()
  attr.speed = 1
  attr.max_colors = 256
  input_image = attr.create_rgba(quantized.flatten('C').tobytes(),
                                 quantized_img.width,
                                 quantized_img.height,
                                 0)
  quantization_result = input_image.quantize(attr)
  quantization_result.dithering_level = 1.0
  # Get the quantization result
  out_pixels = quantization_result.remap_image(input_image)
  out_palette = quantization_result.get_palette()
  np_indices = np.frombuffer(out_pixels, np.uint8)
  np_palette = np.array([c for color in out_palette for c in color], dtype=np.uint8)
    
  # 打印np_indices形状
  # print(f"np_indices size: {np_indices}")
    
  w, h = quantized_img.size
  print(f"quantized img size: width:{w}, height:{h}")

  sd_palettized_bytes = io.BytesIO()
  np.savez_compressed(sd_palettized_bytes, w=w, h=h, i=np_indices.flatten(), p=np_palette)
  with open(output_path + ".npz", "wb") as f:
    f.write(sd_palettized_bytes.getbuffer())

  # Compress the dithered 8-bit latents using zlib and save them to disk
  compressed_bytes = zlib.compress(
      np.concatenate((np_palette, np_indices), dtype=np.uint8).tobytes(),
      level=9
      )
  with open(output_path + ".bin", "wb") as f:
    f.write(compressed_bytes)
  sd_bytes = len(compressed_bytes)

  # Display VAE decoding of dithered 8-bit latents
  np_indices = np_indices.reshape((h,w))
  palettized_latent_img = Image.fromarray(np_indices, mode='P')
  palettized_latent_img.putpalette(np_palette, rawmode='RGBA')
  latents = np.array(palettized_latent_img.convert('RGBA'))
  latents = unquantize(latents)
  palettized_img = to_img(latents)
  display(palettized_img)
  print('VAE decoding of palettized and dithered 8-bit latents')
  print(f"decoded img size: {palettized_img.size}")
  print(f"origin img size: {target_size}")
  print_metrics(gt_img, palettized_img)

  # Use Stable Diffusion U-Net to de-noise the dithered latents
  latents = denoise(latents)
  denoised_img = to_img(latents)
  display(denoised_img)
  del latents
  print('VAE decoding of de-noised dithered 8-bit latents')
  print(f"decoded img size: {denoised_img.size}")
  print(f"origin img size: {target_size}")
  print('size: {}b = {}kB'.format(sd_bytes, sd_bytes/1024.0))

  print_metrics(gt_img, denoised_img)

#   denoised_img.save('denoised_image.png')  # 保存为PNG格式
#   print('Denoised image saved as denoised_image.png')

#   denoised_img.save('denoised_image.jpg')  # 保存为jpg格式
#   print('Denoised image saved as denoised_image.jpg')

#   denoised_img.save('denoised_image.webp')  # 保存为webp格式
#   print('Denoised image saved as denoised_image.webp')

#   denoised_img.save('denoised_image.jpeg')
#   print('Denoised image saved as denoised_image.jpeg')

#   denoised_img.save('denoised_image.svg')
#   print('Denoised image saved as denoised_image.svg')

#   denoised_img.save('denoised_image.jpeg')
#   print('Denoised image saved as denoised_image.jpeg')

#   denoised_img.save('denoised_image.jpeg')
#   print('Denoised image saved as denoised_image.jpeg')


  # 以下内容可以全部进行注释,只是起到一个对比的作用。
  # 分别输出的时jpg压缩和webp压缩的结果
    
  # Find JPG compression settings that result in closest data size that is larger than SD compressed data
  jpg_bytes = io.BytesIO()
  q = 0
  while jpg_bytes.getbuffer().nbytes < sd_bytes:
    jpg_bytes = io.BytesIO()
    gt_img.save(jpg_bytes, format="JPEG", quality=q, optimize=True, subsampling=1)
    jpg_bytes.flush()
    jpg_bytes.seek(0)
    jpg_bytes = io.BytesIO(mozjpeg_lossless_optimization.optimize(jpg_bytes.read()))
    jpg_bytes.flush()
    q += 1

  with open(output_path + ".jpg", "wb") as f:
    f.write(jpg_bytes.getbuffer())
  jpg = Image.open(jpg_bytes)
  try:
    display(jpg)
    print('JPG compressed with quality setting: {}'.format(q))
    print('size: {}b = {}kB'.format(jpg_bytes.getbuffer().nbytes, jpg_bytes.getbuffer().nbytes / 1024.0))
    print_metrics(gt_img, jpg)
  except:
    print('something went wrong compressing {}.jpg'.format(output_path))

  webp_bytes = io.BytesIO()
  q = 0
  while webp_bytes.getbuffer().nbytes < sd_bytes:
    webp_bytes = io.BytesIO()
    gt_img.save(webp_bytes, format="WEBP", quality=q, method=6)
    webp_bytes.flush()
    q += 1

  with open(output_path + ".webp", "wb") as f:
    f.write(webp_bytes.getbuffer())
  try:
    webp = Image.open(webp_bytes)
    display(webp)
    print('WebP compressed with quality setting: {}'.format(q))
    print('size: {}b = {}kB'.format(webp_bytes.getbuffer().nbytes, webp_bytes.getbuffer().nbytes / 1024.0))
    print_metrics(gt_img, webp)
  except:
    print('something went wrong compressing {}.webp'.format(output_path))

以上是优化的第一部分,剩余部分有待继续学习实践。

敬请期待。

PS:从结果上来看,对源代码的修改部分很少,结果论的话感觉我并没有投入多少精力,以及这个工作量看起来很少。但是实际上花了我挺多的时间去翻阅文档、查看修改源码和测试,最终才得到了本文。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值