1.主代码
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
x_train = tf.cast(x_train , tf.float32)
x_test = tf.cast(x_test , tf.float32)
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)
test_db = tf.