TFX完整流水线示例以及补充(下)

TFX完整流水线示例以及补充(上)

Trainer

_taxi_trainer_module_file = 'taxi_trainer.py'
%%writefile {_taxi_trainer_module_file}

from typing import Dict, List, Text

import os
import glob
from absl import logging

import datetime
import tensorflow as tf
import tensorflow_transform as tft

from tfx import v1 as tfx
from tfx_bsl.public import tfxio
from tensorflow_transform import TFTransformOutput

# Imported files such as taxi_constants are normally cached, so changes are
# not honored after the first import.  Normally this is good for efficiency, but
# during development when we may be iterating code it can be a problem. To
# avoid this problem during development, reload the file.
import taxi_constants
import sys
if 'google.colab' in sys.modules:  # Testing to see if we're doing development
  import importlib
  importlib.reload(taxi_constants)

_LABEL_KEY = taxi_constants.LABEL_KEY

_BATCH_SIZE = 40


def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_LABEL_KEY),
      tf_transform_output.transformed_metadata.schema)

def _get_tf_examples_serving_signature(model, tf_transform_output):
  """Returns a serving signature that accepts `tensorflow.Example`."""

  # We need to track the layers in the model in order to save it.
  # TODO(b/162357359): Revise once the bug is resolved.
  model.tft_layer_inference = tf_transform_output.transform_features_layer()

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
  ])
  def serve_tf_examples_fn(serialized_tf_example):
    """Returns the output to be used in the serving signature."""
    raw_feature_spec = tf_transform_output.raw_feature_spec()
    # Remove label feature since these will not be present at serving time.
    raw_feature_spec.pop(_LABEL_KEY)
    raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)
    transformed_features = model.tft_layer_inference(raw_features)
    logging.info('serve_transformed_features = %s', transformed_features)

    outputs = model(transformed_features)
    # TODO(b/154085620): Convert the predicted labels from the model using a
    # reverse-lookup (opposite of transform.py).
    return {'outputs': outputs}

  return serve_tf_examples_fn


def _get_transform_features_signature(model, tf_transform_output):
  """Returns a serving signature that applies tf.Transform to features."""

  # We need to track the layers in the model in order to save it.
  # TODO(b/162357359): Revise once the bug is resolved.
  model.tft_layer_eval = tf_transform_output.transform_features_layer()

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
  ])
  def transform_features_fn(serialized_tf_example):
    """Returns the transformed_features to be fed as input to evaluator."""
    raw_feature_spec = tf_transform_output.raw_feature_spec()
    raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)
    transformed_features = model.tft_layer_eval(raw_features)
    logging.info('eval_transformed_features = %s', transformed_features)
    return transformed_features

  return transform_features_fn


def export_serving_model(tf_transform_output, model, output_dir):
  """Exports a keras model for serving.
  Args:
    tf_transform_output: Wrapper around output of tf.Transform.
    model: A keras model to export for serving.
    output_dir: A directory where the model will be exported to.
  """
  # The layer has to be saved to the model for keras tracking purpases.
  model.tft_layer = tf_transform_output.transform_features_layer()

  signatures = {
      'serving_default':
          _get_tf_examples_serving_signature(model, tf_transform_output),
      'transform_features':
          _get_transform_features_signature(model, tf_transform_output),
  }

  model.save(output_dir, save_format='tf', signatures=signatures)


def _build_keras_model(tf_transform_output: TFTransformOutput
                       ) -> tf.keras.Model:
  """Creates a DNN Keras model for classifying taxi data.

  Args:
    tf_transform_output: [TFTransformOutput], the outputs from Transform

  Returns:
    A keras Model.
  """
  feature_spec = tf_transform_output.transformed_feature_spec().copy()
  feature_spec.pop(_LABEL_KEY)

  inputs = {}
  for key, spec in feature_spec.items():
    if isinstance(spec, tf.io.VarLenFeature):
      inputs[key] = tf.keras.layers.Input(
          shape=[None], name=key, dtype=spec.dtype, sparse=True)
    elif isinstance(spec, tf.io.FixedLenFeature):
      # TODO(b/208879020): Move into schema such that spec.shape is [1] and not
      # [] for scalars.
      inputs[key] = tf.keras.layers.Input(
          shape=spec.shape or [1], name=key, dtype=spec.dtype)
    else:
      raise ValueError('Spec type is not supported: ', key, spec)

  output = tf.keras.layers.Concatenate()(tf.nest.flatten(inputs))
  output = tf.keras.layers.Dense(100, activation='relu')(output)
  output = tf.keras.layers.Dense(70, activation='relu')(output)
  output = tf.keras.layers.Dense(50, activation='relu')(output)
  output = tf.keras.layers.Dense(20, activation='relu')(output)
  output = tf.keras.layers.Dense(1)(output)
  return tf.keras.Model(inputs=inputs, outputs=output)


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, 
                            tf_transform_output, _BATCH_SIZE)
  eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, 
                           tf_transform_output, _BATCH_SIZE)

  model = _build_keras_model(tf_transform_output)

  model.compile(
      loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.BinaryAccuracy()])

  tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=fn_args.model_run_dir, update_freq='batch')

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps,
      callbacks=[tensorboard_callback])

  # Export the model.
  export_serving_model(tf_transform_output, model, fn_args.serving_model_dir)
Overwriting taxi_trainer.py
trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_taxi_trainer_module_file),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=tfx.proto.TrainArgs(num_steps=10000),
    eval_args=tfx.proto.EvalArgs(num_steps=5000))
context.run(trainer, enable_cache=True)
INFO:absl:Generating ephemeral wheel package for '/mnt/c/Users/DELL/jupyter_notebook_code/chicago_taxi_pipeline/taxi_trainer.py' (including modules: ['taxi_constants', 'taxi_trainer', 'taxi_transform']).
INFO:absl:User module package has hash fingerprint version 79f9cdb8dcb0633411b76b3906a3770b749e6c7c16484cb1f26a1a8c7cbf516a.
INFO:absl:Executing: ['/home/xzy/anaconda3/envs/tfx/bin/python', '/tmp/tmptydrvwrp/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmp/tmp9d7rg1o9', '--dist-dir', '/tmp/tmpmyq6a9n7']
/home/xzy/anaconda3/envs/tfx/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer, pypa/build or
        other standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at './pipeline_output_root/_wheels/tfx_user_code_Trainer-0.0+79f9cdb8dcb0633411b76b3906a3770b749e6c7c16484cb1f26a1a8c7cbf516a-py3-none-any.whl'; target user module is 'taxi_trainer'.
INFO:absl:Full user module path is 'taxi_trainer@./pipeline_output_root/_wheels/tfx_user_code_Trainer-0.0+79f9cdb8dcb0633411b76b3906a3770b749e6c7c16484cb1f26a1a8c7cbf516a-py3-none-any.whl'
INFO:absl:Running driver for Trainer
INFO:absl:MetadataStore with DB connection initialized


