如何在python中使用tree sitter获得代码片段

最近一年多一直在做大模型代码补全方面的工作,很多东西都是从零学起,也抽空分享一下小白在学习的路上偶尔捕获的碎片。

大模型爆火之后,很多企业都在尝试将大模型落地,尤其是以Copilot为标杆的“助手”概念率先展示了惊艳的落地感以后,大部分团队都开始研发自己的“助手”——一种“问什么答什么,甚至帮你做一些”的角色。这一概念被各个大小团队结合到自己的产品之中,搜索的助手、购物的助手、编程的助手…人们逐渐发现,要当好一个助手,或者说能给风格迥异的用户中任何一位当好助手,他必须具备应对各种问题的“知识”。目前,这种知识的灌输主要有两个思路,一种是训练(靠脑子记住),一种是外挂知识库(靠翻参考书)。

随着开源大模型能够允许的窗口越来越大,不仅仅是对话的轮次能够变长(说实话,如果较少的轮次能够解决问题,没人会更喜欢一直沟通),更重要的是,RAG(retrieval augmented generation)路线中模型能够输入的信息片段越来越多。一方面现在都追求落地,在已有的产品上落地,运用过去已经积累的数据,或让AI来使用自己过去形成的经验,把人在解决某个问题需要的知识放在AI面前,让AI做选择、组合、创造,比让AI自己憋出一个答案要简单,是非常直观的想法;另一方面是RAG的可控感比训练要强的多,大部分我这样的小白对于这种超大模型数十百亿神经元的训练,用一种“历史是螺旋上升”的感觉来说,和早期的机器学习“炼丹”没什么不同。

闲话少说,在代码补全方面,我们也很直观地会想,有的用户的代码是相似的,如果有别人的代码做参考,AI就能更接近正确答案;有的函数的实现、类的定义对于当前生成位置而言是必须的,如果能补充给AI看到,那么AI就更可能使用正确的属性、填入正确的参数、使用正确的方法。而代码的片段,通常有两种获取思路,一种是暴力切割,几行一切或者多少个字符一切,这样做的缺点是,这样得到的片段不一定都是语义完整的,多个片段拼在一起不通顺,和AI在预训练阶段见到的数据大概率不同,如果自己额外用新的片段式输入SFT,可能会浪费预训练模型的超强语义理解能力。而另一种方法,就是以代码解析器能够解析的一个单元进行划分,这样保证每个片段都是语义完整的,可以轻松地用注释或者什么将它们连接在一起,构成的新的代码全文甚至可以再次被代码解析器解析,这更加符合“代码”的本意。

而这个获得代码片段的环节,离不开tree-sitter这类代码解析工具的使用,我在初接触时因为也有不少困惑,因此分享一些学习笔记,帮类似情况的小伙伴减轻痛苦。

什么是tree-sitter & 为什么选择 tree-sitter

tree-sitter是一个C语言编写的代码解析工具,简单来说就是把一个代码片段,解析成一个树结构,通过父节点、子节点的层级关系来展现代码的层级关系,比如一个函数体 -> 函数体中的一句变量声明,通过子节点、兄弟节点的同级关系来展现代码中并列同一层级的结构。因此,我们可以从tree-sitter解析出的树上(Abstract Syntax Tree简称AST)取某个节点,它就是一个完整的片段,比如取一个花括号包裹的函数体、取一个if语句、取一个while循环…,这比基于正则表达式或其他字符串处理策略去获取函数片段(在代码生成RAG的方案中经常看到的code snippet)会方便得多。

tree-sitter的支持非常广,基本上可以解析任何编程语言,并且也非常快,即便对于有语法错误的代码片段也能完成解析,鲁棒性高,因为是纯C语言编写的,基本也没有什么依赖,所以一直是很多RAG代码生成方案所选择的代码片段化工具,比如TabbyML(一个开源的对标copilot的插件产品,能够和开源的代码生成模型结合使用)

tree-sitter如何使用

