图像分割基础网络代码详解,希望对大家有点帮助,少走弯路
import os
import cv2
import numpy as np
import torch
import torch.utils.data
class Dataset(torch.utils.data.Dataset):
def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
self.img_ids = img_ids
self.img_dir = img_dir
self.mask_dir = mask_dir
self.img_ext = img_ext
self.mask_ext = mask_ext
self.num_classes = num_classes
self.transform = transform
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
# self.img_ids大小是536(即8*67),其中的idx指的应该是536中的id,但不是列表中第idx个,img_id获得的是图片的名字
img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
# img(96,96,3),获得一张图片,img_ext指的是拓展名(此代码中都是png)
# cv2.imread读入图像
# os.path.join将多个路径组合后返回,本代码返回的是图片路径+图片的id和图片的扩展名(也就是图片格式png或者jpg)
mask = []
for i in range(self.num_classes):
mk = cv2.imread(os.path.join(self.mask_dir, str(i), img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[
..., None] # mk是一个(96,96,1)的数组,由(96,96)--->(96,96,1)
mask.append(mk)
mask = np.dstack(mask) # 堆叠一个list,这个list是(96,96,1)的mk
# 获得mask图像中的灰度图片
# cv2.IMREAD_GRAYSCALE 以灰度模式加载图片,可以直接写0
# [..., None] 在最后面追加一个新的维度
if self.transform is not None:
augmented = self.transform(image=img, mask=mask) # 这个包比较方便,能把mask也一并做掉
# 把image和mask一起放入augmented中。做了一些变化转化,但是不知道是什么转化
# Augmented 里面有两个字典,分别是transform过的img和mask。img有变化、mask也有变化。
img = augmented['image'] # 参考https://github.com/albumentations-team/albumentations
# 把self.transform转换过的image输入给img,作为img的值
mask = augmented['mask'] # mask没有变化
img = img.astype('float32') / 255 # 又是一次对image的改变
img = img.transpose(2, 0, 1) # (96,96,3)---->(3,96,96)
mask = mask.astype('float32') / 255 # 把mask变成0或1
mask = mask.transpose(2, 0, 1) # (96,96,1)---->(1,96,96)
return img, mask, {'img_id': img_id}