代码地址:
https://github.com/ericwtodd/function_vectors
代码
1. 代码的结构
我们的主要评估脚本位于 src 目录中,其中包含 src/eval_scripts 中的示例脚本包装。
其他主要代码分散在各种 util 文件中:
- eval_utils.py 包含用于在各种情境中评估函数向量的代码。
- extract_utils.py 包含用于提取函数向量和其他相关模型激活的函数。
- intervention_utils.py 包含在推断过程中干预函数向量的主要功能。
- model_utils.py 包含从 Hugging Face 加载模型和标记器的实用函数。
- prompt_utils.py 包含数据加载和提示创建功能。
2. 开头
%load_ext autoreload
%autoreload 2
看起来你正在使用 Jupyter Notebook,并希望启用自动重新加载模块的功能。这两行代码的作用是在执行代码前自动重新加载已导入的模块,以确保你对代码的修改能够立即生效。
3. 加载模型和tokenizer
model_name = 'EleutherAI/gpt-j-6b'
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
EDIT_LAYER = 9
4. 加载数据集并计算任务条件下的平均激活
dataset = load_dataset('antonym', seed=0)
mean_activations = get_mean_head_activations(dataset, model, model_config, tokenizer)
5. 计算FV
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)
6. 提示创建-ICL、洗牌标签、零样本和自然文本
dataset = load_dataset('antonym')
word_pairs = dataset['train'][:5]
test_pair = dataset['test'][21]
prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')
shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
shuffled_sentence = create_prompt(shuffled_prompt_data)
print("Shuffled ICL Prompt:\n", repr(shuffled_sentence), '\n\n')
zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
zeroshot_sentence = create_prompt(zeroshot_prompt_data)
print("Zero-Shot Prompt:\n", repr(zeroshot_sentence))