import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, stride=1):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.in_channel = in_channel
self.out_channel = out_channel
self.stride = stride
self.shrink = nn.Sequential(
nn.Conv2d(self.in_channel, self.out_channel, kernel_size=1, stride=self.stride, bias=False),
nn.BatchNorm2d(self.out_channel)
) # convert a x with input size to output size
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if out.shape != x.shape:
x = self.shrink(x)
out = out + x
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, num_blocks, num_classes):
super(ResNet, self).__init__()
self.initial_output_channel = 64
self.conv1 = nn.Conv2d(3, self.initial_output_channel, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(self.initial_output_channel)
self.res_layer1 = self._make_layer(
input_channel=self.initial_output_channel,
output_channel=self.initial_output_channel,
num_block=num_blocks[0],
)
self.res_layer2 = self._make_layer(
input_channel=self.initial_output_channel,
output_channel=128,
num_block=num_blocks[1],
if_downside=True
)
self.res_layer3 = self._make_layer(
input_channel=128,
output_channel=256,
num_block=num_blocks[2],
if_downside=True
)
self.res_layer4 = self._make_layer(
input_channel=256,
output_channel=512,
num_block=num_blocks[3],
if_downside=True
)
self.linear = nn.Linear(512, num_classes)
def _make_layer(self, input_channel, output_channel, num_block, if_downside=False):
strides = [1] * num_block
strides[0] = int(if_downside) + 1 # if-down side
layers = []
for i in range(num_block):
layers.append(ResBlock(input_channel, output_channel, strides[i]))
input_channel = output_channel
return nn.Sequential(*layers)
def forward(self, x):
conv1_out = self.bn(self.conv1(x))
out = F.max_pool2d(conv1_out, stride=2, kernel_size=2)
out = self.res_layer1(out)
out = self.res_layer2(out)
out = self.res_layer3(out)
out = self.res_layer4(out)
out = F.avg_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
class ResNet18:
def __call__(self, input_):
return ResNet(num_blocks=[2, 2, 2, 2], num_classes=5)(input_)
def show_one_model(model, input_, output):
width = 8
fig, ax = plt.subplots(output[0].shape[0] // width, width, figsize=(10, 10))
for i in range(output[0].shape[0]):
ix = np.unravel_index(i, ax.shape)
plt.sca(ax[ix])
ax[ix].title.set_text('channel-{}'.format(i))
plt.imshow(output[0][i].detach())
input('this is conv: {}, received a {} tensor, press any key to show next: '.format(model, input_[0].shape))
plt.show()
if __name__ == '__main__':
keji = Image.open('img.png')
channel_num = 64
res_block = ResBlock(3, channel_num)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224))])
pre_trained_model = resnet18(pretrained=True)
keji = transform(keji).unsqueeze(0)
conv_models = [m for _, m in pre_trained_model.named_modules()
if isinstance(m, nn.Conv2d)]
for conv in conv_models:
conv.register_forward_hook(show_one_model)
with torch.no_grad():
output = pre_trained_model(keji)
…