网络参数初始化(整体)
from torch.nn import init
def init_net(net, init_type='normal'):
init_weights(net, init_type)
return net
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
# this will apply to each layer
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname