CodeFuse-MFTCoder 提升 CodeGeeX2-6B 代码能力

CodeFuse-MFTCoder 项目地址:
GitHub - codefuse-ai/MFTCoder: High Accuracy and efficiency multi-task fine-tuning framework for Code LLMs

CodeFuse-Codegeex2-6B 模型地址:
https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeeX2-6B https://huggingface.co/codefuse-ai/CodeFuse-CodeGeeX2-6B
 

一、CodeGeeX2-6B底座代码能力总览

CodeGeeX2-6B 是由智普AI开源的代码大模型。它是在自然语言大模型ChatGLM2-6B的基础上,将GLM中双向attention的部分变成单向以后(该结论由笔者分析CodeGeeX2-6B GitHub issue讨论得出),加入大量代码相关数据进行了Causal Language Model的加训,最终获取的代码大模型。

相比于上一个版本的CodeGeeX-13B,尽管模型变小了,但在多项指标上取得了性能的提升。相比于之前开源界最好的代码大模型Starcoder-15B,也有10%的性能提升。详细的HumanEval-X数据集中五种语言的评测结果如下(摘自官方github repo):

Model

Python

C++

Java

JavaScript

Go

Overall

CodeGen-16B-multi

19.2

18.1

15.0

18.4

13.0

CodeGeeX-13B

22.9

17.1

20.0

17.6

14.4

Replit-code-v1-3B

22.0

20.1

20.1

20.1

12.2

CodeGen2.5-7B-multi

30.6

24.3

29.0

27.5

18.9

StarCoder-15B

35.5

28.2

31.5

33.2

21.3

CodeGeeX2-6B

35.9

29.3

30.8

32.2

22.5


CodeGeeX2的成功,反应出在自然语言大模型底座上,进一步加训代码,也许比起从零开始训练一个代码模型,能够取得更好的结果。毕竟从课程学习(Curriculum Learning)的角度,学好中英文,再学习代码,会更加简单且能够学得更好。在CodeGeeX2之后,CodeLLama也用类似的思路,再一次刷新了开源模型在HumanEvalX榜单上的成绩。
 

二、CodeFuse-MFTCoder 多任务微调CodeGeeX2-6B

借助CodeFuse-MFTCoder的多任务微调能力,我们可以使用多个代码任务数据集对CodeGeex2-6B进行多任务微调(MFT)。由于CodeGeex2已经适合单向的Causal Language Model的模式,因此训练采用和GPT模型一样的微调形式。在任务选择上,我们精选了3个核心代码任务数据,即代码补全(Code Completion),代码生成(Text2Code), 单测生成(Unittest Generation)一共60w条指令问答数据。该数据组合包含代码生成的三个基础任务,用基础任务微调对齐过的模型,在各类未训练过的代码任务上也有不错的泛化能力。

由于CodeGeex2-6B参数量不大,训练采用MFTCoder的多任务LoRA微调模式而不是QLoRA,且代码任务属于相对复杂任务,我们对更多的模块进行微调,包括Attention和MLP,相应的配置如下:

对以上数据进行了约2个Epoch的训练到收敛。训练过程loss情况如下图所示:


通过多任务微调,CodeFuse-CodeGeex2-6B的各方面代码能力均有比较大的提升。
 

三、CodeFuse-CodeGeeX2-6B模型效果

接下来看一些CodeGeeX2-6B-MFT微调后的模型的生成效果。在代码补全/代码生成/代码修复/SQL生成/单元测试生成等代码任务上有较好表现,更丰富的生成样例,请看下面详细展示。
MFT后的CodeGeex2-6b有更好的代码生成能力。在五种编程语言的代码补全测试集HumanEval-x上进行了相关评测(见表2),测试结果显示与CodeGen-16B、CodeGeeX-13B、StarCoder-15B等模型相比,微调后的CodeFuse-CodeGeex2-6B在Java/Python/Cpp/JavaScript均是Top1,相对于底座平均提升8%+,在HumanEval-Python和HumanEval-Java上pass@1均能达到45%+,在同系列的6B代码大模型上表现突出。

表2:CodeFuse-CodeGeex2-6B和其他模型代码补全能力对比

Model

Python

C++

Java

JavaScript

Go

Overall

CodeFuse-CodeGeeX2-6B

45.12

37.2

45.73

37.2

28.05

38.7

CodeGen-16B-multi

19.2

18.1

15.0

18.4

13.0

16.7

CodeGeeX-13B

22.9

17.1

20.0

17.6

14.4

18.4

Replit-code-v1-3B

22.0

20.1

20.1

20.1

12.2

18.9

CodeGen2.5-7B-multi

30.6

24.3

29.0

27.5

18.9

26.1

StarCoder-15B