因为tree-sitter是C编写的,其他语言比如Python通常会使用python bindings的方式使用这样的C语言工具包中的函数(language bindings可以理解为,有一些这个语言内置的库,可以把python里的数据结构转换为目标工具库语言的数据结构,这样就可以传参传进那些函数里),同时,这样的其它语言的库使用时要load进python程序中,通常会通过.so文件来完成导入,后面会讲到。

1. 基本配置

首先需要获得你要解析的目标语言的“语法知识”,比如python需要python对应的知识、javascript需要javascript对应的知识。获得的途径有两个(建议去github搜 {语言} tree-sitter 就可以看到readme里的教程),以python为例

  • 第一种是直接通过pip install (推荐)

    pip install tree-sitter, tree-sitter-python
    
    import tree_sitter_python as tspython
    from tree_sitter import Language, Parser
    
    PY_LANGUAGE = Language(tspython.language(), "python")
    

但是这个方法我实测还是需要.so文件的,目前看来应该是py-tree-sitter这个包想放弃build_library这个帮助编译其他语言的.so文件的函数,让用户自行在其他地方编译。

  • 第二种是通过源文件编译出.so文件

    首先保证自己的电脑上装有C编译器(一般windows下载了 Visual Studio就okay),在github上找到对应知识的实现,拉取

    git clone https://github.com/tree-sitter/tree-sitter-python
    

    接下来写一个小脚本,对它进行编译,获得的.so文件就是解析目标语言所需的“知识”了

    from tree_sitter import Language, Parser
    
    Language.build_library(
        # Store the library in the `build` directory
        "build/my-languages.so",
        # Include one or more languages
        ["vendor/tree-sitter-go", "vendor/tree-sitter-javascript", "vendor/tree-sitter-python"],
    )
    

    这个.so文件的位置要记住,每次在代码里使用都要先加载它

    PY_LANGUAGE = Language("build/my-languages.so", "python")

但是这个build_library函数马上就会被去除了,我们可能要自行解决编译的问题,核心目的就是要得到一个动态库。
这个取决于你在什么地方使用这个库,比如你在windows上用,用Visual Studio编译一下(可能会得到.dll文件,相当于Linux上的.so文件,也可以直接编译成.so文件,这个需要在visual studio里面配置一下,网上有很多教程,不再赘述);如果在Linux上用,直接用gcc应该也是可以的,自行参考其他教程吧。
目前对于只会python的使用者来说,上面这个函数还是挺方便的

2. parser的使用

完成配置后,就可以把这个“知识”装载到parser中,接着就可以用parser来解析代码了。

parser = Parser()
parser.set_language(PY_LANGUAGE)

通常来说,我们掌握解析字符串类型的代码内容就够了,如果有其他的数据类型的代码,也可以传递一个read的方法到parser的parse函数里(具体参考官方github的官方仓库吧)。注意,parser是基于bytes工作的,所以我们需要先把字符串类型转换为bytes,之后node内容的解析等都要记住在bytes上操作

code = f'{所有代码内容}'
src = bytes(code, 'utf8')
tree = parser.parse(src)

这样就很轻易地获得解析完的语法树了,是一个形如这样的结构,子节点我用缩进表示,兄弟节点就是同一层缩进的

program
	comment
	export statement
	...
	

AST的使用场景

获得这样一棵语法树之后,通常会有以下几种使用场景

1. 看看树里的每个node是什么,以及node在代码原文中对应的片段是什么

这里要注意两个概念

  • node type:解析代码的时候实际上会对各种代码片段归类,比如说“statement block”这种node type,通常表示一个由花括号包裹的片段,不管是if,while还是函数体的花括号,里面可能都会包含一个子节点,它的node type都是"statement block"。
  • node text:每个node都会对应代码原文中的一段,父节点对应的片段会包含子节点对应的片段。注意,parser是在bytes上运作的,所以我们想要还原字符串的时候,还要从bytes还原回来:src_bytes[cursor.node.start_byte: cursor.node.end_byte].decode('utf-8')

举一个遍历语法树,打印node type 和 node text的例子

