TSM-normal_方法

复现TSM,碰到以下语句,之前在学习中是浅尝辄止,现在为了更好学透深度学习,遂决定直接搞懂这个代码


if self.new_fc is None:
    normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
    constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)

这里主要有normal_方法和constant_方法,下面逐一解释

提取网络层的信息

这里先解释一下提取网络层的信息

每个层的键的名称是由该层的类型决定的,和该层在网络中的位置无关。所以,每个层的键应该是:

  • nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1): 0

  • nn.ReLU(inplace=True): 1

  • nn.MaxPool2d(kernel_size=2, stride=2): 2

  • nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1): 3

  • nn.ReLU(inplace=True): 4

  • nn.MaxPool2d(kernel_size=2, stride=2): 5

  • nn.AdaptiveAvgPool2d((1, 1)): 6

  • nn.Flatten(): 7

  • nn.Linear(128, 64): 8

  • nn.ReLU(inplace=True): 9

  • nn.Linear(64, 10): 10

请注意,这里的键只是一个数字标识符,用于在模型中唯一地标识每个层的权重和偏差。这些键的名称没有具体含义,只要它们在整个模型中是唯一的,就可以任意选择键的值。

例如:


import torch.nn as nn

base_model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), #0
    nn.ReLU(inplace=True),#1
    nn.MaxPool2d(kernel_size=2, stride=2),#2
    nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),#3
    nn.ReLU(inplace=True),#4
    nn.MaxPool2d(kernel_size=2, stride=2),#5
    nn.AdaptiveAvgPool2d((1, 1)),#6
    nn.Flatten(),#7
    nn.Linear(128, 64),#8
    nn.ReLU(inplace=True),#9
    nn.Linear(64, 10)#10
)

last_layer = getattr(base_model, '8')
print(last_layer)

结果


Linear(in_features=128, out_features=64, bias=True)

normal_

具体来说,该行代码中的 normal_() 方法会对 self.base_model 模型中的 self.base_model.last_layer_name 层的权重进行高斯分布初始化,其中第一个参数是权重张量,第二个参数 0 表示均值为 0,第三个参数 std 表示标准差。这里使用下划线后缀的 normal_ 方法表示在原地修改权重张量,而不是返回一个新的张量。

例如,假设我们有一个名为 base_model 的模型,其中包含一个名为 fc 的全连接层,我们可以使用以下代码对该层的权重进行高斯分布初始化:


import torch.nn as nn

base_model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(128, 64),
    nn.ReLU(inplace=True),
    nn.Linear(64, 10)
)

last_layer_name = getattr(base_model, '10')  # 最后一层
# last_layer = getattr(base_model, 'fc')
print(last_layer_name)
std = 0.01
# nn.init.normal_(getattr(base_model, last_layer_name).weight, 0, std)
print(nn.init.normal_(getattr(base_model, '10').weight, 0, std))  # 对最后一层进行初始化


结果:


