import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import keras
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
def get_data(dataset_name=None,channel=1):
if channel==1 and dataset_name=='mnist':
(x_train, y_train), (x_test, y_test) =tf.keras.datasets.mnist.load_data()
x_train=np.expand_dims(x_train,-1)
x_test=np.expand_dims(x_test,-1)
return (x_train,y_train),(x_test,y_test)
elif channel==1 and dataset_name=='fashion_mnist':
(x_train, y_train), (x_test, y_test) =tf.keras.datasets.fashion_mnist.load_data()
x_train=np.expand_dims(x_train,-1)
x_test=np.expand_dims(x_test,-1)
return (x_train,y_train),(x_test,y_test)
elif channel==3 and dataset_name=='cifar10':
(x_train, y_train), (x_test, y_test) =tf.keras.datasets.cifar10.load_data()
return (x_train, y_train), (x_test, y_test)
(x_train_mnist, y_train_mnist), (x_test_mnist, y_test_mnist) =get_data('mnist')
def show_imgs(x_train,y_train,col,row):
plt.figure(figsize=(col,row))
for i in range(col*row):
plt.subplot(row,col,i+1)
plt.xticks([])
plt.yticks([])
plt.xlabel(y_train[i])
plt.imshow(x_train[i])
plt.tight_layout()
plt.show()
batch_size=256
augment_images=keras