running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying taxi_constants.py -> build/lib
copying taxi_trainer.py -> build/lib
copying taxi_transform.py -> build/lib
installing to /tmp/tmp9d7rg1o9
running install
running install_lib
copying build/lib/taxi_transform.py -> /tmp/tmp9d7rg1o9
copying build/lib/taxi_trainer.py -> /tmp/tmp9d7rg1o9
copying build/lib/taxi_constants.py -> /tmp/tmp9d7rg1o9
running install_egg_info
running egg_info
creating tfx_user_code_Trainer.egg-info
writing tfx_user_code_Trainer.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Trainer.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Trainer.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
......省略
ExecutionResult at 0x7f0295ad3790
.execution_id6
.component
model_run_artifact_dir = trainer.outputs['model_run'].get()[0].uri

%load_ext tensorboard
%tensorboard --logdir {model_run_artifact_dir}

在这里插入图片描述

Evaluator

# Imported files such as taxi_constants are normally cached, so changes are
# not honored after the first import.  Normally this is good for efficiency, but
# during development when we may be iterating code it can be a problem. To
# avoid this problem during development, reload the file.
import taxi_constants
import sys
if 'google.colab' in sys.modules:  # Testing to see if we're doing development
  import importlib
  importlib.reload(taxi_constants)

eval_config = tfma.EvalConfig(
    model_specs=[
        # This assumes a serving model with signature 'serving_default'. If
        # using estimator based EvalSavedModel, add signature_name: 'eval' and
        # remove the label_key.
        tfma.ModelSpec(
            signature_name='serving_default',
            label_key=taxi_constants.LABEL_KEY,
            preprocessing_function_names=['transform_features'],
            )
        ],
    metrics_specs=[
        tfma.MetricsSpec(
            # The metrics added here are in addition to those saved with the
            # model (assuming either a keras model or EvalSavedModel is used).
            # Any metrics added into the saved model (for example using
            # model.compile(..., metrics=[...]), etc) will be computed
            # automatically.
            # To add validation thresholds for metrics saved with the model,
            # add them keyed by metric name to the thresholds map.
            metrics=[
                tfma.MetricConfig(class_name='ExampleCount'),
                tfma.MetricConfig(class_name='BinaryAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.5}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
            ]
        )
    ],
    slicing_specs=[
        # An empty slice spec means the overall slice, i.e. the whole dataset.
        tfma.SlicingSpec(),
        # Data can be sliced along a feature column. In this case, data is
        # sliced along feature column trip_start_hour.
        tfma.SlicingSpec(
            feature_keys=['trip_start_hour'])
    ])
# Use TFMA to compute a evaluation statistics over features of a model and
# validate them against a baseline.

# The model resolver is only required if performing model validation in addition
# to evaluation. In this case we validate against the latest blessed model. If
# no model has been blessed before (as in this case) the evaluator will make our
# candidate the first blessed model.
model_resolver = tfx.dsl.Resolver(
      strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
      model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
      model_blessing=tfx.dsl.Channel(
          type=tfx.types.standard_artifacts.ModelBlessing)).with_id(
              'latest_blessed_model_resolver')
context.run(model_resolver, enable_cache=True)

evaluator = tfx.components.Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config)
#输出过长
context.run(evaluator, enable_cache=True)
INFO:absl:Running driver for latest_blessed_model_resolver
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running publisher for latest_blessed_model_resolver
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running driver for Evaluator
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Evaluator
INFO:absl:udf_utils.get_fn {'eval_config': '{\n  "metrics_specs": [\n    {\n      "metrics": [\n        {\n          "class_name": "ExampleCount"\n        },\n        {\n          "class_name": "BinaryAccuracy",\n          "threshold": {\n            "change_threshold": {\n              "absolute": -1e-10,\n              "direction": "HIGHER_IS_BETTER"\n            },\n            "value_threshold": {\n              "lower_bound": 0.5\n            }\n          }\n        }\n      ]\n    }\n  ],\n  "model_specs": [\n    {\n      "label_key": "tips",\n      "preprocessing_function_names": [\n        "transform_features"\n      ],\n      "signature_name": "serving_default"\n    }\n  ],\n  "slicing_specs": [\n    {},\n    {\n      "feature_keys": [\n        "trip_start_hour"\n      ]\n    }\n  ]\n}', 'feature_slicing_spec': None, 'fairness_indicator_thresholds': 'null', 'example_splits': 'null', 'module_file': None, 'module_path': None} 'custom_eval_shared_model'
INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "serving_default"
  label_key: "tips"
  preprocessing_function_names: "transform_features"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  metrics {
    class_name: "BinaryAccuracy"
    threshold {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
}

......省略
ExecutionResult at 0x7f02535d2d90
.execution_id8
.component
context.show(evaluator.outputs['evaluation'])

在这里插入图片描述

SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'Overall', 'metrics':…
import tensorflow_model_analysis as tfma

# Get the TFMA output result path and load the result.
PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
tfma_result = tfma.load_eval_result(PATH_TO_RESULT)

# Show data sliced along feature column trip_start_hour.
tfma.view.render_slicing_metrics(
    tfma_result, slicing_column='trip_start_hour')

在这里插入图片描述

SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'trip_start_hour:19',…
PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
print(tfma.load_validation_result(PATH_TO_RESULT))
validation_ok: true
validation_details {
  slicing_details {
    slicing_spec {
    }
    num_matching_slices: 25
  }
}

Pusher

pusher = tfx.components.Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher, enable_cache=True)
INFO:absl:Running driver for Pusher
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Pusher
INFO:absl:Model version: 1687834975
INFO:absl:Model written to serving path ./serving_model/taxi_simple/1687834975.
INFO:absl:Model pushed to ./pipeline_output_root/Pusher/pushed_model/9.
INFO:absl:Running publisher for Pusher
INFO:absl:MetadataStore with DB connection initialized
ExecutionResult at 0x7f0253563550
.execution_id9
.component
push_uri = pusher.outputs['pushed_model'].get()[0].uri
model = tf.saved_model.load(push_uri)

for item in model.signatures.items():
  pp.pprint(item)
('serving_default',
 <ConcreteFunction signature_wrapper(*, examples) at 0x7F0265217C70>)
('transform_features',
 <ConcreteFunction signature_wrapper(*, examples) at 0x7F02652D0D90>)

pipeline

def _create_pipeline(pipeline_name:str,pipeline_root:str,metadata_path:str):
    components = [
        example_gen,
        statistics_gen,
        schema_gen,
        example_validator,
        transform,
        trainer,
        evaluator,
        pusher
    ]
    return tfx.dsl.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        metadata_connection_config=tfx.orchestration.metadata
        .sqlite_metadata_connection_config(metadata_path),
        components=components)
tfx.orchestration.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      metadata_path=METADATA_PATH))
/home/xzy/anaconda3/envs/tfx/lib/python3.9/site-packages/tfx/orchestration/pipeline.py:408: UserWarning: Node Evaluator depends on the output of node latest_blessed_model_resolver, but latest_blessed_model_resolver is not included in the components of pipeline. Did you forget to add it?
  warnings.warn(
INFO:absl:Using deployment config:
 executor_specs {
  key: "CsvExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "Evaluator"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.evaluator.executor.Executor"
      }
    }
  }
}
......省略
_serving_model_dir
'./serving_model/taxi_simple'

inspect ServingModel

