工程代码书写规范

数据层面

preprocess.py:规范化文本的逻辑
data.py:数据的处理,包括DataLoader的创建,训练时对输入序列做随机Mask的逻辑

class DataModule():
	def __init__(self,file_path):
		self.file_path = file_path
		self.train_dataset,self.val_dataset,self.test_dataset = None
		self.tokenizer = transformers.BertTokenizer.from_pretrained('vocab.txt')
		
	def build_data(self,data):
		#整理数据-->[{'input':['..x..'],'label':[int]},..]
		output = []
		for item in tqdm(data):
			out_item = {}
			x,y = item.strip().split('\t')
			#分词
			...
			output.append(out_item)
		return  output
		
	def prepare_data(self,train_path,test_path):
		#加载数据
		train_load = open(train_path,)..
		test_load = open(test_path,)..
		#readlines:['..x..\t..y..\n',..]
		train_data = self.build_data(train_load)
		test_data = self.build_data(test_load)
		
	def setup(self):
		self.train_dataset = SamplerDataset(train_data[:len(train_data)*val_size])
		self.val_dataset = SamplerDataset(train_data[len(train_data)*val_size:])
		self.test_datset = SamplerDataset(test_data)
		
	def train_dataloader(self):
		return DataLoader(self.train_dataset,batch_size,shuffle=True,collate_fn=self.collate_fn)
	def val_dataloader(self):
		return DataLoader(self.val_dataset,batch_size,shuffle=False,collate_fn=self.collate_fn)
	def test_dataloader(self):
		return DataLoader(self.test_dataset,batch_size,shuffle=False,collate_fn=self.test_collate_fn)
	
	def collate_fn(self,batch):
		
	def test_collate_fn(self,batch):

模型层面

model.py:定义模型,forward的逻辑,得到训练loss
adam.py:优化器的实现,修正L2的weight_decay方法

训练层面

train.py:模型训练任务的组织

def train_step(model,device,train_loader,optimizer,epochs):
	model.train()
	criterion = nn.CrossEntropyLoss()
	for batch_idx,(x_id,x_type,x_mask,y) in enumerate(train_loader):
		x_id,x_type,x_mask,y = x_id.to(device),x_type.to(device),x_mask.to(device),y.to(device)
		y_pred = model([x_id,x_type,x_mask])
		optimizer.zero_grad()
		loss = criterion(y_pred,y_g)
		loss.backward()
		optimizer.step()
		# 每100个batch输出一次loss
		
def valid_step(model,device,val_loader,optimizer):
	model.eval()
	valid_loss = 0
	val_pre = []  
	val_true = []  # [int,..]
	criterion = nn.CrossEntropyLoss()
	for batch_idx,(x_id,x_type,x_mask,y)in enumerate(val_loader):
		x_id,xtype,x_mask,y = x_id.to(device),x_type.to(device),x_mask.to(device),y.to(device)
		with torch.no_grad():
			y_pre = model(x_id,x_type,x_mask)  # [[c1,c2,c3,c4,c5],..]
		valid_loss += criterion(y_pre,y_g)
		batch_true = y_g.cpu()
		batch_pre = y_pre.cpu()
		for item in batch_true:
			val_true.append(item)
		for item in batch_pre:
			val_pre.append(item.argmax(0))  #[0.1,0.6,0.3]-->1
	valid_loss /= len(valid_loader)
	# 打印验证集平均损失
	valid_true = np.array(val_true)
	valid_pre = np.array(val_pre)
	avg_acc = accuracy_score(valid_true,valid_pre)
	avg_f1 = f1_score(valid_true,valid_pre,average='macro')
	# 打印验证集准确度和f1
	return avg_acc,avg_f1,valid_loss

def test_savaResult(model,device,test_loader,test_entitys,test_ids,result_path):
	model.eval()
	test_pre = []
	for batch_idx,(x1,x2,x3)in enumerate(test_loader):
		x1_g, x2_g, x3_g = x1.to(device), x2.to(device), x3.to(device)
		with torch.no_grad():
			y_pre = model(x1_g,x2_g,x3_g)
		batch_pre = y_pre.detach().cpu().numpy()
		for item in batch_pre:
			test_pre.append(item.argmax(0))
	result = {}  #{'id':{'entity':label,..}}
	for id,entity,pre_y in zip(test_ids,test_entitys,test_pre):
		if id in result.keys():
			result[id][entity] = pre_y
		else:
			result[id] = {entity:pre_y}
	with open(result_path, 'w', encoding='utf-8') as f:
     f.write("id	result")
     f.write('\n')
     for k, v in result.items():
     	f.write(str(k) + '	' + json.dumps(v, ensure_ascii=False) + '\n')
    print(f"保存文件到:{result_path}")

Bert源码书写逻辑

class BertLM(nn.Module):
	def __init__(self):
		#三个需要相加的向量
		self.tok_embed = nn.Embedding(vocab_size,embed_dim,vocab_padding_idx)
		self.pos_embed 
		self.seg_embed
		# 12个transformer layer叠加
		self.layers = nn.ModuluList(..)
		# embed进行层归一化
		self.emb_layer_norm = LayerNorm(embed_dim)
	def forward(self,truth,inp,seg,mak,nxt_snt_flag):
		# tru:真实的tok		seg:句子分割后的tok
		# inp:mask过的tok	msk:被mask掉的那些词的tok
		x = self.tok_embed(inp)+self.pos_embed(inp)+self.seg_embed(seg)
		x = self.emb_layers_norm(x)
		padding_mask = torch.eq(x,vocab_padding_idx)
		for layer  in self.layers:
			x,_,_ = layer(x,self.padding_mask=padding_mask)
		# 把mask掉的词选出来,计算loss
		masked_x = x.masked_select(msk.unsqueeze(-1))
		masked_x = masked_x.view(-1,self.embed_dim)
		gold = truth.masked_select(msk)
		# 计算mask和gold的loss
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

石头猿rock

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值