数 据 集 分 割 保 存 至 本 地 数据集分割保存至本地 数据集分割保存至本地
https://www.lfd.uci.edu/~gohlke/pythonlibs/#rasterio
https://www.lfd.uci.edu/~gohlke/pythonlibs/#gdal
pip install GDAL
pip install rasterio
== 引入工具包 ==
import rasterio
from rasterio.windows import Window
from torch.utils.data import Dataset,Subset
import pandas as pd
import pathlib, sys, os, random, time
from torchvision import transforms
import tqdm
import numpy as np
import albumentations
import torch
import os
import numpy as np
import random
import torch
import numba
== step 0 参数配置==
from major_models import FCN,SegNet
from major_models.U_Net import UNet
from major_models.U_Net_plus_plus_plus import UNet_3Plus
import torch
# 1.batchsize:批次大小
mc_batchsize = 2
# 2.num_epoch:训练轮次,一般默认200
mc_num_epoch = 200
# 3.num_classes:分类数
mc_num_classes = 6
# 4.crop_size:裁剪尺寸
mc_crop_size = (512, 512)
# 5.训练集的图片和label路径
mc_data_path = '/home/jason/major_s/kaggle:肾小球分割/数据集/kaggle-hubmap-kidney-segmentation'
mc_window = 512
mc_min_overlap = 32
# 9.path_test_model : 测试模型的路径
mc_path_test_model = "best_model.pth"
# 10.path_predict_model : 成像模型的路径
mc_path_predict_model = "best_model.pth"
# 11.模型的保存路径
mc_path_saved_model = 'best_model.pth'
# 12.color2class_table:颜色值与类别值的对应表
mc_path_color2class_table = "./major_dataset_repo/major_collected_dataset/color2class_table.csv"
# 13.指定设备
mc_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# 14.(norm_mean,norm_std):数据集的均值和标准差
mc_norm_mean = [0.485, 0.456, 0.406]
mc_norm_std = [0.229, 0.224, 0.225]
#15.model:模型的选择
mc_model = UNet_3Plus()
== step 1 辅助函数 ==
# 设置随机种子
def set_seeds(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# 编码 图片转字符串
# used for converting the decoded image to rle mask
def rle_encode(im):
'''
im: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
pixels = im.flatten(order='F')
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
# 解码 字符串转图片
def rle_decode(mask_rle, shape=(512, 512)):
'''
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
'''
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape, order='F')
@numba.njit()
def rle_numba(pixels):
size = len(pixels)
points = []
if pixels[0] == 1: points.append(0)
flag = True
for i in range(1, size):
if pixels[i] != pixels[i - 1]:
if flag:
points.append(i + 1)
flag = False
else:
points.append(i + 1 - points[-1])
flag = True
if pixels[-1] == 1: points.append(size - points[-1] + 1)
return points
def rle_numba_encode(image):
pixels = image.flatten(order='F')
points = rle_numba(pixels)
return ' '.join(str(x) for x in points)
# 切割
def make_grid(shape, window=512, min_overlap=32):
"""
Return Array of size (N,4), where N - number of tiles,
2nd axis represente slices: x1,x2,y1,y2
"""
x, y = shape
nx = x // (window - min_overlap) + 1
x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
x1[-1] = x - window
x2 = (x1 + window).clip(0, x)
ny = y // (window - min_overlap) + 1
y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
y1[-1] = y - window
y2 = (y1 + window).clip(0, y)
slices = np.zeros((nx, ny, 4), dtype=np.int64)
for i in range(nx):
for j in range(ny):
slices[i, j] = x1[i], x2[i], y1[j], y2[j]
return slices.reshape(nx * ny, 4)
== step 2 数据处理 ==
# 预处理
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
一 数据集分割保存至本地
class BuildSlicesAndSavedImage():
def __init__(self, root_dir=mc_data_path, window=512, overlap=32, threshold=100):
self.path = pathlib.Path(root_dir)
self.overlap = overlap
self.window = window
self.csv = pd.read_csv((self.path / 'train.csv').as_posix(),
index_col=[0])
self.threshold = threshold
self.x, self.y = [], []
self.len = len(self.x)
def build_slices(self):
self.masks = []
self.files = []
self.slices = []
for i, filename in enumerate(self.csv.index.values):
filepath = (self.path / 'train' / (filename + '.tiff')).as_posix()
self.files.append(filepath)
print('Transform', filename)
with rasterio.open(filepath, transform=identity) as dataset:
self.masks.append(rle_decode(self.csv.loc[filename, 'encoding'], dataset.shape))
slices = make_grid(dataset.shape, window=self.window, min_overlap=self.overlap)
print("len(slices):"+len(slices))
for j,slc in enumerate(slices):
x1, x2, y1, y2 = slc
if self.masks[-1][x1:x2, y1:y2].sum() > self.threshold or np.random.randint(100) > 120:
self.slices.append([i, x1, x2, y1, y2])
image = dataset.read([1, 2, 3],
window=Window.from_slices((x1, x2), (y1, y2)))
# if image.std().mean() < 10:
# continue
# print(image.std().mean(), self.masks[-1][x1:x2,y1:y2].sum())
# 图像本地保存
print("./image/"+str(i)+'_'+str(j))
# print("image:",image)
# print("image_shape:",image.shape)
# print("image_type:",type(image))
image = np.moveaxis(image, 0, -1)
np.save("./image/"+str(i)+'_'+str(j),image)
# print("********************************************")
# print("mask:",self.masks[-1][x1:x2, y1:y2])
# print("mask_shape:",self.masks[-1][x1:x2, y1:y2].shape)
# print("image_type:",type(self.masks[-1][x1:x2, y1:y2]))
# print("********************************************")
mask = self.masks[-1][x1:x2, y1:y2]
np.save("./mask/"+"mask"+"_"+str(i)+"_"+str(j),mask)
# time.sleep(100)
# self.x.append(image)
# self.y.append(masks[-1][x1:x2, y1:y2])
buildslices = BuildSlicesAndSavedImage()
buildslices.build_slices()
# -*- coding: utf-8 -*-
"""
将数据集划分为训练集,验证集,测试集
"""
import os
import random
import shutil
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
def read_file(path): # 图片的完整路径
"""从文件夹中读取数据"""
files_list = os.listdir(path)
file_path_list = [os.path.join(path, img) for img in files_list]
file_path_list.sort() # 图片路径排序
return file_path_list
imgs = read_file("./image/")
print(len(imgs))
labels = read_file("./mask/")
print(len(labels))
import shutil
for i in range(0,3000):
shutil.copyfile(imgs[i], "./train/image/"+imgs[i].split("/")[-1])
shutil.copyfile(labels[i], "./train/label/"+labels[i].split("/")[-1])
for i in range(3000,len(imgs)):
shutil.copyfile(imgs[i], "./validation/image/"+imgs[i].split("/")[-1])
shutil.copyfile(labels[i], "./validation/label/"+labels[i].split("/")[-1])