Python typing函式庫和torch.types

前言

在PyTorch的torch/_C/_VariableFunctions.pyi中有如下代碼:

@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...

當中的Sequence, Iterable, Optional, Union以及_int, _bool都是什麼意思呢?可以從torch/_C/_VariableFunctions.pyi.in中一窺端倪:

import builtins
from typing import (
    Any,
    Callable,
    ContextManager,
    Iterator,
    List,
    Literal,
    NamedTuple,
    Optional,
    overload,
    Sequence,
    Tuple,
    TypeVar,
    Union,
)

import torch
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor
from torch.types import (
    _bool,
    _device,
    _dtype,
    _float,
    _int,
    _layout,
    _qscheme,
    _size,
    Device,
    Number,
    SymInt,
)

所以Sequence, Iterable, Optional, Union等是從一個叫做typing的庫中導入的。typing是Python的標準庫之一,作用是提供對類型提示的運行時支持。

_int, _bool等則是PyTorch中自行定義的類型。

typing

Sequence vs Iterable

根據Type hints cheat sheet - Standard “duck types”Sequence代表的是支持__len____getitem__方法的序列類型,例如list, tuple和str。dict和set則不屬於此類型。

# Use Iterable for generic iterables (anything usable in "for"),
# and Sequence where a sequence (supporting "len" and "__getitem__") is
# required

根據Python Iterable vs Sequence

Iterable代表的是支持__iter____getitem__的類型,如rangereversed

r = range(4)
r.__getitem__(0) # 0
r.__iter__() # <range_iterator object at 0x0000015AE7945D30>
l = [1, 2, 3]
rv = reversed(l)
rv.__iter__() # <list_reverseiterator object at 0x0000015AE7980E20>
rv.__getitem__() # 不支援__getitem__方法,但因為支持__iter__所以依然可以歸類為Iterable
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# AttributeError: 'list_reverseiterator' object has no attribute '__getitem__'

因為Sequence也具有__iter____getitem__,所以根據定義,所有的Sequence都是Iterable

l = []
l.__iter__ # <method-wrapper '__iter__' of list object at 0x7f15bb50b5c0>
l.__getitem__ # <built-in method __getitem__ of list object at 0x7f15bb50b5c0>

Callable

typing - Callable

Callable
Frameworks expecting callback functions of specific signatures might be type hinted using Callable[[Arg1Type, Arg2Type], ReturnType].

文檔寫得很淺顯易懂,不過有一點要注意的是入參型別要用[]括起來。

Type hints cheat sheet - Functions中給出了例子:

# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f

如果先不看類型提示的代碼,這句其實就是x = f,把x這個變數設定為f這個函數。當中的Callable[[int, float], float]說明了f是一個接受int, float,輸出float的函數。

Union

typing - Union

typing.Union
Union type; Union[X, Y] is equivalent to X | Y and means either X or Y.

To define a union, use e.g. Union[int, str] or the shorthand int | str. Using that shorthand is recommended.

Union[X, Y]表示型別可以是XY,從Python 3.10以後,可以使用X | Y這種更簡潔的寫法。

Type hints cheat sheet - Useful built-in types中給出的例子:

# On Python 3.10+, use the | operator when something could be one of a few types
x: list[int | str] = [3, 5, "test", "fun"]  # Python 3.10+
# On earlier versions, use Union
x: list[Union[int, str]] = [3, 5, "test", "fun"]

Optional

typing - Optional

Optional type.

Optional[X] is equivalent to X | None (or Union[X, None]).

Optional[X]表示該變數可以是X型別或是None型別。

Type hints cheat sheet - Useful built-in types中給出了一個很好的例子:

# Use Optional[X] for a value that could be None
# Optional[X] is the same as X | None or Union[X, None]
x: Optional[str] = "something" if some_condition() else None

這裡x根據some_condition()的回傳值有可能是一個字串或是None,所以此處選用Optional[str]的類型提示。

Functions

mypy - Functions

指定參數和回傳值型別:

from typing import Callable, Iterator, Union, Optional

# This is how you annotate a function definition
def stringify(num: int) -> str:
    return str(num)

多個參數:

# And here's how you specify multiple arguments
def plus(num1: int, num2: int) -> int:
    return num1 + num2

無回傳值的函數以None為回傳型別,並且參數的預設值應寫在參數型別後面:

# If a function does not return a value, use None as the return type
# Default value for an argument goes after the type annotation
def show(value: str, excitement: int = 10) -> None:
    print(value + "!" * excitement)

可以接受任意型別參數的函數則不必指定參數型別:

# Note that arguments without a type are dynamically typed (treated as Any)
# and that functions without any annotations not checked
def untyped(x):
    x.anything() + 1 + "string"  # no errors

