对图片进行1*1的卷积,分通道卷积再合并,原图到原图

import torch.nn as nn
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
import torch
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
from tools.common_tools import transform_invert, set_seed
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
# assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))


os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
set_seed(1)

path_img = os.path.join(os.path.dirname(os.path.abspath(__file__)),'lena.png')
img = Image.open(path_img).convert('RGB')
img_transform = transforms.Compose([transforms.ToTensor()])
img_tensor = img_transform(img)
img_tensor.unsqueeze_(dim=0)
r,g,b=img.split()
img_r = transforms.ToTensor()(r)
img_g = transforms.ToTensor()(g)
img_b = transforms.ToTensor()(b)
print(img_r)
img_r.unsqueeze_(dim=0)
img_g.unsqueeze_(dim=0)
img_b.unsqueeze_(dim=0)

flag = True
# flag = False
if flag:
    conv_layer = nn.Conv2d(1, 1, (1,1),stride=1)  # input:(i, o, size)
    conv_layer.weight.data=torch.Tensor([[[[1]]]])
    conv_layer.bias.data=torch.Tensor([0])
    print(conv_layer.weight)
    print(conv_layer.bias)
    # calculation
    img_convr = conv_layer(img_r)
    img_convg = conv_layer(img_g)
    img_convb = conv_layer(img_b)
    img_raw = transform_invert(img_tensor.squeeze(), img_transform)
    # print(img_convr)
# img_conv = transform_invert(img_conv.squeeze(), img_transform)
# print(img_convr.shape)

    img_convr = img_convr.detach().numpy() * 255
    img_convg = img_convg.detach().numpy() * 255
    img_convb = img_convb.detach().numpy() * 255
    print(img_convr)
    img_convr = Image.fromarray(img_convr.astype('uint8').squeeze())
    img_convg = Image.fromarray(img_convg.astype('uint8').squeeze())
    img_convb = Image.fromarray(img_convb.astype('uint8').squeeze())

# img_convr = transform_invert(img_convr.squeeze(), img_transform)
# img_convg = transform_invert(img_convg.squeeze(), img_transform)
# img_convb = transform_invert(img_convb.squeeze(), img_transform)
# print(img_raw.numpy())
# print(img_conv)
# img_conv=img_conv.convert('BGR')
    img_cond = Image.merge('RGB',(img_convr,img_convg,img_convb))
    plt.subplot(122).imshow(img_cond)
    plt.subplot(121).imshow(img_raw)
    plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值