Linear(in_features=64, out_features=10, bias=True)
Parameter containing:
tensor([[-1.3272e-02, -2.2359e-03,  1.0343e-02, -1.0792e-02, -1.2242e-02,
         -1.7024e-02, -5.8491e-03,  2.3467e-03,  4.8827e-03, -8.6968e-03,
          1.2297e-03, -1.2535e-03,  1.2796e-02, -2.5463e-03, -3.9475e-03,
         -7.0661e-03, -1.1183e-03, -2.8131e-03,  1.6252e-02, -1.2049e-03,
          8.7546e-03, -6.3224e-03, -4.5075e-03, -3.3601e-03, -5.9193e-03,
          5.8264e-03,  4.2756e-03,  6.2804e-03,  1.7132e-02,  9.2935e-03,
         -1.7752e-02, -1.3216e-03, -3.6371e-04,  5.7609e-04, -6.0026e-03,
         -6.1346e-03, -2.7561e-03, -1.4461e-02, -1.1557e-02,  3.0463e-03,
         -1.0108e-02, -4.7012e-03,  9.2200e-04, -2.0413e-04, -6.1705e-03,
         -3.6149e-03, -1.0387e-02, -2.1668e-03,  2.4830e-02,  1.4840e-02,
          6.0030e-03,  4.5739e-03, -8.8294e-03, -9.5166e-05, -1.7639e-03,
          2.0190e-02, -3.2003e-03, -1.2733e-02, -7.3871e-04, -1.3978e-03,
         -3.9033e-03,  6.4459e-03, -1.1231e-02,  8.8730e-03],
        [-2.3404e-02,  5.0958e-03, -2.8636e-03,  7.4566e-03, -1.3217e-03,
          9.6338e-03,  4.0845e-03, -9.4511e-03, -1.5848e-02,  9.6535e-03,
         -1.9223e-02,  5.6821e-03, -9.5851e-03,  6.4113e-03,  1.2916e-02,
          4.4283e-03, -1.4849e-02,  1.7074e-04, -6.9544e-03, -2.0033e-02,
          1.1659e-03,  8.3090e-03,  1.0376e-02, -1.3552e-03,  5.1678e-04,
         -1.1280e-02, -1.5385e-03,  2.3204e-03,  2.1148e-02, -8.2046e-03,
         -3.1950e-03, -2.6627e-03, -8.0619e-03,  6.8565e-03, -1.5831e-03,
         -5.3802e-03,  4.9502e-03, -1.0993e-03,  1.5999e-03,  1.6042e-02,
          1.7201e-02,  1.0180e-02,  1.6592e-03, -5.1286e-03, -9.9063e-03,
         -1.0357e-02, -1.0337e-02, -6.0520e-03,  2.3609e-02, -5.8884e-03,
          9.9698e-03,  1.3698e-02, -1.2772e-02,  8.2779e-03, -7.6448e-03,
          1.1789e-02,  1.1248e-02, -7.1244e-03,  4.6928e-03, -1.1212e-03,
          2.8188e-03, -1.2279e-02,  6.7612e-03, -3.8642e-03],
        [ 3.1708e-03,  4.8198e-03, -1.3153e-03,  5.7256e-03,  2.9513e-03,
          6.0120e-03, -1.2716e-02, -2.9583e-02, -1.3539e-02,  9.6046e-03,
         -2.3146e-03, -5.9442e-03, -3.5330e-03,  6.3374e-03, -2.2096e-03,
         -3.5567e-03, -1.0496e-02, -9.1474e-03,  1.6573e-02, -2.7625e-03,
          7.2689e-03,  5.5843e-03,  1.9446e-03,  3.9445e-03, -1.0196e-02,
          6.3983e-03, -1.1957e-03,  1.9038e-03, -3.2439e-03, -9.9891e-03,
          1.7751e-03, -1.2842e-02, -1.0921e-02,  7.6490e-03, -6.7258e-03,
          2.7367e-03, -5.8537e-03, -4.4515e-03, -1.8622e-03,  1.8290e-03,
         -2.1976e-02, -1.0761e-02,  8.5432e-03, -7.4048e-04,  9.6255e-03,
          4.9710e-03,  2.6487e-03,  1.1278e-02, -2.2165e-02, -5.3400e-03,
         -1.2628e-02, -9.0693e-03,  1.1717e-04, -3.9173e-03,  7.0556e-03,
         -3.2840e-03, -2.0703e-02,  1.2574e-02, -1.9498e-03,  6.0815e-03,
          2.0596e-03, -3.7182e-03, -1.2596e-02, -1.2627e-02],
        [-9.8718e-03,  5.4179e-03,  1.6379e-03,  5.9276e-03,  3.1523e-04,
         -8.7425e-03, -7.6249e-03, -4.3939e-03, -6.2318e-03,  1.4632e-02,
          2.2710e-03,  9.9909e-03, -1.1965e-02,  1.3041e-02, -9.0492e-03,
         -1.1099e-03,  1.7039e-03,  1.1821e-02,  3.0194e-03, -6.0026e-03,
          1.6889e-02, -5.0959e-03, -1.9247e-03,  1.8685e-03, -1.7513e-02,
          1.6138e-02, -7.7140e-03,  1.0976e-02,  1.6791e-02, -1.3144e-02,
          1.0865e-03,  7.4303e-03, -3.3156e-03,  2.6016e-03, -2.5576e-03,
          6.6795e-03,  3.3842e-03,  8.1809e-03, -1.4860e-02, -9.7519e-03,
         -1.7827e-03,  3.6443e-03, -1.1853e-02, -1.8691e-02,  1.5285e-02,
         -1.9333e-02,  1.2217e-02, -1.9054e-02,  9.4615e-03, -2.2510e-03,
          4.9319e-03,  3.0562e-03,  1.0821e-02,  1.0460e-02,  2.2750e-03,
          1.5753e-02,  3.2901e-03, -2.3684e-02, -2.2839e-03, -2.0044e-03,
          1.7136e-02,  1.1856e-02, -6.5554e-04, -5.5353e-03],
        [-1.7280e-02, -6.1747e-04, -2.2409e-02, -4.7504e-03,  5.8930e-05,
          2.0783e-02, -2.3622e-03, -5.2687e-03, -5.8447e-03, -1.0608e-02,
         -1.0498e-02, -9.5139e-03,  1.3671e-02,  5.2071e-03,  6.2913e-03,
          1.5897e-02,  1.9130e-03, -1.7698e-02, -1.1381e-02, -2.4493e-03,
         -9.0661e-03,  3.0104e-04, -3.1901e-03, -1.4501e-02,  6.6054e-04,
          5.9062e-03,  9.6727e-03, -1.7085e-02, -7.1479e-03,  1.1499e-02,
          2.7148e-02, -1.6385e-02, -2.7822e-03,  1.5409e-02, -1.6170e-03,
          2.0706e-03, -1.0137e-03, -6.9971e-03,  4.5867e-03, -1.7533e-03,
         -1.5169e-02,  7.1676e-03,  7.3481e-04,  3.9260e-03,  2.2786e-02,
         -4.4119e-03, -5.0950e-04,  1.8821e-02, -9.7902e-03, -1.6165e-02,
          1.0704e-02,  1.1139e-03,  1.0848e-02,  8.7063e-03, -1.5427e-02,
          6.0734e-03, -6.5893e-03, -6.0677e-03,  3.6704e-03, -6.9275e-03,
         -6.3169e-03,  5.0168e-03,  5.0394e-03,  9.5354e-03],
        [-2.5179e-04,  3.3235e-03, -2.6350e-03, -2.8805e-03,  1.7708e-02,
          7.3398e-03,  1.8750e-03, -1.1013e-02, -1.4283e-03, -2.4949e-03,
          1.8975e-02, -6.0564e-03, -9.4183e-03,  7.4389e-03,  1.9309e-03,
         -7.5402e-03,  1.4321e-02, -1.4126e-02,  8.1422e-03,  5.2280e-03,
          5.0304e-03, -5.1896e-03,  5.6178e-03,  2.7014e-02,  8.8836e-04,
         -7.5104e-03, -2.3235e-03,  1.8136e-04, -2.9145e-02,  1.4800e-02,
         -1.2273e-04,  6.1381e-03, -1.5862e-02, -6.1995e-04, -2.7211e-03,
         -1.9053e-02,  5.0899e-03, -5.4222e-03,  1.0337e-02,  7.6167e-03,
         -3.2999e-03,  1.0699e-03,  1.3340e-02, -3.8160e-03,  3.0833e-03,
         -5.2089e-03,  9.0612e-04,  1.4491e-02, -4.5410e-03,  5.4649e-03,
         -3.3898e-03,  2.8065e-03,  1.5491e-02,  7.5988e-03,  1.5773e-02,
          1.5484e-02,  5.4282e-03,  2.5454e-03, -2.1613e-02, -1.5429e-02,
         -1.2897e-02, -5.3088e-03,  1.1335e-02, -9.0223e-03],
        [ 1.0117e-02, -3.3639e-03, -5.0150e-03,  3.2073e-03,  1.0271e-02,
         -6.6959e-03, -2.1131e-03, -1.2018e-02,  4.2930e-03, -1.6980e-04,
         -7.7216e-03,  5.6076e-03,  1.4555e-02,  1.7849e-02,  8.0165e-03,
          4.1849e-03,  5.5320e-03,  1.5881e-02, -6.8613e-03,  1.2461e-03,
         -9.2352e-03,  1.1187e-02, -7.9894e-03, -1.6583e-02,  1.4254e-02,
          7.2171e-04,  9.4763e-03, -8.1024e-03, -1.9460e-02,  7.6837e-03,
          2.7487e-04,  1.1689e-02,  4.6567e-03, -1.1756e-02,  2.4855e-03,
         -4.1040e-03, -9.1597e-03,  1.5789e-02, -1.8456e-03,  2.1223e-02,
          9.2912e-03, -1.5335e-02, -4.1271e-03, -6.4253e-04,  9.4843e-03,
         -3.3001e-03, -4.1901e-03, -1.0254e-02,  8.9056e-03,  9.9998e-04,
         -1.8859e-03,  3.7849e-03, -2.8724e-03,  7.0505e-04, -7.0398e-03,
         -2.5400e-03, -5.2459e-03,  7.9450e-03,  2.2319e-02, -1.0812e-02,
          1.5204e-02, -4.8161e-03, -7.8047e-04, -8.8051e-03],
        [-5.7175e-03,  2.0120e-02, -1.1453e-02, -1.2576e-02,  2.1838e-02,
         -1.5232e-02,  3.1967e-03,  1.2660e-02, -2.0898e-02,  6.2341e-03,
          4.6589e-04, -1.0194e-02,  7.9064e-03,  8.0972e-03,  5.9819e-03,
          1.2644e-02,  1.6601e-02, -9.0595e-04, -2.5263e-02, -7.0945e-03,
         -9.3729e-03,  9.0826e-03, -5.7136e-03,  4.9057e-03,  1.1597e-02,
          1.1955e-02, -7.8716e-04,  9.9030e-04,  5.8346e-03, -3.2973e-04,
         -7.7189e-03, -3.7571e-03, -1.3204e-02, -1.0440e-02,  2.9435e-03,
         -2.8907e-04,  1.2057e-03, -1.0351e-03, -5.7206e-03,  1.6962e-03,
         -3.2682e-03,  1.6592e-03, -8.5040e-03,  2.2232e-02, -5.1094e-04,
          1.2425e-02, -3.0755e-03,  1.1618e-02,  3.6595e-03, -1.5270e-02,
          4.9968e-03,  4.4446e-03,  1.1779e-02,  1.5565e-02, -1.2305e-02,
          2.8609e-03, -3.9866e-03,  1.3608e-02, -2.0619e-02, -4.7859e-03,
         -1.6829e-02, -7.4733e-03, -1.9138e-02,  1.7258e-03],
        [-1.7574e-02,  1.2523e-02,  2.0962e-03,  4.2426e-03,  8.0573e-03,
          2.2357e-02,  1.2657e-02,  3.5991e-03,  4.3030e-03, -7.8645e-03,
         -2.4566e-03,  5.2175e-03, -8.6353e-03, -8.3184e-03, -1.1575e-02,
          4.3127e-03, -5.1229e-03,  5.4804e-03, -8.9780e-03,  1.5546e-04,
         -1.4716e-02, -2.0695e-02, -7.6624e-03, -1.1113e-02, -5.7346e-03,
          6.1758e-03, -5.6148e-03,  2.0378e-05,  1.0982e-02,  8.3183e-04,
         -9.1416e-04, -3.9552e-03, -9.1955e-03,  1.1313e-03, -2.9609e-04,
         -1.1788e-02,  3.7255e-04,  1.0457e-02, -1.8796e-02,  2.8319e-03,
         -5.6307e-03, -6.3487e-04, -1.1184e-02, -1.4268e-02,  1.1114e-02,
         -1.2992e-02, -1.8135e-03, -1.0604e-02,  9.8879e-03, -2.3624e-03,
          1.3414e-02,  8.9875e-03,  2.2747e-02,  3.8558e-03, -3.3536e-03,
          9.5849e-03, -2.1084e-03,  6.9714e-04, -6.0838e-03, -8.5648e-03,
         -1.2717e-02, -5.1056e-04,  8.7422e-03, -4.1311e-03],
        [ 1.5267e-02, -1.0984e-02, -2.0117e-02,  1.1511e-02,  7.1848e-03,
         -4.4533e-03,  4.0069e-03,  2.2058e-02,  2.6312e-02, -1.1855e-02,
         -5.7762e-03, -2.0696e-03,  1.5378e-02,  1.8970e-02, -2.1807e-02,
          9.5076e-03,  5.6942e-03,  9.4878e-03,  6.0089e-05, -2.0919e-02,
         -2.0222e-03,  2.0285e-02,  2.7224e-03, -1.0630e-02,  5.5314e-03,
         -5.6111e-03,  1.1118e-02,  3.5967e-03,  8.3760e-03, -3.1836e-03,
         -9.6490e-03,  1.3429e-02,  4.2757e-03,  1.0438e-02, -2.9009e-03,
          2.3239e-03,  5.5037e-03,  1.1511e-02,  1.2017e-02,  1.3676e-03,
          1.5089e-02, -1.3869e-02, -8.1693e-04,  1.3668e-03,  4.1932e-03,
         -2.5586e-04,  9.7049e-05, -1.4709e-03,  8.5214e-03,  1.2624e-02,
          1.6878e-03, -2.0740e-03,  8.4441e-03, -2.2894e-02,  1.7953e-03,
          1.9498e-02,  2.4912e-02, -2.5735e-03, -5.9107e-03, -1.7209e-03,
          1.8648e-02,  5.6585e-03, -1.8484e-03, -9.2792e-03]],
       requires_grad=True)

