STARFM(Python版)

对应文献:2006-STARFM-On the Blending of the Landsat and MODIS Surface Reflectance

代码:

import numpy as np
import torch
import torch.nn as nn
import time
#import skimage.measure as  sm
import skimage.metrics  as  sm
import cv2
from osgeo import gdal,gdalconst
import matplotlib.pyplot as plt

###img read tool###############################################################
def imgread(file,mode='gdal'):
    if mode=='cv2':
        img=cv2.imread(file,-1)/10000.  #/10000.
    if mode=='gdal':
        img=gdal.Open(file).ReadAsArray()/10000.   #/10000.
    return img

###weight caculate tools######################################################
def weight_caculate(data):
    return  torch.log((abs(data)*10000+1.00001))

def caculate_weight(l1m1,m1m2):
    #atmos difference
    wl1m1=weight_caculate(l1m1 )
    #time deference
    wm1m2=weight_caculate(m1m2 )
    return  wl1m1*wm1m2

###space distance caculate tool################################################
def indexdistance(window):
    #one window, one distance weight matrix
    [distx,disty]=np.meshgrid(np.arange(window[0]),np.arange(window[1]))
    centerlocx,centerlocy=(window[0]-1)//2,(window[1]-1)//2
    dist=1+(((distx-centerlocx)**2+(disty-centerlocy)**2)**0.5)/((window[0]-1)//2)
    return  dist

###threshold select tool######################################################
def weight_bythreshold(weight,data,threshold):
    #make weight tensor
    weight[data<=threshold]=1
    return  weight
def weight_bythreshold_allbands(weight,l1m1,m1m2,thresholdmax):
    #make weight tensor
    weight[l1m1<=thresholdmax[0]]=1
    weight[m1m2<=thresholdmax[1]]=1
    allweight=(weight.sum(0).view(1,weight.shape[1],weight.shape[2]))/weight.shape[0]
    allweight[allweight!=1]=0
    return  allweight


###initial similar pixels tools################################################
def spectral_similar_threshold(clusters,NIR,red):
    thresholdNIR=NIR.std()*2/clusters
    thresholdred=red.std()*2/clusters
    return  (thresholdNIR,thresholdred)  

def caculate_similar(l1,threshold,window):
    #read l1
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    l1=nn.functional.unfold(l1,window)    
    #caculate similar
    weight=torch.zeros(l1.shape,dtype=torch.float32).to(device)  
    centerloc=( l1.size()[1]-1)//2
    weight=weight_bythreshold(weight,abs(l1-l1[:,centerloc:centerloc+1,:]) ,threshold)
    return weight

def classifier(l1):
    '''not used'''
    return

###similar pixels filter tools#################################################
def allband_arrayindex(arraylist,indexarray,rawindexshape):
    shape=arraylist[0].shape
    datalist=[]
    for array in arraylist:
        newarray=torch.zeros(rawindexshape,dtype=torch.float32).cuda()
        for band in range(shape[1]):
            newarray[0,band]=array[0,band][indexarray]
        datalist.append(newarray)
    return  datalist

def similar_filter(datalist,sital,sitam):
    [l1,m1,m2]=datalist
    l1m1=abs(l1-m1)
    m1m2=abs(m2-m1)
    #####
    l1m1=nn.functional.unfold(l1m1,(1,1)).max(1)[0]+(sital**2+sitam**2)**0.5
    m1m2=nn.functional.unfold(m1m2,(1,1)).max(1)[0]+(sitam**2+sitam**2)**0.5
    return (l1m1,m1m2)

###starfm for onepart##########################################################
def starfm_onepart(datalist,similar,thresholdmax,window,outshape,dist):
    #####param and data 
    [l1,m1,m2]=datalist
    bandsize=l1.shape[1]
    outshape=outshape
    blocksize=outshape[0]*outshape[1]
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #####img to col
    l1=nn.functional.unfold(l1,window)
    m1=nn.functional.unfold(m1,window)
    m2=nn.functional.unfold(m2,window)
    l1=l1.view(bandsize,-1,blocksize)
    m1=m1.view(bandsize,-1,blocksize)
    m2=m2.view(bandsize,-1,blocksize)   
    l1m1=abs(l1-m1)
    m1m2=abs(m2-m1)
    #####caculate weights
    #time and space weight
    w=caculate_weight(l1m1,m1m2)
    w=1/(w*dist)
    #similar pixels: 1:by threshold 2:by classifier
    wmask=torch.zeros(l1.shape,dtype=torch.float32).to(device)  
    
    #filter similar pixels  for each band: (bandsize,windowsize,blocksize)
    #wmasknew=weight_bythreshold(wmask,l1m1,thresholdmax[0]) 
    #wmasknew=weight_bythreshold(wmasknew,m1m2,thresholdmax[1])    
    
    #filter similar pixels for all bands: (1,windowsize,blocksize)
    wmasknew=weight_bythreshold_allbands(wmask,l1m1,m1m2,thresholdmax) 
    #mask
    w=w*wmasknew*similar
    #normili
    w=w/(w.sum(1).view(w.shape[0],1,w.shape[2]))
    #####predicte and trans
    #predicte l2
    l2=(l1+m2-m1)*w
    l2=l2.sum(1).reshape(1,bandsize,l2.shape[2])
    #col to img
    l2=nn.functional.fold(l2.view(1,-1,blocksize),outshape,(1,1))
    return l2
###starfm for allpart#########################################################
def starfm_main(l1r,m1r,m2r,
                param={'part_shape':(140,140),
               'window_size':(31,31),
               'clusters':5,
               'NIRindex':3,'redindex':2,
               'sital':0.001,'sitam':0.001}):
    #get start time
    time_start=time.time()  
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #read parameters
    parts_shape=param['part_shape']
    window=param['window_size']
    clusters=param['clusters']
    NIRindex=param['NIRindex']
    redindex=param['redindex']
    sital=param['sital']
    sitam=param['sitam']
    #caculate initial similar pixels threshold
    threshold=spectral_similar_threshold(clusters,l1r[:,NIRindex:NIRindex+1],l1r[:,redindex:redindex+1])    
    print('similar threshold (NIR,red)',threshold)
    ####shape
    imageshape=(l1r.shape[1],l1r.shape[2],l1r.shape[3])
    print('datashape:',imageshape)
    row=imageshape[1]//parts_shape[0]+1
    col=imageshape[2]//parts_shape[1]+1
    padrow=window[0]//2
    padcol=window[1]//2 
    #####padding constant for conv;STARFM use Inverse distance weight(1/w),better to avoid 0 and NAN(1/0),or you can use another distance measure
    constant1=10
    constant2=20
    constant3=30
    l1=torch.nn.functional.pad( l1r,(padrow,padcol,padrow,padcol),'constant', constant1)
    m1=torch.nn.functional.pad( m1r,(padrow,padcol,padrow,padcol),'constant', constant2)
    m2=torch.nn.functional.pad( m2r,(padrow,padcol,padrow,padcol),'constant', constant3)
    #split parts , get index and  run for every part
    row_part=np.array_split( np.arange(imageshape[1]), row , axis = 0) 
    col_part=np.array_split( np.arange(imageshape[2]),  col, axis = 0) 
    print('Split into {} parts,row number: {},col number: {}'.format(len(row_part)*len(row_part),len(row_part),len(row_part)))
    dist=nn.functional.unfold(torch.tensor(  indexdistance(window),dtype=torch.float32).reshape(1,1,window[0],window[1]),window).to(device)

    for rnumber,row_index in enumerate(row_part):
        for cnumber,col_index in enumerate(col_part):
            ####run for part: (rnumber,cnumber)
            print('now for part{}'.format((rnumber,cnumber)))
            ####output index
            rawindex=np.meshgrid(row_index,col_index)
            ####output shape
            rawindexshape=(col_index.shape[0],row_index.shape[0])
            ####the real parts_index ,for reading the padded data 
            row_pad=np.arange(row_index[0],row_index[len(row_index)-1]+window[0])
            col_pad=np.arange(col_index[0],col_index[len(col_index)-1]+window[1])    
            padindex=np.meshgrid(row_pad,col_pad)
            padindexshape=(col_pad.shape[0],row_pad.shape[0])
            ####caculate initial similar pixels
            NIR_similar=caculate_similar(l1[0,NIRindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[0],window)   
            red_similar=caculate_similar(l1[0,redindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[1],window)  
            similar=NIR_similar*red_similar      
            ####caculate threshold used for similar_pixels_filter  
            thresholdmax=similar_filter( allband_arrayindex([l1r,m1r,m2r],rawindex,(1,imageshape[0],rawindexshape[0],rawindexshape[1])),
                                        sital,sitam)
            ####Splicing each col at rnumber-th row
            if cnumber==0:
                rowdata=starfm_onepart( allband_arrayindex([l1,m1,m2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])),
                                       similar,thresholdmax,window,rawindexshape,dist
                                       )  
                
            else:
                rowdata=torch.cat( (rowdata,
                                    starfm_onepart( allband_arrayindex([l1,m1,m2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])),
                                                   similar,thresholdmax,window,rawindexshape,dist)  ) ,2) 
        ####Splicing each row        
        if rnumber==0:
            l2_fake=rowdata
        else:            
            l2_fake=torch.cat((l2_fake,rowdata),3)
   
    l2_fake=l2_fake.transpose(3,2)
    #time cost
    time_end=time.time()    
    print('now over,use time {:.4f}'.format(time_end-time_start))  
    return l2_fake


def trans(datafile):
    datashape=datafile.shape
    for index in range(datashape[2]):
        for i in range(datashape[0]):
            for j in range(datashape[1]):
                datafile[i][j][index] *=10000
    #datafile.transpose(1, 2, 0)
    return  datafile

def test():
    ##three band datas(sorry,just find them at home,i cant recognise the spectral response range of each band,'NIR' and 'red' are only examples)
    l1file='L72000306_SZ_B432_30m.tif'
    l2file='L72002311_SZ_B432_30m.tif'
    m1file='MOD09_2000306_SZ_B214_250m.tif'
    m2file='MOD09_2002311_SZ_B214_250m.tif'
    
    ##param
    param={'part_shape':(75,75),
           'window_size':(31,31),
           'clusters':5,
           'NIRindex':1,'redindex':0,
           'sital':0.001,'sitam':0.001}
    
    ##read images from files(numpy)
    l1=imgread(l1file)
    m1=imgread(m1file)
    m2=imgread(m2file)
    l2_gt=imgread(l2file)    
    
    ##numpy to tensor
    shape=l1.shape
    l1r=torch.tensor(l1.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32)
    m1r=torch.tensor(m1.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32)
    m2r=torch.tensor(m2.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32)
    device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    l1r=l1r.to(device)
    m1r=m1r.to(device)
    m2r=m2r.to(device)      
    
    ##predicte(tensor input —> tensor output)
    l2_fake=starfm_main(l1r,m1r,m2r,param)
    print(l2_fake.shape)
    
    ##tensor to numpy
    if device.type=='cuda':
        l2_fake=l2_fake[0].cpu().numpy()
    else:
        l2_fake=l2_fake[0].numpy()    
    
    ##show results 
    #transform:(chanel,H,W) to (H,W,chanel)
    l2_fake=l2_fake.transpose(1,2,0)
    l2_gt=l2_gt.transpose(1,2,0)
    l1=l1.transpose(1,2,0)
    m1=m1.transpose(1,2,0)
    m2=m2.transpose(1,2,0)
    #plot
    # plt.figure('landsat:t1')
    # plt.imshow(l1)
    # plt.figure('landsat:t2_fake')
    # plt.imshow(l2_fake)
    # plt.figure('landsat:t2_groundtrue')
    # plt.imshow(l2_gt)

    ##evaluation
    psnr = 10. * np.log10(1. / np.mean((l2_fake - l2_gt) ** 2))
    ssim1=sm.structural_similarity(l2_fake,l2_gt,data_range=1,multichannel=True)
    ssim2=sm.structural_similarity(l1,l2_gt,data_range=1,multichannel=True)
    ssim3=sm.structural_similarity(l1+m2-m1,l2_gt,data_range=1,multichannel=True)
    print('psnr:{:.4f};with-similarpixels ssim: {:.4f};landsat_t1 ssim: {:.4f};non-similarpixels ssim: {:.4f}'.format(psnr,ssim1,ssim2,ssim3))

    trans(l2_fake)
    writetif(l2_fake, 'fake', l2file)

    
    return


def writetif(dataset, target_file, reference_file):
    reference = gdal.Open(reference_file, gdalconst.GA_ReadOnly)
    band_count = dataset.shape[2]  # 波段数
    print("波段数:",band_count)
    band1 = dataset[0]
    # data_type = band1.DataType
    target = gdal.GetDriverByName("GTiff").Create(target_file, xsize=dataset.shape[1],
                                                  ysize=dataset.shape[0],
                                                  bands=band_count,
                                                  eType=reference.GetRasterBand(1).DataType)
    geotrans = list(reference.GetGeoTransform())
    target.SetProjection(reference.GetProjection())  # 设置投影坐标
    target.SetGeoTransform(geotrans)  # 设置地理变换参数
    total = band_count + 1
    for index in range(1, total):
        # data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=dataset.shape[0], buf_ysize=dataset.shape[1])
        out_band = target.GetRasterBand(index)
        # out_band.SetNoDataValue(dataset.GetRasterBand(index).GetNoDataValue())
        out_band.WriteArray(dataset[:, :, index - 1])  # 写入数据到新影像中
        out_band.FlushCache()
        out_band.ComputeBandStats(False)  # 计算统计信息
    print("正在写入完成")
    del dataset


if __name__ == "__main__":
    test()

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值