HybridSN代码修改

研究小组里初学习深度学习的同学我都布置写过 HybridSN 的代码:https://github.com/OUCTheoryGroup/colab_demo/blob/master/202003_models/HybridSN_GRSL2020.ipynb

最近做SRDP的同学反映,跑 Pavia 数据集的时候内存会爆,主要原因是 createImageCubes 这个函数有个地方:

patchesData = np.zeros([ width*height, windowSize, windowsSize, spectral_num])

因为 Pavia 数据集尺寸较大,width*height 就比较大了,内存会爆掉。

其实图像中大部分为是0,没有label,我们要取的,只是有label的部分。现在我改了改,先做个循环,看看有多少个像素有 label,然后记录在 count 里,分配内存时:

patchesData = np.zeros([count, windowSize, windowSize, spectral_num])

这样 count 比以前的 width*height 要小很多,内存就不会爆了

修改后的代码如下,供感兴趣的同学参考(Github上的我就不改了,留给以后新同学排雷):

# 在每个像素周围提取 patch 
def createImageCubes(X, y, windowSize=5, removeZeroLabels = True):
    # 给 X 做 padding
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # 获得 y 中的标记样本数
    count = 0
    for r in range(0, y.shape[0]):
        for c in range(0, y.shape[1]):
            if y[r, c] != 0:
                count = count+1

    # split patches
    patchesData = np.zeros([count, windowSize, windowSize, X.shape[2]])
    patchesLabels = np.zeros(count)

    count = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            if y[r-margin, c-margin] != 0:
                patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]   
                patchesData[count, :, :, :] = patch
                patchesLabels[count] = y[r-margin, c-margin]
                count = count + 1

    return patchesData, patchesLabels
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值