问题描述:
自己写了一个GeneratorDataset,用于迭代数据,使用create_dict_iterator(),可以迭代。但是放入model来train的时候model进程就卡住了。之前我有mindrecord dataset,是完全可以运行的。
生成dataset的代码,一部分是mindrecord。另一部分是generatordataset def load_adEAST_dataset(mindrecord_file, batch_size=64, device_num=8, rank_id=0, is_training=True, num_parallel_workers=8): ''' ds = de.GeneratorDataset(gen, column_names=['images', 'label'], num_parallel_workers=num_parallel_workers, shuffle=is_training) # hwc_to_chw = vision.HWC2CHW() #ds = ds.map(operations=[hwc_to_chw], input_columns=["images"], num_parallel_workers=num_parallel_workers) #ds = ds.batch(batch_size, drop_remainder=True) with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train: f_list = f_train.readlines() batch_num = len(f_list) // cfg.batch_size #ds = ds.shuffle(batch_num) return ds, batch_num ''' ds = de.MindDataset(mindrecord_file, columns_list=["image", "label"], num_shards=device_num, shard_id=rank_id, num_parallel_workers=8, shuffle=is_training) hwc_to_chw = vision.HWC2CHW() ds = ds.map(operations=[hwc_to_chw], input_columns=["image"], num_parallel_workers=num_parallel_workers) ds = ds.batch(batch_size, drop_remainder=True) batch_num = ds.get_dataset_size() ds = ds.shuffle(batch_num) return ds, batch_num
解决方案:
-
的model.train卡住是一直吗?经过一段时间打不打印报错信息?如果有报错信息,也麻烦发一下。
-
建议在 GeneratorDataset/MindDataset这个接口里面shuffle=True是比较高效的,那么ds.shuffle这个操作就没有必要了。
你可以把ds.shuffle操作中的参数调小一点/去掉这行代码,我看现在是把整个数据集大小填在里面了,这会导致整个数据集都缓存在shuflle这个算子里,一是占太多内存,二是shuffle效率不高。