Action
先验知识:
SerializationMixin:
Pydantic在序列化一个对象时,会将其序列化为其最顶层父类的形式,而不是它实际子类的形式。这意味着,如果你有一个父类和一个继承自该父类的子类,当你将子类的实例序列化时,得到的字典将只包含父类的字段,而不会包含子类特有的字段。同样,在反序列化时,Pydantic也无法根据数据内容自动选择正确的子类来实例化,而是只能实例化父类。
因此定义了一个名为SerializationMixin的Python类,用于在Pydantic模型中实现多态序列化和反序列化的混合类
@classmethod
def __get_pydantic_core_schema__(
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
# 调用传入的handler函数,获取模型的默认核心架构
schema = handler(source)
# 保存原始的核心架构引用,这个引用是Pydantic用于识别模型的一个唯一标识
og_schema_ref = schema["ref"]
# 在原始的核心架构引用后面添加一个后缀`:mixin`,以便在序列化和反序列化过程中能够识别这个被修改过的架构
schema["ref"] += ":mixin"
# 创建一个验证器函数,它将在序列化和反序列化过程中被调用,先于标准的Pydantic验证器执行
return core_schema.no_info_before_validator_function(
cls.__deserialize_with_real_type__, # 这个类方法将在反序列化过程中被调用
schema=schema, # 修改后的核心架构
ref=og_schema_ref, # 原始的核心架构引用
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__), # 包装序列化函数
)
覆盖了Pydantic的__get_pydantic_core_schema__方法,用于自定义模型的序列化和反序列化过程
@classmethod
def __serialize_add_class_type__(
cls,
value,
handler: core_schema.SerializerFunctionWrapHandler,
) -> Any:
# 调用传入的handler函数,这个函数是Pydantic用于序列化模型的默认函数
ret = handler(value)
# 检查当前类是否有子类,如果没有子类,说明它是一个具体的子类而不是基类
if not len(cls.__subclasses__()):
# 只有具体的子类才添加`__module_class_name`字段,这个字段包含了子类的全限定类名
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
# 返回修改后的字典,这个字典将包含额外的类型信息
return ret
在序列化过程中被调用,目的是在序列化过程中添加额外的类型信息,以便在反序列化时能够恢复正确的子类类型
@classmethod
def __deserialize_with_real_type__(cls, value: Any):
# 如果传入的值不是字典类型,直接返回该值,因为只有字典类型的值才可能包含序列化的模型数据
if not isinstance(value, dict):
return value
# 如果当前类不是多态基类,或者有子类且序列化的数据中没有`__module_class_name`字段,
# 直接返回传入的值,不进行特殊处理
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
return value
# 从序列化的数据中获取`__module_class_name`字段的值,这个值是子类的全限定类名
module_class_name = value.get("__module_class_name", None)
# 如果没有找到`__module_class_name`字段,抛出ValueError异常
if module_class_name is None:
raise ValueError("Missing field: __module_class_name")
# 从`__subclasses_map__`中获取与全限定类名对应的类类型
class_type = cls.__subclasses_map__.get(module_class_name, None)
# 如果没有找到对应的类类型,抛出TypeError异常
if class_type is None:
raise TypeError(f"Trying to instantiate {module_class_name} which not defined yet.")
# 使用找到的类类型和传入的数据来实例化子类
return class_type(**value)
目的是在反序列化过程中使用之前序列化时保存的类型信息来实例化正确的子类
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
# 将is_polymorphic_base参数设置为子类的多态基类标志
cls.__is_polymorphic_base = is_polymorphic_base
# 将当前子类添加到__subclasses_map__映射中,以便在反序列化时能够找到正确的子类
# __subclasses_map__是一个字典,用于存储子类与其全限定类名的映射关系
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
# 调用基类的__init_subclass__方法,以便子类可以继承基类的其他设置
super().__init_subclass__(**kwargs)
在定义子类时被自动调用,确保在创建子类时子类被正确地注册到__subclasses_map__中,从而在__deserialize_with_real_type__能找到正确的与全限定类名字、对应的类类型
ContexMixin:
用于处理上下文、配置和大型语言模型的相关操作。这个混入类提供了对上下文和配置的访问和设置方法,以及一个大型语言模型的实例
class ContextMixin(BaseModel):
"""Mixin class for context and config"""
# 定义了一个ConfigDict类型的模型配置,允许任意类型的字段。
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
# Pydantic has bug on _private_attr when using inheritance, so we use private_* instead
# - https://github.com/pydantic/pydantic/issues/7142
# - https://github.com/pydantic/pydantic/issues/7083
# - https://github.com/pydantic/pydantic/issues/7091
# Env/Role/Action will use this context as private context, or use self.context as public context
# 用于存储一个Context类型的私有上下文。
private_context: Optional[Context] = Field(default=None, exclude=True)
# Env/Role/Action will use this config as private config, or use self.context.config as public config
# 用于存储一个Config类型的私有配置。
private_config: Optional[Config] = Field(default=None, exclude=True)
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
# 用于存储一个BaseLLM类型的私有大型语言模型。
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
定义了三个私有字段:private_context、private_config和private_llm,用于存储私有上下文、配置和大型语言模型。
@model_validator(mode="after")
def validate_context_mixin_extra(self):
self._process_context_mixin_extra()
return self
def _process_context_mixin_extra(self):
"""Process the extra field"""
# 从model_extra字段中获取额外的参数,这是一个字典,包含了模型创建时传入的所有额外字段
kwargs = self.model_extra or {}
# 如果context键存在于字典中,则使用self.set_context方法来设置上下文
self.set_context(kwargs.pop("context", None))
# 如果config键存在于字典中,则使用self.set_config方法来设置配置
self.set_config(kwargs.pop("config", None))
# 如果llm键存在于字典中,则使用self.set_llm方法来设置大型语言模型
self.set_llm(kwargs.pop("llm", None))
确保在模型实例化后,能够正确地设置上下文、配置和大型语言模型
def set(self, k, v, override=False):
"""Set attribute"""
# 这个方法用于设置模型的属性。如果override参数为True或者当前没有这个属性的值,
# 它将设置这个属性。
if override or not self.__dict__.get(k):
self.__dict__[k] = v
def set_context(self, context: Context, override=True):
"""Set context"""
# 这个方法用于设置上下文。如果override参数为True或者当前没有上下文,
# 它将设置私有上下文。
self.set("private_context", context, override)
def set_config(self, config: Config, override=False):
"""Set config"""
# 这个方法用于设置配置。如果override参数为True或者当前没有配置,
# 它将设置私有配置。如果配置不为None,它还会初始化LLM。
self.set("private_config", config, override)
if config is not None:
_ = self.llm # init llm
def set_llm(self, llm: BaseLLM, override=False):
"""Set llm"""
# 这个方法用于设置大型语言模型。如果override参数为True或者当前没有大型语言模型,
# 它将设置私有大型语言模型。
self.set("private_llm", llm, override)
@property
def config(self) -> Config:
"""Role config: role config > context config"""
# 这个属性用于获取配置。它首先检查是否有私有配置,如果没有,
# 则从上下文获取配置。
if self.private_config:
return self.private_config
return self.context.config
@config.setter
def config(self, config: Config) -> None:
"""Set config"""
# 这个属性设置器用于设置配置。
self.set_config(config)
@property
def context(self) -> Context:
"""Role context: role context > context"""
# 这个属性用于获取上下文。它首先检查是否有私有上下文,如果没有,
# 则创建一个新的上下文实例。
if self.private_context:
return self.private_context
return Context()
@context.setter
def context(self, context: Context) -> None:
"""Set context"""
# 这个属性设置器用于设置上下文。
self.set_context(context)
@property
def llm(self) -> BaseLLM:
"""Role llm: if not existed, init from role.config"""
# 这个属性用于获取大型语言模型(LLM)。如果私有LLM不存在,
# 它会从角色的配置中初始化一个LLM。
if not self.private_llm:
self.private_llm = self.context.llm_with_cost_manager_from_llm_config(self.config.llm)
return self.private_llm
@llm.setter
def llm(self, llm: BaseLLM) -> None:
"""Set llm"""
# 这个属性设置器用于设置LLM。
self.private_llm = llm
不断设置和获取那三个属性
ProjectRepo:
继承自FileRepository类,FileRepository提供了一系列与文件操作相关的方法,这些操作包括保存文件、获取文件依赖、获取已更改的依赖项、获取文件内容、列出所有文件、保存文档、删除文件等。这些方法主要用于处理存储在Git仓库中的文件,这些功能对于维护Git仓库中的文件和跟踪文件之间的依赖关系非常有用。
GitRepository对象提供了一个全面的接口,用于在Python中与Git仓库交互,包括管理仓库、跟踪变更、提交更改、获取文件列表等
class ProjectRepo(FileRepository):
def __init__(self, root: str | Path | GitRepository):
# 如果传入的root参数是字符串或Path对象,则创建一个新的GitRepository对象
# 如果传入的root参数是一个已存在的GitRepository对象,则直接使用该对象
# 如果传入的root参数无效,则抛出一个ValueError
if isinstance(root, str) or isinstance(root, Path):
git_repo_ = GitRepository(local_path=Path(root))
elif isinstance(root, GitRepository):
git_repo_ = root
else:
raise ValueError("Invalid root")
# 调用父类的构造函数,初始化FileRepository对象
# git_repo_是GitRepository对象,relative_path是相对于Git仓库根目录的相对路径
super().__init__(git_repo=git_repo_, relative_path=Path("."))
# 初始化ProjectRepo的属性
self._git_repo = git_repo_
self.docs = DocFileRepositories(self._git_repo)
self.resources = ResourceFileRepositories(self._git_repo)
self.tests = self._git_repo.new_file_repository(relative_path=TEST_CODES_FILE_REPO)
self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO)
self._srcs_path = None
self.code_files_exists()
def __str__(self):
# 返回一个字符串表示,包括Git仓库的工作目录、文档、资源、测试代码和测试输出
repo_str = f"ProjectRepo({self._git_repo.workdir})"
docs_str = f"Docs({self.docs.all_files})"
srcs_str = f"Srcs({self.srcs.all_files})"
return f"{repo_str}\n{docs_str}\n{srcs_str}"
@property
async def requirement(self):
# 异步获取REQUIREMENT_FILENAME文件的内容,通常是一个依赖列表
return await self.docs.get(filename=REQUIREMENT_FILENAME)
@property
def git_repo(self) -> GitRepository:
# 返回与Git仓库交互的GitRepository对象
return self._git_repo
@property
def workdir(self) -> Path:
# 返回Git仓库的工作目录的路径
return Path(self.git_repo.workdir)
@property
def srcs(self) -> FileRepository:
# 返回一个用于访问源代码文件的FileRepository对象
if not self._srcs_path:
raise ValueError("Call with_srcs first.")
return self._git_repo.new_file_repository(self._srcs_path)
def code_files_exists(self) -> bool:
# 检查Git仓库中是否存在代码文件
git_workdir = self.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
return bool(code_files)
def with_src_path(self, path: str | Path) -> ProjectRepo:
# 设置源代码文件的路径,并返回当前ProjectRepo对象
try:
self._srcs_path = Path(path).relative_to(self.workdir)
except ValueError:
self._srcs_path = Path(path)
return self
@property
def src_relative_path(self) -> Path | None:
# 返回源代码文件的相对路径
return self._srcs_path
Cofig:
提供了一种集中化的方式来存储和管理项目配置,包括从YAML文件加载配置、更新配置、验证LLM密钥等,设置未提前设定的LLM的API需要在Config中增加获取config.yaml配置的代码
Field:
用于为 Pydantic 模型字段提供丰富的配置选项,从而使得模型能够更好地适应不同的数据验证和序列化需求
ActionNode:
用于构建一个表示动作或操作的节点树。这个类的主要目的是为了在编程任务中组织和存储操作步骤,以及它们的参数和期望的输出类型。
def __init__(
self,
key: str,
expected_type: Type,
instruction: str,
example: Any,
content: str = "",
children: dict[str, "ActionNode"] = None,
schema: str = "",
):
"""
初始化 ActionNode 对象。
:param key: 节点的唯一标识符。
:param expected_type: 节点期望的输入类型。
:param instruction: 节点的操作说明或指令。
:param example: 节点的示例输入。
:param content: 节点的输出内容。
:param children: 一个字典,用于存储子节点。
:param schema: 节点的数据格式。
"""
self.key = key
self.expected_type = expected_type
self.instruction = instruction
self.example = example
self.content = content
self.children = children if children is not None else {}
self.schema = schema
def __str__(self):
"""
返回 ActionNode 对象的字符串表示。
字符串包含节点的 key、expected_type、instruction、example、content 和 children。
使用 repr(self.expected_type) 来获取期望类型的字符串表示。
"""
return (
f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}"
f", {self.content}, {self.children}"
)
def __repr__(self):
"""
返回 ActionNode 对象的“官方”字符串表示。
__repr__ 方法默认调用 __str__ 方法,所以它的行为与 __str__ 方法相同。
通常,__repr__ 方法用于在交互式环境中打印对象,或者在文档字符串中作为对象的字符串表示。
"""
return self.__str__()
定义了如何将ActionNode对象转换为字符串表示
def add_child(self, node: "ActionNode"):
"""
向 ActionNode 添加一个子 ActionNode。
:param node: 要添加的子 ActionNode 对象。
"""
self.children[node.key] = node
def add_children(self, nodes: List["ActionNode"]):
"""
批量添加子 ActionNode。
:param nodes: 一个 ActionNode 对象的列表,用于批量添加子节点。
"""
for node in nodes:
self.add_child(node)
@classmethod
def from_children(cls, key, nodes: List["ActionNode"]):
"""
直接从一系列的子 nodes 初始化。
:param key: 新的 ActionNode 对象的 key。
:param nodes: 一个 ActionNode 对象的列表,用于初始化新的 ActionNode 对象。
"""
obj = cls(key, str, "", "")
obj.add_children(nodes)
return obj
添加了三个方法:add_child、add_children和from_children,用于处理节点树中子节点的添加和初始化
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""
获得子 ActionNode 的字典,以 key 索引。
:param exclude: 需要排除的子节点 key 列表。
:return: 一个字典,其中包含子节点的 key 和对应的类型元组。
"""
exclude = exclude or []
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
"""
获取自身的键类型映射。
:return: 一个字典,其中包含自身的 key 和对应的类型元组。
"""
return {self.key: (self.expected_type, ...)}
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""
获取键类型映射。
:param mode: 指定返回的映射类型,可以是 "children" 或 "auto"。
:param exclude: 需要排除的节点 key 列表。
:return: 根据模式返回相应的键类型映射,如果模式不匹配则返回空字典。
"""
if mode == "children" or (mode == "auto" and self.children):
return self.get_children_mapping(exclude=exclude)
return {} if exclude and self.key in exclude else self.get_self_mapping()
获取节点及其子节点的键到类型的映射
@classmethod
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
"""
基于pydantic v1的模型动态生成,用来检验结果类型正确性。
"""
def check_fields(cls, values):
# 检查缺失的字段
required_fields = set(mapping.keys())
missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f"Missing fields: {missing_fields}")
# 检查未识别的字段
unrecognized_fields = set(values.keys()) - required_fields
if unrecognized_fields:
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
return values
# 创建一个带有检查器的新模型类
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
new_class = create_model(class_name, __validators__=validators, **mapping)
return new_class
def create_children_class(self, exclude=None):
"""
使用object内有的字段直接生成model_class。
"""
class_name = f"{self.key}_AN"
mapping = self.get_children_mapping(exclude=exclude)
return self.create_model_class(class_name, mapping)
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
"""
将当前节点与子节点都按照node: format的格式组织成字典。
"""
# 如果没有提供格式化函数,使用默认的格式化方式
if format_func is None:
format_func = lambda node: f"{node.instruction}"
# 使用提供的格式化函数来格式化当前节点的值
formatted_value = format_func(self)
# 创建当前节点的键值对
if mode == "children" or (mode == "auto" and self.children):
node_dict = {}
else:
node_dict = {self.key: formatted_value}
if mode == "root":
return node_dict
# 遍历子节点并递归调用 to_dict 方法
exclude = exclude or []
for _, child_node in self.children.items():
if child_node.key in exclude:
continue
node_dict.update(child_node.to_dict(format_func))
return node_dict
def compile_to(self, i: Dict, schema, kv_sep) -> str:
# 根据给定的模式(json或markdown)将字典转换为字符串
if schema == "json":
return json.dumps(i, indent=4)
elif schema == "markdown":
return dict_to_markdown(i, kv_sep=kv_sep)
else:
return str(i)
def tagging(self, text, schema, tag="") -> str:
# 给文本添加标签,标签的格式取决于模式(json或markdown)
if not tag:
return text
if schema == "json":
return f"[{tag}]\n" + text + f"\n[/{tag}]"
else: # markdown
return f"[{tag}]\n" + text + f"\n[/{tag}]"
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
# 用于编译节点数据
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
text = self.compile_to(nodes, schema, kv_sep)
return self.tagging(text, schema, tag)
def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str:
# 编译节点的指令
"""
compile to raw/json/markdown template with all/root/children nodes
"""
format_func = lambda i: f"{i.expected_type} # {i.instruction}"
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude)
def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str:
# 编译节点的示例
"""
compile to raw/json/markdown examples with all/root/children nodes
"""
# 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example
# 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str
format_func = lambda i: i.example
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude)
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str:
# 主要的编译方法,它将上下文、示例、指令和约束组合成一个模板
"""
mode: all/root/children
mode="children": 编译所有子节点为一个统一模板,包括instruction与example
mode="all": NotImplemented
mode="root": NotImplemented
schmea: raw/json/markdown
schema="raw": 不编译,context, lang_constaint, instruction
schema="json":编译context, example(json), instruction(markdown), constraint, action
schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action
"""
if schema == "raw":
return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
# compile example暂时不支持markdown
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude)
# nodes = ", ".join(self.to_dict(mode=mode).keys())
constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT]
constraint = "\n".join(constraints)
prompt = template.format(
context=context,
example=example,
instruction=instruction,
constraint=constraint,
)
return prompt
class Action(SerializationMixin, ContextMixin, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str = ""
i_context: Union[
dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, CodePlanAndChangeContext, str, None
] = ""
prefix: str = "" # aask* 时会加上 prefix,作为 system_message
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
# 返回一个 ProjectRepo 实例,如果上下文中没有设置 repo
@property
def repo(self) -> ProjectRepo:
if not self.context.repo:
self.context.repo = ProjectRepo(self.context.git_repo)
return self.context.repo
# 返回配置中的 prompt_schema
@property
def prompt_schema(self):
return self.config.prompt_schema
# 返回配置中的 project_name
@property
def project_name(self):
return self.config.project_name
# 设置配置中的 project_name
@project_name.setter
def project_name(self, value):
self.config.project_name = value
# 返回配置中的 project_path
@property
def project_path(self):
return self.config.project_path
# 模型验证器,在验证之前设置 name 为类名,如果未设置
@model_validator(mode="before")
@classmethod
def set_name_if_empty(cls, values):
if "name" not in values or not values["name"]:
values["name"] = cls.__name__
return values
# 模型验证器,在验证之前初始化 node
@model_validator(mode="before")
@classmethod
def _init_with_instruction(cls, values):
if "instruction" in values:
name = values["name"]
i = values.pop("instruction")
values["node"] = ActionNode(key=name, expected_type=str, instruction=i, example="", schema="raw")
return values
# 设置前缀,并更新 system_prompt 和 node 的 llm 属性
def set_prefix(self, prefix):
"""设置前缀,用于后续使用"""
self.prefix = prefix
self.llm.system_prompt = prefix
if self.node:
self.node.llm = self.llm
return self
# 定义 __str__ 方法,返回类的名称
def __str__(self):
return self.__class__.__name__
# 定义 __repr__ 方法,返回类的字符串表示
def __repr__(self):
return self.__str__()
# 定义异步方法 _aask,向语言模型发送请求,并添加默认前缀
async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str:
"""添加默认前缀"""
return await self.llm.aask(prompt, system_msgs)
# 定义异步方法 _run_action_node,运行 action node
async def _run_action_node(self, *args, **kwargs):
"""运行 action node"""
msgs = args[0]
context = "## 历史消息\n"
context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))])
return await self.node.fill(context=context, llm=self.llm)
# 定义异步方法 run,执行动作
async def run(self, *args, **kwargs):
"""执行动作"""
if self.node:
return await self._run_action_node(*args, **kwargs)
raise NotImplementedError("The run method should be implemented in a subclass.")