paddlepaddle异步读取数据+保存模型-只有关键代码


def list_sample_generator(sample_file_list):
    """sample_gene"""
    
    def dfm_pointwise_generator(sample_file_list):
        """
        deepfm pointwise样本生成器
        """
        for sample_file in sample_file_list:
            with open(sample_file, "r") as f:
                for line in f:
                    toks = line.strip('\n').split('\t')

                    
                    user_fea = toks[2].split("\1")
                    item_fea = toks[3].split("\1")
            
                    label = toks[1]
                    # main_poi = json.loads(toks[5])
                    # target_poi = json.loads(toks[6])
                    # his = json.loads(toks[7])
                    # mask = json.loads(toks[8])

                    main_poi = toks[5].split("\1")
                    target_poi = toks[6].split("\1")
                    his = [x.split("\1") for x in toks[7].split("\2")]
                    mask = toks[8].split("\1")

                    
                    featue = user_fea + item_fea
                    yield np.array([label]).astype('float32'), np.array(featue).astype('float32'), np.array(main_poi).astype("float32"), np.array(target_poi).astype("float32"), np.array(his).astype("float32"), np.array(mask).astype("float32")
                    # yield np.array([label]).astype("int64"), np.array(featue).astype("float32")

    def sample_generator():
        """
        样本生成器
        """
        for sample in dfm_pointwise_generator(sample_file_list):
            yield sample

    return sample_generator



# 定义训练网络
with fluid.program_guard(train_prog, train_startup):
    # fluid.unique_name.guard() to share parameters with test network
    with fluid.unique_name.guard():
        [train_loss, prob, l2_reg_cross_loss, rank_loss, auc_var], train_loader = network(args, feature_name2conf)
        adam = fluid.optimizer.Adam(learning_rate=0.01)
        adam.minimize(train_loss)

# 创建预测的main_program和startup_program
test_prog = fluid.Program()
test_startup = fluid.Program()

# 定义预测网络
# with fluid.program_guard(test_prog, test_startup):
#     # Use fluid.unique_name.guard() to share parameters with train network
#     with fluid.unique_name.guard():
#         test_loss, test_loader = network()

place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)

# 运行startup_program进行初始化
exe.run(train_startup)
# exe.run(test_startup)

# Compile programs
train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(loss_name=train_loss.name)
# test_prog = fluid.CompiledProgram(test_prog).with_data_parallel(share_vars_from=train_prog)

# 设置DataLoader的数据源
places = fluid.cuda_places() if ITERABLE else None

reader_list = [("demo_data" + "/" + file_name) for file_name in os.listdir("./demo_data")]
print("read_list:", reader_list)
sample_list_reader = fluid.io.batch(list_sample_generator(reader_list), batch_size=128)

sample_list_reader = fluid.io.shuffle(sample_list_reader, buf_size=64) # 还可以进行适当的shuffle
train_loader.set_sample_list_generator(sample_list_reader, places=places)

# train_loader.set_sample_list_generator(
#     fluid.io.shuffle(fluid.io.batch(mnist.train(), 512), buf_size=1024), places=places)

# test_loader.set_sample_list_generator(fluid.io.batch(mnist.test(), 512), places=places)


def run_iterable(program, exe, loss, data_loader):
    for data in data_loader():
        loss_value = exe.run(program=program, feed=data, fetch_list=[loss])
        print('loss is {}'.format(loss_value))

for epoch_id in six.moves.range(1):
    # print("main program is: {}".format(train_prog))
    run_iterable(train_prog, exe, train_loss, train_loader)
    # run_iterable(test_prog, exe, test_loss, test_loader)
    fluid.io.save_inference_model(dirname="./model_test", feeded_var_names=['doc'], target_vars=[train_loss], executor=exe, main_program=train_prog)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值