MindQuantum 0.7.0 源代码阅读理解(0)
因本人在研究中需要对 MindQuantum 源代码进行改写,以实现所需功能,在接下来的几篇博文里,我会以笔记的形式充分阅读并理解其源代码,欢迎读者一起学习。本人能力有限,肯定不乏错误,建议读者在参考时要时刻保持批判的态度。
关于 MindQuantum 的安装、API、版本信息等,请参考 MindQuantum 官网。
1. utils (常用工具)文件
该部分是源代码中常见的内容,下面先解决他们,打好基础。
1.1 type_value_check.py
用于检查 类型 和 数值
import numbers
import numpy as np
_num_type = (int, float, complex, np.int32, np.int64, np.float32, np.float64)
# 间接检查输入类型 dtype 是否为合法的 numpy 的 dtype,如果是的话,就什么都不返回,否则报错。
def _check_np_dtype(dtype):
"""Check dtype is a valid numpy dtype."""
np.array([0], dtype=dtype)
# 检查是否为有效的种子数,需要为整数,且在有效区间内
def _check_seed(seed):
"""Check seed."""
_check_int_type("seed", seed)
_check_value_should_between_close_set("seed", 0, 2**23, seed)
# 检查是输入 arg 是否为指定类型 require_type, arg_ms 是描述 arg 的字符串
def _check_input_type(arg_msg, require_type, arg):
"""Check input type."""
if not isinstance(arg, require_type):
raise TypeError(f"{arg_msg} requires a {require_type}, but get {type(arg)}")
# 检查输入 arg 是否为整数,args_msg 为描述 arg 的字符串
def _check_int_type(args_msg, arg):
"""Check int type."""
if not isinstance(arg, (int, np.int64)) or isinstance(arg, bool):
raise TypeError(f"{args_msg} requires an int, but get {type(arg)}")
# 检查输入 arg 是否小于指定数值 require_value,args_msg 为描述 arg 的字符串
def _check_value_should_not_less(arg_msg, require_value, arg):
"""Check value should not less."""
if arg < require_value:
raise ValueError(f'{arg_msg} should be not less than {require_value}, but get {arg}')
# 检查输入数值 arg 是否在指定的区间内。arg_ms 是描述 arg 的字符串
def _check_value_should_between_close_set(arg_ms, min_value, max_value, arg):
"""Check value should between."""
if arg < min_value or arg > max_value:
raise ValueError(f"{arg_ms} should between {min_value} and {max_value}, but get {arg}")
# 检查并生成合法的参数, pr 参数, names 参数名。
# 合法的参数可以为 ParameterResolver, np.ndarray, list, dict
def _check_and_generate_pr_type(pr, names=None):
"""Check and generate PR type."""
# pylint: disable=import-outside-toplevel,cyclic-import
from ..core.parameterresolver import ParameterResolver
if isinstance(pr, _num_type): # 检查输入参数 pr 是否为单个的合法数值。当 pr 为单个数值时,参数名 names 也应唯一。之后将单个数值的 pr 修饰为单元素数组。
if len(names) != 1:
raise ValueError(f"number of given parameters value is less than parameters ({len(names)})")
pr = np.array([pr])
_check_input_type('parameter', (ParameterResolver, np.ndarray, list, dict), pr)
if isinstance(pr, dict):
pr = ParameterResolver(pr)
elif isinstance(pr, (np.ndarray, list)):
pr = np.array(pr)
if len(pr) != len(names) or len(pr.shape) != 1:
raise ValueError(f"given parameter value size ({pr.shape}) not match with parameter size ({len(names)})")
pr = ParameterResolver(dict(zip(names, pr)))
if isinstance(pr, ParameterResolver):
if names is not None:
for n in names:
if n not in pr:
raise ValueError(f"Parameter {n} not in given parameter resolver.")
return pr
# 检查是否为数值
def _check_number_type(arg_msg, arg):
"""Check number type."""
if not isinstance(arg, numbers.Number):
raise TypeError(f"{arg_msg} requires a number, but get {type(arg)}")
# 检查门的类型是否属于基本门
def _check_gate_type(gate):
# pylint: disable=import-outside-toplevel,cyclic-import
from ..core.gates.basic import BasicGate
if not isinstance(gate, BasicGate):
raise TypeError(f"Require a quantum gate, but get {type(gate)}")
# 检查量子门是否有目标比特(Barrier 门是可以没有目标比特的)
def _check_gate_has_obj(gate):
from ..core.gates.basicgate import ( # pylint: disable=import-outside-toplevel,cyclic-import
BarrierGate,
)
if not isinstance(gate, BarrierGate):
if not gate.obj_qubits:
raise ValueError("Gate shuould act on some qubits first.")
# 检查比特序号是否为自然数
def _check_qubit_id(qubit_id):
if not isinstance(qubit_id, (int, np.int64)):
raise TypeError(f"Qubit should be a non negative int, but get {type(qubit_id)}!")
if qubit_id < 0:
raise ValueError(f"Qubit should be non negative int, but get {qubit_id}!")
# 检查控制比特和目标比特是否合法:不能相同,彼此以及各自是否重复
def _check_obj_and_ctrl_qubits(obj_qubits, ctrl_qubits):
if set(obj_qubits) & set(ctrl_qubits):
raise ValueError("obj_qubits and ctrl_qubits cannot have same qubits.")
if len(set(obj_qubits)) != len(obj_qubits):
raise ValueError("obj_qubits cannot have same qubits")
if len(set(ctrl_qubits)) != len(ctrl_qubits):
raise ValueError("ctrl_qubits cannot have same qubits")
# 检查控制比特的数量
def _check_control_num(ctrl_qubits, require_n):
from .quantifiers import ( # pylint: disable=import-outside-toplevel,cyclic-import
s_quantifier,
)
if len(ctrl_qubits) != require_n:
raise RuntimeError(f"requires {s_quantifier(require_n,'control qubit')}, but get {len(ctrl_qubits)}")