赛题背景
赛题链接
遥感技术已成为获取地表覆盖信息最为行之有效的手段,遥感技术已经成功应用于地表覆盖检测、植被面积检测和建筑物检测任务。本赛题使用航拍数据,需要参赛选手完成地表建筑物识别,将地表航拍图像素划分为有建筑物和无建筑物两类。
如下图,左边为原始航拍图,右边为对应的建筑物标注。
引入库
import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import cv2, gc
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
from tqdm.notebook import tqdm
import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torchvision
from torchvision import transforms as T
数据分析
赛题数据为航拍图,需要识别图片中的地表建筑具体像素位置。
- train_mask.csv:存储图片的标注的rle编码;
- train和test文件夹:存储训练集和测试集图片;
rle编码的具体的读取代码如下:
# 将图片编码为rle格式
def rle_encode(im):
'''
im: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
pixels = im.flatten(order = 'F')
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
# 将rle格式进行解码为图片
def rle_decode(mask_rle, shape=(512, 512)):
'''
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
'''
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape, order='F')
设置常用变量
- DEVICE:这是用于后续选择将数据放到GPU设备或者CPU设备上运行的属性
- IMAGE_SIZE:不同的图像大小,网络中的参数数量不一样。图像越大,参数越多,对算力要求越高。
- BATCH_SIZE: 批处理次数
- EPOCHES: 迭代轮数
DEVICE =