Hierarchical structure correlation inference for pose estimation 实现CCIM代码(word文档HRNet总结里面)

https://blog.csdn.net/qq_41456721/article/details/102943499(看到gram矩阵)

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging

import torch
import torch.nn as nn

# import utils

def conv3x3(in_plane,out_plane,stride=1):

    return nn.Conv2d(in_plane,out_plane,kernel_size=3,stride=stride,
                     padding=1,bias=False)
def gram_matrix(y):
	"""
	使用Gram矩阵来表示图像的风格特征
	输入B,C,H,W
	输出B,C,C
	"""
	(b,ch,h,w) = y.size() # 比如1,8,2,2
	features = y.view(b,ch,w*h)  # 得到1,8,4
	features_t = features.transpose(1,2) # 调换第二维和第三维的顺序,即矩阵的转置,得到1,4,8


    # gram = features.bmm(features_t) / (ch * h * w)  # bmm()用来做矩阵乘法,及未转置的矩阵乘以转置后的矩阵,得到的就是1,8,8了
    # 由于要对batch中的每一个样本都计算Gram Matrix,因此使用bmm()来计算矩阵乘法,而不是mm()

	gram = features_t.bmm(features) # bmm()用来做矩阵乘法,及未转置的矩阵乘以转置后的矩阵,得到的就是1,8,8了
    # 由于要对batch中的每一个样本都计算Gram Matrix,因此使用bmm()来计算矩阵乘法,而不是mm()
	return gram


# b = torch.arange(0,32)
# b=b.view(1,4,4,2)
# a = gram_matrix(b)
# print(a)
#
class HW(nn.Module):
    def __init__(self,inplane,outplane,stride):
        super(HW,self).__init__()
        self.conv = conv3x3(inplane,outplane,stride)
        self.max = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up =nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self,x):
        out = self.conv(x)
        out = self.max(out)
        # print('out2',out.shape)

        # out=out.permate(0,2,3,1)
        out1 = gram_matrix(out)
        print('out',out1.shape)

        # out = t.bmm(out.T,out)
        scale = torch.sigmoid(out1)
        print('scale', scale.shape)
        b,d,h,w = out.size()
        out2=out.view(b,d, h*w)
        print('out3', out2.shape)
        out1 = torch.bmm(out2,scale)
        out1 =out1.view(out.size(0),out.size(1),out.size(2),out.size(3))
        print('out4', out1.shape)
        # out1=out1.unsqueeze(0)
        out1 = out1.expand(out.size(0), -1, -1, -1)
        print('out5', out1.shape)
        out1 =self.up(out1)
        out = out1.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3])

        out =out+x

        return out


# class CW(nn.Module):
#     def __init__(self, inplane, outplane, stride):
#         super(CW, self).__init__()
#         self.conv = conv3x3(inplane, outplane, stride)
#         self.max = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.up = nn.Upsample(scale_factor=2, mode='nearest')
#
#     def forward(self, x):
#         out = self.conv(x)
#         out = self.max(out)
#
#         # out=out.permate(0,2,3,1)
#         out = out.view((out.shape[0], out.shape[1]*out.shape[2], out.shape[3] ))
#         print('out =',out.shape)
#         print('out.t=',out.transpose(2,1).shape)
#         out = t.bmm(out.transpose(2,1), out)
#         scale = t.sigmoid(out)
#         out = out * scale
#
#         out = self.up(out)
#         out = out.view(x.shape[0],x.shape[1],x.shape[2],x.shape[3])
#         out = out + x
#
#         return out
#
# class CH(nn.Module):
#     def __init__(self, inplane, outplane, stride):
#         super(CH, self).__init__()
#         self.conv = conv3x3(inplane, outplane, stride)
#         self.max = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.up = nn.Upsample(scale_factor=2, mode='nearest')
#
#     def forward(self, x):
#         out = self.conv(x)
#         out = self.max(out)
#
#         # out=out.permate(0,2,3,1)
#         out = out.view((out.shape[0], out.shape[1]*out.shape[3], out.shape[2]))
#         out = t.bmm(out.t, out)
#         scale = t.sigmoid(out)
#         out = out * scale
#
#         out = self.up(out)
#         out = out.view(x.shape[0],x.shape[1],x.shape[2],x.shape[3])
#         out = out + x
#
#         return out


class Triple(nn.Module):
    def __init__(self, inplane, outplane, stride):
        super(Triple,self).__init__()
        self.hw = HW(inplane, outplane, stride)
        # self.ch = CH(inplane, outplane, stride)
        # self.cw = CW(inplane, outplane, stride)

    def forward(self,x):

        # x_out1 = self.cw(x)
        # x_out3 = self.ch(x)
        x_out2 = self.hw(x)


        out = x_out2+x

        return out
#
#
if __name__ == '__main__':
    import torch
    from time import time
    from tensorboardX import SummaryWriter

    pose = Triple(4, 4, 1)  # .cuda()

    input = torch.randn(5, 4, 32, 32)  # .cuda()
    # print(pose)
    output = pose(input)  ## type: torch.Tensor
    with SummaryWriter(comment='pose') as w:
        w.add_graph(pose, (input,))
        # w.add_graph(pose,input)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值