```
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 定义路径
train_dir = r'C:\Users\29930\Desktop\结构参数图'
# 数据增强配置
train_datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
# 生成训练集和验证集
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(224, 224),
batch_size=32,
class_mode='binary',
subset='training'
)
val_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(224, 224),
batch_size=32,
class_mode='binary',
subset='validation'
)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(224,224,3)),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
# 添加早停法
early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
)
# 训练模型
history = model.fit(
train_generator,
validation_data=val_generator,
epochs=30,
callbacks=[early_stop]
)
# 保存模型
model.save('copd_cnn_model.h5')
# 评估指标可视化
import matplotlib.pyplot as plt
plt.plot(history.history['auc'], label='Training AUC')
plt.plot(history.history['val_auc'], label='Validation AUC')
plt.title('模型AUC曲线')
plt.ylabel('AUC值')
plt.xlabel('Epoch')
plt.legend()
plt.show()```运行结果是Found 213 images belonging to 2 classes.
Found 52 images belonging to 2 classes.
Warning (from warnings module):
File "C:\Users\29930\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\layers\convolutional\base_conv.py", line 107
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
Warning (from warnings module):
File "C:\Users\29930\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py", line 121
self._warn_if_super_not_called()
UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
Traceback (most recent call last):
File "D:/建模/cnn.py", line 61, in <module>
history = model.fit(
File "C:\Users\29930\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Users\29930\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\utils\image_utils.py", line 227, in load_img
raise ImportError(
ImportError: Could not import PIL.Image. The use of `load_img` requires PIL.
请根据结果修改代码使其能正常运行
最新发布