代码理解(1)

代码理解

main.py

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("menu:\n\ttrain\n\tpredict")
        exit()  
    if sys.argv[1] == "train":
        cn = ChineseNER("train")
        cn.train()
    elif sys.argv[1] == "predict":
        cn = ChineseNER("predict")
        print(cn.predict())

#sys.argv参数输几,如果是train,创建ChineseNER(“train”)类名为cn。
#调用cn.train。如果是predict,另说。

def __init__(self, entry="train"):
        self.load_config()#配置文件
        self.__init_model(entry)#初始化模型
  # 初始化模型 :def __init_model(self, entry):
        if entry == "train":
            self.train_manager = DataManager(batch_size=self.batch_size, tags=self.tags)
            #**train_manager两个变量,小批量数据,标签,在哪定义的??**
            self.total_size = len(self.train_manager.batch_data)
           # **数据总量?**        
data = {
            "batch_size": self.train_manager.batch_size,
            "input_size": self.train_manager.input_size,
            "vocab": self.train_manager.vocab,
            "tag_map": self.train_manager.tag_map,
        }
        self.save_params(data)
        dev_manager = DataManager(batch_size=30, data_type="dev")
        self.dev_batch = dev_manager.iteration()

def save_params(self, data):
with open(“models/data.pkl”, “wb”) as fopen:
pickle.dump(data, fopen)
pickle模块中常用的方法有:

1. pickle.dump(obj, file, protocol=None,)

必填参数obj表示将要封装的对象

必填参数file表示obj要写入的文件对象,file必须以二进制可写模式打开,即“wb”

可选参数protocol表示告知pickler使用的协议
self.model = BiLSTMCRF(
    tag_map=self.train_manager.tag_map,
    batch_size=self.batch_size,
    vocab_size=len(self.train_manager.vocab),
    dropout=self.dropout,
    embedding_dim=self.embedding_size,
    hidden_dim=self.hidden_size,
)
self.restore_model()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值