利用tree-sitter提取代码文件中的函数和注释
1. 需求
提取.c或.cpp文件中的带有注释的函数,作为训练数据喂给大语言模型。要求是能够批量处理,提取函数前带有注释的函数和注释,并将函数中的注释同样提取出来作为辅助训练数据,结果保存在JSON文件中。
真心建议大家如果要干这件事之前先了解一下要提取代码的格式
2. 工具
- tree-sitter
如何配置tree-sitter的使用环境 - pycharm
如何将conda的虚拟python环境添加到pycharm中
3. 实现
利用节点的type
属性判断节点类型是否为function_definition
,如果是函数定义的话,判断上一个节点类型是否是comment
(如果前序多个连续的节点类型都为comment
,则将所有的comment
全部保存)。如果上一个节点类型是comment
的话,先提取类型为function_definition
节点的信息,即函数定义;再提取函数前的类型为comment
的节点信息,即函数前注释;最后,深度优先遍历,提取函数体中的类型为comment
的节点信息,即函数体中的注释。此外,能够批量处理文件夹中的所有.c或者.cpp文件。
from tree_sitter import Language, Parser
import json
import os
import re
# 加载C语言模块
Language.build_library(
'build/my-languages.so',
[
'vendor/tree-sitter-c'
]
)
C_LANGUAGE = Language('build/my-languages.so', 'c')
parser = Parser()
parser.set_language(C_LANGUAGE)
# 提取代码信息
def extract_code_information(node, code):
functions = [] # 存放最终的代码提取结果
comment = '' # 存放函数前的注释
in_comment = '' # 存放函数中的注释
function = '' # 存放函数
for child in node.children:
# 只保存函数前存在注释的函数及其注释
if child.type == 'function_definition' and child.prev_sibling and child.prev_sibling.type == 'comment':
# 首先处理函数
function = extract_node_information(child, code)
# 然后处理函数中的注释
in_comment = traverse_children(child, code)
# 最后处理函数前的注释
temp_node = child.prev_sibling
while temp_node.type == 'comment':
comment += extract_node_information(temp_node, code)
if temp_node.prev_sibling:
temp_node = temp_node.prev_sibling
else:
break
# 将函数和其注释保存到最终的结果中
functions.append({
'comment_before_function': comment,
'comment_in_function': in_comment,
'function': function
})
comment = ''
in_comment = ''
function = ''
return functions
# 深度优先遍历节点的全部孩子节点
def traverse_children(node, code):
if node is None:
return ''
comment = ''
if node.type == 'comment':
comment += extract_node_information(node, code)
for child in node.children:
comment += traverse_children(child, code)
return comment
# 提取节点信息
def extract_node_information(node, code):
try:
start_row, start_col = node.start_point
end_row, end_col = node.end_point
# 将源代码按行进行拆分
code_lines = code.split('\n')
# 如果起始行和结束行在同一行
if start_row == end_row:
extracted_code = code_lines[start_row][start_col:end_col]
else:
# 提取起始行到结束行中的内容
extracted_code = code_lines[start_row][start_col:]
for i in range(start_row + 1, end_row):
extracted_code += code_lines[i] + '\n'
extracted_code += code_lines[end_row][:end_col]
return extracted_code
except AttributeError as e:
return ''
# 查找文件夹中的.c和.cpp文件
def get_c_files(folder):
c_files = []
for root, dirs, files in os.walk(folder):
for file in files:
if re.search(r'\.c$|\.cpp$', file):
c_files.append(os.path.join(root, file))
return c_files
# 处理文件夹中的.c和.cpp文件
def pipeline(folder_path):
c_files = get_c_files(folder_path)
functions = []
for c_file in c_files:
print(c_file)
temp = []
try:
try:
with open(c_file, 'r', encoding='gbk') as file:
code = file.read()
tree = parser.parse(bytes(code, 'gbk'))
root_node = tree.root_node
temp = extract_code_information(root_node, code)
functions.append(temp)
except UnicodeDecodeError as e:
with open(c_file, 'r', encoding='utf8') as file:
code = file.read()
tree = parser.parse(bytes(code, 'utf8'))
root_node = tree.root_node
temp = extract_code_information(root_node, code)
functions.append(temp)
except UnicodeDecodeError as e:
print("UnicodeDecodeError!")
# 将结果保存在functions.json中
with open('functions.json', 'w', encoding='utf8') as json_file:
json.dump(functions, json_file, indent=4, ensure_ascii=False)
if __name__ == '__main__':
folder_path = '文件夹的绝对路径'
pipeline(folder_path)
4. 测试代码 + 问题记录 + 修改代码
测试问题记录:
- 批量处理文件夹中的文件时服务器报错
UnicodeEncoderError: 'utf-8' codec can't encode character '\udcd0' in position 7: surrogates not allowed
,测试了一下,应该是文件名中包含中文以及中文字符导致的。而且文件的后缀只有两种情况.c.txt
和.h.txt
- 忽略了函数前没有注释,但是与函数定义同一行有注释的情况。还好我们的书写格式还是比较规整的,基本除了函数前注释就是下述这种情况
void func( void ) /* function description */
{
...
}
- 忽略了
#ifndef #else #endif
和#if defined() #endif
之间函数无法识别的情况
/* 第一种情况 */
#if defined( __FLOAT__ )
...
#endif
/* 第二种情况 */
#ifndef __***__
...
#else
...
#endif
- 用于保存结果的json文件一直被覆盖
- 如果文件不存在注释,会输出空
list
- 只声明了函数,但是函数体没有东西
- 每个函数的最后一行加
\n
问题解决方案: - 首先文件的数量很大,所以不想处理
.h.txt
的文件,而对于.c.txt
的文件,由于其中的中文以及中文字符导致服务器无法识别,所以就按照最简单的方式处理,将文件名按照12345的顺序重新命名,并把修改后的文件保存到另一个文件夹中。如果在运行中遇到什么问题的话,可以参考解决办法。如果执行权限不够的话,可以使用chmod +x rename.sh
命令
#!/bin/bash
source_dir="./txtfromcfiles"
target_dir="./handledcfiles"
counter=1
for file in "$source_dir"/*.c.txt; do
if [ -f "$file" ]; then
new_name="${target_dir}/${counter}.c.txt"
mv "$file" "$new_name"
((counter++))
fi
done
- 对于函数前的注释,同样把连续的注释提取出来;对于与函数定义在同一行的注释,由于函数前可能没有注释或者注释内容没有实际的意义,所以决定把这种注释和函数前的注释结合在一起,作为函数前的注释。此外,这种注释同样作为函数中的注释;对于其他情况,直接使用深度优先遍历,只要是函数节点则进行判断
from tree_sitter import Language, Parser
import json
import os
import re
# 加载C语言模块
Language.build_library(
'build/my-languages.so',
[
'vendor/tree-sitter-c'
]
)
C_LANGUAGE = Language('build/my-languages.so', 'c')
parser = Parser()
parser.set_language(C_LANGUAGE)
# 判断函数是否应该被处理
def judge_function(node):
# 如果函数前有注释,函数同行没有注释
# 如果函数前有注释,函数同行有注释
if node.prev_sibling and node.prev_sibling.type == 'comment':
return 1
# 如果函数前没有注释,函数同行有注释
else:
# 判断函数同行有没有注释
for child in node.children:
if child.type == 'comment':
return 1
# 如果函数前没有注释,函数同行没有注释
return 0
# 提取函数和注释
def extract_comment_and_function(node, code):
comment_before_function = ''
# 提取函数前的注释
temp_node = node.prev_sibling
# 找到函数前最开始的注释
while temp_node.prev_sibling and temp_node.prev_sibling.type == 'comment':
temp_node = temp_node.prev_sibling
while temp_node.type == 'comment':
comment_before_function += extract_node_information(temp_node, code)
temp_node = temp_node.next_sibling
for child in node.children:
if child.type == 'comment':
comment_before_function += extract_node_information(child, code)
# 提取函数中的注释
comment_in_function = traverse_children(node, code)
# 提取函数
function = extract_node_information(node, code) + '\n'
return comment_before_function, comment_in_function, function
# 提取代码信息
def extract_code_information(node, code):
result = []
# 非递归的深度优先遍历
if node is None:
return
stack = [node]
while stack:
node = stack.pop()
if node.type == 'function_definition' and judge_function(node):
comment_before_function, comment_in_function, function = extract_comment_and_function(node, code)
# 处理结果
result.append({
'comment_before_function': comment_before_function,
'comment_in_function': comment_in_function,
'function': function
})
for child in reversed(node.children):
stack.append(child)
return result
# 深度优先遍历节点的全部孩子节点
def traverse_children(node, code):
if node is None:
return ''
comment = ''
if node.type == 'comment':
comment += extract_node_information(node, code)
for child in node.children:
comment += traverse_children(child, code)
return comment
# 提取节点信息
def extract_node_information(node, code):
try:
start_row, start_col = node.start_point
end_row, end_col = node.end_point
# 将源代码按行进行拆分
code_lines = code.split('\n')
# 如果起始行和结束行在同一行
if start_row == end_row:
extracted_code = code_lines[start_row][start_col:end_col]
else:
# 提取起始行到结束行中的内容
extracted_code = code_lines[start_row][start_col:]
for i in range(start_row + 1, end_row):
extracted_code += code_lines[i] + '\n'
extracted_code += code_lines[end_row][:end_col]
return extracted_code
except AttributeError as e:
return ''
# 查找文件夹中的.c和.cpp文件
def get_c_files(folder):
c_files = []
for root, dirs, files in os.walk(folder):
for file in files:
if re.search(r'\.c.txt$', file):
c_files.append(os.path.join(root, file))
return c_files
# 处理文件夹中的.c和.cpp文件
def pipeline(folder_path):
c_files = get_c_files(folder_path)
functions = []
count = 0
for c_file in c_files:
count += 1
print(str(count) + ": " + c_file)
try:
with open(c_file, 'r', encoding='utf8') as file:
code = file.read()
tree = parser.parse(bytes(code, 'utf8'))
root_node = tree.root_node
functions.extend(extract_code_information(root_node, code))
except UnicodeDecodeError as e:
print("UnicodeDecodeError!")
# 将结果保存在functions.json中
with open('functions.json', 'w', encoding='utf8') as json_file:
json.dump(functions, json_file, indent=4, ensure_ascii=False)
if __name__ == '__main__':
folder_path = 'C:\\Users\\86139\\Desktop\\pythonProject3\\venv'
pipeline(folder_path)