Mistralai最新发布了22B的Codestral。
Codestrall-22B-v0.1 在一个包含 80 多种编程语言的不同数据集上进行了训练,其中包括最流行的语言,如 Python、Java、C、C++、JavaScript 和 Bash(更多详情请参见博文)。可以对模型进行查询:
- 作为指示,例如回答有关代码片段的任何问题(编写文档、解释、因式分解)或根据特定指示生成代码
- 作为 “中间填充”(FIM),预测前缀和后缀之间的中间标记(对软件开发插件(如 VS Code)非常有用。
安装
pip install mistral_inference
下载
from huggingface_hub import snapshot_download
from pathlib import Path
mistral_models_path = Path.home().joinpath('mistral_models', 'Codestral-22B-v0.1')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistralai/Codestral-22B-v0.1", allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], local_dir=mistral_models_path)
聊天
mistral-chat $HOME/mistral_models/Codestral-22B-v0.1 --instruct --max_tokens 256
Prompt:“Write me a function that computes fibonacci in Rust”
Sure, here's a simple implementation of a function that computes the Fibonacci sequence in Rust. This function takes an integer `n` as an argument and returns the `n`th Fibonacci number.
fn fibonacci(n: u32) -> u32 {
match n {
0 => 0,
1 => 1,
_ => fibonacci(n - 1) + fibonacci(n - 2),
}
}
fn main() {
let n = 10;
println!("The {}th Fibonacci number is: {}", n, fibonacci(n));
}
This function uses recursion to calculate the Fibonacci number. However, it's not the most efficient solution because it performs a lot of redundant calculations. A more efficient solution would use a loop to iteratively calculate the Fibonacci numbers.
Fill-in-the-middle (FIM)
安装 mistral_inference 并运行 pip install --upgrade mistral_common,确保已安装 mistral_common>=1.2:
from mistral_inference.model import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.instruct.request import FIMRequest
tokenizer = MistralTokenizer.v3()
model = Transformer.from_folder("~/codestral-22B-240529")
prefix = """def add("""
suffix = """ return sum"""
request = FIMRequest(prompt=prefix, suffix=suffix)
tokens = tokenizer.encode_fim(request).tokens
out_tokens, _ = generate([tokens], model, max_tokens=256, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
middle = result.split(suffix)[0].strip()
print(middle)
结果
num1, num2):
# Add two numbers
sum = num1 + num2
# return the sum