35.5

28.2

31.5

33.2

21.3

29.9

CodeGeeX2-6B

35.9

29.3

30.8

32.2

22.5

30.1


代码生成(Code Completion / Text2Code)

以下是多语言代码补全 & Text2Code任务上的具体样例:
代码补全(Code Completion )

Python

Question:

from typing import List, Tuple def rolling_max(numbers: List[int]) -> List[int]: 
    """ From a given list of integers, generate a list of rolling maximum element 
    found until given moment in the sequence. 
    >>> rolling_max([1, 2, 3, 2, 3, 4, 2]) [1, 2, 3, 3, 3, 4, 4] """

Answer:

    max_numbers = [] 
    for i in range(len(numbers)): 
        if i == 0: 
            max_numbers.append(numbers[i]) 
        else: 
            max_numbers.append(max(numbers[i], max_numbers[i-1])) 
    return max_numbers 

Question:

    from typing import List 
def parse_music(music_string: str) -> List[int]: 
    """ Input to this function is a string representing musical notes in a special ASCII format. Your task is to parse this string and return list of integers corresponding to how many beats does each not last. Here is a legend: 'o' - whole note, lasts four beats 'o|' - half note, lasts two beats '.|' - quater note, lasts one beat >>> parse_music('o o| .| o| o| .| .| .| .| o o') [4, 2, 1, 2, 2, 1, 1, 1, 1, 4, 4] """ 

Answer:

    notes = music_string.split() 
    parsed_notes = [] 
    for note in notes: 
        if note == 'o': 
            parsed_notes.append(4) 
        elif note == 'o|': 
            parsed_notes.append(2) 
        elif note == '.|': 
            parsed_notes.append(1) 
    return parsed_notes 

C++

Question:

/* Input are two strings a and b consisting only of 1s and 0s. Perform binary XOR on these inputs and return result also as a string. >>> string_xor("010", "110") "100" */ 
#include<stdio.h> #include<string> using namespace std; 
string string_xor(string a,string b){ 

Answer:

int n = a.size(); 
string c = ""; 
for(int i=0;i<n;i++){ 
    if(a[i]==b[i]){ 
        c+='0'; 
    } 
    else{ 
        c+='1'; 
    } 
} 
return c; 
} 

Question:

/* Given a string 'text", return its md5 hash equivalent string. If 'text" is an empty string, return None. >>> string_to_md5("Hello world") == "3e25960a79dbc69b674cd4ec67a72c62" */ 
#include<stdio.h> #include<string> 
#include<openssl/md5.h> 
using namespace std; 
string string_to_md5(string text){ 

Answer:

if(text.empty()){ 
    return "None"; 
} 
MD5_CTX md5; 
MD5_Init(&md5);
MD5_Update(&md5, text.c_str(), text.length());
unsigned char md5_hash[MD5_DIGEST_LENGTH]; MD5_Final(md5_hash, &md5); 
char md5_hash_str[MD5_DIGEST_LENGTH*2+1]; 
for(int i=0; i<MD5_DIGEST_LENGTH; i++){
    sprintf(md5_hash_str+i*2, "%02x", md5_hash[i]); 
} 
return md5_hash_str; 
}

Java

Question:

import java.util.*; 
import java.lang.*; 
class Solution { /** From a given list of integers, generate a list of rolling maximum element found until given moment in the sequence. >>> rollingMax(Arrays.asList(1, 2, 3, 2, 3, 4, 2)) [1, 2, 3, 3, 3, 4, 4] */ 
    public List<Integer> rollingMax(List<Integer> numbers) { 

Answer:

    List<Integer> result = new ArrayList<>(); 
    int max = Integer.MIN_VALUE; 
    for (int i = 0; i < numbers.size(); i++) { 
        if (numbers.get(i) > max) { 
            max = numbers.get(i); 
        } 
        result.add(max); 
    } 
    return result; } 
} 

Question:

import java.util.*; 
import java.lang.*; 
class Solution { 
    /** Given a string, find out how many distinct characters (regardless of case) does it consist of >>> countDistinctCharacters("xyzXYZ") 3 >>> countDistinctCharacters("Jerry") 4 */ 
    public int countDistinctCharacters(String string) { 

Answer:

    Set<Character> distinctChars = new HashSet<>(); 
    for (char c : string.toCharArray()) { 
        distinctChars.add(Character.toUpperCase(c)); 
    } 
    return distinctChars.size(); } 
} 

JavaScript

Question:

