import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import ModuleList
from torch.nn.modules.normalization import LayerNorm
import numpy as np
import os
from tqdm import tqdm_notebook, trange
import logging
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger()
在每个decoder block中有Masked self-attention和feed forward 两个操作,其中每部进行两个linear projection
在Attention中首先将输入的embedding经过conv1D将维度变成3embd
self.c_attn = Conv1D(d_model, d_model3)
attention计算完毕后在最后再进行一次转换
self.c_proj = Conv1D(d_model, d_model)
linear projection
class Conv1D(nn.Module):
def __init__(self, nx, nf):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
FFD
在ffd中首先进行emb->emb x 4的转换然后再进行emb x 4->emb的转换
class FeedForward(nn.Module):
def __init__(self, dropout, d_model=768, nx=768*4):
super().__init__()
self.c_fc = Conv1D(d_model, nx)
self.c_proj = Conv1D(nx, d_model)
self.act = F.gelu
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.c_proj(self.act(self.c_fc(x))))
Masked Self Attention
class Attention(nn.Module):
def __init__(self, d_model=768, n_head=12, n_ctx=1024, d_head=64, bias=True, scale=False):
super().__init__()
self.n_head = n_head
self.d_model = d_model
self.c_attn = Conv1D(d_model, d_model*3)
self.scale = scale
self.softmax = nn.Softmax(dim=-1)
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.dropout = nn.Dropout(0.1)
self.c_proj = Conv1D(d_model, d_model)
def split_heads(self, x):
"return shape [`batch`, `head`, `sequence`, `features`]"
new_shape = x.size()[:-1] + (self.n_head, x.size(-1)//self.n_head)
x = x.view(*new_shape)
return x.permute(0, 2, 1, 3)
def _attn(self, q, k, v, attn_mask=None):
scores = torch.matmul(q, k.transpose(-2, -1))
if self.scale: scores = scores/math.sqrt(v.size(-1))
nd, ns = scores.size(-2), scores.size(-1)
if attn_mask is not None: scores = scores + attn_mask
scores = self.softmax(scores)
scores = self.dropout(scores)
outputs = torch.matmul(scores, v)
return outputs
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_shape = x.size()[:-2] + (x.size(-2)*x.size(-1),)
return x.view(*new_shape)
def forward(self, x):
x = self.c_attn(x) #new `x` shape - `[1,3,2304]`
q, k, v = x.split(self.d_model, dim=2)
q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)
out = self._attn(q, k, v)
out = self.merge_heads(out)
out = self.c_proj(out)
return out
Decoder Block
class TransformerBlock(nn.Module):
def __init__(self, d_model=768, n_head=12, dropout=0.1):
super(TransformerBlock, self).__init__()
self.attn = Attention(d_model=768, n_head=12, d_head=64, n_ctx=1024, bias=True, scale=False)
self.feedforward = FeedForward(dropout=0.1, d_model=768, nx=768*4)
self.ln_1 = LayerNorm(d_model)
self.ln_2 = LayerNorm(d_model)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.feedforward(self.ln_2(x))
return x
GPT2 architecture
def _get_clones(module, n):
return ModuleList([copy.deepcopy(module) for i in range(n)])
class GPT2(nn.Module):
def __init__(self, nlayers=12, n_ctx=1024, d_model=768, vcb_sz=50257):
super(GPT2, self).__init__()
self.nlayers = nlayers
block = TransformerBlock(d_model=768, n_head=12, dropout=0.1)
self.h = _get_clones(block, 12)
self.wte = nn.Embedding(vcb_sz, d_model)
self.wpe = nn.Embedding(n_ctx, d_model)
self.drop = nn.Dropout(0.1)
self.ln_f = LayerNorm(d_model)
self.out = nn.Linear(d_model, vcb_sz, bias=False)
self.loss_fn = nn.CrossEntropyLoss()
self.init_weights()
def init_weights(self):
self.out.weight = self.wte.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, src, labels=None, pos_ids=None):
if pos_ids is None: pos_ids = torch.arange(0, src.size(-1)).unsqueeze(0)
inp = self.drop((self.wte(src)+self.wpe(pos_ids)))
for i in range(self.nlayers): inp = self.h[i](inp)
inp = self.ln_f(inp)
logits = self.out(inp)
outputs = (logits,) + (inp,)
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs
return logits
其中 loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
参考文档
如果target包含class的indices,则输入的shape要以三种形式,其中N就是input的第一维
通常我们的logits是(B,T,C)形式,其实B为batch,T为length,C为channel也就是embd维度,为768,N=BxT,而数据input和target为(B,T)形式,所以target的维度要与shift_logits.view(-1, shift_logits.size(-1))的第一维N一致
Example
model = GPT2()
# load pretrained_weights from hugging face
# download file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin to `.`
model_dict = model.state_dict() #currently with random initialization
state_dict = torch.load("./gpt2-pytorch_model.bin") #pretrained weights
old_keys = []
new_keys = []
for key in state_dict.keys():
if "mlp" in key: #The hugging face state dict references the feedforward network as mlp, need to replace to `feedforward` be able to reuse these weights
new_key = key.replace("mlp", "feedforward")
new_keys.append(new_key)
old_keys.append(key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key]=state_dict.pop(old_key)
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
context = torch.tensor([tokenizer.encode("The planet earth")])
def generate(context, ntok=20):
for _ in range(ntok):
out = model(context)
logits = out[:, -1, :]
indices_to_remove = logits < torch.topk(logits, 10)[0][..., -1, None]
logits[indices_to_remove] = np.NINF
next_tok = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1).squeeze(1)
context = torch.cat([context, next_tok.unsqueeze(-1)], dim=-1)
return context
out = generate(context, ntok=20)
tokenizer.decode(out[0])