python 计算TIFF图像相关性 根据相关性筛选特征 使用PCA提取主要特征

该博客主要介绍了如何对遥感图像进行预处理,包括读取tif文件、修复NaN值、执行PCA降维以及计算像素间的相关性。通过PCA将数据转换为3个主要成分,并保存为TIFF格式。同时,生成了相关性矩阵并筛选出相关性低于0.5的特征,将这些特征对应的图像保存到新的文件夹中。
摘要由CSDN通过智能技术生成
import os
import shutil
import imageio
import numpy as np
from osgeo import gdal
from sklearn.decomposition import PCA


def save_tiff(path, data, tiff_sample, dtype=gdal.GDT_Float32):
    """
    保存tiff文件
    :param path: 文件路径
    :param data: 数据array(c,h,w)
    :param tiff_sample: tiff样本
    :param dtype: gdal数据类型
    :return:
    """
    tiff_driver = gdal.GetDriverByName('GTiff').Create(path,
                                                       data.shape[2], data.shape[1], data.shape[0], dtype)
    tiff_driver.SetProjection(tiff_sample.GetProjection())
    tiff_driver.SetGeoTransform(tiff_sample.GetGeoTransform())
    for i in range(data.shape[0]):
        tiff_driver.GetRasterBand(i + 1).WriteArray(data[i])
    tiff_driver.FlushCache()
    return


def export_corr(out_path, coe: np.ndarray, feature_name):
    """
    导出相关性矩阵为CSV文件
    :param out_path: 文件名
    :param coe: 2d相关性矩阵
    :param feature_name: 1d列表 参数名
    :return:
    """
    f = open(out_path, 'w')
    f.write('\t'.join(['correlation'] + list(feature_name)) + '\n')
    # 每行前加上feature_name
    for i in range(len(feature_name)):
        f.write('\t'.join([feature_name[i]] + [str(j) for j in coe[i]]) + '\n')
    f.close()


def feature_select_first(coe: np.ndarray, coe_limit=0.5):
    """
    优先保留排位靠前的参数
    :param coe: 2d相关性矩阵
    :param coe_limit: 相关性阈值
    :return: 剩余参数索引
    """
    num_feature = coe.shape[0]
    coe_mask = np.abs(coe) - np.eye(num_feature) > coe_limit
    indices = []
    for i in range(num_feature):
        if not coe_mask[i, indices].any():
            indices.append(i)
    return indices


def feature_select_more(coe: np.ndarray, coe_limit=0.5, is_del_more_first=True):
    """
    优先剔除相关数量较多/较少的参数
    :param coe: 2d相关性矩阵
    :param coe_limit: 相关性阈值
    :param is_del_more_first: 优先剔除相关多还是少的参数, True:多的, False:少的 (推荐True, False会导致相近特征全部删除)
    :return: 剩余参数索引
    """
    num_feature = coe.shape[0]
    coe_mask = np.abs(coe) - np.eye(num_feature) > coe_limit
    indices = []
    while True:
        cor_num = [coe_mask[i].sum() for i in range(num_feature)]  # 计算高相关的数量
        if not is_del_more_first:
            cor_num = [i + num_feature if i == 0 else i for i in cor_num]
        index = np.argmax(cor_num) if is_del_more_first else np.argmin(cor_num)  # 取最高/最低相关数量的参数
        indices.append(index)
        coe_mask[index, :] = False
        coe_mask[:, index] = False
        if not coe_mask.any():
            break
    indices = [i for i in range(num_feature) if i not in indices]  # 计算剩余参数索引
    return indices


# ************************************** 参数设置 **************************************
# ######## 输入设置 ########
# 原始tif文件夹路径, 可以是多个tif文件夹
path = [r'D:\data\0.25\a',
		r'D:\data\0.25\b',
        ]