!saved_model_cli show --dir {_serving_model_dir}/$(ls -1 {_serving_model_dir}|sort -nr|head -1) --tag_set serve --signature_def serving_default
2023-06-27 11:05:35.695269: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/TensorRT/lib:/usr/local/cuda-11.7/lib64
2023-06-27 11:05:35.695379: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/TensorRT/lib:/usr/local/cuda-11.7/lib64
2023-06-27 11:05:35.695409: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
The given SavedModel SignatureDef contains the following input(s):
  inputs['examples'] tensor_info:
      dtype: DT_STRING
      shape: (-1)
      name: serving_default_examples:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['outputs'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1)
      name: StatefulPartitionedCall_27:0
Method name is: tensorflow/serving/predict

Test ServingModel

model_dirs = [item for item in os.scandir(_serving_model_dir) if item.is_dir()]
model_path = max(model_dirs,key=lambda i:int(i.name)).path
model_path
'./serving_model/taxi_simple/1687835132'
loaded_model = tf.keras.models.load_model(model_path) #输入是经过transform后的字典数据
inference_fn = loaded_model.signatures['serving_default'] #输入是原始examples数据(可不包含label)
transform_fn=loaded_model.signatures['transform_features'] #输入是原始examples数据(一定包含label)
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f026a194c10> and <keras.engine.input_layer.InputLayer object at 0x7f02dd6632e0>).
from typing import List,Union,Optional
from tensorflow_metadata.proto.v0 import schema_pb2
import pandas as pd
import tensorflow_data_validation as tfdv
import numpy as np
def create_example_by_schema_from_dataframe(row:pd.Series,column_names:List[str],schema_or_schemapath:Union[str,schema_pb2.Schema]):
    """
    根据数据原来的Schema信息将输入的一行数据转换为序列化后的example
    input:
        row:类型为pd.Series的一行数据
        column_names:类型为列表,包含需要转换的列名
        schema_or_schemapath:数据的Schema实例或者Schema的路径(需要具体到schema.pbtxt)
    output:
        serialized_example:序列化后的example数据
    """
    features = {}
    if isinstance(schema_or_schemapath,str):
        schema_or_schemapath=tfdv.load_schema_text(schema_or_schemapath)
    for columnName in column_names:
        typeCode = tfdv.get_feature(schema_or_schemapath,columnName).type
        tempvalue = None
        if typeCode == 1: #string
            if pd.isna(row[columnName]):
                tempvalue = b''
            else:
                tempvalue = row[columnName].encode()
            features[columnName] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[tempvalue]))
        elif typeCode == 2: #int
            if pd.isna(row[columnName]):
                tempvalue = 0
            else:
                tempvalue = int(row[columnName])
            features[columnName] = tf.train.Feature(int64_list=tf.train.Int64List(value=[tempvalue]))
        elif typeCode == 3: #float
            if pd.isna(row[columnName]):
                tempvalue = 0.0
            else:
                tempvalue = float(row[columnName])
            features[columnName] = tf.train.Feature(float_list=tf.train.FloatList(value=[tempvalue]))
    example_proto = tf.train.Example(features=tf.train.Features(feature=features))
    return example_proto.SerializeToString()

def convert_from_csv_to_tfrecord(csv_path:str,tfrecord_path:str,
                                 schema_or_schemapath:Optional[Union[str,schema_pb2.Schema]]=None,
                                 infer_schema_bool:bool=False):
    """
    将csv数据文件转换为tfrecord文件,csv文件需要有header。
    input:
        csv_path:csv文件path
        schema_or_schemapath:schema实例或者路径
        infer_schema_bool:是否在转换前立即推断schema,如果为True,则schema_or_schemapath忽略。
    output:
        如果成功,'OK'
    """
    if infer_schema_bool ==False and schema_or_schemapath == None:
        raise Exception('infer_schema_bool or schema_or_schemapath is error!!!')
    data = pd.read_csv('tfx-data/data.csv')
    if infer_schema_bool: #推断Schema
        statistics = tfdv.generate_statistics_from_csv(data)
        schema_or_schemapath=tfdv.infer_schema(statistics)
    ColumnNames = data.columns
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for i in range(len(data)):
            row = data.iloc[i]
            example=create_example_by_schema_from_dataframe(row,ColumnNames,schema_or_schemapath)
            writer.write(example)
    return 'OK'

import tensorflow_data_validation as tfdv
reload_schema=tfdv.load_schema_text(schema_gen.outputs['schema'].get()[0].uri+'/schema.pbtxt')
data = pd.read_csv('tfx-data/data.csv')
ColumnNames = taxi_constants.NUMERICAL_FEATURES + taxi_constants.BUCKET_FEATURES + \
    taxi_constants.CATEGORICAL_NUMERICAL_FEATURES + taxi_constants.CATEGORICAL_STRING_FEATURES + \
    [taxi_constants.LABEL_KEY]
example=create_example_by_schema_from_dataframe(data.iloc[207],ColumnNames,reload_schema)
inference_fn(tf.constant([example]))['outputs'].numpy() #输入可以不用包含label,因为这是推断
array([[11.537613]], dtype=float32)
transform_fn(tf.constant([example])) #输入一定要包含label,因为这是预处理(用于训练)
{'dropoff_census_tract_xf': <tf.Tensor: shape=(1, 216), dtype=float32, numpy=
 array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>,
 'company_xf': <tf.Tensor: shape=(1, 55), dtype=float32, numpy=
 array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>,
 'dropoff_longitude_xf': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([3.], dtype=float32)>,
 'dropoff_latitude_xf': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([8.], dtype=float32)>,
 'trip_miles_xf': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.1588674], dtype=float32)>,
 'trip_start_day_xf': <tf.Tensor: shape=(1, 17), dtype=float32, numpy=
 array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]], dtype=float32)>,
 'payment_type_xf': <tf.Tensor: shape=(1, 16), dtype=float32, numpy=
 array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       dtype=float32)>,
 'pickup_longitude_xf': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([3.], dtype=float32)>,
 'tips': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>,
 'trip_seconds_xf': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.54713595], dtype=float32)>,
 'pickup_census_tract_xf': <tf.Tensor: shape=(1, 11), dtype=float32, numpy=array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>,
 'pickup_latitude_xf': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([8.], dtype=float32)>,
 'trip_start_hour_xf': <tf.Tensor: shape=(1, 34), dtype=float32, numpy=
 array([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]], dtype=float32)>,
 'dropoff_community_area_xf': <tf.Tensor: shape=(1, 79), dtype=float32, numpy=
 array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       dtype=float32)>,
 'pickup_community_area_xf': <tf.Tensor: shape=(1, 66), dtype=float32, numpy=
 array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]], dtype=float32)>,
 'trip_start_month_xf': <tf.Tensor: shape=(1, 22), dtype=float32, numpy=
 array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0.]], dtype=float32)>,
 'fare_xf': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.60789293], dtype=float32)>}
convert_from_csv_to_tfrecord('./tfx-data/data.csv','./tt.tfrecord',schema_gen.outputs['schema'].get()[0].uri+'/schema.pbtxt')
'OK'

Test Train

dataset =tf.data.TFRecordDataset('./tt.tfrecord')
dataset = dataset.batch(32).map(transform_fn)
2023-06-27 11:07:38.139610: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'args_0' with dtype string and shape [?]
	 [[{{node args_0}}]]
def make_input_and_label(Batch): #制作字典形式的输入和标签时,输入时字典,但是标签是张量
    label = Batch.pop('tips')
    return Batch,label
