PyTorch檔案生成機制中的FileManager.write_with_template

前言

PyTorch中有些檔案是在編譯過程中跑腳本生成的,如.pyi檔是由.pyi.in檔生成,torch/csrc/autograd/generated目錄下的.cpp檔則是由tools/autograd/templates下的template .cpp檔生成的。

它們底層都是調用FileManager.write_with_template函數,其功能是對原檔案中的特定字串依照callback function所指示的方式做替換,進而生成對應的.pyi.cpp檔。

本文會先查看FileManager.write_with_template函數是如何被調用的,再細看它的實現。

FileManager.write_with_template調用

gen_pyi

tools/pyi/gen_pyi.py

main函數中調用了gen_pyi

def main() -> None:
    parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch")
    parser.add_argument(
        "--native-functions-path",
        metavar="NATIVE",
        default="aten/src/ATen/native/native_functions.yaml",
        help="path to native_functions.yaml",
    )
    parser.add_argument(
        "--tags-path",
        metavar="TAGS",
        default="aten/src/ATen/native/tags.yaml",
        help="path to tags.yaml",
    )
    parser.add_argument(
        "--deprecated-functions-path",
        metavar="DEPRECATED",
        default="tools/autograd/deprecated.yaml",
        help="path to deprecated.yaml",
    )
    parser.add_argument(
        "--out", metavar="OUT", default=".", help="path to output directory"
    )
    args = parser.parse_args()
    fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False)
    gen_pyi(
        args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm
    )


if __name__ == "__main__":
    main()

首先創建一個FileManager物件fm,它的前兩個參數如下:

  • install_dir:預設為’.’
  • template_dir:‘.’