/* For a given list of integers, return a tuple consisting of a sum and a product of all the integers in a list. Empty sum should be equal to 0 and empty product should be equal to 1. >>> sumProduct([]) (0, 1) >>> sumProduct([1, 2, 3, 4]) (10, 24) */ 
const sumProduct = (numbers, int) => { 
  

Answer:

let sum = 0; 
  let product = 1; 
  for (let i = 0; i < numbers.length; i++) { 
    sum += numbers[i]; product *= numbers[i]; 
  } 
  return [sum, product]; 
}

Question:

/*Given a list of positive integers x. return a sorted list of all elements that hasn't any even digit. Note: Returned list should be sorted in increasing order. For example: >>> uniqueDigits([15, 33, 1422, 1]) [1, 15, 33] >>> uniqueDigits([152, 323, 1422, 10]) [] */ 
const uniqueDigits = (x) => { 

Answer:

let result = []; 
for (let i = 0; i < x.length; i++) { 
  let str = x[i].toString(); 
  let isEven = false; 
  for (let j = 0; j < str.length; j++) { 
    if (Number(str[j]) % 2 === 0) { 
      isEven = true; break; } 
  } 
  if (!isEven) { 
    result.push(x[i]); } 
} 
return result.sort((a, b) => a - b); 
}; 

Go

Question:

import ( "fmt" ) 
// Input are two strings a and b consisting only of 1s and 0s. // Perform binary XOR on these inputs and return result also as a string. // >>> StringXor('010', '110') // '100' func StringXor(a string, b string) string { 

Answer:

result := "" 
for i := 0; i < len(a); i++ { 
     if a[i] != b[i] { 
         result += "1" 
     } 
     else { 
         result += "0" } 
 } 
return result 
} 

Question:

import ( "crypto/md5" "fmt" ) 
// Given a string 'text', return its md5 hash equivalent string. // If 'text' is an empty string, return nil. // 
// >>> StringToMd5('Hello world') == '3e25960a79dbc69b674cd4ec67a72c62' func StringToMd5(text string) interface{} { 

Answer:

 if text == "" { 
     return nil 
 } 
hash := md5.Sum([]byte(text)) 
return fmt.Sprintf("%x", hash) 
}

Text2Code 测试

摘录一些Python测试cases,模型不论在中文还是英文生成代码指令表现都比较优秀,摘录部分cases如下

Python-中文

Question:

"编写一个python函数,统计A出现的个数"

Answer:

def count_A(s):
    return s.count('A')

Python-英文

Question:

"Write a python function to remove first and last occurrence of a given character from the string."

Answer:

def remove_Occ(s, c): if c in s: return s.replace(c, '') else: return s 

You can test this function with the following code

assert remove_Occ("hello","l") == "heo" assert remove_Occ("abcda","a") == "bcd" assert remove_Occ("PHP","P") == "H"

In the first test, "l" is the first and last character in "hello", so the function removes "l" and returns "heo". In the second test, "a" is not the first or last character in "abcda", so the function returns "abcda". In the third test, "P" is the first and last character in "PHP", so the function removes "P" and returns "H".

单元测试生成(Unittest Generation)

模型可以根据prompt生成函数对应的测试单元,以下是生成单元测试的例子,这些单测是经过人工验证通过的cases:

Python单元测试生成

Question:

为以下Python代码生成单元测试
def words_in_sentence(sentence):
    new_lst = []
    for word in sentence.split():
        flg = 0
        if len(word) == 1:
            flg = 1
        for i in range(2, len(word)):
            if len(word)%i == 0:
                flg = 1
        if flg == 0 or len(word) == 2:
            new_lst.append(word)
    return " ".join(new_lst)

Answer:

import unittest

class TestWordsInSentence(unittest.TestCase):
    
    def test_empty_sentence(self):
        # 测试空句子
        sentence = ""
        expected_output = ""
        self.assertEqual(words_in_sentence(sentence), expected_output)
    
    def test_single_word(self):
        # 测试只有一个单词的句子
        sentence = "hello"
        expected_output = "hello"
        self.assertEqual(words_in_sentence(sentence), expected_output)
    
    def test_multiple_words(self):
        # 测试多个单词的句子
        sentence = "hello world"
        expected_output = "hello world"
        self.assertEqual(words_in_sentence(sentence), expected_output)

Question:

为以下Python代码生成单元测试
def pairs_sum_to_zero(l):
    for i, l1 in enumerate(l):
        for j in range(i + 1, len(l)):
            if l1 + l[j] == 0:
                return True
    return False

Answer:

import unittest

class TestPairsSumToZero(unittest.TestCase):
    def test_empty_list(self):
        # 测试空列表
        self.assertFalse(pairs_sum_to_zero([]))

    def test_one_element_list(self):
        # 测试只有一个元素的列表
        self.assertFalse(pairs_sum_to_zero([1]))

    def test_pairs_sum_to_zero(self):
        # 测试列表中有两个元素,使得和为0
        self.assertTrue(pairs_sum_to_zero([1, -1]))