def display_tree(src_bytes, cursor, ident=4, node_text=True):
    '''
    display syntax node of the tree and its corresponding text

    Args:
        src_bytes: input bytes of the source code
        cursor
        ident: use to format the display info
        node_text: whether to display the text of the node
    '''
    if node_text:
        print(' ' * ident + cursor.node.type + ' [' + repr(
            src_bytes[cursor.node.start_byte: cursor.node.end_byte].decode('utf-8')) + ']')
    else:
        print(' ' * ident + cursor.node.type)

    if cursor.goto_first_child():
        display_tree(src_bytes, cursor, ident + 4)
        cursor.goto_parent()

    while cursor.goto_next_sibling():
        display_tree(src_bytes, cursor, ident)

cursor = tree.walk()
display_tree(src, cursor)

注意,遍历语法树使用到的cursor,是一个引用,在调用cursor.goto_parent()等方法时,会修改其指向的值,这一点在递归的时候要加以注意,在递归内部修改,也会影响到外部的cursor

此外,你可能会注意到,这个例子中使用了很多node的方法、cursor的方法,但是py-tree-sitter的github readme并没有给出很清晰的解说,你可以从这里看到更详细的,node / cursor 有哪些可供使用的方法、可供访问的属性
在这里插入图片描述

2. 使用query直接捕获树中的某一类节点

另一个需要掌握的使用场景是,不通过遍历的方式,直接捕获树上所有某种节点,比如我要拿到所有的函数调用、变量声明的代码片段。使用query时要注意,query是基于sexp来捕获的,它和我们上面遍历出来的各个node会有一点出入。sexp是形如 (program (comment) (export_statement declaration: (function_declaration name: (identifier)...)))的一种字符串表示,用括号来表示层级,附带一些形如declaration: 的子节点标识,这种标识在我们上面的遍历并打印node type的过程中一般是看不见的

因此,推荐想要使用query时,先使用print(tree.root_node.sexp()),看一下整体的情况,比如得到一个如下的内容在这里插入图片描述

再找到自己感兴趣的部分,这样使用query

query = TS_LANGUAGE.query("(lexical_declaration) @lexical")

# 捕获更多节点时可以这样写:这里捕获了name标识的子节点identifier和lexical_declaration这个节点
query = TS_LANGUAGE.query("(lexical_declaration (variable_declarator name: (identifier) @name type: (type_annotation (type_identifier)))) @lexical")

captures = query.captures(tree.root_node)
# 返回一个list,[(Node -> 对应上面的lexical_declaration, @的名字 -> 对应上面的lexical)]

这里有几个坑要注意:

  • query里的@是给捕获的节点起一个“名字”,必须在query中使用了@标识了的括号(每个括号就是一个节点)才会捕获,比如第二个query中,lexical_declaration和它的子节点identifier都会被捕获,但是同样写在query里但是没有标识@取名字的type_annotaion就不会被捕获,未被捕获的节点不会出现在captures结果中
  • 可能直觉上会认为,如果我在一个query里面捕获三个节点,那么如果代码原文中有N个符合这个条件的捕获,它们应该在captures结果中三个三个一组,像这样[[(Node1, @1), (Node2, @2), (Node3, @3)], [第二处捕获...], ...]。实际上不是的,capture不会进行group,它们会以同级的方式被捕获 [(第一处捕获的Node1), Node2, Node3, 第二处捕获的Node1, Node2, Node3, ...]所以我们需要@的名字来辅助判断

获得了node之后,我们又可以再次使用src_bytes[capture_node.start_byte: capture_node.end_byte].decode('utf-8')来获得它们对应的代码片段了

总结

掌握本文中的内容后,基本上就可以应付大部分的代码片段获取场景了,至少我在大模型代码生成任务中获取代码片段的环节,遇到卡壳的地方大致就是这些:

  • tree-sitter的配置和加载导入
  • 使用parser进行解析,并使用cursor在tree上遍历
  • 使用query在tree上捕获node

对本文有任何疑问、发现任何错误或有任何建议,还望评论区或私信多多指教!谢谢!

  • 27
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MetLightt

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值