使用深度学习做分割,出的结果往往是tensor张量,要把tensor转化为tiff文件,并赋予地理参考,便于后续做地理分析
首先是获取存放原始遥感影像数据名(*.tif),代码如下:
# 获取原始遥感影像文件夹
origin_images_path = 'your path'
all_path = os.listdir(origin_images_path) # 将文件夹下的文件列出来
all_len = len(all_path) # 获取文件个数,便于后续获取文件序列
# 获取文件夹下的tiff格式文件
tiffs_path=[] # 空列表存储tiff文件名
for idx in range(all_len):
if all_path[idx].endswith('.tif'):
tiffs_path.append(all_path[idx])
tiffs_len = len(tiffs_path) # 获取tiff格式文件夹个数
使用已经训练好的模型测试:
model = torch.load('your model.pth')
model.eval()
with torch.no_grad():
idx1 = 0
for i, data in enumerate(img_loader):
image = data
image = image.to(device)
image = image.float()
output = model(image)
output = output.unsqueeze(0) # (256*256)转(1*256*256)
output = torch.where(output >= 0.5, 255, 0) # 设置为255,太小无法显示出来
output = output.unsqueeze(0) # 再转(1*1*256*256)
# writer.add_images('result_{}'.format(i), output,1) 在tensorboard中显示
# 生成分类结果png格式文件
output_array = output.cpu().numpy()
output_array = output_array[0, 0]
# output_image =Image.fromarray(output_array.astype('uint8'))
# output_path = 'your png path\\{}.png'.format(idx1)
# output_image.save(output_path)
# 加载完整tiff文件路径
rsimg_path = os.path.join(origin_images_path, tiffs_path[idx1])
# 使用gdal打开tiff文件
rsimg = gdal.Open(rsimg_path)
# 获取原始影像的地理坐标信息
geotrans = rsimg.GetGeoTransform()
project = rsimg.GetProjection()
# 设置输出路径
output_tiff_path = 'your tiff path\\output_{}.tif'.format(idx1)
# 创建 GDAL 数据集,指定 TIFF 驱动
driver = gdal.GetDriverByName('GTiff')
dataset = driver.Create(output_tiff_path, 256, 256, 1, gdal.GDT_Byte) #2568256大小,1个灰度波段
# 设置地理坐标信息
dataset.SetGeoTransform(geotrans)
dataset.SetProjection(project)
# 将数据写入 TIFF 文件
dataset.GetRasterBand(1).WriteArray(output_array)
dataset = None
idx1 += 1
print('-----完成第{}张-----'.format(idx1))
发现处理后的图像的边界部分无结果,差不多一个像素的宽度(因人而异),使用envi对处理后的tiff影像合并,选择羽化即可较好处理边界无值的问题