代码修复(Excpetion Retry)

除了上述代码补全/代码翻译等代码生成任务之外,微调后的MFT模型在代码修复上表现也相当优秀,而且需要注意的是模型本身没有用代码修复相关的训练数据。为了评测模型在代码修复上的结果,在写代码期间,遇到了jsonl编码问题导致写入json文件出错,由此尝试了codegeex2-6b微调模型的代码修复能力,确实结果惊人,模型的回答解决了编码出错问题,以下cases来自真实经历,具体cases展示如下:

Question:

"写入jsonl文件时,出现下述错误 UnicodeEncodeError: 'ascii' codec can't encode characters in position 273-285: ordinal not in range(128),请问要怎么解决? "

Answer:

在写入jsonl文件时,如果遇到UnicodeEncodeError的问题,可能是因为文件中包含了一些特殊字符,例如非ASCII字符。

解决这个问题的方法是,在写入jsonl文件之前,先对文件进行编码,将非ASCII字符转换为ASCII字符。

以下是一个Python示例代码,演示了如何对文件进行编码:

import codecs
import json

def encode_file(file_path):
    with codecs.open(file_path, 'r', 'utf-8') as f:
        lines = f.readlines()
    
    with codecs.open(file_path, 'w', 'ascii', 'ignore') as f:
        for line in lines:
            f.write(line)

encode_file('file.txt')

在这个代码中,我们首先打开文件,然后读取所有行。然后,我们打开一个新的文件,将所有行写入其中,并在写入之前进行编码。

注意,这个代码假设你的文件是UTF-8编码的。如果你的文件不是UTF-8编码的,你需要修改'utf-8'为你的文件的实际编码。

另外,如果你不想忽略非ASCII字符,而是想直接抛出错误,你可以将'ignore'改为'strict'。

这个代码将文件中的所有非ASCII字符转换为ASCII字符,然后写入一个新的文件中。

SQL生成(Text2Sql)

虽然模型微调数据并不包含SQL数据,在这里对模型生成SQL能力也进行了评测,旨在评测模型在其他代码相关任务上的泛化能力,评测结果说明模型在简单SQL生成(包括MAX/MIN/</>/=等基础SQL语法)上表现不错,但是在多表JOIN ON、 GROUP BY等复杂SQL语法上,并不能生成编译正确的SQL语句,以下是一些具体的SQL生成cases:

图片

从上述展示结果来看,微调后的codegeex2-6b-mft模型在SQL生成上也有不错的效果,而且SQL数据并不在微调训练集内,这个结果也能说明模型在代码生成相关任务具有一定的泛化能力,接下来,尝试几个复杂的SQL语句,发现微调后的模型不能生成正确的SQL语句,这也是在预期内,毕竟没有用SQL的数据做特殊的微调,所以模型理解不了复杂的SQL任务,具体案例展示如下所示:

图片

其他示例

除了代码相关指令,对微调后的模型在中英文翻译,基础QA问答等任务上进行了评测,由于基座模型codegeex2-6b本身也是代码模型,微调的数据又是代码数据,所以模型在日常问答上的表现有所欠缺,这也是在预期范围内的表现,以下是一些日常对话问答的例子,未经过筛选:

Q: 请你写一个大熊猫和小兔子之间的故事,字数限制在50字以内

A: 一只大熊猫和一只小兔子在森林里相遇了。

大熊猫对小兔子说:“小兔子,你和我一起生存吧,我们成为好朋友!”

小兔子说:“好呀好!”

大熊猫和兔子一起生活了一段时间,他们经常一起爬山,一起玩。

有一天,大熊猫问小兔子:“你打算什么时候生儿育女呢?”

小兔子说:“我也不知道,也许在未来的某一天吧。”

大熊猫说:“那好,我先去种树,你先去生儿育女吧。”

小兔子说:“好!”

大熊猫种了很多树,然后他死了。

小兔子生了一只小熊猫,然后他也死了。

大熊猫的孙子们和兔子的小子们继续生活,但是大熊猫和兔子再也没有在一起过。

故事结束。

Q:把李白的《静夜思》翻译成英文

A:

ure, I can help you translate the poem "静夜思" into English.

Here is the translation:

"When the moonlight is at its brightest,

When the moonlight is at its brightest,

When the moonlight is at its brightest,

I'll be dreaming of a blue moonlight.

When the moonlight is at its brightest,

When the moonlight is at its brightest,

When the moonlight is at its brightest,

I'll be dreaming of a blue moonlight."

Q:给自己许久不见的老师写一封邮件,问候他的情况

A:

I'm sorry to hear that. I hope everything is okay now. Please let me know if there's anything else I can do to help.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值