导语 :使用pytorch搭好网络之后,没有显式地初始化网络参数,是否可以直接训练网络呢?
class _ConvNd(Module):
__constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
'padding_mode', 'output_padding', 'in_channels',
'out_channels', 'kernel_size']
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode):
super(_ConvNd, self).__init__()
... ... # 此处略去几十行代码
self.reset_parameters() # 核心代码
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
Remark :
- class Convkd(_ConvNd), where k = {1, 2, 3}.
- class ConvTransposekd(_ConvTransposeMixin, _ConvNd), where k = {1, 2, 3}.
-
class _NormBase(Module): """Common base of _InstanceNorm and _BatchNorm""" _version = 2 __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked', 'num_features', 'affine'] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_NormBase, self).__init__() ... ... # 此处略去几十行代码 self.reset_parameters() # 核心代码 def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: init.ones_(self.weight) init.zeros_(self.bias)
class _BatchNorm(_NormBase):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(_BatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
Remark : class BatchNormkd(_BatchNorm), where k = {1, 2, 3}.
Linear
class Linear(Module):
__constants__ = ['bias', 'in_features', 'out_features']
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
... ... # 此处略去几十行代码
self.reset_parameters() # 核心代码
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
结语 :使用Pytorch搭建网络后,如果不显式地指定网络参数的初始化方式,模型则采用默认的初始化方式!