遥感融合残差网络多光谱成分提取-自建Lambda层

该博客介绍了如何基于《Target-AdaptiveCNN-BasedPansharpening》建立的TaPNN框架,改进skipconnection层,以从5层输入图像中精确提取出4层的多光谱图像。通过使用Lambda层实现这一功能,解决了因维度调整导致的封装错误。最终构建了一个包含Lambda层的模型,并进行了训练和编译。
摘要由CSDN通过智能技术生成

遥感融合残差网络多光谱成分提取-自建Lambda层

之前基于《Target-Adaptive CNN-Based Pansharpening》建立了TaPNN的框架,最近在拿来跑数据库的时候发现,当时为了省力,直接在skip connection里加了一个conv2d卷积层使输出一个4层的图像,尽管训练出来的图像是无误的,但是本着尽可能尊重文献的原则,还是希望可以从5层输入图像中提取出一个4层的多光谱图像,然后直连到最后一个卷积层的输出。
将所有层封装,需要所有的封装单位是Layer而非Tensor,在字符包埋后,由于Conv2D需要提供channel维度,所以一开始使用reshape命令使tensor维度增加,这个函数使x的属性变为Tensor而非Layer,所以引起封装错误,使用Lambda函数或Reshape 层函数解决

创建函数:

def get_ms_components(x): #x:Input image
    return x[:,:,:,:4]   #extract the ms component

自建一个Lambda层:

img_input1=Lambda(get_ms_components)(img_input)

skip connection:

conv4=layers.add([img_input1,conv3])

整体框架:

    def get_ms_components(x):
        return x[:,:,:,:4]

    img_input = Input(shape=(image_height_size, image_width_size, n_bands))
    img_input1=Lambda(get_ms_components)(img_input)
    conv1 = Conv2D(n1, (f1, f1), padding='same', activation='relu')(img_input)
    conv2 = Conv2D(n2, (f2, f2), padding='same', activation='relu')(conv1)
    conv3 = Conv2D(n_bands - 1, (f3, f3), padding='same')(conv2)

    conv4=layers.add([img_input1,conv3])

    model = Model(inputs=img_input, outputs=conv4)
    model.compile(optimizer=Adam(lr=l_r), loss='mae', metrics=['mse'])

    return model
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值