从spark 写tfrecord到dataset读入到模型

1、spark tensorflow 环境搭建:

https://zhuanlan.zhihu.com/p/51819048

2、spark 写tfrecord示例

rdd = sc.textFile(','.join(inpath)) #读入埋点数据,埋点数据一般按天存储,所以inpath是个列表
df = spark_session.createDataFrame(rdd1, record_schema).repartition(200).orderBy(F.rand())
df.write.mode("overwrite").format("tfrecord").option("recordType", "Example").save(outpath_)

3、dataset读tfrecord,使其满足可以喂入模型的数据类型

  • 读入tfrecord 
    dataset = tf.data.TFRecordDataset(data_file) #data_file一般是个列表,因为tfrecord是一块块的,不只一个文件
  • 定义每个样本处理模板函数
            def parser(value):
                """Parse train and eval data with label
                Args:
                    value: Tensor("arg0:0", shape=(), dtype=string)
                """
                # `tf.decode_csv` return rank 0 Tensor list: <tf.Tensor 'DecodeCSV:60' shape=() dtype=string>
                # na_value fill with record_defaults
                # columns = tf.csv_defaults(
                #     value, record_defaults=list(csv_defaults.values()),
                #     field_delim=field_delim, use_quote_delim=False, na_value=na_value)#针对读取csv格式的数据
                features = tf.parse_single_example(value, csv_defaults) #针对读取tfrecord格式的数据
                features['goodsid']=tf.sparse.to_dense(features['goodsid'],default_value='')
    
                if is_pred:
                    return features
                else:
                    label = features.pop('label')
           
                    ctr = tf.cond(tf.less_equal(1.0, label[0]), lambda:tf.constant([1.0]), lambda:tf.constant([0.0]))
                    # if use_weight:
                    #     pred = labels[0] if multivalue else labels  # pred must be rank 0 scalar
                    #     pos_weight, neg_weight = pos_w or 1, neg_w or 1
                    #     weight = tf.cond(pred, lambda: pos_weight, lambda: neg_weight)
                    #     features["weight_column"] = [weight]  # padded_batch need rank 1
                    return features, {'ctr':ctr}
            return parser

    其中csv_default 形如 OrderedDict([('label', tf.FixedLenFeature(shape=[1], dtype=tf.float32, default_value=None)), ('devices_type', tf.FixedLenFeature(shape=[1], dtype=tf.string, default_value=None)), ('device_brands', tf.FixedLenFeature(shape=[1], dtype=tf.string, default_value=None)),...)  注意:读取csv时,csv_default是另外一种形式

  • 然后对整个数据进行map

    dataset = dataset.map(parse(is_pred=(mode == 'pred')), num_parallel_calls=self._num_parallel_calls)
  • 最后设置批次
    if mode == 'train':
        dataset = dataset.repeat(self._train_epochs)  # define outside loop
    
    dataset = dataset.prefetch(2 * batch_size)
    dataset = dataset.batch(batch_size)
    return dataset.make_one_shot_iterator().get_next()
  • 完成

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值