dataset = dataset.map(make_input_and_label)
for i in dataset.take(1):
    print(i)
2023-06-27 11:07:38.314296: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_15' with dtype float and shape [1,9]
	 [[{{node Placeholder/_15}}]]


({'dropoff_census_tract_xf': <tf.Tensor: shape=(32, 216), dtype=float32, numpy=
array([[1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, 'company_xf': <tf.Tensor: shape=(32, 55), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, 'dropoff_longitude_xf': <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9.,
       9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 1., 1., 1., 1.],
      dtype=float32)>, 'dropoff_latitude_xf': <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 7., 7., 7., 7.],
      dtype=float32)>, 'trip_miles_xf': <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([-0.1588674 , -0.1588674 ,  0.53216076, -0.1588674 , -0.1588674 ,
        0.21955279,  0.68572253,  0.6418478 ,  0.15977335, -0.1588674 ,
       -0.1588674 , -0.0491804 ,  0.13180317, -0.00475717, -0.13693   ,
        0.18664669, -0.0694725 , -0.04369606, -0.1588674 ,  0.08847681,
       -0.1588674 , -0.00969308, -0.0897646 , -0.07002094,  0.02101928,
        1.0970489 , -0.1588674 ,  0.4169894 ,  0.03308485, -0.1462534 ,
       -0.09853955, -0.06563345], dtype=float32)>, 'trip_start_day_xf': <tf.Tensor: shape=(32, 17), dtype=float32, numpy=
array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.]], dtype=float32)>, 'payment_type_xf': <tf.Tensor: shape=(32, 16), dtype=float32, numpy=
array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
      dtype=float32)>, 'pickup_longitude_xf': <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([9., 9., 3., 0., 0., 0., 0., 0., 0., 8., 0., 0., 0., 8., 0., 0., 4.,
       6., 1., 2., 8., 0., 4., 3., 8., 9., 9., 9., 2., 1., 1., 1.],
      dtype=float32)>, 'trip_seconds_xf': <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([-7.1184874e-01, -4.3732741e-01,  5.5094934e-01, -5.4713595e-01,
        2.7642804e-01,  1.9067146e-03,  3.8623658e-01,  1.6490347e+00,
        2.7642804e-01, -7.1184874e-01,  2.4725988e+00, -3.2751888e-01,
        2.2152378e-01,  5.6810979e-02, -6.0204017e-01,  1.2098006e+00,
        5.6810979e-02, -4.9223167e-01, -7.1184874e-01, -1.6280608e-01,
       -7.1184874e-01, -2.1771035e-01, -4.3732741e-01, -2.1771035e-01,
        5.6810979e-02,  2.1980774e+00,  1.7039390e+00,  3.8623658e-01,
        1.1171524e-01, -6.0204017e-01, -3.8242313e-01, -2.7261460e-01],
      dtype=float32)>, 'pickup_census_tract_xf': <tf.Tensor: shape=(32, 11), dtype=float32, numpy=
array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>, 'pickup_latitude_xf': <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([0., 0., 0., 9., 9., 9., 8., 0., 9., 0., 7., 3., 9., 0., 3., 7., 7.,
       0., 9., 8., 0., 9., 7., 6., 0., 0., 0., 0., 8., 6., 7., 7.],
      dtype=float32)>, 'trip_start_hour_xf': <tf.Tensor: shape=(32, 34), dtype=float32, numpy=
array([[1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, 'dropoff_community_area_xf': <tf.Tensor: shape=(32, 79), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, 'pickup_community_area_xf': <tf.Tensor: shape=(32, 66), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, 'trip_start_month_xf': <tf.Tensor: shape=(32, 22), dtype=float32, numpy=
array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.]], dtype=float32)>, 'fare_xf': <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([ 0.0610606 , -0.9546066 ,  1.2521242 , -0.4773654 ,  0.40369532,
        0.38737947,  1.6600226 ,  2.1821327 ,  0.24053593, -0.6894726 ,
        2.932666  , -0.3957857 ,  0.22422   , -0.15104656, -0.70578855,
        0.74633014, -0.29789004, -0.3957857 , -0.9546066 ,  0.04474468,
       -0.6894726 , -0.23262626, -0.4773654 , -0.36315382, -0.08578287,
        2.867402  ,  1.855814  ,  0.9910692 ,  0.04474468, -0.6405248 ,
       -0.46104944, -0.36315382], dtype=float32)>}, <tf.Tensor: shape=(32,), dtype=int64, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 1])>)
#计算有几个实例(在batch后,所以这里计算几个batch)
dataset.reduce(0, lambda x, _: x + 1).numpy()
2023-06-27 11:07:38.782656: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_9' with dtype float
	 [[{{node Placeholder/_9}}]]





469
eval_dataset = dataset.take(100)
train_dataset = dataset.skip(100)
loaded_model.compile(
      loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.BinaryAccuracy()])
loaded_model.fit(train_dataset,epochs=10,validation_data=eval_dataset)
2023-06-27 11:07:39.767286: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_31' with dtype int64
	 [[{{node Placeholder/_31}}]]


Epoch 1/10
    355/Unknown - 2s 3ms/step - loss: 0.1946 - binary_accuracy: 0.9396

2023-06-27 11:07:42.644332: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_47' with dtype int64
	 [[{{node Placeholder/_47}}]]


369/369 [==============================] - 3s 6ms/step - loss: 0.1922 - binary_accuracy: 0.9401 - val_loss: 0.1629 - val_binary_accuracy: 0.9459
Epoch 2/10
369/369 [==============================] - 2s 4ms/step - loss: 0.1210 - binary_accuracy: 0.9553 - val_loss: 0.1560 - val_binary_accuracy: 0.9444
Epoch 3/10
369/369 [==============================] - 2s 4ms/step - loss: 0.1022 - binary_accuracy: 0.9611 - val_loss: 0.1599 - val_binary_accuracy: 0.9403
Epoch 4/10
369/369 [==============================] - 2s 4ms/step - loss: 0.0879 - binary_accuracy: 0.9659 - val_loss: 0.1878 - val_binary_accuracy: 0.9303
Epoch 5/10
369/369 [==============================] - 2s 5ms/step - loss: 0.0770 - binary_accuracy: 0.9712 - val_loss: 0.2153 - val_binary_accuracy: 0.9241
Epoch 6/10
369/369 [==============================] - 2s 5ms/step - loss: 0.0761 - binary_accuracy: 0.9709 - val_loss: 0.2313 - val_binary_accuracy: 0.9141
Epoch 7/10
369/369 [==============================] - 2s 4ms/step - loss: 0.0772 - binary_accuracy: 0.9692 - val_loss: 0.2275 - val_binary_accuracy: 0.9287
Epoch 8/10
369/369 [==============================] - 2s 4ms/step - loss: 0.0667 - binary_accuracy: 0.9742 - val_loss: 0.2536 - val_binary_accuracy: 0.9228
Epoch 9/10
369/369 [==============================] - 2s 4ms/step - loss: 0.0495 - binary_accuracy: 0.9818 - val_loss: 0.2533 - val_binary_accuracy: 0.9278
Epoch 10/10
369/369 [==============================] - 2s 4ms/step - loss: 0.0428 - binary_accuracy: 0.9848 - val_loss: 0.3007 - val_binary_accuracy: 0.9197





