深度网络在做拟合乘法除法这些操作时候很难去使用像Dense这样得网络去凑出来。
使用二进制得角度的确是可以使用线性的方式来处理乘法除法甚至sin等非线性运算,二想要使用线性运算加激活函数来去拟合这样非线性组合的函数,就只能多重的非线性组合来进行操作。我这里使用几种组合来尝试拟除法。
使用单层网络:
# network weights
input_layer = Input(shape=(2,), name="unet_input")
fc1 = Dense(100)(input_layer)
fc1 = ReLU()(fc1)
# fc2 = Dense(32)(fc1)
# fc2 = ReLU()(fc2)
#get Q_value
output_layer= Dense(1)(fc1)
model = Model(input=[input_layer], output=[output_layer], name='Q_net')
model.compile(loss='mse',optimizer=Adam(lr=0.0001))
使用两层网络:
# network weights
input_layer = Input(shape=(2,), name="unet_input")
fc1 = Dense(100)(input_layer)
fc1 = ReLU()(fc1)
fc2 = Dense(32)(fc1)
fc2 = ReLU()(fc2)
#get Q_value
output_layer= Dense(1)(fc2)
model = Model(input=[input_layer], output=[output_layer], name='Q_net')
model.compile(loss='mse',optimizer=Adam(lr=0.0001))
使用三层网络:
input_layer = Input(shape=(2,), name="unet_input")
fc1 = Dense(100)(input_layer)
fc1 = ReLU()(fc1)
fc2 = Dense(32)(fc1)
fc2 = ReLU()(fc2)
fc3 = Dense(16)(fc2)
fc3 = ReLU()(fc3)
output_layer= Dense(1)(fc3)
model = Model(input=[input_layer], output=[output_layer], name='v_net')
model.compile(loss='mse',optimizer=Adam(lr=0.0001))
差不多到800次epochs后下降很明显的。
修改一下三层网络的参数
# network weights
input_layer = Input(shape=(2,), name="unet_input")
fc1 = Dense(64)(input_layer)
fc1 = ReLU()(fc1)
fc2 = Dense(32)(fc1)
fc2 = ReLU()(fc2)
fc3 = Dense(16)(fc2)
fc3 = ReLU()(fc3)
#get Q_value
output_layer= Dense(1)(fc3)
model = Model(input=[input_layer], output=[output_layer], name='Q_net')
model.compile(loss='mse',optimizer=Adam(lr=0.0001))
更换激活函数
# network weights
input_layer = Input(shape=(2,), name="unet_input")
fc1 = Dense(64)(input_layer)
fc1 = LeakyReLU(0.2)(fc1)
fc2 = Dense(32)(fc1)
fc2 = LeakyReLU(0.2)(fc2)
fc3 = Dense(16)(fc2)
fc3 = LeakyReLU(0.2)(fc3)
#get Q_value
output_layer= Dense(1)(fc3)
model = Model(input=[input_layer], output=[output_layer], name='Q_net')
model.compile(loss='mse',optimizer=Adam(lr=0.0001))
效果略差于使用Relu做激活函数,这个也可以理解,因为Relu比LeakyReLU更加非线性。
得到的几轮可以得出的结论:
- 越深的网络对于非线性的拟合效果越好。
- 单元多一些带来的改变很小。
- 若想要更非线性一点用Relu更好一点。
代码目录
https://github.com/hlzy/operator_dnn