数据层面
preprocess.py:规范化文本的逻辑
data.py:数据的处理,包括DataLoader的创建,训练时对输入序列做随机Mask的逻辑
class DataModule():
def __init__(self,file_path):
self.file_path = file_path
self.train_dataset,self.val_dataset,self.test_dataset = None
self.tokenizer = transformers.BertTokenizer.from_pretrained('vocab.txt')
def build_data(self,data):
#整理数据-->[{'input':['..x..'],'label':[int]},..]
output = []
for item in tqdm(data):
out_item = {}
x,y = item.strip().split('\t')
#分词
...
output.append(out_item)
return output
def prepare_data(self,train_path,test_path):
#加载数据
train_load = open(train_path,)..
test_load = open(test_path,)..
#readlines:['..x..\t..y..\n',..]
train_data = self.build_data(train_load)
test_data = self.build_data(test_load)
def setup(self):
self.train_dataset = SamplerDataset(train_data[:len(train_data)*val_size])
self.val_dataset = SamplerDataset(train_data[len(train_data)*val_size:])
self.test_datset = SamplerDataset(test_data)
def train_dataloader(self):
return DataLoader(self.train_dataset,batch_size,shuffle=True,collate_fn=self.collate_fn)
def val_dataloader(self):
return DataLoader(self.val_dataset,batch_size,shuffle=False,collate_fn=self.collate_fn)
def test_dataloader(self):
return DataLoader(self.test_dataset,batch_size,shuffle=False,collate_fn=self.test_collate_fn)
def collate_fn(self,batch):
def test_collate_fn(self,batch):
模型层面
model.py:定义模型,forward的逻辑,得到训练loss
adam.py:优化器的实现,修正L2的weight_decay方法
训练层面
train.py:模型训练任务的组织
def train_step(model,device,train_loader,optimizer,epochs):
model.train()
criterion = nn.CrossEntropyLoss()
for batch_idx,(x_id,x_type,x_mask,y) in enumerate(train_loader):
x_id,x_type,x_mask,y = x_id.to(device),x_type.to(device),x_mask.to(device),y.to(device)
y_pred = model([x_id,x_type,x_mask])
optimizer.zero_grad()
loss = criterion(y_pred,y_g)
loss.backward()
optimizer.step()
# 每100个batch输出一次loss
def valid_step(model,device,val_loader,optimizer):
model.eval()
valid_loss = 0
val_pre = []
val_true = [] # [int,..]
criterion = nn.CrossEntropyLoss()
for batch_idx,(x_id,x_type,x_mask,y)in enumerate(val_loader):
x_id,xtype,x_mask,y = x_id.to(device),x_type.to(device),x_mask.to(device),y.to(device)
with torch.no_grad():
y_pre = model(x_id,x_type,x_mask) # [[c1,c2,c3,c4,c5],..]
valid_loss += criterion(y_pre,y_g)
batch_true = y_g.cpu()
batch_pre = y_pre.cpu()
for item in batch_true:
val_true.append(item)
for item in batch_pre:
val_pre.append(item.argmax(0)) #[0.1,0.6,0.3]-->1
valid_loss /= len(valid_loader)
# 打印验证集平均损失
valid_true = np.array(val_true)
valid_pre = np.array(val_pre)
avg_acc = accuracy_score(valid_true,valid_pre)
avg_f1 = f1_score(valid_true,valid_pre,average='macro')
# 打印验证集准确度和f1
return avg_acc,avg_f1,valid_loss
def test_savaResult(model,device,test_loader,test_entitys,test_ids,result_path):
model.eval()
test_pre = []
for batch_idx,(x1,x2,x3)in enumerate(test_loader):
x1_g, x2_g, x3_g = x1.to(device), x2.to(device), x3.to(device)
with torch.no_grad():
y_pre = model(x1_g,x2_g,x3_g)
batch_pre = y_pre.detach().cpu().numpy()
for item in batch_pre:
test_pre.append(item.argmax(0))
result = {} #{'id':{'entity':label,..}}
for id,entity,pre_y in zip(test_ids,test_entitys,test_pre):
if id in result.keys():
result[id][entity] = pre_y
else:
result[id] = {entity:pre_y}
with open(result_path, 'w', encoding='utf-8') as f:
f.write("id result")
f.write('\n')
for k, v in result.items():
f.write(str(k) + ' ' + json.dumps(v, ensure_ascii=False) + '\n')
print(f"保存文件到:{result_path}")
Bert源码书写逻辑
class BertLM(nn.Module):
def __init__(self):
#三个需要相加的向量
self.tok_embed = nn.Embedding(vocab_size,embed_dim,vocab_padding_idx)
self.pos_embed
self.seg_embed
# 12个transformer layer叠加
self.layers = nn.ModuluList(..)
# embed进行层归一化
self.emb_layer_norm = LayerNorm(embed_dim)
def forward(self,truth,inp,seg,mak,nxt_snt_flag):
# tru:真实的tok seg:句子分割后的tok
# inp:mask过的tok msk:被mask掉的那些词的tok
x = self.tok_embed(inp)+self.pos_embed(inp)+self.seg_embed(seg)
x = self.emb_layers_norm(x)
padding_mask = torch.eq(x,vocab_padding_idx)
for layer in self.layers:
x,_,_ = layer(x,self.padding_mask=padding_mask)
# 把mask掉的词选出来,计算loss
masked_x = x.masked_select(msk.unsqueeze(-1))
masked_x = masked_x.view(-1,self.embed_dim)
gold = truth.masked_select(msk)
# 计算mask和gold的loss