利用tree-sitter提取代码文件中的函数和注释

利用tree-sitter提取代码文件中的函数和注释

1. 需求

提取.c或.cpp文件中的带有注释的函数,作为训练数据喂给大语言模型。要求是能够批量处理,提取函数前带有注释的函数和注释,并将函数中的注释同样提取出来作为辅助训练数据,结果保存在JSON文件中。
真心建议大家如果要干这件事之前先了解一下要提取代码的格式

2. 工具

3. 实现

利用节点的type属性判断节点类型是否为function_definition,如果是函数定义的话,判断上一个节点类型是否是comment(如果前序多个连续的节点类型都为comment,则将所有的comment全部保存)。如果上一个节点类型是comment的话,先提取类型为function_definition节点的信息,即函数定义;再提取函数前的类型为comment的节点信息,即函数前注释;最后,深度优先遍历,提取函数体中的类型为comment的节点信息,即函数体中的注释。此外,能够批量处理文件夹中的所有.c或者.cpp文件。

from tree_sitter import Language, Parser
import json
import os
import re


# 加载C语言模块
Language.build_library(
    'build/my-languages.so',
    [
        'vendor/tree-sitter-c'
    ]
)


C_LANGUAGE = Language('build/my-languages.so', 'c')
parser = Parser()
parser.set_language(C_LANGUAGE)


# 提取代码信息
def extract_code_information(node, code):
    functions = []  # 存放最终的代码提取结果
    comment = ''  # 存放函数前的注释
    in_comment = ''  # 存放函数中的注释
    function = ''  # 存放函数
    for child in node.children:
        # 只保存函数前存在注释的函数及其注释
        if child.type == 'function_definition' and child.prev_sibling and child.prev_sibling.type == 'comment':
            # 首先处理函数
            function = extract_node_information(child, code)
            # 然后处理函数中的注释
            in_comment = traverse_children(child, code)
            # 最后处理函数前的注释
            temp_node = child.prev_sibling
            while temp_node.type == 'comment':
                comment += extract_node_information(temp_node, code)
                if temp_node.prev_sibling:
                    temp_node = temp_node.prev_sibling
                else:
                    break
            # 将函数和其注释保存到最终的结果中
            functions.append({
                'comment_before_function': comment,
                'comment_in_function': in_comment,
                'function': function
            })
            comment = ''
            in_comment = ''
            function = ''
    return functions


# 深度优先遍历节点的全部孩子节点
def traverse_children(node, code):
    if node is None:
        return ''
    comment = ''
    if node.type == 'comment':
        comment += extract_node_information(node, code)
    for child in node.children:
        comment += traverse_children(child, code)
    return comment


# 提取节点信息
def extract_node_information(node, code):
    try:
        start_row, start_col = node.start_point
        end_row, end_col = node.end_point
        # 将源代码按行进行拆分
        code_lines = code.split('\n')
        # 如果起始行和结束行在同一行
        if start_row == end_row:
            extracted_code = code_lines[start_row][start_col:end_col]
        else:
            # 提取起始行到结束行中的内容
            extracted_code = code_lines[start_row][start_col:]
            for i in range(start_row + 1, end_row):
                extracted_code += code_lines[i] + '\n'
            extracted_code += code_lines[end_row][:end_col]
        return extracted_code
    except AttributeError as e:
        return ''


# 查找文件夹中的.c和.cpp文件
def get_c_files(folder):
    c_files = []
    for root, dirs, files in os.walk(folder):
        for file in files:
            if re.search(r'\.c$|\.cpp$', file):
                c_files.append(os.path.join(root, file))
    return c_files


# 处理文件夹中的.c和.cpp文件
def pipeline(folder_path):
    c_files = get_c_files(folder_path)
    functions = []
    for c_file in c_files:
        print(c_file)
        temp = []
        try:
            try:
                with open(c_file, 'r', encoding='gbk') as file:
                    code = file.read()
                    tree = parser.parse(bytes(code, 'gbk'))
                    root_node = tree.root_node
                    temp = extract_code_information(root_node, code)
                    functions.append(temp)
            except UnicodeDecodeError as e:
                with open(c_file, 'r', encoding='utf8') as file:
                    code = file.read()
                    tree = parser.parse(bytes(code, 'utf8'))
                    root_node = tree.root_node
                    temp = extract_code_information(root_node, code)
                    functions.append(temp)
        except UnicodeDecodeError as e:
            print("UnicodeDecodeError!")
        # 将结果保存在functions.json中
        with open('functions.json', 'w', encoding='utf8') as json_file:
            json.dump(functions, json_file, indent=4, ensure_ascii=False)


if __name__ == '__main__':
    folder_path = '文件夹的绝对路径'
    pipeline(folder_path)

4. 测试代码 + 问题记录 + 修改代码

测试问题记录:

  • 批量处理文件夹中的文件时服务器报错 UnicodeEncoderError: 'utf-8' codec can't encode character '\udcd0' in position 7: surrogates not allowed,测试了一下,应该是文件名中包含中文以及中文字符导致的。而且文件的后缀只有两种情况.c.txt.h.txt
  • 忽略了函数前没有注释,但是与函数定义同一行有注释的情况。还好我们的书写格式还是比较规整的,基本除了函数前注释就是下述这种情况
void func( void )  /* function description */
{
	...
}
  • 忽略了#ifndef #else #endif#if defined() #endif之间函数无法识别的情况
/* 第一种情况 */
#if defined( __FLOAT__ )
	...
#endif
/* 第二种情况 */
#ifndef __***__
	...
