Python typing函式庫和torch.types
前言
在PyTorch的torch/_C/_VariableFunctions.pyi
中有如下代碼:
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
當中的Sequence
, Iterable
, Optional
, Union
以及_int
, _bool
都是什麼意思呢?可以從torch/_C/_VariableFunctions.pyi.in
中一窺端倪:
import builtins
from typing import (
Any,
Callable,
ContextManager,
Iterator,
List,
Literal,
NamedTuple,
Optional,
overload,
Sequence,
Tuple,
TypeVar,
Union,
)
import torch
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor
from torch.types import (
_bool,
_device,
_dtype,
_float,
_int,
_layout,
_qscheme,
_size,
Device,
Number,
SymInt,
)
所以Sequence
, Iterable
, Optional
, Union
等是從一個叫做typing
的庫中導入的。typing是Python的標準庫之一,作用是提供對類型提示的運行時支持。
_int
, _bool
等則是PyTorch中自行定義的類型。
typing
Sequence vs Iterable
根據Type hints cheat sheet - Standard “duck types”,Sequence
代表的是支持__len__
及__getitem__
方法的序列類型,例如list, tuple和str。dict和set則不屬於此類型。
# Use Iterable for generic iterables (anything usable in "for"),
# and Sequence where a sequence (supporting "len" and "__getitem__") is
# required
根據Python Iterable vs Sequence:
Iterable
代表的是支持__iter__
或__getitem__
的類型,如range
和reversed
。
r = range(4)
r.__getitem__(0) # 0
r.__iter__() # <range_iterator object at 0x0000015AE7945D30>
l = [1, 2, 3]
rv = reversed(l)
rv.__iter__() # <list_reverseiterator object at 0x0000015AE7980E20>
rv.__getitem__() # 不支援__getitem__方法,但因為支持__iter__所以依然可以歸類為Iterable
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# AttributeError: 'list_reverseiterator' object has no attribute '__getitem__'
因為Sequence
也具有__iter__
和__getitem__
,所以根據定義,所有的Sequence
都是Iterable
。
l = []
l.__iter__ # <method-wrapper '__iter__' of list object at 0x7f15bb50b5c0>
l.__getitem__ # <built-in method __getitem__ of list object at 0x7f15bb50b5c0>
Callable
Callable
Frameworks expecting callback functions of specific signatures might be type hinted using Callable[[Arg1Type, Arg2Type], ReturnType].
文檔寫得很淺顯易懂,不過有一點要注意的是入參型別要用[]
括起來。
Type hints cheat sheet - Functions中給出了例子:
# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f
如果先不看類型提示的代碼,這句其實就是x = f
,把x
這個變數設定為f
這個函數。當中的Callable[[int, float], float]
說明了f
是一個接受int
, float
,輸出float
的函數。
Union
typing.Union
Union type; Union[X, Y] is equivalent to X | Y and means either X or Y.
To define a union, use e.g. Union[int, str] or the shorthand int | str. Using that shorthand is recommended.
Union[X, Y]
表示型別可以是X
或Y
,從Python 3.10以後,可以使用X | Y
這種更簡潔的寫法。
Type hints cheat sheet - Useful built-in types中給出的例子:
# On Python 3.10+, use the | operator when something could be one of a few types
x: list[int | str] = [3, 5, "test", "fun"] # Python 3.10+
# On earlier versions, use Union
x: list[Union[int, str]] = [3, 5, "test", "fun"]
Optional
Optional type.
Optional[X] is equivalent to X | None (or Union[X, None]).
Optional[X]
表示該變數可以是X
型別或是None
型別。
Type hints cheat sheet - Useful built-in types中給出了一個很好的例子:
# Use Optional[X] for a value that could be None
# Optional[X] is the same as X | None or Union[X, None]
x: Optional[str] = "something" if some_condition() else None
這裡x
根據some_condition()
的回傳值有可能是一個字串或是None,所以此處選用Optional[str]
的類型提示。
Functions
指定參數和回傳值型別:
from typing import Callable, Iterator, Union, Optional
# This is how you annotate a function definition
def stringify(num: int) -> str:
return str(num)
多個參數:
# And here's how you specify multiple arguments
def plus(num1: int, num2: int) -> int:
return num1 + num2
無回傳值的函數以None
為回傳型別,並且參數的預設值應寫在參數型別後面:
# If a function does not return a value, use None as the return type
# Default value for an argument goes after the type annotation
def show(value: str, excitement: int = 10) -> None:
print(value + "!" * excitement)
可以接受任意型別參數的函數則不必指定參數型別:
# Note that arguments without a type are dynamically typed (treated as Any)
# and that functions without any annotations not checked
def untyped(x):
x.anything() + 1 + "string" # no errors
Callable
將Callable
當作參數的函數:
# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f
def register(callback: Callable[[str], int]) -> None: ...
Iterator/generator
generator函數相當於一個Iterator
:
# A generator function that yields ints is secretly just a function that
# returns an iterator of ints, so that's how we annotate it
def gen(n: int) -> Iterator[int]:
i = 0
while i < n:
yield i
i += 1
將function annotation分成多行:
# You can of course split a function annotation over multiple lines
def send_email(address: Union[str, list[str]],
sender: str,
cc: Optional[list[str]],
bcc: Optional[list[str]],
subject: str = '',
body: Optional[list[str]] = None
) -> bool:
...
位置參數 & 關鍵字參數
# Mypy understands positional-only and keyword-only arguments
# Positional-only arguments can also be marked by using a name starting with
# two underscores
def quux(x: int, /, *, y: int) -> None:
pass
quux(3, y=5) # Ok
quux(3, 5) # error: Too many positional arguments for "quux"
quux(x=3, y=5) # error: Unexpected keyword argument "x" for "quux"
注意到此處參數列表中有/
和*
兩個符號,參考What Are Python Asterisk and Slash Special Parameters For?:
Left side | Divider | Right side |
---|---|---|
Positional-only arguments | / | Positional or keyword arguments |
Positional or keyword arguments | * | Keyword-only arguments |
Python的參數分為三種:位置參數,關鍵字參數及可變參數(可以透過位置或關鍵字的方式傳遞)。
/
符號的左邊必須是位置參數,*
符號的右邊則必須是關鍵字參數。
所以上例中x
必須以位置參數的方式傳遞,y
必須以關鍵字參數的方式傳遞。
一次指定多個參數的型別:
# This says each positional arg and each keyword arg is a "str"
def call(self, *args: str, **kwargs: str) -> str:
reveal_type(args) # Revealed type is "tuple[str, ...]"
reveal_type(kwargs) # Revealed type is "dict[str, str]"
request = make_request(*args, **kwargs)
return self.do_api_query(request)
Classes
self
class BankAccount:
# The "__init__" method doesn't return anything, so it gets return
# type "None" just like any other method that doesn't return anything
def __init__(self, account_name: str, initial_balance: int = 0) -> None:
# mypy will infer the correct types for these instance variables
# based on the types of the parameters.
self.account_name = account_name
self.balance = initial_balance
# For instance methods, omit type for "self"
def deposit(self, amount: int) -> None:
self.balance += amount
def withdraw(self, amount: int) -> None:
self.balance -= amount
成員函數self
參數的型別不需指定。
自定義類別
可以將變數型別指定為自定義的類別:
# User-defined classes are valid as types in annotations
account: BankAccount = BankAccount("Alice", 400)
def transfer(src: BankAccount, dst: BankAccount, amount: int) -> None:
src.withdraw(amount)
dst.deposit(amount)
# Functions that accept BankAccount also accept any subclass of BankAccount!
class AuditedBankAccount(BankAccount):
# You can optionally declare instance variables in the class body
audit_log: list[str]
def __init__(self, account_name: str, initial_balance: int = 0) -> None:
super().__init__(account_name, initial_balance)
self.audit_log: list[str] = []
def deposit(self, amount: int) -> None:
self.audit_log.append(f"Deposited {amount}")
self.balance += amount
def withdraw(self, amount: int) -> None:
self.audit_log.append(f"Withdrew {amount}")
self.balance -= amount
audited = AuditedBankAccount("Bob", 300)
transfer(audited, account, 100) # type checks!
transfer
函數的第一個參數型別應為BankAccount
,而AuditedBankAccount
是BankAccount
的子類別,所以在做類型檢查時不會出錯。
ClassVar
Python中類別的變數有類別變數別實例變數兩種。如果想要將成員變數標記為類別變數,可以用ClassVar[type]
。
# You can use the ClassVar annotation to declare a class variable
class Car:
seats: ClassVar[int] = 4
passengers: ClassVar[list[str]]
__setattr__ 與 __getattr__
# If you want dynamic attributes on your class, have it
# override "__setattr__" or "__getattr__"
class A:
# This will allow assignment to any A.x, if x is the same type as "value"
# (use "value: Any" to allow arbitrary types)
def __setattr__(self, name: str, value: int) -> None: ...
# This will allow access to any A.x, if x is compatible with the return type
def __getattr__(self, name: str) -> int: ...
a.foo = 42 # Works
a.bar = 'Ex-parrot' # Fails type checking
__setattr__
函數可以為類別新增實體變數。
torch.types
PyTorch中自定義的類型。
torch/types.py
import torch
from typing import Any, List, Sequence, Tuple, Union
import builtins
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
# In some cases, these basic types are shadowed by corresponding
# top-level values. The underscore variants let us refer to these
# types. See https://github.com/python/mypy/issues/4146 for why these
# workarounds is necessary
_int = builtins.int
_float = builtins.float
_bool = builtins.bool
_dtype = torch.dtype
_device = torch.device
_qscheme = torch.qscheme
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
_layout = torch.layout
_dispatchkey = Union[str, torch._C.DispatchKey]
class SymInt:
pass
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]
# Meta-type for "device-like" things. Not to be confused with 'device' (a
# literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)
Device = Union[_device, str, _int, None]
# Storage protocol implemented by ${Type}StorageBase classes
class Storage(object):
_cdata: int
device: torch.device
dtype: torch.dtype
_torch_load_uninitialized: bool
def __deepcopy__(self, memo) -> 'Storage':
...
def _new_shared(self, int) -> 'Storage':
...
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:
...
def element_size(self) -> int:
...
def is_shared(self) -> bool:
...
def share_memory_(self) -> 'Storage':
...
def nbytes(self) -> int:
...
def cpu(self) -> 'Storage':
...
def data_ptr(self) -> int:
...
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':
...
def _new_with_file(self, f: Any, element_size: int) -> 'Storage':
...
...
builtins
torch.types
中的_int
, _float
, _bool
就是Python內建的builtins.int
, builtins.float
, builtins.bool
。
_size
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
torch.Size
torch/_C/__init__.pyi
# Defined in torch/csrc/Size.cpp
class Size(Tuple[_int, ...]):
# TODO: __reduce__
@overload # type: ignore[override]
def __getitem__(self: Size, key: _int) -> _int: ...
@overload
def __getitem__(self: Size, key: slice) -> Size: ...
def numel(self: Size) -> _int: ...
torch/csrc/Size.cpp
這裡將torch.Size
和THPSizeType
綁定:
PyTypeObject THPSizeType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch.Size", /* tp_name */
sizeof(THPSize), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPSize_repr, /* tp_repr */
nullptr, /* tp_as_number */
&THPSize_as_sequence, /* tp_as_sequence */
&THPSize_as_mapping, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPSize_methods, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
&PyTuple_Type, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPSize_pynew, /* tp_new */
};
SymInt
SymInt
為空的類別?
Number
PyTorch中定義的Number
則是_int
, _float
, _bool
中的其中一個。
builtins
This module provides direct access to all ‘built-in’ identifiers of Python; for example, builtins.open is the full name for the built-in function open().
可以透過builtins
這個模組存取Python內建的identifier,例如Python中的open()
函數可以使用builtins.open
來存取。
參數前的*
參考What does the Star operator mean in Python?
Single asterisk as used in function declaration allows variable number of arguments passed from calling environment. Inside the function it behaves as a tuple.
在函數參數前加上*
表示可以接受任意個參數,在函數內部,該參數會被當成一個tuple。
def function(*arg):
print (type(arg))
for i in arg:
print (i)
function(1,2,3)
# <class 'tuple'>
# 1
# 2
# 3