最近在学习transreid 关于viT-pytorch中的这段代码并不是很理解,因此写这个博客进行总结 # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): return x return tuple(repeat(x, n)) return parse # 迭代器构造 # isinstance(变量名,变量的类型) # 用于判断一个变量是不是属于输入的变量类型 # repeat(element,n)将一个元素重复n遍,并返回一个迭代器 # 应该是迭代两次,用来下面的生成的图片的X,y,以及patch的x,y IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) to_2tuple = _ntuple(2)
在应用的时候,主要用在输入的尺寸上:
class PatchEmbed(nn.Module): """ Image to Patch Embedding 图片切块分为patch """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches
输入的尺寸为224,patch为16,要生成X,Y两个方向,所以n输入为2,即to_2tuple = _ntuple(2),输入的尺寸为int,而我们想要的是两个数据,即元组,元组第一个即img_size[0]为X或者Y, if isinstance(x, container_abcs.Iterable):用于判断一个变量是不是属于输入的变量类型
以及在以下代码中实现
class PatchEmbed_overlap(nn.Module): """ Image to Patch Embedding with overlapping patches """ def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) stride_size_tuple = to_2tuple(stride_size) self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 # python中“//”是一个算术运算符,表示整数除法, # 它可以返回商的整数部分(向下取整) (224-16)//20+1=10+1=11 self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x)) num_patches = self.num_x * self.num_y # 总的patch数 self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches