代码链接:PPLM_code
一. Example command for bag-of-words control:
python run_pplm.py -B military --cond_text “The potato” --length 50 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.03 --window_length 5 --kl_scale 0.01 --gm_scale 0.99 --colorama --sample
代码逻辑——PPLM_词袋模型:
1. run_pplm_example():
2. 加载模型,词典
3. full_text_generation():返回生成结果:无扰动句子和扰动后句子;
3.1 generate_text_pplm(无扰动):返回没有扰动句子
3.1.1:for i in range_func:每次更新一个词
-
如果输入3个词,这步只用前2个词,返回hidden
-
无扰动pert_past = past
-
传入last,和pert_past;返回pert_logits, pert_past, pert_all_hidden;进行past更新,加入当前last信息
-
根据pert_logits取出topk的词;
-
取出概率最大的词,当做下一个词的last,加入列表output_so_far
3.2:这样只采样生成5句扰动后句子:
3.2.1 generate_text_pplm(加入扰动)
3.2.1.1 for i in range_func:每次更新生成一个词
- 如果输入3个词,这步只用前2个词,返回hidden
- 输入全部三个词,返回无扰动:unpert_logits, unpert_past, unpert_all_hidden
- accumulated_hidden = unpert_last_hidden[:, :-1, :]; 只考虑前k-1个词;accumulated_hidden 再求和
- perturb_past():返回pert_past
4.1.初始化一个为0的grad_accumulator,size和24个k,v一起的一样(2,1,16,2,64)
4.2.初始化一个window_mask,size同(2,1,16,2,64)
4.3.进行三轮的梯度迭代:
4.3.1. 把前面轮累计的梯度 grad_accumulator 加到 past 得到 perturbed_past; past 是已经更新后不进行梯度计算
4.3.2. 把最近生成的词last 和 perturbed_past传入模型;得到all_logits和 all_hidden
4.3.3. 从all_hidden取出根据last 生成的hidden 加到 accumulated_hidden 且不进行更新
4.3.4. 取出根据last 从模型中生成的 logits 和probs;(1, 50257)
4.3.5. (1, 50257) 矩阵乘 (50257, 149)= (1,149) linear;然后求和,当前词生成词袋词的loss;
4.3.6. 根据unpert_probs 和 probs 计算 kl-loss
4.3.7. 两个loss相加进行反向传播
4.3.8. 用window_mask对curr_perturbation (实际上就是grad_accumulator ) 计算梯度norms ; 然后step size 计算 grad
4.3.9. 把计算得grad 加到 grad_accumulator
4.3.10. 用newpast 去取出pastnei
- 传入模型last,和pert_past;返回pert_logits, pert_past, pert_all_hidden;进行past更新,加入当前last信息
- 融合扰动后概率和扰动前概率;再取出top_k的词;rescale;得到新词典的概率分布
- 取出概率最大的词,当做下一个词的last
print无扰动句子和扰动后的句子;
4. print无扰动句子和扰动后的句子;
代码如下:
#! /usr/bin/env python3
# coding=utf-8
# Copyright 2018 The Uber AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example command with bag of words:
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
Example command with discriminator:
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
"""
import argparse
import json
from operator import add
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel
from pplm_classification_head import ClassificationHead
PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
BIG_CONST = 1e10
QUIET = 0
REGULAR = 1
VERBOSE = 2
VERY_VERBOSE = 3
VERBOSITY_LEVELS = {
'quiet': QUIET,
'regular': REGULAR,
'verbose': VERBOSE,
'very_verbose': VERY_VERBOSE,
}
BAG_OF_WORDS_ARCHIVE_MAP = {
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
'monsters': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/monsters.txt",
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
'positive_words': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/positive_words.txt",
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
}
DISCRIMINATOR_MODELS_PARAMS = {
"clickbait": {
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
"class_size": 2,
"embed_size": 1024,
"class_vocab": {
"non_clickbait": 0, "clickbait": 1},
"default_class": 1,
"pretrained_model": "gpt2-medium",
},
"sentiment": {
# "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
'path' : '/home/xps/huanghong/workdir/PPLM-master/cache/SST_classifier_head.pt',
"class_size": 5,
"embed_size": 1024,
"class_vocab": {
"very_positive": 2, "very_negative": 3},
"default_class": 3,
"pretrained_model": "gpt2-medium",
},
}
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
if torch.cuda.is_available() and device == 'cuda':
x = x.cuda()
elif device != 'cuda':
x = x.to(device)
return Variable(x, requires_grad=requires_grad, volatile=volatile)
def top_k_filter(logits, k, probs=False):
"""
Masks everything but the k top entries as -infinity (1e10).
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if k == 0:
return logits
else:
values = torch.topk(logits, k)[0]
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
if probs:
return torch.where(logits < batch_mins,
torch.ones_like(logits) * 0.0, logits)
return torch.where(logits < batch_mins,
torch.ones_like(logits) * -BIG_CONST,
logits)
def perturb_past(
past,
model,
last,
unpert_past=None,
unpert_logits=None,
accumulated_hidden=None,
grad_norms=None,
stepsize=0.01,
one_hot_bows_vectors=None,
classifier=None,
class_label=None,
loss_type=0,
num_iterations=3,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
kl_scale=0.01,
device='cuda',
verbosity_level=REGULAR
):
# Generate inital perturbed past # shape同24个block,k 和v stack后的值
grad_accumulator = [
(np.zeros(p.shape).astype("float32"))
for p in past
]
if accumulated_hidden is None:
accumulated_hidden = 0
if decay:
decay_mask = torch.arange(
0.,
1.0 + SMALL_CONST,
1.0 / (window_length)
)[1:]
else:
decay_mask = 1.0
# TODO fix this comment (SUMANTH)
# Generate a mask is gradient perturbated is based on a past window
_, _, _, curr_length, _ = past[0].shape
if curr_length > window_length and window_length > 0:
ones_key_val_shape = (
tuple(past[0].shape[:-2])
+ tuple([window_length])
+ tuple(past[0].shape[-1:])
)
zeros_key_val_shape = (
tuple(past[0].shape[:-2])
+ tuple([curr_length - window_length])
+ tuple(past[0].shape[-1:])
)
ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
window_mask = torch.cat(
(ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2
).to(device)
else:
window_mask = torch.ones_like(past[0]).to(device) # 2,1,16,2,64
# accumulate perturbations for num_iterations 累计三次迭代的扰动
loss_per_iter = []
new_accumulated_hidden