plt.figure(figsize=(15,10))# 图形的宽为15高为10for images, labels in train_ds.take(1):for i inrange(8):
ax = plt.subplot(5,8, i +1)
plt.imshow(images[i])
plt.title(class_names[labels[i]])
plt.axis("off")
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
defVGG16(nb_classes, input_shape):
input_tensor = Input(shape=input_shape)# 1st block
x = Conv2D(64,(3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)
x = Conv2D(64,(3,3), activation='relu', padding='same',name='block1_conv2')(x)
x = MaxPooling2D((2,2), strides=(2,2), name ='block1_pool')(x)# 2nd block
x = Conv2D(128,(3,3), activation='relu', padding='same',name='block2_conv1')(x)
x = Conv2D(128,(3,3), activation='relu', padding='same',name='block2_conv2')(x)
x = MaxPooling2D((2,2), strides=(2,2), name ='block2_pool')(x)# 3rd block
x = Conv2D(256,(3,3), activation='relu', padding='same',name='block3_conv1')(x)
x = Conv2D(256,(3,3), activation='relu', padding='same',name='block3_conv2')(x)
x = Conv2D(256,(3,3), activation='relu', padding='same',name='block3_conv3')(x)
x = MaxPooling2D((2,2), strides=(2,2), name ='block3_pool')(x)# 4th block
x = Conv2D(512,(3,3), activation='relu', padding='same',name='block4_conv1')(x)
x = Conv2D(512,(3,3), activation='relu', padding='same',name='block4_conv2')(x)
x = Conv2D(512,(3,3), activation='relu', padding='same',name='block4_conv3')(x)
x = MaxPooling2D((2,2), strides=(2,2), name ='block4_pool')(x)# 5th block
x = Conv2D(512,(3,3), activation='relu', padding='same',name='block5_conv1')(x)
x = Conv2D(512,(3,3), activation='relu', padding='same',name='block5_conv2')(x)
x = Conv2D(512,(3,3), activation='relu', padding='same',name='block5_conv3')(x)
x = MaxPooling2D((2,2), strides=(2,2), name ='block5_pool')(x)# full connection
x = Flatten()(x)
x = Dense(4096, activation='relu', name='fc1')(x)
x = Dense(4096, activation='relu', name='fc2')(x)
output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)
model = Model(input_tensor, output_tensor)return model
model=VGG16(1000,(img_width, img_height,3))
model.summary()
model.compile(optimizer="adam",
loss ='sparse_categorical_crossentropy',
metrics =['accuracy'])
pip install tqdm
Collecting tqdm
Downloading tqdm-4.64.1-py2.py3-none-any.whl (78 kB)
-------------------------------------- 78.5/78.5 kB 396.7 kB/s eta 0:00:00
Requirement already satisfied: colorama in c:\users\administrator\appdata\local\programs\python\python38\lib\site-packages (from tqdm) (0.4.5)
Installing collected packages: tqdm
Successfully installed tqdm-4.64.1
Note: you may need to restart the kernel to use updated packages.
from tqdm import tqdm
import tensorflow.keras.backend as K
epochs =10
lr =1e-4# 记录训练数据,方便后面的分析
history_train_loss =[]
history_train_accuracy =[]
history_val_loss =[]
history_val_accuracy =[]for epoch inrange(epochs):
train_total =len(train_ds)
val_total =len(val_ds)"""
total:预期的迭代数目
ncols:控制进度条宽度
mininterval:进度更新最小间隔,以秒为单位(默认值:0.1)
"""with tqdm(total=train_total, desc=f'Epoch {epoch +1}/{epochs}',mininterval=1,ncols=100)as pbar:
lr = lr*0.92
K.set_value(model.optimizer.lr, lr)for image,label in train_ds:"""
训练模型,简单理解train_on_batch就是:它是比model.fit()更高级的一个用法
想详细了解 train_on_batch 的同学,
可以看看我的这篇文章:https://www.yuque.com/mingtian-fkmxf/hv4lcq/ztt4gy
"""
history = model.train_on_batch(image,label)
train_loss = history[0]
train_accuracy = history[1]
pbar.set_postfix({"loss":"%.4f"%train_loss,"accuracy":"%.4f"%train_accuracy,"lr": K.get_value(model.optimizer.lr)})
pbar.update(1)
history_train_loss.append(train_loss)
history_train_accuracy.append(train_accuracy)print('开始验证!')with tqdm(total=val_total, desc=f'Epoch {epoch +1}/{epochs}',mininterval=0.3,ncols=100)as pbar:for image,label in val_ds:
history = model.test_on_batch(image,label)
val_loss = history[0]
val_accuracy = history[1]
pbar.set_postfix({"loss":"%.4f"%val_loss,"accuracy":"%.4f"%val_accuracy})
pbar.update(1)
history_val_loss.append(val_loss)
history_val_accuracy.append(val_accuracy)print('结束验证!')print("验证loss为:%.4f"%val_loss)print("验证准确率为:%.4f"%val_accuracy)