环境:
python3.8 torch1.12.1 torchvision0.13.1
在RandomCrop源码中,crop之前用了get_params函数获取crop的入参,防止crop出错。
但我在看get_params源码过程中发现,抛异常的条件应该是:if h < th or w < tw(代码14行),否则h=th-1时,不会抛异常,但20行和21行randint会报错。
如果在调用RandomCrop之前没有pad,此处会有问题。
# torchvision中transforms.py源码截取
class RandomCrop(torch.nn.Module):
def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
_, h, w = F.get_dimensions(img)
th, tw = output_size
if h + 1 < th or w + 1 < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
调用测试:

结果:

randint报错。