错误写法1 (在本案例中,此错误写法导致第1个epoch的测试集精度降低5%)
class Net(nn.Module):
"""
简写约定:
小格子:cell, (一个28*28的图片被分割成7*7个 4*4的小格子)
层级:level:lv
选择器:selector:sel
分类器:classify:clsfy
尺寸:size:sz
输出:output:o
激活:active:a, 经过非线性函数后的输出 称为 激活 或 激活值
输入:input:in
图片高度:高度:height:h
图片宽度:宽度:width:w
图片高度*图片宽度:height*width:HW:hw
全连接:full connection:fc
self.level*_fc*:lv*_fc*
self.level1_selector:lv1_sel
self.classify_fc:clsfy_fc
LV1_FC_O_SZ:LEVEL1_FC_OUTPUT_SIZE
"""
LEVEL1_FC_CNT =5
lv1_w_pair_cnt=int(LEVEL1_FC_CNT*(LEVEL1_FC_CNT-1)/2)
def __init__(self):
super(Net, self).__init__()
CELL_HW=MnistDim.CELL_HW
CELL_H = MnistDim.CELL_H
CELL_W = MnistDim.CELL_W
CELL_CNT_H = MnistDim.CELL_CNT_H
CELL_CNT_W = MnistDim.CELL_CNT_W
LV1_FC_O_SZ = MnistDim.LV1_FC_O_SZ
CLSFY_FC_IN_SZ = MnistDim.CLSFY_FC_IN_SZ
self.lv1_sel=nn.Linear(CELL_HW, Net.LEVEL1_FC_CNT)
self.lv1_fc = [nn.Linear(CELL_HW, LV1_FC_O_SZ, bias=False)]*Net.LEVEL1_FC_CNT
错误写法2 (在本案例中,此错误写法导致第1个epoch的测试集精度降低2%)
class Net(nn.Module):
"""
简写约定:
小格子:cell, (一个28*28的图片被分割成7*7个 4*4的小格子)
层级:level:lv
选择器:selector:sel
分类器:classify:clsfy
尺寸:size:sz
输出:output:o
激活:active:a, 经过非线性函数后的输出 称为 激活 或 激活值
输入:input:in
图片高度:高度:height:h
图片宽度:宽度:width:w
图片高度*图片宽度:height*width:HW:hw
全连接:full connection:fc
self.level*_fc*:lv*_fc*
self.level1_selector:lv1_sel
self.classify_fc:clsfy_fc
LV1_FC_O_SZ:LEVEL1_FC_OUTPUT_SIZE
"""
LEVEL1_FC_CNT =5
lv1_w_pair_cnt=int(LEVEL1_FC_CNT*(LEVEL1_FC_CNT-1)/2)
def __init__(self):
super(Net, self).__init__()
CELL_HW=MnistDim.CELL_HW
CELL_H = MnistDim.CELL_H
CELL_W = MnistDim.CELL_W
CELL_CNT_H = MnistDim.CELL_CNT_H
CELL_CNT_W = MnistDim.CELL_CNT_W
LV1_FC_O_SZ = MnistDim.LV1_FC_O_SZ
CLSFY_FC_IN_SZ = MnistDim.CLSFY_FC_IN_SZ
self.lv1_sel=nn.Linear(CELL_HW, Net.LEVEL1_FC_CNT)
self.lv1_fc = [nn.Linear(CELL_HW, LV1_FC_O_SZ, bias=False) for _ in range(Net.LEVEL1_FC_CNT) ]
正确写法
class Net(nn.Module):
"""
简写约定:
小格子:cell, (一个28*28的图片被分割成7*7个 4*4的小格子)
层级:level:lv
选择器:selector:sel
分类器:classify:clsfy
尺寸:size:sz
输出:output:o
激活:active:a, 经过非线性函数后的输出 称为 激活 或 激活值
输入:input:in
图片高度:高度:height:h
图片宽度:宽度:width:w
图片高度*图片宽度:height*width:HW:hw
全连接:full connection:fc
self.level*_fc*:lv*_fc*
self.level1_selector:lv1_sel
self.classify_fc:clsfy_fc
LV1_FC_O_SZ:LEVEL1_FC_OUTPUT_SIZE
"""
LEVEL1_FC_CNT =5
lv1_w_pair_cnt=int(LEVEL1_FC_CNT*(LEVEL1_FC_CNT-1)/2)
def __init__(self):
super(Net, self).__init__()
CELL_HW=MnistDim.CELL_HW
CELL_H = MnistDim.CELL_H
CELL_W = MnistDim.CELL_W
CELL_CNT_H = MnistDim.CELL_CNT_H
CELL_CNT_W = MnistDim.CELL_CNT_W
LV1_FC_O_SZ = MnistDim.LV1_FC_O_SZ
CLSFY_FC_IN_SZ = MnistDim.CLSFY_FC_IN_SZ
self.lv1_sel=nn.Linear(CELL_HW, Net.LEVEL1_FC_CNT)
self.lv1_fc1 = nn.Linear(CELL_HW, LV1_FC_O_SZ, bias=False)
self.lv1_fc2 = nn.Linear(CELL_HW, LV1_FC_O_SZ, bias=False)
self.lv1_fc3 = nn.Linear(CELL_HW, LV1_FC_O_SZ, bias=False)
self.lv1_fc4 = nn.Linear(CELL_HW, LV1_FC_O_SZ, bias=False)
self.lv1_fc5 = nn.Linear(CELL_HW, LV1_FC_O_SZ, bias=False)
self.lv1_fc = [self.lv1_fc1,self.lv1_fc2,self.lv1_fc3,self.lv1_fc4,self.lv1_fc5]