import torch.nn as nn
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
import torch
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
from tools.common_tools import transform_invert, set_seed
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
# assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
set_seed(1)
path_img = os.path.join(os.path.dirname(os.path.abspath(__file__)),'lena.png')
img = Image.open(path_img).convert('RGB')
img_transform = transforms.Compose([transforms.ToTensor()])
img_tensor = img_transform(img)
img_tensor.unsqueeze_(dim=0)
r,g,b=img.split()
img_r = transforms.ToTensor()(r)
img_g = transforms.ToTensor()(g)
img_b = transforms.ToTensor()(b)
print(img_r)
img_r.unsqueeze_(dim=0)
img_g.unsqueeze_(dim=0)
img_b.unsqueeze_(dim=0)
flag = True
# flag = False
if flag:
conv_layer = nn.Conv2d(1, 1, (1,1),stride=1) # input:(i, o, size)
conv_layer.weight.data=torch.Tensor([[[[1]]]])
conv_layer.bias.data=torch.Tensor([0])
print(conv_layer.weight)
print(conv_layer.bias)
# calculation
img_convr = conv_layer(img_r)
img_convg = conv_layer(img_g)
img_convb = conv_layer(img_b)
img_raw = transform_invert(img_tensor.squeeze(), img_transform)
# print(img_convr)
# img_conv = transform_invert(img_conv.squeeze(), img_transform)
# print(img_convr.shape)
img_convr = img_convr.detach().numpy() * 255
img_convg = img_convg.detach().numpy() * 255
img_convb = img_convb.detach().numpy() * 255
print(img_convr)
img_convr = Image.fromarray(img_convr.astype('uint8').squeeze())
img_convg = Image.fromarray(img_convg.astype('uint8').squeeze())
img_convb = Image.fromarray(img_convb.astype('uint8').squeeze())
# img_convr = transform_invert(img_convr.squeeze(), img_transform)
# img_convg = transform_invert(img_convg.squeeze(), img_transform)
# img_convb = transform_invert(img_convb.squeeze(), img_transform)
# print(img_raw.numpy())
# print(img_conv)
# img_conv=img_conv.convert('BGR')
img_cond = Image.merge('RGB',(img_convr,img_convg,img_convb))
plt.subplot(122).imshow(img_cond)
plt.subplot(121).imshow(img_raw)
plt.show()