# 文件后缀, .tif or .tiff
tif_end = '.tif'
# ######## 输出设置 ########
# 输出文件夹
out_path = r'D:\output'
# ************************************** 读取文件 **************************************
# 获取文件夹下所有.tif文件的绝对路径
file_name = np.concatenate([[i[:-4] for i in os.listdir(p) if i.endswith(tif_end)] for p in path])
file_path = np.concatenate([[os.path.join(p, i) for i in os.listdir(p) if i.endswith(tif_end)] for p in path])
# 读取tif文件
img_array = [imageio.imread(i) for i in file_path]
# 文件索引
num_img = [1 if len(img_array[i].shape) == 2 else img_array[i].shape[-1] for i in range(len(img_array))]
img2file_map = np.array([i for i in range(len(file_path)) for j in range(num_img[i])])
# 将所有文件读入一个array
img_array = np.concatenate([i[None, :] if len(i.shape) == 2 else i.transpose(2, 0, 1) for i in img_array])
img_shape = img_array.shape
# img2file_map
# flatten
img_array = img_array.reshape(img_array.shape[0], -1)

# 检查NaN, 构建mask
nan_file = np.isnan(img_array).any(axis=1)
nan_mask = np.isnan(img_array).any(axis=0)
if nan_file.any():
    print('Warning: NaN data in:')
    for i in np.where(nan_file)[0]:
        print(file_name[img2file_map[i]])

# 修复 NaN
fix_nan_with = 0
if fix_nan_with is not False:
    img_array[np.isnan(img_array)] = fix_nan_with
    nan_mask[nan_mask] = False
    print('fix NaN with', fix_nan_with)
# fix bug
img_array[img_array == -9999] = 0

# 创建输出文件夹
if not os.path.exists(out_path):
    os.makedirs(out_path)

# ************************************** 计算PCA **************************************
pca = PCA(n_components=3)
# PCA前需要归一化
x = img_array[:, ~nan_mask]
x -= x.mean(axis=1, keepdims=True)
x /= x.std(axis=1, keepdims=True)
u = pca.fit_transform(x.T).T
print('-' * 80)
print('PCA: %f' % pca.explained_variance_ratio_.sum(), pca.explained_variance_ratio_)
print('-' * 80)

# # 保存PCA结果为RGB-PNG图像(会损失精度)
# pca_output_file = os.path.join(out_path, 'pca.png')
# # u缩放到0,1
# u -= u.min(1, keepdims=True)
# u /= u.max(1, keepdims=True)
# imageio.imwrite(png_output_file, u.reshape([3, *img_shape[1:]]).transpose([1, 2, 0]))
# print('PCA result save in:', png_output_file)

# 保存PCA结果为TIFF格式
pca_output_file = os.path.join(out_path, 'pca.tiff')
u = u.reshape([-1, *img_shape[1:]])
tiff_sample = gdal.Open(file_path[0])
save_tiff(pca_output_file, u, tiff_sample, dtype=gdal.GDT_Float32)
print('PCA result save in:', pca_output_file)

# ************************************** 计算相关性 **************************************
coe = np.corrcoef(img_array[:, ~nan_mask])  # corrcoef不需要归一化
if np.isnan(coe).any():
    raise Exception('Error: NaN data in correlation matrix')

# 导出相关性矩阵
cor_mat_output_file = os.path.join(out_path, 'corr.csv')
export_corr(cor_mat_output_file, coe, file_name)
print('Correlation Matrix save in:', cor_mat_output_file)

# 剔除相关性高于coe_limit的参数
coe_limit = 0.5
indices = feature_select_more(coe, coe_limit=coe_limit)
print('-' * 80)
print('Select: %i Features with coe_limit=%f\n' % (len(indices), coe_limit))
print('Feature:', file_name[img2file_map[indices]])  # TODO: 不能明确显示是多波段文件中的哪一个波段
print('-' * 80)

# 提取相关性高的文件
copy_path = os.path.join(out_path, 'Correlation'+str(coe_limit))
if not os.path.exists(copy_path):
    os.mkdir(copy_path)
for i in img2file_map[indices]:
    save_path = os.path.join(copy_path, file_name[i]+tif_end)
    # # 复制tiff文件
    # print('Copy File to', shutil.copyfile(file_path[i], save_path)
    # 重新保存tiff文件
    if os.path.exists(save_path):
        raise Exception('Error: file exist in', save_path)
    save_tiff(save_path,
              img_array[i].reshape(1, *img_shape[1:]), tiff_sample, dtype=gdal.GDT_Float32)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值