SpatialSoftmax implenmentation

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
import matplotlib.pyplot as plt


class SpatialSoftmax(torch.nn.Module):
    def __init__(self, height, width, channel, temperature=None, data_format='NCHW', debug=False):
        super(SpatialSoftmax, self).__init__()
        self.height = height
        self.width = width
        self.channel = channel   
        self.data_format = data_format
        self.debug = debug                                             

        if temperature is None: 
            self.temperature = Parameter(torch.ones(1))
        else:
            self.temperature = temperature

        pos_x, pos_y = np.meshgrid(
                np.linspace(-1., 1., self.height),
                np.linspace(-1., 1., self.width)
                )
        # print("pos_x:\n{}\npos_y:\n{}".format(pos_x, pos_y))
        pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float()
        pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float()
        self.register_buffer('pos_x', pos_x)
        self.register_buffer('pos_y', pos_y)
        # print("self.pos_x:\n{}\nself.pos_y:\n{}".format(self.pos_x, self.pos_y))

        
    def forward(self, feature):
        if self.debug:
            print("input:\n{}".format(feature))
        # Output:
        #   (N, C*2) x_0 y_0 ...
        if self.data_format == 'NHWC':  # trnsform to 'NCHW' then flatten to N*C imgs of H*W
            feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height*self.width)
        else:  # flatten to N*C imgs of H*W
            feature = feature.view(-1, self.height*self.width)
        softmax_attention = F.softmax(feature/self.temperature, dim=-1)
        expected_x = torch.sum(self.pos_x*softmax_attention, dim=1, keepdim=True)
        expected_y = torch.sum(self.pos_y*softmax_attention, dim=1, keepdim=True)
        expected_xy = torch.cat([expected_x, expected_y], 1)
        feature_keypoints = expected_xy.view(-1, self.channel*2)
        if self.debug:
            print("softmax_attention:\n{}".format(softmax_attention))
            print("self.pos_x:\n{}\nself.pos_y:\n{}".format(self.pos_x, self.pos_y))
            print("expected_x:\n{}\nexpected_y:\n{}".format(expected_x, expected_y))
            print("expected_xy:\n{}".format(expected_xy))
            print("feature_keypoints:\n{}".format(feature_keypoints))
        return feature_keypoints
    
    
if __name__ == '__main__':
#   data = torch.zeros([3,3,3,3])
#   data[0,0,0,1] = 10
#   data[0,1,1,1] = 10
#   data[0,2,1,2] = 10
#   layer = SpatialSoftmax(3, 3, 3, temperature=3, debug=True)
#   layer(data)
    feature_from_conv = torch.zeros(6,3,28,28)
    feature_from_conv[0,:,10,0:4] = 1
    feature_from_conv[1,:,10,4:8] = 1
    feature_from_conv[2,:,10,8:12] = 1
    feature_from_conv[3,:,10,12:16] = 1
    feature_from_conv[4,:,10,16:20] = 1
    feature_from_conv[5,:,10,20:24] = 1
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(feature_from_conv[i][0], interpolation='none')
        plt.title("Feature of: {}".format(i))
        plt.xticks([])
        plt.yticks([])
    plt.show()
    layer = SpatialSoftmax(28, 28, 3, debug=False)
    feature_points = layer(feature_from_conv).detach().numpy()
    plt.imshow(feature_points)
    plt.title("Feature Points")
    plt.xticks([])
    plt.yticks([])
    plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
YOLO系列是基于深度学习的端到端实时目标检测方法。 PyTorch版的YOLOv5轻量而性能高,更加灵活和易用,当前非常流行。 本课程将手把手地教大家使用labelImg标注和使用YOLOv5训练自己的数据集。课程实战分为两个项目:单目标检测(足球目标检测)和多目标检测(足球和梅西同时检测)。 本课程的YOLOv5使用ultralytics/yolov5,在Windows系统上做项目演示。包括:安装YOLOv5、标注自己的数据集、准备自己的数据集、修改配置文件、使用wandb训练可视化工具、训练自己的数据集、测试训练出的网络模型和性能统计。 希望学习Ubuntu上演示的同学,请前往 《YOLOv5(PyTorch)实战:训练自己的数据集(Ubuntu)》课程链接:https://edu.csdn.net/course/detail/30793  本人推出了有关YOLOv5目标检测的系列课程。请持续关注该系列的其它视频课程,包括:《YOLOv5(PyTorch)目标检测实战:训练自己的数据集》Ubuntu系统 https://edu.csdn.net/course/detail/30793Windows系统 https://edu.csdn.net/course/detail/30923《YOLOv5(PyTorch)目标检测:原理与源码解析》课程链接:https://edu.csdn.net/course/detail/31428《YOLOv5目标检测实战:Flask Web部署》课程链接:https://edu.csdn.net/course/detail/31087《YOLOv5(PyTorch)目标检测实战:TensorRT加速部署》课程链接:https://edu.csdn.net/course/detail/32303《YOLOv5目标检测实战:Jetson Nano部署》课程链接:https://edu.csdn.net/course/detail/32451《YOLOv5+DeepSORT多目标跟踪与计数精讲》课程链接:https://edu.csdn.net/course/detail/32669《YOLOv5实战口罩佩戴检测》课程链接:https://edu.csdn.net/course/detail/32744《YOLOv5实战中国交通标志识别》课程链接:https://edu.csdn.net/course/detail/35209《YOLOv5实战垃圾分类目标检测》课程链接:https://edu.csdn.net/course/detail/35284       
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值