该楼层疑似违规已被系统折叠 隐藏此楼查看此楼
训练数据用的是自制数据,如下所示,左图是输入数据,右图是输出标签
优化器选择adam,损失函数选择MSE,训练数据只有15张图片,测试图片有2张。
刚开始选择标准Unet训练数据,没有任何问题,后来在每层Unet中加入ResNet后loss从训练一开始就开始震荡。学习速率从1e-5~0.1之间都试过,batchsize从4~7之间也试过,但这个问题一直存在。
loss、accuracy都一直在震荡,情况如下图所示
想请教大佬们这种情况该怎么解决?
加入ResNet的网络结构代码如下
def Res_block(cov_input, filter_size):
con_shortcut = Conv2D(filter_size, (1, 1), padding = 'same', kernel_initializer = 'he_normal')(cov_input)
con1 = Conv2D(filter_size, (1, 1), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cov_input) con1 = BatchNormalization()(con1)
con2 = Conv2D(filter_size, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(con1) con2 = BatchNormalization()(con2)
con3 = Conv2D(filter_size, (1, 1), padding = 'same', kernel_initializer = 'he_normal')(con2)
con = Add()([con3, con_shortcut])
con = Activation("relu")(con)
cov_out = BatchNormalization()(con)
return cov_out
def Unet(input_height = 224, input_width = 224, loss_names = "binary_crossentropy", optimizer_names = "adam"):
main_inputs = Input(shape = (input_height, input_width, 6))
# encode
# 第一层
conv1 = Res_block(main_inputs, 32)
conv1 = Res_block(conv1, 32)
pool1 = MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(conv1)
# 第二层
conv2 = Res_block(pool1, 64)
conv2 = Res_block(conv2, 64)
pool2 = MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(conv2)
# 第三层
conv3 = Res_block(pool2, 128)
conv3 = Res_block(conv3, 128)
pool3 = MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(conv3)
# 第四层
conv4 = Res_block(pool3, 256)
conv4 = Res_block(conv4, 256)
pool4 = MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(conv4)
# 第五层
conv5 = Res_block(pool4, 512)
conv5 = Res_block(conv5, 512)
pool5 = MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(conv5)
# 第六层
conv6 = Res_block(pool5, 1024)
conv6 = Res_block(conv6, 1024)
# decode
# 第七层
up7 = (UpSampling2D((2, 2)))(conv6)
up7 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up7)
conv7 = concatenate([up7, conv5], axis = -1)
conv7 = Res_block(conv7, 512)
conv7 = Res_block(conv7, 512)
# 第八层
up8 = (UpSampling2D((2, 2)))(conv7)
up8 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up8)
conv8 = concatenate([up8, conv4], axis = -1)
conv8 = Res_block(conv8, 256)
conv8 = Res_block(conv8, 256)
# 第九层
up9 = (UpSampling2D((2, 2)))(conv8)
up9 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up9)
conv9 = concatenate([up9, conv3], axis = -1)
conv9 = Res_block(conv9, 128)
conv9 = Res_block(conv9, 128)
# 第十层
up10 = (UpSampling2D((2, 2)))(conv9)
up10 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up10)
conv10 = concatenate([up10, conv2], axis = -1)
conv10 = Res_block(conv10, 64)
conv10 = Res_block(conv10, 64)
# 第十一层
up11 = (UpSampling2D((2, 2)))(conv10)
up11 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(up11)
conv11 = (concatenate([up11, conv1], axis = -1))
conv11 = Res_block(conv11, 32)
conv11 = Res_block(conv11, 32)
# 输出层
o = Conv2D(32, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv11)
o = BatchNormalization()(o)
o = Conv2D(32, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(o)
o = BatchNormalization()(o)
o = Conv2D(16, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(o)
o = BatchNormalization()(o)
o = Conv2D(16, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(o)
o = BatchNormalization()(o)
o = Conv2D(8, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(o)
o = Conv2D(4, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(o)
o = Conv2D(2, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(o)
o = Conv2D(1, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(o)
o = Conv2D(1, (1, 1), activation = 'relu', kernel_initializer = 'he_normal')(o)
model = Model(inputs = main_inputs, outputs = o)