def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):#判断m对象是否属于这些类
nn.init.xavier_uniform(m.weight.data)
nn.init.constant(m.bias, 0.1)
class Net:
def __init():
self.apply(weights_init)#权重初始化
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):#判断m对象是否属于这些类
nn.init.xavier_uniform(m.weight.data)
nn.init.constant(m.bias, 0.1)
class Net:
def __init():
self.apply(weights_init)#权重初始化