<keras.callbacks.History at 0x7f02980733a0>

Test TFMA

# from google.protobuf import text_format
# import tensorflow_model_analysis as tfma
# import os
# keras_eval_config = text_format.Parse("""
#     model_specs {
#         signature_name: "serving_default"
#         label_key: "tips"
#         preprocessing_function_names: "transform_features"
#     }
#     metrics_specs {
#         metrics {
#             class_name: "ExampleCount"
#         }
#         metrics {
#             class_name: "Calibration"
#         }
#         metrics {
#             class_name: "CalibrationPlot"
#         }
#         metrics {
#             class_name: "ConfusionMatrixPlot"
#         }
#         metrics {
#             class_name: "FairnessIndicators"
#             config: '{"thresholds":[0.1, 0.3, 0.5, 0.7, 0.9]}'
#         }
#         metrics {
#             class_name: "AUC"
#             threshold {
#                 value_threshold {
#                     lower_bound {
#                         value: 0.5
#                     }
#                 }
#             }
#         }
#     }
#     slicing_specs {}
#     slicing_specs {
#         feature_keys: ["trip_start_hour"]
#     }
#     slicing_specs {
#         feature_keys: ["payment_type"]
#     }
#     options {
#         compute_confidence_intervals { value: False }
#         disabled_outputs { values: "analysis" }
#     }
# """,tfma.EvalConfig())

# _serving_model_dir = os.path.join(
#     '.', 'serving_model/taxi_simple')
# model_dirs = [item for item in os.scandir(_serving_model_dir) if item.is_dir()]
# model_path = max(model_dirs,key=lambda i:int(i.name)).path

# keras_eval_shared_model = tfma.default_eval_shared_model(
#     eval_saved_model_path=model_path,
#     eval_config=keras_eval_config
# )

# keras_output_path = os.path.join('./evalresult','keras')
# keras_eval_result = tfma.run_model_analysis(
#     eval_shared_model=keras_eval_shared_model,
#     eval_config=keras_eval_config,
#     data_location='./tt.tfrecord',
#     output_path=keras_output_path
# )




#使用类
from google.protobuf.wrappers_pb2 import BoolValue

keras_eval_config = tfma.EvalConfig(
    model_specs=[
        # This assumes a serving model with signature 'serving_default'. If
        # using estimator based EvalSavedModel, add signature_name: 'eval' and 
        # remove the label_key.
        tfma.ModelSpec(label_key="tips",signature_name='serving_default',
                       preprocessing_function_names=['transform_features'])
    ],
    metrics_specs=[
        tfma.MetricsSpec(
            # The metrics added here are in addition to those saved with the
            # model (assuming either a keras model or EvalSavedModel is used).
            # Any metrics added into the saved model (for example using
            # model.compile(..., metrics=[...]), etc) will be computed
            # automatically.
            # To add validation thresholds for metrics saved with the model,
            # add them keyed by metric name to the thresholds map.
            metrics=[
                tfma.MetricConfig(class_name='ExampleCount'),
                tfma.MetricConfig(class_name='Calibration'),
                tfma.MetricConfig(class_name='CalibrationPlot'),
                tfma.MetricConfig(class_name='ConfusionMatrixPlot'),
                tfma.MetricConfig(class_name='FairnessIndicators',
                                  config='{ "thresholds": [0.1, 0.3, 0.5, 0.7, 0.9] }'),
                tfma.MetricConfig(class_name='AUC',
                                 threshold=tfma.MetricThreshold(
                                 value_threshold=tfma.GenericValueThreshold(
                                     lower_bound={'value':0.5})))
            ]
        )
    ]
    +tfma.metrics.default_binary_classification_specs() #默认的指标
    ,
    slicing_specs=[
        # An empty slice spec means the overall slice, i.e. the whole dataset.
        tfma.SlicingSpec(),
        tfma.SlicingSpec(feature_keys=['trip_start_hour']),
        tfma.SlicingSpec(feature_keys=['payment_type'])
    ],
    options = tfma.Options(compute_confidence_intervals=BoolValue(value=False))
)
#与上面一样
_serving_model_dir = os.path.join(
    '.', 'serving_model/taxi_simple')
model_dirs = [item for item in os.scandir(_serving_model_dir) if item.is_dir()]
model_path = max(model_dirs,key=lambda i:int(i.name)).path

keras_eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=model_path,
    eval_config=keras_eval_config
)

keras_output_path = os.path.join('./evalresult','keras')
keras_eval_result = tfma.run_model_analysis(
    eval_shared_model=keras_eval_shared_model,
    eval_config=keras_eval_config,
    data_location='./tt.tfrecord',
    output_path=keras_output_path
)
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f026f1bfb20> and <keras.engine.input_layer.InputLayer object at 0x7f02000a0f70>).


WARNING:absl:Tensorflow version (2.12.0) found. Note that TFMA support for TF 2.0 is currently in beta
INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "serving_default"
  label_key: "tips"
  preprocessing_function_names: "transform_features"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
