训练集和测试集loss 和acc的变化情况
# (二分类)预测结果
ROC曲线
邮件数据样例
4万多的数据
训练过程
部分代码
def build_cnn_model(input_shape):
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3, activation='relu', input_shape=input_shape))
model.add(GlobalMaxPooling1D())
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
def train():
train_set, test_set, train_label, test_label = load_train_data()
train_data, test_data = preprocess_data(train_set, test_set)
max_sequence_length = train_data.shape[1] # Get the maximum sequence length
model = build_cnn_model((max_sequence_length, 100))
history = model.fit(train_data, train_label, epochs=30, batch_size=1024, validation_data=(test_data, test_label), verbose=1)
plot_metrics(history)
evaluate_model(model, test_data, test_label)
def plot_metrics(history):
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.show()