cv2读入的BGR图像转换成torch.tensor格式,torch.tensor格式的图像转换成ndarray格式并保存

代码如下:

包含两个函数

1.ndarray2tensor :

        img(W,H,C):cv2读取的一张图像,ndarray格式,注意读取的为BGR图像,函数中没有转换成RGB图像,如需自行可以更换;

        输出(1,C,W,H):torch.Tensor格式

2.torch2ndarray_save:

        input(1,C,W,H):torch.Tensor格式

        filename:str格式,保存的路径

         denormal:bool格式,是否反标准化

注意:标准化和反标准化已经实例出来,可参考使用

import cv2
import torch
import numpy as np
import os
import torchvision

c_mean = [0.4914, 0.4822, 0.4465]
c_std = [0.2023, 0.1994, 0.2010]
de_mean = [-mean / std for mean, std in zip(c_mean, c_std)]
de_std = [1 / std for std in c_std]
normalize = torchvision.transforms.Normalize(c_mean, c_std)
denormalize = torchvision.transforms.Normalize(de_mean, de_std)

def ndarray2tensor(img):
    img = cv2.resize(img, (256, 256))
    img = img / 255.
    img = torch.tensor(img, dtype=torch.float32)
    img = img.unsqueeze(0)
    img = img.permute(0, 3, 1, 2)
    img = normalize(img)
    return img

def torch2ndarray_save(input:torch.Tensor, filename, denormal = True):
    assert (len(input.shape) == 4 and input.shape[0] == 1)
    input = input.clone().detach()
    if input.is_cuda == True:
        input = input.to(torch.device('cpu'))
    if denormal:
        input = denormalize(input) * 255
    input = torch.tensor(input, dtype=torch.int)
    img = input.permute(0, 2, 3, 1)
    img = torch.reshape(img, (img.shape[1], img.shape[2], img.shape[3])).numpy()
    cv2.imwrite(filename,img)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值