slicing_specs {
  feature_keys: "payment_type"
}
...省略 
keras_eval_config
model_specs {
  signature_name: "serving_default"
  label_key: "tips"
  preprocessing_function_names: "transform_features"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
slicing_specs {
  feature_keys: "payment_type"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  metrics {
    class_name: "Calibration"
  }
  metrics {
    class_name: "CalibrationPlot"
  }
  metrics {
    class_name: "ConfusionMatrixPlot"
  }
  metrics {
    class_name: "FairnessIndicators"
    config: "{ \"thresholds\": [0.1, 0.3, 0.5, 0.7, 0.9] }"
  }
  metrics {
    class_name: "AUC"
    threshold {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
    config: "{\"name\": \"example_count\"}"
  }
  example_weights {
    unweighted: true
  }
}
metrics_specs {
  metrics {
    class_name: "WeightedExampleCount"
    config: "{\"name\": \"weighted_example_count\"}"
  }
  example_weights {
    weighted: true
  }
}
metrics_specs {
  metrics {
    class_name: "BinaryAccuracy"
    config: "{\"name\": \"binary_accuracy\"}"
  }
  metrics {
    class_name: "AUC"
    config: "{\"curve\": \"ROC\", \"name\": \"auc\", \"num_thresholds\": 10000, \"summation_method\": \"interpolation\"}"
  }
  metrics {
    class_name: "AUC"
    config: "{\"curve\": \"PR\", \"name\": \"auc_precison_recall\", \"num_thresholds\": 10000, \"summation_method\": \"interpolation\"}"
  }
  metrics {
    class_name: "Precision"
    config: "{\"name\": \"precision\"}"
  }
  metrics {
    class_name: "Recall"
    config: "{\"name\": \"recall\"}"
  }
  metrics {
    class_name: "MeanLabel"
    config: "{\"name\": \"mean_label\"}"
  }
  metrics {
    class_name: "MeanPrediction"
    config: "{\"name\": \"mean_prediction\"}"
  }
  metrics {
    class_name: "Calibration"
    config: "{\"name\": \"calibration\"}"
  }
  metrics {
    class_name: "ConfusionMatrixPlot"
    config: "{\"name\": \"confusion_matrix_plot\", \"num_thresholds\": 1000}"
  }
  metrics {
    class_name: "CalibrationPlot"
    config: "{\"left\": null, \"name\": \"calibration_plot\", \"num_buckets\": 1000, \"right\": null}"
  }
  metrics {
    class_name: "BinaryCrossentropy"
    config: "{\"dtype\": \"float32\", \"from_logits\": false, \"label_smoothing\": 0, \"name\": \"loss\"}"
  }
}
options {
  compute_confidence_intervals {
  }
}
keras_eval_result
EvalResult(slicing_metrics=[((), {'': {'': {'binary_accuracy': {'doubleValue': 0.9538061591787762}, 'loss': {'doubleValue': 0.6815148591995239}, 'example_count': {'doubleValue': 15002.0}, 'weighted_example_count': {'doubleValue': 15002.0}, 'calibration': {'doubleValue': -141.13269912566062}, 'fairness_indicators_metrics/false_positive_rate@0.1': {'doubleValue': 0.03674698795180723}, 'fairness_indicators_metrics/false_negative_rate@0.1': {'doubleValue': 0.07894736842105263}, 'fairness_indicators_metrics/true_positive_rate@0.1': {'doubleValue': 0.9210526315789473}, 'fairness_indicators_metrics/true_negative_rate@0.1': {'doubleValue': 0.9632530120481928}, 'fairness_indicators_metrics/positive_rate@0.1': {'doubleValue': 0.23610185308625517}, 'fairness_indicators_metrics/negative_rate@0.1': {'doubleValue': 0.7638981469137448}, 'fairness_indicators_metrics/false_discovery_rate@0.1': {'doubleValue': 0.12055335968379446}, 'fairness_indicators_metrics/false_omission_rate@0.1': {'doubleValue': 0.02329842931937173}, 'fairness_indicators_metrics/precision@0.1': {'doubleValue': 0.8794466403162056}, 'fairness_indicators_metrics/recall@0.1': {'doubleValue': 0.9210526315789473}, 'fairness_indicators_metrics/false_positive_rate@0.3': {'doubleValue': 0.03571428571428571}, 'fairness_indicators_metrics/false_negative_rate@0.3': {'doubleValue': 0.08042578356002365}, 'fairness_indicators_metrics/true_positive_rate@0.3': {'doubleValue': 0.9195742164399764}, 'fairness_indicators_metrics/true_negative_rate@0.3': {'doubleValue': 0.9642857142857143}, 'fairness_indicators_metrics/positive_rate@0.3': {'doubleValue': 0.23496867084388748}, 'fairness_indicators_metrics/negative_rate@0.3': ......省略
#获取metrics的所有的pandas
dfs=tfma.experimental.dataframe.metrics_as_dataframes(tfma.load_metrics(keras_output_path))
dfs.double_value
slicesmetric_keysmetric_values
Overalltrip_start_hourpayment_typenamemodel_nameoutput_nameis_diffexample_weighteddouble_value
0NaNNaNbinary_accuracyFalseNaN0.953806
1NaNNaNlossFalseNaN0.358257
2NaNNaNlossFalseFalse0.681515
3NaNNaNexample_countFalseFalse15002.000000
4NaNNaNweighted_example_countFalseTrue15002.000000
..............................
2011NaNNaNb'Pcard'auc_precison_recallFalseFalse0.000000
2012NaNNaNb'Pcard'precisionFalseFalseNaN
2013NaNNaNb'Pcard'recallFalseFalseNaN
2014NaNNaNb'Pcard'mean_labelFalseFalse0.000000
2015NaNNaNb'Pcard'mean_predictionFalseFalse-23.387853

2016 rows × 9 columns

#切片为行,评价为列
tfma.experimental.dataframe.auto_pivot(dfs.double_value)
(metric_keys, name)binary_accuracyfairness_indicators_metrics/precision@0.5fairness_indicators_metrics/recall@0.5fairness_indicators_metrics/false_positive_rate@0.7fairness_indicators_metrics/false_negative_rate@0.7fairness_indicators_metrics/true_positive_rate@0.7fairness_indicators_metrics/true_negative_rate@0.7fairness_indicators_metrics/positive_rate@0.7fairness_indicators_metrics/negative_rate@0.7fairness_indicators_metrics/false_discovery_rate@0.7...fairness_indicators_metrics/positive_rate@0.3fairness_indicators_metrics/negative_rate@0.3fairness_indicators_metrics/false_discovery_rate@0.3fairness_indicators_metrics/false_omission_rate@0.3fairness_indicators_metrics/precision@0.3fairness_indicators_metrics/false_positive_rate@0.3fairness_indicators_metrics/false_positive_rate@0.5fairness_indicators_metrics/false_negative_rate@0.5fairness_indicators_metrics/true_positive_rate@0.5fairness_indicators_metrics/true_negative_rate@0.5
(metric_keys, example_weighted)NaNFalseFalseFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
slices
Overall:0.9538060.8827210.9169130.0350260.0851570.9148430.9649740.2333690.7666310.116252...0.2349690.7650310.1177300.0237000.8822700.0357140.0354560.0830870.9169130.964544
payment_type:b'Cash'0.9996971.0000000.2500000.0000000.7500000.2500001.0000000.0001010.9998990.000000...0.0001010.9998990.0000000.0003031.0000000.0000000.0000000.7500000.2500001.000000
payment_type:b'Credit Card'0.8625130.8833760.9184700.2528090.0836050.9163950.7471910.7027140.2972860.115847...0.7073370.2926630.1170790.1826920.8829210.2571790.2553060.0815300.9184700.744694
payment_type:b'Dispute'1.000000NaNNaN0.000000NaNNaN1.0000000.0000001.000000NaN...0.0000001.000000NaN0.000000NaN0.0000000.000000NaNNaN1.000000
payment_type:b'No Charge'0.9876540.000000NaN0.000000NaNNaN1.0000000.0000001.000000NaN...0.0123460.9876541.0000000.0000000.0000000.0123460.012346NaNNaN0.987654
payment_type:b'Pcard'1.000000NaNNaN0.000000NaNNaN1.0000000.0000001.000000NaN...0.0000001.000000NaN0.000000NaN0.0000000.000000NaNNaN1.000000
payment_type:b'Prcard'0.0000000.000000NaN1.000000NaNNaN0.0000001.0000000.0000001.000000...1.0000000.0000001.000000NaN0.0000001.0000001.000000NaNNaN0.000000
payment_type:b'Unknown'0.8620690.6666670.4000000.0416670.6000000.4000000.9583330.1034480.8965520.333333...0.1034480.8965520.3333330.1153850.6666670.0416670.0416670.6000000.4000000.958333
trip_start_hour:0.00.9684390.9333330.9264710.0193130.0735290.9264710.9806870.2242520.7757480.066667...0.2242520.7757480.0666670.0214130.9333330.0193130.0193130.0735290.9264710.980687
trip_start_hour:1.00.9466910.8832120.9029850.0390240.1044780.8955220.9609760.2500000.7500000.117647...0.2518380.7481620.1167880.0319410.8832120.0390240.0390240.0970150.9029850.960976
trip_start_hour:10.00.9708880.9289940.9515150.0229890.0484850.9515150.9770110.2459970.7540030.071006...0.2474530.7525470.0764710.0154740.9235290.0249040.0229890.0484850.9515150.977011
trip_start_hour:11.00.9515350.8692810.9300700.0420170.0699300.9300700.9579830.2471730.7528270.130719...0.2504040.7495960.1290320.0172410.8709680.0420170.0420170.0699300.9300700.957983
trip_start_hour:12.00.9583890.8959540.9226190.0311960.0773810.9226190.9688040.2322150.7677850.104046...0.2322150.7677850.1040460.0227270.8959540.0311960.0311960.0773810.9226190.968804
trip_start_hour:13.00.9484240.8750000.8974360.0369000.1025640.8974360.9631000.2292260.7707740.125000...0.2292260.7707740.1250000.0297400.8750000.0369000.0369000.1025640.8974360.963100
trip_start_hour:14.00.9429350.8324320.9333330.0507880.0666670.9333330.9492120.2486410.7513590.158470...0.2513590.7486410.1675680.0199640.8324320.0542910.0542910.0666670.9333330.945709
trip_start_hour:15.00.9570640.8851350.9034480.0294630.0965520.9034480.9705370.2049860.7950140.114865...0.2063710.7936290.1208050.0244330.8791950.0311960.0294630.0965520.9034480.970537
trip_start_hour:16.00.9370900.8601040.8877010.0468750.1176470.8823530.9531250.2516380.7483620.140625...0.2529490.7470510.1398960.0368420.8601040.0468750.0468750.1122990.8877010.953125
trip_start_hour:17.00.9481300.8813560.8764040.0322580.1235960.8764040.9677420.2135100.7864900.118644...0.2135100.7864900.1186440.0337420.8813560.0322580.0322580.1235960.8764040.967742
trip_start_hour:18.00.9553670.9025420.9181030.0310300.0818970.9181030.9689700.2497340.7502660.093617...0.2507970.7492030.0974580.0269500.9025420.0324400.0324400.0818970.9181030.967560
trip_start_hour:19.00.9603960.8953490.9467210.0352480.0532790.9467210.9647520.2554460.7445540.104651...0.2564360.7435640.1042470.0159790.8957530.0352480.0352480.0532790.9467210.964752
trip_start_hour:2.00.9594270.8627450.9670330.0396340.0329670.9670330.9603660.2410500.7589500.128713...0.2434370.7565630.1372550.0094640.8627450.0426830.0426830.0329670.9670330.957317
trip_start_hour:20.00.9637680.8969960.9500000.0321720.0500000.9500000.9678280.2412010.7587990.103004...0.2422360.7577640.1025640.0136610.8974360.0321720.0321720.0500000.9500000.967828
trip_start_hour:21.00.9606210.8928570.9358290.0322580.0641710.9358290.9677420.2338900.7661100.107143...0.2350840.7649160.1065990.0171610.8934010.0322580.0322580.0641710.9358290.967742
trip_start_hour:22.00.9546000.9030610.9030610.0296410.0969390.9030610.9703590.2341700.7658300.096939...0.2365590.7634410.0959600.0266040.9040400.0296410.0296410.0969390.9030610.970359
trip_start_hour:23.00.9503450.8918920.8684210.0279230.1513160.8486840.9720770.2000000.8000000.110345...0.2041380.7958620.1081080.0346620.8918920.0279230.0279230.1315790.8684210.972077
trip_start_hour:3.00.9417810.8382350.9047620.0480350.0952380.9047620.9519650.2328770.7671230.161765...0.2397260.7602740.1714290.0225230.8285710.0524020.0480350.0952380.9047620.951965
trip_start_hour:4.00.9593910.8421050.9411760.0368100.0588240.9411760.9631900.1928930.8071070.157895...0.1928930.8071070.1578950.0125790.8421050.0368100.0368100.0588240.9411760.963190
trip_start_hour:5.00.9300700.8421050.8888890.0560750.1111110.8888890.9439250.2657340.7342660.157895...0.2657340.7342660.1578950.0380950.8421050.0560750.0560750.1111110.8888890.943925
trip_start_hour:6.00.9481870.8974360.8536590.0263160.1463410.8536590.9736840.2020730.7979270.102564...0.2020730.7979270.1025640.0389610.8974360.0263160.0263160.1463410.8536590.973684
trip_start_hour:7.00.9569540.8490570.9000000.0317460.1000000.9000000.9682540.1754970.8245030.150943...0.1754970.8245030.1509430.0200800.8490570.0317460.0317460.1000000.9000000.968254
trip_start_hour:8.00.9490570.8536590.9210530.0432690.0789470.9210530.9567310.2320750.7679250.146341...0.2320750.7679250.1463410.0221130.8536590.0432690.0432690.0789470.9210530.956731
trip_start_hour:9.00.9427710.8451610.9034480.0443160.1103450.8896550.9556840.2289160.7710840.151316...0.2349400.7650600.1538460.0255910.8461540.0462430.0462430.0965520.9034480.953757

32 rows × 63 columns

#fileter slices
df_double = dfs.double_value
df_filtered = df_double.loc[df_double.slices.trip_start_hour.isin([1,3,5,7])]
tfma.experimental.dataframe.auto_pivot(df_filtered)
(metric_keys, name)binary_accuracyfairness_indicators_metrics/precision@0.5fairness_indicators_metrics/recall@0.5fairness_indicators_metrics/false_positive_rate@0.7fairness_indicators_metrics/false_negative_rate@0.7fairness_indicators_metrics/true_positive_rate@0.7fairness_indicators_metrics/true_negative_rate@0.7fairness_indicators_metrics/positive_rate@0.7fairness_indicators_metrics/negative_rate@0.7fairness_indicators_metrics/false_discovery_rate@0.7...fairness_indicators_metrics/false_positive_rate@0.5fairness_indicators_metrics/precision@0.3fairness_indicators_metrics/recall@0.3fairness_indicators_metrics/false_discovery_rate@0.3fairness_indicators_metrics/negative_rate@0.3fairness_indicators_metrics/positive_rate@0.3fairness_indicators_metrics/true_negative_rate@0.3fairness_indicators_metrics/true_positive_rate@0.3fairness_indicators_metrics/false_negative_rate@0.3fairness_indicators_metrics/false_omission_rate@0.3
(metric_keys, example_weighted)NaNFalseFalseFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
slices
trip_start_hour:1.00.9466910.8832120.9029850.0390240.1044780.8955220.9609760.2500000.7500000.117647...0.0390240.8832120.9029850.1167880.7481620.2518380.9609760.9029850.0970150.031941
trip_start_hour:3.00.9417810.8382350.9047620.0480350.0952380.9047620.9519650.2328770.7671230.161765...0.0480350.8285710.9206350.1714290.7602740.2397260.9475980.9206350.0793650.022523
trip_start_hour:5.00.9300700.8421050.8888890.0560750.1111110.8888890.9439250.2657340.7342660.157895...0.0560750.8421050.8888890.1578950.7342660.2657340.9439250.8888890.1111110.038095
trip_start_hour:7.00.9569540.8490570.9000000.0317460.1000000.9000000.9682540.1754970.8245030.150943...0.0317460.8490570.9000000.1509430.8245030.1754970.9682540.9000000.1000000.020080

4 rows × 63 columns

#sort metric values
tfma.experimental.dataframe.auto_pivot(df_filtered).sort_values(by=('auc',False),ascending=True)
(metric_keys, name)binary_accuracyfairness_indicators_metrics/precision@0.5fairness_indicators_metrics/recall@0.5fairness_indicators_metrics/false_positive_rate@0.7fairness_indicators_metrics/false_negative_rate@0.7fairness_indicators_metrics/true_positive_rate@0.7fairness_indicators_metrics/true_negative_rate@0.7fairness_indicators_metrics/positive_rate@0.7fairness_indicators_metrics/negative_rate@0.7fairness_indicators_metrics/false_discovery_rate@0.7...fairness_indicators_metrics/false_positive_rate@0.5fairness_indicators_metrics/precision@0.3fairness_indicators_metrics/recall@0.3fairness_indicators_metrics/false_discovery_rate@0.3fairness_indicators_metrics/negative_rate@0.3fairness_indicators_metrics/positive_rate@0.3fairness_indicators_metrics/true_negative_rate@0.3fairness_indicators_metrics/true_positive_rate@0.3fairness_indicators_metrics/false_negative_rate@0.3fairness_indicators_metrics/false_omission_rate@0.3
(metric_keys, example_weighted)NaNFalseFalseFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
slices
trip_start_hour:5.00.9300700.8421050.8888890.0560750.1111110.8888890.9439250.2657340.7342660.157895...0.0560750.8421050.8888890.1578950.7342660.2657340.9439250.8888890.1111110.038095
trip_start_hour:1.00.9466910.8832120.9029850.0390240.1044780.8955220.9609760.2500000.7500000.117647...0.0390240.8832120.9029850.1167880.7481620.2518380.9609760.9029850.0970150.031941
trip_start_hour:7.00.9569540.8490570.9000000.0317460.1000000.9000000.9682540.1754970.8245030.150943...0.0317460.8490570.9000000.1509430.8245030.1754970.9682540.9000000.1000000.020080
trip_start_hour:3.00.9417810.8382350.9047620.0480350.0952380.9047620.9519650.2328770.7671230.161765...0.0480350.8285710.9206350.1714290.7602740.2397260.9475980.9206350.0793650.022523

4 rows × 63 columns

#上面的slicing_specs配置是需要在哪些切片上计算的配置
#如果要显示某切片上的metrics需要指定,否则这里显示overall
tfma.view.render_slicing_metrics(keras_eval_result,slicing_column='trip_start_hour')

在这里插入图片描述

SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'trip_start_hour:19',…
#Rendering Plots
tfma.view.render_plot(keras_eval_result)

在这里插入图片描述

PlotViewer(config={'sliceName': 'Overall', 'metricKeys': {'calibrationPlot': {'metricName': 'calibrationHistog…
tfma.addons.fairness.view.widget_view.render_fairness_indicator(keras_eval_result)

在这里插入图片描述

FairnessIndicatorViewer(slicingMetrics=[{'sliceValue': 'Overall', 'slice': 'Overall', 'metrics': {'binary_accu…
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: Tensorboard在TensorFlow 2.x中的使用示例代码非常简单,只需要几行代码就可以启动它。下面是一个使用Tensorboard的示例代码:from tensorflow.keras.callbacks import TensorBoardtensorboard_callback = TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True, write_grads=True, write_images=True)model.fit(x, y, epochs=100, callbacks=[tensorboard_callback]) ### 回答2: TensorBoard是一个可视化工具,可以用于分析和监控TensorFlow训练过程中的数据流图和训练性能。下面是一段使用TensorBoard的示例代码: ```python import tensorflow as tf # 创建一个计算图 a = tf.constant(2, name="input_a") b = tf.constant(3, name="input_b") c = tf.multiply(a, b, name="mul_c") d = tf.add(a, b, name="add_d") e = tf.add(c, d, name="add_e") # 创建一个写入TensorBoard日志的文件夹 log_dir = "./logs" # 创建一个TensorBoard的summary writer writer = tf.summary.create_file_writer(log_dir) # 开启一个TensorFlow会话 with tf.compat.v1.Session() as sess: # 初始化变量 sess.run(tf.compat.v1.global_variables_initializer()) # 将计算图写入TensorBoard writer.add_graph(sess.graph) # 运行并计算结果 result = sess.run(e) # 将结果写入TensorBoard with writer.as_default(): tf.summary.scalar("output", result, step=0) # 关闭summary writer writer.close() ``` 上述代码中,我们首先创建了一个计算图,包括一些简单的算术运算。然后,我们指定了一个保存TensorBoard日志文件的文件夹。接着,我们创建了一个TensorBoard的summary writer,用于将计算图和计算结果写入TensorBoard。在TensorFlow会话中,我们使用`writer.add_graph()`将计算图写入TensorBoard,使用`sess.run()`计算结果。最后,我们使用`tf.summary.scalar()`将结果写入TensorBoard。最后,我们需要在合适的地方关闭summary writer。 在运行上述代码后,打开终端并切换到当前目录,输入以下命令启动TensorBoard: ``` tensorboard --logdir=logs/ ``` 然后,在浏览器中输入`http://localhost:6006`即可访问TensorBoard的可视化界面,查看计算图和结果。 ### 回答3: TensorBoard是一个用于可视化和监控机器学习模型训练过程的工具,它可以用于观察模型的图形结构、参数分布、训练过程中的指标等。在TensorFlow 2.x版本中,使用TensorBoard也非常简单。 首先,确保已经安装好TensorFlow 2.x和TensorBoard。 以下是一个使用TensorBoard的示例代码,其中假设我们有一个简单的线性回归模型,并且已经定义好了模型的结构和训练过程: ```python import tensorflow as tf from datetime import datetime # 创建一个日志存储目录 log_dir = "logs/" # 定义一个简单的线性回归模型 model = tf.keras.Sequential([ tf.keras.layers.Dense(1, input_shape=(1,)) ]) # 编译模型 model.compile(optimizer='sgd', loss='mse') # 创建一个TensorBoard回调函数 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir) # 生成输入和标签数据 x_train = [1, 2, 3, 4, 5] y_train = [3, 5, 7, 9, 11] # 训练模型,并将TensorBoard回调传入fit函数中 model.fit(x_train, y_train, epochs=100, callbacks=[tensorboard_callback]) # 启动TensorBoard tensorboard_cmd = f"tensorboard --logdir {log_dir}" print(f"请在终端中运行以下命令启动TensorBoard:\n{tensorboard_cmd}") ``` 在上述示例代码中,我们首先创建一个存储TensorBoard日志的目录`log_dir`,然后定义一个简单的线性回归模型,并编译模型。接下来,我们创建了一个TensorBoard回调函数,并将其传入模型的`fit`函数中,这样在模型训练过程中会自动将相关的日志信息写入指定目录。然后我们生成了一组训练数据,并使用这些数据训练模型。最后,我们通过打印一个命令来提示用户在终端中启动TensorBoard。 要在终端中启动TensorBoard,只需按照提示运行相应的命令即可。TensorBoard会自动打开一个本地服务器,并在浏览器中显示训练过程的可视化结果。 通过使用TensorBoard,我们可以轻松地观察训练过程中的损失变化、权重分布、计算图结构等信息,对模型的训练过程进行优化和调试。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

起名大废废

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值