resnet18 图像可视化分析

在这里插入图片描述


import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18


class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)

        self.in_channel = in_channel
        self.out_channel = out_channel
        self.stride = stride

        self.shrink = nn.Sequential(
            nn.Conv2d(self.in_channel, self.out_channel, kernel_size=1, stride=self.stride, bias=False),
            nn.BatchNorm2d(self.out_channel)
        )  # convert a x with input size to output size

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if out.shape != x.shape:
            x = self.shrink(x)

        out = out + x
        out = F.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, num_blocks, num_classes):
        super(ResNet, self).__init__()
        self.initial_output_channel = 64
        self.conv1 = nn.Conv2d(3, self.initial_output_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(self.initial_output_channel)

        self.res_layer1 = self._make_layer(
            input_channel=self.initial_output_channel,
            output_channel=self.initial_output_channel,
            num_block=num_blocks[0],
        )

        self.res_layer2 = self._make_layer(
            input_channel=self.initial_output_channel,
            output_channel=128,
            num_block=num_blocks[1],
            if_downside=True
        )

        self.res_layer3 = self._make_layer(
            input_channel=128,
            output_channel=256,
            num_block=num_blocks[2],
            if_downside=True
        )

        self.res_layer4 = self._make_layer(
            input_channel=256,
            output_channel=512,
            num_block=num_blocks[3],
            if_downside=True
        )

        self.linear = nn.Linear(512, num_classes)

    def _make_layer(self, input_channel, output_channel, num_block, if_downside=False):
        strides = [1] * num_block
        strides[0] = int(if_downside) + 1 # if-down side

        layers = []

        for i in range(num_block):
            layers.append(ResBlock(input_channel, output_channel, strides[i]))
            input_channel = output_channel

        return nn.Sequential(*layers)

    def forward(self, x):
        conv1_out = self.bn(self.conv1(x))
        out = F.max_pool2d(conv1_out, stride=2, kernel_size=2)

        out = self.res_layer1(out)
        out = self.res_layer2(out)
        out = self.res_layer3(out)
        out = self.res_layer4(out)

        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)

        out = self.linear(out)

        return out


class ResNet18:
    def __call__(self, input_):
        return ResNet(num_blocks=[2, 2, 2, 2], num_classes=5)(input_)


def show_one_model(model, input_, output):
    width = 8
    fig, ax = plt.subplots(output[0].shape[0] // width, width, figsize=(10, 10))

    for i in range(output[0].shape[0]):
        ix = np.unravel_index(i, ax.shape)
        plt.sca(ax[ix])

        ax[ix].title.set_text('channel-{}'.format(i))
        plt.imshow(output[0][i].detach())

    input('this is conv: {}, received a {} tensor, press any key to show next: '.format(model, input_[0].shape))

    plt.show()


if __name__ == '__main__':
    keji = Image.open('img.png')

    channel_num = 64
    res_block = ResBlock(3, channel_num)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224))])

    pre_trained_model = resnet18(pretrained=True)

    keji = transform(keji).unsqueeze(0)

    conv_models = [m for _, m in pre_trained_model.named_modules()
                   if isinstance(m, nn.Conv2d)]

    for conv in conv_models:
        conv.register_forward_hook(show_one_model)

    with torch.no_grad():
        output = pre_trained_model(keji)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值