# encoding: utf-8
from gpt_model import GPTConfig, GPTModel
import numpy as np
import sys
import torch
from data_set import load_tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-3
max_iters = 100
config = GPTConfig()
config.dropout = 0.1
config.batch_size = 32
model = GPTModel(config).to(device)
# load tokenizer
# 加载模型
model.load_state_dict(torch.load("bird_shooter_step_15000.ckpt"))
model.eval() # 设置模型为评估模式
model_file = "bird_shooter.model"
flag, sp = load_tokenizer(model_file) # 注意这里假设load_tokenizer是正确的函数名,且返回两个值
if not flag:
print(f"load tokenizer model from: {model_file} failed")
sys.exit(1)
# generate from
evaluate.py
最新推荐文章于 2025-02-18 11:24:05 发布