问题背景及描述:
在复写transforms.RandomRotation 时,出现了 运行报错
Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant
写法如下
class RandomRotation_once(transforms.RandomRotation):
def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
super().__init__(degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
self.angle = self.get_params(degrees)
解决过程及问题成因:
这个问题其他文章里的出现原因大都为 是 使用 了 PIL库 里的代码导致的,
即在源码 functional里的如下部分:
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
so,别人传进来的类型就不对,但是我是直接复制了 源码例的这个 InterpolationMode 类:
class InterpolationMode(Enum):
"""Interpolation modes
Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
and ``lanczos``.
"""
NEAREST = "nearest"
NEAREST_EXACT = "nearest-exact"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
# For PIL compatibility
BOX = "box"
HAMMING = "hamming"
LANCZOS = "lanczos"
那为何还是有问题呢?
解决方案:
将我的代码改为:
import torchvision.transforms.functional as F
class RandomRotation_once(transforms.RandomRotation):
def __init__(self, degrees, interpolation=F.InterpolationMode.NEAREST, expand=False, center=None, fill=0):
super().__init__(degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
self.angle = self.get_params(degrees)
结论及重点:
将 interpolation=InterpolationMode.NEAREST 改为 interpolation=F.InterpolationMode.NEAREST,
因为 可能写进自己Module 的 这个 InterpolationMode类与源码的InterpolationMode类 Module 并不相同,所以通过不了instanence 检测