实现对降噪模型的输入单张图片测试

import random
import clip
from utils import NoiseImageDataSet, initialize_network, eval_request, NoiseDataSetWithClip
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import os
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from multiprocessing import Process
from config import config
import shutil
from torch.optim import lr_scheduler
from utils.loss import PSNRLoss, MSELoss
from experiments.base_enum import NoieType
import torchvision.transforms as transforms
from PIL import Image
from  experiments.model import model_bn_large_attention

import numpy as np
from random import random


def image_to_numpy(img: Image):
    data = np.array(img)
    return data

class RandomResizedCrop:
    def __init__(self, p):
        self.p = p

    def __call__(self, img: Image):
        if random() > self.p:
            return img

        width, height = img.size[:2]
        print("width,height:",width,height)
        base_w = int(width * (0.5 + 0.5 * random()))
        base_h = int(height * (0.5 + 0.5 * random()))

        final_w = int(base_w * (0.8 + 0.2 * random()))
        final_h = int(base_h * (0.8 + 0.2 * random()))

        ratio = base_w / base_h

        if ratio > width / height:
            scale = base_h / height

        else:
            scale = base_w / width

        return transforms.RandomResizedCrop(size=(final_h, final_w), scale=(scale, scale),
                                            ratio=(ratio, ratio))(img)

model= model_bn_large_attention.UNet_n2n_un()
all_dict=torch.load('C:/Users/qhq/PycharmProjects/lossy-compression-denoise-master/lossy-compression-denoise-master/checkpoint/l2/145.pth')
ls=[]
for k in list(all_dict.keys()):
    if k== 'net':
        ls.append(all_dict[k])

flag='module'
for k in list(ls[0].keys()):
    if flag in k:
        le = k.strip('module')
        le = le.lstrip('.')
        print(le)
        ls[0][le]=ls[0][k]
        del ls[0][k]
flag1='track'
for k in list(ls[0].keys()):
    if flag1 in k:
        del ls[0][k]


model.load_state_dict(ls[0])

print("load true!")
img_path='C:/Users/qhq/PycharmProjects/lossy-compression-denoise-master/lossy-compression-denoise-master/samper_test/2.jpg'
img = Image.open(img_path)


transfor_valid=transforms.Compose(
            [
                # RandomResizedCrop(0.9),
                transforms.RandomApply([transforms.ColorJitter((0.7, 1.3), (0.6, 1.4), (0.7, 1.3),
                                                               (-0.15, 0.15))], p=0.7),
                image_to_numpy
            ]
        )
img = transfor_valid(img)
img_tor=torch.from_numpy(img)
img_tor=img_tor.permute([2,0,1])
img_batch=img_tor.unsqueeze(dim=0)  #增加batch_size这一维度

outputs = model(img_batch.to(torch.float32))
print(outputs.shape)
outputs=outputs.squeeze(dim=0)
outputs=outputs.permute(2,1,0)
outputs=outputs.detach().numpy()
#print(outputs.shape) #[1280,1280,3]
pil_image=Image.fromarray(np.uint8(outputs))  #nuppy数组转成可以显示的数组数据
pil_image.show()
pil_image.save(r'C:\Users\qhq\PycharmProjects\lossy-compression-denoise-master\lossy-compression-denoise-master\samper_test\1_1.jpg')

1.定义模型,实例化模型

model= model_bn_large_attention.UNet_n2n_un()

2.加载模型参数(自己写的模型,可能保存了一些与本身模型不对应的参数,需要进行处理)

模型参数是以字典的格式存储的,所以可以进行for k in list(all_dict.keys()):进行输出查看

加载模型一般都是model.load_state_dict(torch.load(‘模型参数保存的位置.pth')

model= model_bn_large_attention.UNet_n2n_un()
all_dict=torch.load('C:/Users/qhq/PycharmProjects/lossy-compression-denoise-master/lossy-compression-denoise-master/checkpoint/l2/145.pth')
ls=[]
for k in list(all_dict.keys()):
    if k== 'net':
        ls.append(all_dict[k])

flag='module'
for k in list(ls[0].keys()):
    if flag in k:
        le = k.strip('module')
        le = le.lstrip('.')
        print(le)
        ls[0][le]=ls[0][k]
        del ls[0][k]
flag1='track'
for k in list(ls[0].keys()):
    if flag1 in k:
        del ls[0][k]


model.load_state_dict(ls[0])

 3.打开需要输入测试的图片:

img = Image.open(img_path)

对输入的图片进行标准化(达到输入模型的尺寸要求)

transfor_valid=transforms.Compose(
            [
                # RandomResizedCrop(0.9),
                transforms.RandomApply([transforms.ColorJitter((0.7, 1.3), (0.6, 1.4), (0.7, 1.3),
                                                               (-0.15, 0.15))], p=0.7),
                image_to_numpy
            ]
        )
img = transfor_valid(img)

 对输入的图片转成torch(只有将numpy数据转成torch之后才可以使用permute方法),并转换(permute)通道位置。

由于输入到模型的数据为[batch_size,通道数,高,宽],所以需要进行增加维度使用.unsqueeze方法。

img = transfor_valid(img)
img_tor=torch.from_numpy(img)
img_tor=img_tor.permute([2,0,1])
img_batch=img_tor.unsqueeze(dim=0)  #增加batch_size这一维度

4.将符合输入格式的torch数据输入到model中

outputs = model(img_batch.to(torch.float32))

5.为了显示出图片,又需要进行降维处理,调换维度位置,最后转成PIL可以显示的图片数据

outputs=outputs.squeeze(dim=0)
outputs=outputs.permute(2,1,0)
outputs=outputs.detach().numpy()
#print(outputs.shape) #[1280,1280,3]
pil_image=Image.fromarray(np.uint8(outputs))  #nuppy数组转成可以显示的数组数据
pil_image.show()

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值