【MindSpore】【自定义数据集】GeneratorDataset 可迭代,但是model.train跑不动

问题描述:

自己写了一个GeneratorDataset,用于迭代数据,使用create_dict_iterator(),可以迭代。但是放入model来train的时候model进程就卡住了。之前我有mindrecord dataset,是完全可以运行的。

生成dataset的代码,一部分是mindrecord。另一部分是generatordataset
def load_adEAST_dataset(mindrecord_file, batch_size=64, device_num=8, rank_id=0,
                              is_training=True, num_parallel_workers=8):
    '''                  
    ds = de.GeneratorDataset(gen, column_names=['images', 'label'], num_parallel_workers=num_parallel_workers, shuffle=is_training)
   # hwc_to_chw = vision.HWC2CHW()

    #ds = ds.map(operations=[hwc_to_chw], input_columns=["images"], num_parallel_workers=num_parallel_workers)
    #ds = ds.batch(batch_size, drop_remainder=True)
    with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
        f_list = f_train.readlines()
        batch_num = len(f_list) // cfg.batch_size
    #ds = ds.shuffle(batch_num)
    return ds, batch_num

    '''
    ds = de.MindDataset(mindrecord_file, columns_list=["image", "label"], num_shards=device_num, shard_id=rank_id, num_parallel_workers=8, shuffle=is_training)
    hwc_to_chw = vision.HWC2CHW()

    ds = ds.map(operations=[hwc_to_chw], input_columns=["image"], num_parallel_workers=num_parallel_workers)

    ds = ds.batch(batch_size, drop_remainder=True)
    batch_num = ds.get_dataset_size()
    ds = ds.shuffle(batch_num)
    return ds, batch_num

解决方案:

  1. 的model.train卡住是一直吗?经过一段时间打不打印报错信息?如果有报错信息,也麻烦发一下。

  2. 建议在 GeneratorDataset/MindDataset这个接口里面shuffle=True是比较高效的,那么ds.shuffle这个操作就没有必要了。

    你可以把ds.shuffle操作中的参数调小一点/去掉这行代码,我看现在是把整个数据集大小填在里面了,这会导致整个数据集都缓存在shuflle这个算子里,一是占太多内存,二是shuffle效率不高。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值