import torch
from torch import nn
class FRN(nn.Module):
def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
"""
weight = gamma, bias = beta
beta, gamma:
Variables of shape [1, 1, 1, C]. if TensorFlow
Variables of shape [1, C, 1, 1]. if PyTorch
eps: A scalar constant or learnable variable.
"""
super(FRN, self).__init__()
self.num_features = num_features
self.init_eps = eps
self.is_eps_leanable = is_eps_leanable
self.weight = nn.parameter.Parameter(
torch.Tensor(1, num_features, 1, 1), requires_grad=True)
self.bias = nn.parameter.Parameter(
torch.Tensor(1, num_features, 1, 1), requires_grad=True)
if is_eps_leanable:
self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True)
else:
self.register_buffer('eps', torch.Tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.is_eps_leanable:
nn.init.constant_(self.eps, self.init_eps)
def extra_repr(self):
return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
def forward(self, x):
"""
0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
0, 1, 2, 3 -> (B, C, H, W) in PyTorch
TensorFlow code
nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
x = x * tf.rsqrt(nu2 + tf.abs(eps))
# This Code include TLU function max(y, tau)
return tf.maximum(gamma * x + beta, tau)
"""
# Compute the mean norm of activations per channel.
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
# Perform FRN.
x = x * torch.rsqrt(nu2 + self.eps.abs())
# Scale and Bias
x = self.weight * x + self.bias
return x
FRN层pytorch实现
最新推荐文章于 2024-09-09 21:11:50 发布
该代码定义了一个名为FRN的PyTorch模块,它实现了一种正则化层,包括权重、偏置和可学习的ε参数。层首先计算激活的通道均方根,然后应用FRN,包括尺度和偏置操作,最后可能包含TLU函数(最大值操作)。
摘要由CSDN通过智能技术生成