1、spark tensorflow 环境搭建:
https://zhuanlan.zhihu.com/p/51819048
2、spark 写tfrecord示例
rdd = sc.textFile(','.join(inpath)) #读入埋点数据,埋点数据一般按天存储,所以inpath是个列表 df = spark_session.createDataFrame(rdd1, record_schema).repartition(200).orderBy(F.rand()) df.write.mode("overwrite").format("tfrecord").option("recordType", "Example").save(outpath_)
3、dataset读tfrecord,使其满足可以喂入模型的数据类型
- 读入tfrecord
dataset = tf.data.TFRecordDataset(data_file) #data_file一般是个列表,因为tfrecord是一块块的,不只一个文件
- 定义每个样本处理模板函数
def parser(value): """Parse train and eval data with label Args: value: Tensor("arg0:0", shape=(), dtype=string) """ # `tf.decode_csv` return rank 0 Tensor list: <tf.Tensor 'DecodeCSV:60' shape=() dtype=string> # na_value fill with record_defaults # columns = tf.csv_defaults( # value, record_defaults=list(csv_defaults.values()), # field_delim=field_delim, use_quote_delim=False, na_value=na_value)#针对读取csv格式的数据 features = tf.parse_single_example(value, csv_defaults) #针对读取tfrecord格式的数据 features['goodsid']=tf.sparse.to_dense(features['goodsid'],default_value='') if is_pred: return features else: label = features.pop('label') ctr = tf.cond(tf.less_equal(1.0, label[0]), lambda:tf.constant([1.0]), lambda:tf.constant([0.0])) # if use_weight: # pred = labels[0] if multivalue else labels # pred must be rank 0 scalar # pos_weight, neg_weight = pos_w or 1, neg_w or 1 # weight = tf.cond(pred, lambda: pos_weight, lambda: neg_weight) # features["weight_column"] = [weight] # padded_batch need rank 1 return features, {'ctr':ctr} return parser
其中csv_default 形如 OrderedDict([('label', tf.FixedLenFeature(shape=[1], dtype=tf.float32, default_value=None)), ('devices_type', tf.FixedLenFeature(shape=[1], dtype=tf.string, default_value=None)), ('device_brands', tf.FixedLenFeature(shape=[1], dtype=tf.string, default_value=None)),...) 注意:读取csv时,csv_default是另外一种形式
-
然后对整个数据进行map
dataset = dataset.map(parse(is_pred=(mode == 'pred')), num_parallel_calls=self._num_parallel_calls)
- 最后设置批次
if mode == 'train': dataset = dataset.repeat(self._train_epochs) # define outside loop dataset = dataset.prefetch(2 * batch_size) dataset = dataset.batch(batch_size) return dataset.make_one_shot_iterator().get_next()
- 完成