在上述代码中,我们使用 nn.init.normal_() 方法对 base_model 模型中的 最后一层也就是nn.Linear(64, 10) 层的权重进行高斯分布初始化,其中第一个参数 getattr(base_model, last_layer_name).weight 表示获取 base_model 模型中名为 last_layer_name 的层的权重张量;第二个参数 0 表示均值为 0;第三个参数 std 表示标准差为 0.01。最终得到的结果是对 fc 层的权重进行了高斯分布初始化。

constant_

这行代码使用 PyTorch 提供的 constant_() 方法将指定层的偏置设置为常量值。

具体来说,该行代码中的 constant_() 方法会将 self.base_model 模型中的 self.base_model.last_layer_name 层的偏置设置为常量值 0,其中第一个参数是偏置张量,第二个参数 0 表示常量值。这里使用下划线后缀的 constant_ 方法表示在原地修改偏置张量,而不是返回一个新的张量。

例如,假设我们有一个名为 base_model 的模型,其中包含一个名为 fc 的全连接层,我们可以使用以下代码将该层的偏置设置为常量值 0


import torch.nn as nn

base_model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(128, 64),
    nn.ReLU(inplace=True),
    nn.Linear(64, 10)
)


last_layer_name = getattr(base_model, '10')  # 最后一层
print(last_layer_name)

# nn.init.constant_(getattr(base_model, '10').bias, 0)
print(nn.init.constant_(getattr(base_model, '10').bias, 0))

结果


Linear(in_features=64, out_features=10, bias=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

在上述代码中,我们使用 nn.init.constant_() 方法将 base_model 模型中的 nn.Linear(64, 10) 层的偏置设置为常量值 0,其中第一个参数 getattr(base_model, last_layer_name).bias 表示获取 base_model 模型最后一层的偏置张量;第二个参数 0 表示常量值。最终得到的结果是将 最后一层的偏置设置为常量值 0

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值