Callable

Callable當作參數的函數:

# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f
def register(callback: Callable[[str], int]) -> None: ...

Iterator/generator

generator函數相當於一個Iterator

# A generator function that yields ints is secretly just a function that
# returns an iterator of ints, so that's how we annotate it
def gen(n: int) -> Iterator[int]:
    i = 0
    while i < n:
        yield i
        i += 1

將function annotation分成多行:

# You can of course split a function annotation over multiple lines
def send_email(address: Union[str, list[str]],
               sender: str,
               cc: Optional[list[str]],
               bcc: Optional[list[str]],
               subject: str = '',
               body: Optional[list[str]] = None
               ) -> bool:
    ...

位置參數 & 關鍵字參數

# Mypy understands positional-only and keyword-only arguments
# Positional-only arguments can also be marked by using a name starting with
# two underscores
def quux(x: int, /, *, y: int) -> None:
    pass

quux(3, y=5)  # Ok
quux(3, 5)  # error: Too many positional arguments for "quux"
quux(x=3, y=5)  # error: Unexpected keyword argument "x" for "quux"

注意到此處參數列表中有/*兩個符號,參考What Are Python Asterisk and Slash Special Parameters For?

Left sideDividerRight side
Positional-only arguments/Positional or keyword arguments
Positional or keyword arguments*Keyword-only arguments

Python的參數分為三種:位置參數,關鍵字參數及可變參數(可以透過位置或關鍵字的方式傳遞)。

/符號的左邊必須是位置參數,*符號的右邊則必須是關鍵字參數。

所以上例中x必須以位置參數的方式傳遞,y必須以關鍵字參數的方式傳遞。

一次指定多個參數的型別:

# This says each positional arg and each keyword arg is a "str"
def call(self, *args: str, **kwargs: str) -> str:
    reveal_type(args)  # Revealed type is "tuple[str, ...]"
    reveal_type(kwargs)  # Revealed type is "dict[str, str]"
    request = make_request(*args, **kwargs)
    return self.do_api_query(request)

Classes

mypy - Classes

self

class BankAccount:
    # The "__init__" method doesn't return anything, so it gets return
    # type "None" just like any other method that doesn't return anything
    def __init__(self, account_name: str, initial_balance: int = 0) -> None:
        # mypy will infer the correct types for these instance variables
        # based on the types of the parameters.
        self.account_name = account_name
        self.balance = initial_balance

    # For instance methods, omit type for "self"
    def deposit(self, amount: int) -> None:
        self.balance += amount

    def withdraw(self, amount: int) -> None:
        self.balance -= amount

成員函數self參數的型別不需指定。

自定義類別

可以將變數型別指定為自定義的類別:

# User-defined classes are valid as types in annotations
account: BankAccount = BankAccount("Alice", 400)
def transfer(src: BankAccount, dst: BankAccount, amount: int) -> None:
    src.withdraw(amount)
    dst.deposit(amount)
# Functions that accept BankAccount also accept any subclass of BankAccount!
class AuditedBankAccount(BankAccount):
    # You can optionally declare instance variables in the class body
    audit_log: list[str]

    def __init__(self, account_name: str, initial_balance: int = 0) -> None:
        super().__init__(account_name, initial_balance)
        self.audit_log: list[str] = []

    def deposit(self, amount: int) -> None:
        self.audit_log.append(f"Deposited {amount}")
        self.balance += amount

    def withdraw(self, amount: int) -> None:
        self.audit_log.append(f"Withdrew {amount}")
        self.balance -= amount

audited = AuditedBankAccount("Bob", 300)
transfer(audited, account, 100)  # type checks!

transfer函數的第一個參數型別應為BankAccount,而AuditedBankAccountBankAccount的子類別,所以在做類型檢查時不會出錯。

ClassVar

Python中類別的變數有類別變數別實例變數兩種。如果想要將成員變數標記為類別變數,可以用ClassVar[type]

# You can use the ClassVar annotation to declare a class variable
class Car:
    seats: ClassVar[int] = 4
    passengers: ClassVar[list[str]]

__setattr__ 與 __getattr__

# If you want dynamic attributes on your class, have it
# override "__setattr__" or "__getattr__"
class A:
    # This will allow assignment to any A.x, if x is the same type as "value"
    # (use "value: Any" to allow arbitrary types)
    def __setattr__(self, name: str, value: int) -> None: ...

    # This will allow access to any A.x, if x is compatible with the return type
    def __getattr__(self, name: str) -> int: ...

a.foo = 42  # Works
a.bar = 'Ex-parrot'  # Fails type checking

__setattr__函數可以為類別新增實體變數。

torch.types

PyTorch中自定義的類型。

torch/types.py

import torch
from typing import Any, List, Sequence, Tuple, Union

import builtins

# Convenience aliases for common composite types that we need
# to talk about in PyTorch

_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]

# In some cases, these basic types are shadowed by corresponding
# top-level values.  The underscore variants let us refer to these
# types.  See https://github.com/python/mypy/issues/4146 for why these
# workarounds is necessary
_int = builtins.int
_float = builtins.float
_bool = builtins.bool

_dtype = torch.dtype
_device = torch.device
_qscheme = torch.qscheme
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
_layout = torch.layout
_dispatchkey = Union[str, torch._C.DispatchKey]

class SymInt:
    pass

# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]

# Meta-type for "device-like" things.  Not to be confused with 'device' (a
# literal device object).  This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)
Device = Union[_device, str, _int, None]

# Storage protocol implemented by ${Type}StorageBase classes

class Storage(object):
    _cdata: int
    device: torch.device
    dtype: torch.dtype
    _torch_load_uninitialized: bool

    def __deepcopy__(self, memo) -> 'Storage':
        ...

    def _new_shared(self, int) -> 'Storage':
        ...

    def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:
        ...

    def element_size(self) -> int:
        ...

    def is_shared(self) -> bool:
        ...

    def share_memory_(self) -> 'Storage':
        ...

    def nbytes(self) -> int:
        ...

    def cpu(self) -> 'Storage':
        ...

    def data_ptr(self) -> int:
        ...

    def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':
        ...

    def _new_with_file(self, f: Any, element_size: int) -> 'Storage':
        ...

    ...

builtins

torch.types中的_int, _float, _bool就是Python內建的builtins.int, builtins.float, builtins.bool

_size

_size = Union[torch.Size, List[_int], Tuple[_int, ...]]

torch.Size

torch/_C/__init__.pyi

# Defined in torch/csrc/Size.cpp
class Size(Tuple[_int, ...]):
    # TODO: __reduce__

    @overload  # type: ignore[override]
    def __getitem__(self: Size, key: _int) -> _int: ...
    @overload
    def __getitem__(self: Size, key: slice) -> Size: ...
    def numel(self: Size) -> _int: ...

torch/csrc/Size.cpp

這裡將torch.SizeTHPSizeType綁定:

PyTypeObject THPSizeType = {
    PyVarObject_HEAD_INIT(nullptr, 0) "torch.Size", /* tp_name */
    sizeof(THPSize), /* tp_basicsize */
    0, /* tp_itemsize */
    nullptr, /* tp_dealloc */
    0, /* tp_vectorcall_offset */
    nullptr, /* tp_getattr */
    nullptr, /* tp_setattr */
    nullptr, /* tp_reserved */
    (reprfunc)THPSize_repr, /* tp_repr */
    nullptr, /* tp_as_number */
    &THPSize_as_sequence, /* tp_as_sequence */
    &THPSize_as_mapping, /* tp_as_mapping */
    nullptr, /* tp_hash  */
    nullptr, /* tp_call */
    nullptr, /* tp_str */
    nullptr, /* tp_getattro */
    nullptr, /* tp_setattro */
    nullptr, /* tp_as_buffer */
    Py_TPFLAGS_DEFAULT, /* tp_flags */
    nullptr, /* tp_doc */
    nullptr, /* tp_traverse */
    nullptr, /* tp_clear */
    nullptr, /* tp_richcompare */
    0, /* tp_weaklistoffset */
    nullptr, /* tp_iter */
    nullptr, /* tp_iternext */
    THPSize_methods, /* tp_methods */
    nullptr, /* tp_members */
    nullptr, /* tp_getset */
    &PyTuple_Type, /* tp_base */
    nullptr, /* tp_dict */
    nullptr, /* tp_descr_get */
    nullptr, /* tp_descr_set */
    0, /* tp_dictoffset */
    nullptr, /* tp_init */
    nullptr, /* tp_alloc */
    THPSize_pynew, /* tp_new */
};

SymInt

SymInt為空的類別?

Number

PyTorch中定義的Number則是_int, _float, _bool中的其中一個。

builtins

builtins — Built-in objects

This module provides direct access to all ‘built-in’ identifiers of Python; for example, builtins.open is the full name for the built-in function open().

可以透過builtins這個模組存取Python內建的identifier,例如Python中的open()函數可以使用builtins.open來存取。

參數前的*

參考What does the Star operator mean in Python?

Single asterisk as used in function declaration allows variable number of arguments passed from calling environment. Inside the function it behaves as a tuple.

在函數參數前加上*表示可以接受任意個參數,在函數內部,該參數會被當成一個tuple。

def function(*arg):
    print (type(arg))
    for i in arg:
      print (i)
function(1,2,3)
# <class 'tuple'>
# 1
# 2
# 3
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值