遥感融合残差网络多光谱成分提取-自建Lambda层
之前基于《Target-Adaptive CNN-Based Pansharpening》建立了TaPNN的框架,最近在拿来跑数据库的时候发现,当时为了省力,直接在skip connection里加了一个conv2d卷积层使输出一个4层的图像,尽管训练出来的图像是无误的,但是本着尽可能尊重文献的原则,还是希望可以从5层输入图像中提取出一个4层的多光谱图像,然后直连到最后一个卷积层的输出。
将所有层封装,需要所有的封装单位是Layer而非Tensor,在字符包埋后,由于Conv2D需要提供channel维度,所以一开始使用reshape命令使tensor维度增加,这个函数使x的属性变为Tensor而非Layer,所以引起封装错误,使用Lambda函数或Reshape 层函数解决
创建函数:
def get_ms_components(x): #x:Input image
return x[:,:,:,:4] #extract the ms component
自建一个Lambda层:
img_input1=Lambda(get_ms_components)(img_input)
skip connection:
conv4=layers.add([img_input1,conv3])
整体框架:
def get_ms_components(x):
return x[:,:,:,:4]
img_input = Input(shape=(image_height_size, image_width_size, n_bands))
img_input1=Lambda(get_ms_components)(img_input)
conv1 = Conv2D(n1, (f1, f1), padding='same', activation='relu')(img_input)
conv2 = Conv2D(n2, (f2, f2), padding='same', activation='relu')(conv1)
conv3 = Conv2D(n_bands - 1, (f3, f3), padding='same')(conv2)
conv4=layers.add([img_input1,conv3])
model = Model(inputs=img_input, outputs=conv4)
model.compile(optimizer=Adam(lr=l_r), loss='mae', metrics=['mse'])
return model