#else
	...
#endif
  • 用于保存结果的json文件一直被覆盖
  • 如果文件不存在注释,会输出空list
  • 只声明了函数,但是函数体没有东西
  • 每个函数的最后一行加\n
    问题解决方案:
  • 首先文件的数量很大,所以不想处理.h.txt的文件,而对于.c.txt的文件,由于其中的中文以及中文字符导致服务器无法识别,所以就按照最简单的方式处理,将文件名按照12345的顺序重新命名,并把修改后的文件保存到另一个文件夹中。如果在运行中遇到什么问题的话,可以参考解决办法。如果执行权限不够的话,可以使用chmod +x rename.sh命令
#!/bin/bash
source_dir="./txtfromcfiles"
target_dir="./handledcfiles"
counter=1
for file in "$source_dir"/*.c.txt; do
    if [ -f "$file" ]; then
        new_name="${target_dir}/${counter}.c.txt"
        mv "$file" "$new_name"
        ((counter++))
    fi
done
  • 对于函数前的注释,同样把连续的注释提取出来;对于与函数定义在同一行的注释,由于函数前可能没有注释或者注释内容没有实际的意义,所以决定把这种注释和函数前的注释结合在一起,作为函数前的注释。此外,这种注释同样作为函数中的注释;对于其他情况,直接使用深度优先遍历,只要是函数节点则进行判断
from tree_sitter import Language, Parser
import json
import os
import re


# 加载C语言模块
Language.build_library(
    'build/my-languages.so',
    [
        'vendor/tree-sitter-c'
    ]
)


C_LANGUAGE = Language('build/my-languages.so', 'c')
parser = Parser()
parser.set_language(C_LANGUAGE)


# 判断函数是否应该被处理
def judge_function(node):
    # 如果函数前有注释,函数同行没有注释
    # 如果函数前有注释,函数同行有注释
    if node.prev_sibling and node.prev_sibling.type == 'comment':
        return 1
    # 如果函数前没有注释,函数同行有注释
    else:
        # 判断函数同行有没有注释
        for child in node.children:
            if child.type == 'comment':
                return 1
    # 如果函数前没有注释,函数同行没有注释
    return 0


# 提取函数和注释
def extract_comment_and_function(node, code):
    comment_before_function = ''
    # 提取函数前的注释
    temp_node = node.prev_sibling
    # 找到函数前最开始的注释
    while temp_node.prev_sibling and temp_node.prev_sibling.type == 'comment':
        temp_node = temp_node.prev_sibling
    while temp_node.type == 'comment':
        comment_before_function += extract_node_information(temp_node, code)
        temp_node = temp_node.next_sibling
    for child in node.children:
        if child.type == 'comment':
            comment_before_function += extract_node_information(child, code)
    # 提取函数中的注释
    comment_in_function = traverse_children(node, code)
    # 提取函数
    function = extract_node_information(node, code) + '\n'
    return comment_before_function, comment_in_function, function


# 提取代码信息
def extract_code_information(node, code):
    result = []
    # 非递归的深度优先遍历
    if node is None:
        return
    stack = [node]
    while stack:
        node = stack.pop()
        if node.type == 'function_definition' and judge_function(node):
            comment_before_function, comment_in_function, function = extract_comment_and_function(node, code)
            # 处理结果
            result.append({
                'comment_before_function': comment_before_function,
                'comment_in_function': comment_in_function,
                'function': function
            })
        for child in reversed(node.children):
            stack.append(child)
    return result


# 深度优先遍历节点的全部孩子节点
def traverse_children(node, code):
    if node is None:
        return ''
    comment = ''
    if node.type == 'comment':
        comment += extract_node_information(node, code)
    for child in node.children:
        comment += traverse_children(child, code)
    return comment


# 提取节点信息
def extract_node_information(node, code):
    try:
        start_row, start_col = node.start_point
        end_row, end_col = node.end_point
        # 将源代码按行进行拆分
        code_lines = code.split('\n')
        # 如果起始行和结束行在同一行
        if start_row == end_row:
            extracted_code = code_lines[start_row][start_col:end_col]
        else:
            # 提取起始行到结束行中的内容
            extracted_code = code_lines[start_row][start_col:]
            for i in range(start_row + 1, end_row):
                extracted_code += code_lines[i] + '\n'
            extracted_code += code_lines[end_row][:end_col]
        return extracted_code
    except AttributeError as e:
        return ''


# 查找文件夹中的.c和.cpp文件
def get_c_files(folder):
    c_files = []
    for root, dirs, files in os.walk(folder):
        for file in files:
            if re.search(r'\.c.txt$', file):
                c_files.append(os.path.join(root, file))
    return c_files


# 处理文件夹中的.c和.cpp文件
def pipeline(folder_path):
    c_files = get_c_files(folder_path)
    functions = []
    count = 0
    for c_file in c_files:
        count += 1
        print(str(count) + ": " + c_file)
        try:
            with open(c_file, 'r', encoding='utf8') as file:
                code = file.read()
                tree = parser.parse(bytes(code, 'utf8'))
                root_node = tree.root_node
                functions.extend(extract_code_information(root_node, code))
        except UnicodeDecodeError as e:
            print("UnicodeDecodeError!")
    # 将结果保存在functions.json中
    with open('functions.json', 'w', encoding='utf8') as json_file:
        json.dump(functions, json_file, indent=4, ensure_ascii=False)


if __name__ == '__main__':
    folder_path = 'C:\\Users\\86139\\Desktop\\pythonProject3\\venv'
    pipeline(folder_path)

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值