1.知识讲解
- 内容:定义一个字典,在python中一切皆对象,将所有的函数进行封装,然后定一个分发函数进行分发,将原来if…else全部干掉。
- 角色:
- 函数(function)
- 函数工厂(function factory)
- 客户端 (client)
- 举个例子:
需求:封装一个函数,能够同时进行加减乘除运算。
加减乘除函数:
# 定义一个计算器的相关功能
def plus(a, b):
return a + b
def substact(a, b):
return a - b
def multiply(a, b):
return a * b
def divide(a, b):
return a / b
定义封装函数:
# 定义一个计算函数
def cal(a, b, how):
if how == 1:
return plus(a, b)
elif how == 2:
return substact(a, b)
elif how == 3:
return multiply(a, b)
else:
return None
从上面这个封装函数来看,太多了if…else…很冗余
于是定义一个函数工厂,将所有函数进行封装,然后根据函数名进行调用
# 定义函数工厂
# 在python里面一切皆是对象
# 定义了一个字典,key是函数名称,value是函数对象
func_map = {
"plus": plus,
"substract": substact,
"multiply": multiply,
"divide": divide
}
# 函数工厂模式就是一种对函数进行动态分发的模式
def cal(a,b,how):
if how in func_map.keys():
return func_map[how](a,b)
else:
return None
- 优点:
- 对函数进行动态分发,减少了函数的冗余代码。
2.实战
2.1 demo1
需求:这个是我在写深度学习项目的时候遇到的一个设计模式,当初不明白,现在明白了这个设计模式。自然语言处理中,有一次有一个实验,需要同时验证Bert,roberta,gpt,Xnet等预训练模型的相关功能的性能,他们大致分以下几个模块
- config
- tokenizer
- 掩码模型:Bert,roberta,gpt使用的是mlm掩码模型,而Xnet使用的是plm掩码模型
- 自带的分类模型:sequence_classifier ,但是GPT没有
因为他们每个的这四个部分的功能实现都不相同,但是在实验过程中都需要用到,因此就用到了函数工厂模式。
from torch import nn
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, BertForMaskedLM, RobertaConfig, \
RobertaTokenizer, RobertaForSequenceClassification, RobertaForMaskedLM, XLMRobertaConfig, XLMRobertaTokenizer, \
XLMRobertaForSequenceClassification, XLMRobertaForMaskedLM, XLNetConfig, XLNetTokenizer, \
XLNetForSequenceClassification, XLNetLMHeadModel, AlbertConfig, AlbertTokenizer, AlbertForSequenceClassification, \
AlbertForMaskedLM, GPT2Config, GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer
# 定义一个函数工厂,将所有的函数全部用一个字典封装好,到时候用到那个预训练模型,则就根据预训练模型的名称调用对应的函数。
MODEL_CLASSES = {
'bert': {
'config': BertConfig,
'tokenizer': BertTokenizer,
"sequence_classifier": BertForSequenceClassification,
"mlm":BertForMaskedLM
},
'roberta': {
'config': RobertaConfig,
'tokenizer': RobertaTokenizer,
"sequence_classifier": RobertaForSequenceClassification,
"mlm": RobertaForMaskedLM
},
'xlm-roberta': {
'config': XLMRobertaConfig,
'tokenizer': XLMRobertaTokenizer,
"sequence_classifier": XLMRobertaForSequenceClassification,
"mlm": XLMRobertaForMaskedLM
},
'xlnet': {
'config': XLNetConfig,
'tokenizer': XLNetTokenizer,
"sequence_classifier": XLNetForSequenceClassification,
"plm": XLNetLMHeadModel
},
'albert': {
'config': AlbertConfig,
'tokenizer': AlbertTokenizer,
"sequence_classifier": AlbertForSequenceClassification,
"mlm": AlbertForMaskedLM
},
'gpt2': {
'config': GPT2Config,
'tokenizer': GPT2Tokenizer,
"mlm": GPT2LMHeadModel
},
}
class TransformerModelWrapper(nn.Module):
# 基于Transformer的语言模型的包装器。
'''WrapperConfig封装了:
model_type为Bert,roberta,gpt,Xnet,
wrapper_type为mlm和plm两种类型'''
def __init__(self, config: WrapperConfig):
super(TransformerModelWrapper, self).__init__()
self.config = config
config_class = MODEL_CLASSES[self.config.model_type]['config']
tokenizer_class = MODEL_CLASSES[self.config.model_type]['tokenizer']
model_class = MODEL_CLASSES[self.config.model_type][self.config.wrapper_type]