文章目录
6 PyTorch 官网教材之 TORCHVISION 0.3 OBJECT DETECTION FINETUNING TUTORIAL
0. 官网链接
1. TORCHVISION 0.3 OBJECT DETECTION FINETUNING TUTORIAL
1. 构建数据集
import os
import numpy as np
import torch
from PIL import Image
class PennFudanDataset(object):
def __init__(self, root, transforms): # 根据路径 root 分别加载 imgs 和 masks 的名称列表
self.root = root
self.transforms = transforms
# load all image files, sorting them to
# ensure that they are aligned
self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
# load images ad masks
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
# note that we haven't converted the mask to RGB,
# because each color corresponds to a different instance
# with 0 being background
mask = Image.open(mask_path)
# convert the PIL Image into a numpy array
mask = np.array(mask)
# instances are encoded as different colors
obj_ids = np.unique(mask) # 去除一维元组、列表中的重复元素
# first id is the background, so remove it
obj_ids = obj_ids[1:]
# split the color-encoded mask into a set
# of binary masks
masks = mask == obj_ids[:, None, None] # ? 这两个None 不理解
# get bounding box coordinates for each mask
num_objs &