接著調用fm.write_with_template

    fm.write_with_template(
        "torch/_C/__init__.pyi",
        "torch/_C/__init__.pyi.in",
        lambda: {
            "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_VF.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/return_types.pyi",
        "torch/_C/return_types.pyi.in",
        lambda: {
            "generated_comment": "@" + "generated from torch/_C/return_types.pyi",
            **env,
        },
    )
    gen_nn_functional(fm)

此處的四個fm.write_with_template會由torch/_C資料夾下的四個.pyi.in檔生成torch/_C資料夾下的__init__.pyi, _VariableFunctions.pyitorch資料夾下的_VF.pyi, return_types.pyi

最後還有一行gen_nn_functional,當中也會調用fm.write_with_template

gen_nn_functional

tools/pyi/gen_pyi.py

def gen_nn_functional(fm: FileManager) -> None:
    # ...
    fm.write_with_template(
        "torch/nn/functional.pyi",
        "torch/nn/functional.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )
    # ...
    fm.write_with_template(
        "torch/_C/_nn.pyi",
        "torch/_C/_nn.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )

此處的兩個fm.write_with_template會由torch/nn/functional.pyi.intorch/_C/_nn.pyi.in生成torch/nn/functional.pyitorch/_C/_nn.pyi.in

write_sharded

tools/autograd/gen_python_functions.py

gen_python_functions.gen中首先創建一個FileManager物件:

def gen(
    out: str,
    native_yaml_path: str,
    tags_yaml_path: str,
    deprecated_yaml_path: str,
    template_path: str,
    *,
    symint: bool = True,
) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)

它的前兩個參數如下:

  • install_dir:./torch/csrc/autograd/generated
  • template_dir:tools/autograd/templates

經過tools/autograd/gen_python_functions.py中的create_python_bindings_sharded後來到FileManager.write_sharded

torchgen/utils.py

    def write_sharded(
        self,
        filename: str,
        items: Iterable[T],
        *,
        key_fn: Callable[[T], str],
        env_callable: Callable[[T], Dict[str, List[str]]],
        num_shards: int,
        base_env: Optional[Dict[str, Any]] = None,
        sharded_keys: Set[str],
    ) -> None:
        #...
        for shard in all_shards:
            shard_id = shard["shard_id"]
            self.write_with_template(
                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
            )
        #...

其中的all_shards為:

[{'shard_id': 'Everything'}, {'shard_id': '_0'}, {'shard_id': '_1'}, {'shard_id': '_2'}]

所以這裡的write_with_template會由filenamepython_torch_functions.cpp生成python_torch_functionsEverything.cpp, python_torch_functions_0.cpp, python_torch_functions_1.cpppython_torch_functions_2.cpp四個檔案。

注意到上面三個例子中,write_with_template的第三個參數(env_callable)都是一個呼叫後會返回dict的lambda函數,它就是在生成過程中所用到的替換函數。

在正式進入FileManager.write_with_template之前,先來看看它的建構子。

FileManager建構子

torchgen/utils.py

# A small abstraction for writing out generated files and keeping track
# of what files have been written (so you can write out a list of output
# files)
class FileManager:
    install_dir: str
    template_dir: str
    dry_run: bool
    filenames: Set[str]

    def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
        self.install_dir = install_dir
        self.template_dir = template_dir
        self.filenames = set()
        self.dry_run = dry_run
    # ...

它的前兩個參數為:

  • install_dir:要將生成的檔案放在哪個目錄
  • template_dir:做為輸入的模板檔案所在的目錄

FileManager.write_with_template實現

torchgen/utils.py

FileManager.write_with_template

write_with_template除了self以外有三個參數:

  • filename:生成的.pyi的檔名或.cpp的檔名
  • template_fn:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • env_callable:在做替換時會用到的callback function
    def write_with_template(
        self,
        filename: str,
        template_fn: str,
        env_callable: Callable[[], Union[str, Dict[str, Any]]],
    ) -> None:
        filename = "{}/{}".format(self.install_dir, filename)
        assert filename not in self.filenames, "duplicate file write {filename}"
        self.filenames.add(filename)
        if not self.dry_run:
            substitute_out = self.substitute_with_template(
                template_fn=template_fn,
                env_callable=env_callable,
            )
            self._write_if_changed(filename=filename, contents=substitute_out)

傳入的filename是相對路徑,前面加上self.install_dir後才是完整路徑(更精確應該說是相對於PyTorch根目錄的路徑)。在gen_pyigen_nn_functional中,install_dir是’.',傳入的filename則是完整路徑;反之,在write_sharded中,install_dir./torch/csrc/autograd/generated,傳入的filename則是python_torch_functions_i.cpp,兩個合併起來才是完整路徑。

可以看到這段代碼最核心的內容就是調用substitute_with_template生成substitute_out

之後再將替換後的結果,也就是substitute_out寫入filename.pyi檔)這個檔案中。

注:在做類型檢查時,callback function是由typing.Callable表示的,詳見Python typing函式庫和torch.types

FileManager.substitute_with_template

torchgen/utils.py

self外有兩個參數:

  • template_fn:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • env_callable:在做替換時會用到的callback function
    # Read from template file and replace pattern with callable (type could be dict or str).
    def substitute_with_template(
        self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
    ) -> str:
        template_path = os.path.join(self.template_dir, template_fn)
        env = env_callable()
        if isinstance(env, dict):
            # TODO: Update the comment reference to the correct location
            if "generated_comment" not in env:
                comment = "@" + "generated by torchgen/gen.py"
                comment += " from {}".format(os.path.basename(template_path))
                env["generated_comment"] = comment
            template = _read_template(template_path)
            return template.substitute(env)
        elif isinstance(env, str):
            return env
        else:
            assert_never(env)

env_callable是一個呼叫後會返回dict的lambda函數,所以會進入isinstance(env, dict)這個分支,先由_read_template讀入template檔案(.pyi.in檔或template .cpp檔)後調用template.substitute

_read_template

torchgen/utils.py

參數template_fnpyi或template cpp的檔名。

@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
    return CodeTemplate.from_file(template_fn)

讀入template_fn,生成CodeTemplate物件並回傳。

torchgen/code_template.py

CodeTemplate

torchgen/code_template.py

先來看看CodeTemplate類別的作用。

# match $identifier or ${identifier} and replace with value in env
# If this identifier is at the beginning of whitespace on a line
# and its value is a list then it is treated as
# block substitution by indenting to that depth and putting each element
# of the list on its own line
# if the identifier is on a line starting with non-whitespace and a list
# then it is comma separated ${,foo} will insert a comma before the list
# if this list is not empty and ${foo,} will insert one after.


class CodeTemplate:
    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
    substitution = re.compile(substitution_str, re.MULTILINE)

    pattern: str
    filename: str
    
    # ...

注釋裡說明了CodeTemplate的功用是把模板中${identifier}字樣替換成env中對應的value。

torch/_C/_VariableFunctions.pyi.in中就有以下字樣:

# ${generated_comment}
# ...
${function_hints}

${all_directive}

python_torch_functions.cpp中則有以下字樣:

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
    
// ...
// generated forward declarations start here

${py_forwards}

// ...
static PyMethodDef torch_functions_shard[] = {
  ${py_method_defs}
};

// ...
// generated methods start here

${py_methods}

CodeTemplate.from_file

torchgen/code_template.py

class CodeTemplate:
    # ...

    @staticmethod
    def from_file(filename: str) -> "CodeTemplate":
        with open(filename, "r") as f:
            return CodeTemplate(f.read(), filename)
        
    # ...

調用CodeTemplate的建構子,傳入filename的內容及名稱。

CodeTemplate._init_

  • filename:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • pattern:在CodeTemplate.from_file中是以CodeTemplate(f.read(), filename)調用CodeTemplate建構子,所以pattern成員變數會被設為從filename檔案裡讀出來的東西
class CodeTemplate:
    # ...
    
    def __init__(self, pattern: str, filename: str = "") -> None:
        self.pattern = pattern
        self.filename = filename
        
    # ...

substitute

torchgen/code_template.py

回顧torchgen/utils.pysubstitute_with_template中的:

            template = _read_template(template_path)

生成了CodeTemplate物件template後繼續調用:

            return template.substitute(env)

其功能是做一些正則替換:

class CodeTemplate:
    # ...
    def substitute(
        self, env: Optional[Mapping[str, object]] = None, **kwargs: object
    ) -> str:
        if env is None:
            env = {}

        def lookup(v: str) -> object:
            assert env is not None
            return kwargs[v] if v in kwargs else env[v]

        def indent_lines(indent: str, v: Sequence[object]) -> str:
            return "".join(
                [indent + l + "\n" for e in v for l in str(e).splitlines()]
            ).rstrip()

        def replace(match: Match[str]) -> str:
            indent = match.group(1)
            key = match.group(2)
            comma_before = ""
            comma_after = ""
            if key[0] == "{":
                key = key[1:-1]
                if key[0] == ",":
                    comma_before = ", "
                    key = key[1:]
                if key[-1] == ",":
                    comma_after = ", "
                    key = key[:-1]
            v = lookup(key)
            if indent is not None:
                if not isinstance(v, list):
                    v = [v]
                return indent_lines(indent, v)
            elif isinstance(v, list):
                middle = ", ".join([str(x) for x in v])
                if len(v) == 0:
                    return middle
                return comma_before + middle + comma_after
            else:
                return str(v)

        return self.substitution.sub(replace, self.pattern)

函數最後的self.substitution.sub(replace, self.pattern)中的self.substitutionCodeTemplate的成員:

    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
    substitution = re.compile(substitution_str, re.MULTILINE)

re.compile後得到的substitution是一個re.Pattern物件。

先來看看re.Pattern.sub是什麼,參考Passing a function to re.sub in PythonPython: re.compile and re.sub中給出的例子:

import re
substitution = re.compile(r'\d')
number_mapping = {'1': 'one', '2': 'two', '3': 'three'}
s = "1 testing 2 3"
substitution.sub(lambda x: number_mapping[x.group()], s) # 'one testing two three'

re.Pattern.sub的第一個參數是做替換的函數,第二個參數則是欲處理的字串,它會尋找特定樣式的字串(此處是r'\d'),對它們做替換後回傳。

所以self.substitution.sub(replace, self.pattern)這句是在self.pattern(也就是pyi.in或template cpp檔中的內容)中尋找substitution_str樣式的字串,並用replace這個函數所指定的方式做替換。

得到替換後的結果後,回到substitute_with_template函數:

            return template.substitute(env)

那裡繼續將結果回傳,來到write_with_template函數:

            substitute_out = self.substitute_with_template(
                template_fn=template_fn,
                env_callable=env_callable,
            )
            self._write_if_changed(filename=filename, contents=substitute_out)

在那裡會把替換結果substitute_out寫入filename,也就是生成的.pyi的檔名或.cpp的檔名。

來看看torch/_C/_VariableFunctions.pyi中的${generated_comment}

回顧gen_pyi函數中呼叫write_with_template時,與env一同傳入了generated_comment的key value pair:

    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )

所以到了substitute函數,env參數便是一個包含generated_comment的key value pair的字典。

# ${generated_comment}在做替換後,會變成生成的torch/_C/_VariableFunctions.pyi檔案中的第一行:

# @generated from torch/_C/_VariableFunctions.pyi.in
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值