代码理解
main.py
if __name__ == "__main__":
if len(sys.argv) < 2:
print("menu:\n\ttrain\n\tpredict")
exit()
if sys.argv[1] == "train":
cn = ChineseNER("train")
cn.train()
elif sys.argv[1] == "predict":
cn = ChineseNER("predict")
print(cn.predict())
#sys.argv参数输几,如果是train,创建ChineseNER(“train”)类名为cn。
#调用cn.train。如果是predict,另说。
def __init__(self, entry="train"):
self.load_config()#配置文件
self.__init_model(entry)#初始化模型
# 初始化模型 :def __init_model(self, entry):
if entry == "train":
self.train_manager = DataManager(batch_size=self.batch_size, tags=self.tags)
#**train_manager两个变量,小批量数据,标签,在哪定义的??**
self.total_size = len(self.train_manager.batch_data)
# **数据总量?**
data = {
"batch_size": self.train_manager.batch_size,
"input_size": self.train_manager.input_size,
"vocab": self.train_manager.vocab,
"tag_map": self.train_manager.tag_map,
}
self.save_params(data)
dev_manager = DataManager(batch_size=30, data_type="dev")
self.dev_batch = dev_manager.iteration()
def save_params(self, data):
with open(“models/data.pkl”, “wb”) as fopen:
pickle.dump(data, fopen)
pickle模块中常用的方法有:1. pickle.dump(obj, file, protocol=None,) 必填参数obj表示将要封装的对象 必填参数file表示obj要写入的文件对象,file必须以二进制可写模式打开,即“wb” 可选参数protocol表示告知pickler使用的协议
self.model = BiLSTMCRF(
tag_map=self.train_manager.tag_map,
batch_size=self.batch_size,
vocab_size=len(self.train_manager.vocab),
dropout=self.dropout,
embedding_dim=self.embedding_size,
hidden_dim=self.hidden_size,
)
self.restore_model()