网课视频:https://www.bilibili.com/video/BV1xq4y127JR/?p=4&vd_source=5be7027f4aa2ef056c31c89ffc9464f8
学习记录
数据集地址:https://repod.icm.edu.pl/dataset.xhtml?persistentId=doi:10.18150/repod.0107441
1. read and process data
!pip install mne
Defaulting to user installation because normal site-packages is not writeable
Collecting mne
Downloading mne-1.8.0-py3-none-any.whl.metadata (21 kB)
Requirement already satisfied: decorator in d:\anaconda\anaconda3\lib\site-packages (from mne) (5.1.1)
Requirement already satisfied: jinja2 in d:\anaconda\anaconda3\lib\site-packages (from mne) (3.1.4)
Requirement already satisfied: lazy-loader>=0.3 in d:\anaconda\anaconda3\lib\site-packages (from mne) (0.4)
Requirement already satisfied: matplotlib>=3.6 in d:\anaconda\anaconda3\lib\site-packages (from mne) (3.8.4)
Requirement already satisfied: numpy<3,>=1.23 in d:\anaconda\anaconda3\lib\site-packages (from mne) (1.26.4)
Requirement already satisfied: packaging in d:\anaconda\anaconda3\lib\site-packages (from mne) (23.2)
Collecting pooch>=1.5 (from mne)
Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: scipy>=1.9 in d:\anaconda\anaconda3\lib\site-packages (from mne) (1.13.1)
Requirement already satisfied: tqdm in d:\anaconda\anaconda3\lib\site-packages (from mne) (4.66.4)
Requirement already satisfied: contourpy>=1.0.1 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (1.2.0)
Requirement already satisfied: cycler>=0.10 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (1.4.4)
Requirement already satisfied: pillow>=8 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (10.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (3.0.9)
Requirement already satisfied: python-dateutil>=2.7 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (2.9.0.post0)
Requirement already satisfied: platformdirs>=2.5.0 in d:\anaconda\anaconda3\lib\site-packages (from pooch>=1.5->mne) (3.10.0)
Requirement already satisfied: requests>=2.19.0 in d:\anaconda\anaconda3\lib\site-packages (from pooch>=1.5->mne) (2.32.2)
Requirement already satisfied: MarkupSafe>=2.0 in d:\anaconda\anaconda3\lib\site-packages (from jinja2->mne) (2.1.3)
Requirement already satisfied: colorama in d:\anaconda\anaconda3\lib\site-packages (from tqdm->mne) (0.4.6)
Requirement already satisfied: six>=1.5 in d:\anaconda\anaconda3\lib\site-packages (from python-dateutil>=2.7->matplotlib>=3.6->mne) (1.16.0)
Requirement already satisfied: charset-normalizer<4,>=2 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (2024.6.2)
Downloading mne-1.8.0-py3-none-any.whl (7.4 MB)
---------------------------------------- 0.0/7.4 MB ? eta -:--:--
---------------------------------------- 0.0/7.4 MB 2.0 MB/s eta 0:00:04
--------------------------------------- 0.1/7.4 MB 1.2 MB/s eta 0:00:07
- -------------------------------------- 0.2/7.4 MB 1.6 MB/s eta 0:00:05
-- ------------------------------------- 0.4/7.4 MB 2.3 MB/s eta 0:00:04
-- ------------------------------------- 0.5/7.4 MB 2.6 MB/s eta 0:00:03
---- ----------------------------------- 0.8/7.4 MB 3.1 MB/s eta 0:00:03
----- ---------------------------------- 1.1/7.4 MB 3.8 MB/s eta 0:00:02
--------- ------------------------------ 1.7/7.4 MB 4.7 MB/s eta 0:00:02
------------ --------------------------- 2.3/7.4 MB 5.6 MB/s eta 0:00:01
-------------- ------------------------- 2.6/7.4 MB 5.8 MB/s eta 0:00:01
------------------ --------------------- 3.4/7.4 MB 6.7 MB/s eta 0:00:01
---------------------- ----------------- 4.1/7.4 MB 7.7 MB/s eta 0:00:01
------------------------- -------------- 4.7/7.4 MB 7.9 MB/s eta 0:00:01
----------------------------- ---------- 5.4/7.4 MB 8.7 MB/s eta 0:00:01
--------------------------------- ------ 6.1/7.4 MB 8.9 MB/s eta 0:00:01
------------------------------------- -- 6.8/7.4 MB 9.3 MB/s eta 0:00:01
--------------------------------------- 7.4/7.4 MB 9.6 MB/s eta 0:00:01
---------------------------------------- 7.4/7.4 MB 9.1 MB/s eta 0:00:00
Downloading pooch-1.8.2-py3-none-any.whl (64 kB)
---------------------------------------- 0.0/64.6 kB ? eta -:--:--
---------------------------------------- 64.6/64.6 kB 1.7 MB/s eta 0:00:00
Installing collected packages: pooch, mne
Successfully installed mne-1.8.0 pooch-1.8.2
WARNING: The script mne.exe is installed in 'C:\Users\picasso\AppData\Roaming\Python\Python312\Scripts' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
from glob import glob
import os
import mne
import numpy as np
import pandas
import matplotlib.pyplot as plt
all_file_path=glob("D:/lab/eegtest1/dataverse_files/*.edf")
print(len(all_file_path))
28
all_file_path[0]
'D:/lab/eegtest1/dataverse_files\\h01.edf'
healthy_file_path = [i for i in all_file_path if 'h' in i.split('\\')[1]]
patient_file_path = [i for i in all_file_path if 's' in i.split('\\')[1]]
print(len(healthy_file_path),len(patient_file_path))
14 14
def read_data(file_path):
data = mne.io.read_raw_edf(file_path,preload=True)
data.set_eeg_reference()
data.filter(l_freq=0.5,h_freq=45)
epochs=mne.make_fixed_length_epochs(data,duration=5,overlap=1)
array= epochs.get_data()
return array
sample_data = read_data(healthy_file_path[0])
Extracting EDF parameters from D:\lab\eegtest1\dataverse_files\h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249 = 0.000 ... 924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz
FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (6.604 s)
Not setting metadata
231 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 231 events and 1250 original time points ...
0 bad epochs dropped
[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s
sample_data.shape #number of epochs,channels,length nof signal
(231, 19, 1250)
%%capture
control_epochs_array=[read_data(i) for i in healthy_file_path]
patient_epochs_array=[read_data(i) for i in patient_file_path]
control_epochs_array[0].shape,control_epochs_array[1].shape
((231, 19, 1250), (227, 19, 1250))
control_epochs_labels = [len(i)*[0] for i in control_epochs_array]
patient_epochs_labels = [len(i)*[1] for i in patient_epochs_array]
len(control_epochs_labels),len(patient_epochs_labels)
(14, 14)
data_list = control_epochs_array +patient_epochs_array
label_list = control_epochs_labels+patient_epochs_labels
epochs_array = control_epochs_array +patient_epochs_array
epochs_labels = control_epochs_labels+patient_epochs_labels
groups = [[i]*len(j) for i,j in enumerate(data_list)]
group_list = [[i]*len(j) for i,j in enumerate(data_list)]
len(group_list)
28
data_array = np.vstack(data_list)
label_array = np.hstack(label_list)
group_array = np.hstack(group_list)
print(data_array.shape,label_array.shape,group_array.shape)
(7201, 19, 1250) (7201,) (7201,)
data_array = np.vstack(data_list)
label_array = np.hstack(label_list)
group_array = np.hstack(group_list)
print(data_array.shape,label_array.shape,group_array.shape)
2.ML
from scipy import stats
def mean(x):
return np.mean(x,axis = -1)
def std(x):
return np.std(x,axis = -1)
def ptp(x):
return np.ptp(x,axis = -1)
def var(x):
return np.var(x,axis = -1)
def minim(x):
return np.min(x,axis = -1)
def maxim(x):
return np.max(x,axis = -1)
def argminim(x):
return np.argmin(x,axis = -1)
def argmaxim(x):
return np.argmax(x,axis = -1)
def rms(x):
return np.sqrt(np.mean(x**2,axis=-1))
def abs_diff_signal(x):
return np.sum(np.abs(np.diff(x,axis=-1)),axis=-1)
def skewness(x):
return stats.skew(x,axis = -1)
def kurtosos(x):
return stats.kurtosis(x,axis = -1)
def concatenate_features(x):
return np.concatenate((mean(x),std(x),ptp(x),var(x),minim(x),maxim(x),argminim(x),argmaxim(x),rms(x),
abs_diff_signal(x),skewness(x),kurtosos(x)),axis=-1)
features=[]
for d in data_array:
features.append(concatenate_features(d))
features_array=np.array(features)
features_array.shape
(7201, 228)
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold,GridSearchCV
clf = LogisticRegression()
gkf = GroupKFold(5)
pipe = Pipeline([('scaler',StandardScaler()),('clf',clf)])
param_grid={'clf__C':[0.1,0.5,0.7,1,3,5,7]}
gscv=GridSearchCV(pipe,param_grid,cv=gkf,n_jobs=12)
gscv.fit(features_array,label_array,groups=group_array)
D:\Anaconda\Anaconda3\Lib\site-packages\sklearn\linear_model\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
GridSearchCV(cv=GroupKFold(n_splits=5),estimator=Pipeline(steps=[('scaler', StandardScaler()), ('clf', LogisticRegression())]), n_jobs=12, param_grid={'clf__C': [0.1, 0.5, 0.7, 1, 3, 5, 7]})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-1" type="checkbox" ><label for="sk-estimator-id-1" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted"> GridSearchCV<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.model_selection.GridSearchCV.html">?<span>Documentation for GridSearchCV</span></a><span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span></label><div class="sk-toggleable__content fitted"><pre>GridSearchCV(cv=GroupKFold(n_splits=5), estimator=Pipeline(steps=[('scaler', StandardScaler()), ('clf', LogisticRegression())]), n_jobs=12, param_grid={'clf__C': [0.1, 0.5, 0.7, 1, 3, 5, 7]})</pre></div> </div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-2" type="checkbox" ><label for="sk-estimator-id-2" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">estimator: Pipeline</label><div class="sk-toggleable__content fitted"><pre>Pipeline(steps=[('scaler', StandardScaler()), ('clf', LogisticRegression())])</pre></div> </div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-3" type="checkbox" ><label for="sk-estimator-id-3" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted"> StandardScaler<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.preprocessing.StandardScaler.html">?<span>Documentation for StandardScaler</span></a></label><div class="sk-toggleable__content fitted"><pre>StandardScaler()</pre></div> </div></div><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-4" type="checkbox" ><label for="sk-estimator-id-4" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted"> LogisticRegression<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.linear_model.LogisticRegression.html">?<span>Documentation for LogisticRegression</span></a></label><div class="sk-toggleable__content fitted"><pre>LogisticRegression()</pre></div> </div></div></div></div></div></div></div></div></div></div></div>
gscv.best_score_
0.6636027871434461
3.deep learning CNN
epochs_array = np.vstack(epochs_array)
epochs_labels = np.hstack(epochs_labels)
groups_array = np.hstack(groups)
epochs_array.shape,epochs_labels.shape,groups_array.shape
((136819, 1250), (7201,), (7201,))
groups_array = np.moveaxis(epochs_array ,1,2)
groups_array.shape
(7201, 1250, 19)
from tensorflow.keras.layers import Conv1D,BatchNormalization,LeakyReLU,MaxPool1D,\
GlobalAveragePooling1D,Dense,Dropout,AveragePooling1D
from tensorflow.keras.models import Sequential
from tensorflow.keras.backend import clear_session
def cnnmodel():
clear_session()
model=Sequential()
model.add(Conv1D(filters=5,kernel_size=3,strides=1,input_shape=(6250,19)))#1
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(MaxPool1D(pool_size=2,strides=2))#2
model.add(Conv1D(filters=5,kernel_size=3,strides=1))#3
model.add(LeakyReLU())
model.add(MaxPool1D(pool_size=2,strides=2))#4
model.add(Dropout(0.5))
model.add(Conv1D(filters=5,kernel_size=3,strides=1))#5
model.add(LeakyReLU())
model.add(AveragePooling1D(pool_size=2,strides=2))#6
model.add(Dropout(0.5))
model.add(Conv1D(filters=5,kernel_size=3,strides=1))#7
model.add(LeakyReLU())
model.add(AveragePooling1D(pool_size=2,strides=2))#8
model.add(Conv1D(filters=5,kernel_size=3,strides=1))#9
model.add(LeakyReLU())
model.add(GlobalAveragePooling1D())#10
model.add(Dense(1,activation='sigmoid'))#11
model.compile('adam',loss='binary_crossentropy',metrics=['accuracy'])
return model
model=cnnmodel()
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ conv1d (Conv1D) │ (None, 6248, 5) │ 290 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ batch_normalization │ (None, 6248, 5) │ 20 │ │ (BatchNormalization) │ │ │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ leaky_re_lu (LeakyReLU) │ (None, 6248, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling1d (MaxPooling1D) │ (None, 3124, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv1d_1 (Conv1D) │ (None, 3122, 5) │ 80 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ leaky_re_lu_1 (LeakyReLU) │ (None, 3122, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling1d_1 (MaxPooling1D) │ (None, 1561, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout (Dropout) │ (None, 1561, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv1d_2 (Conv1D) │ (None, 1559, 5) │ 80 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ leaky_re_lu_2 (LeakyReLU) │ (None, 1559, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ average_pooling1d (AveragePooling1D) │ (None, 779, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_1 (Dropout) │ (None, 779, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv1d_3 (Conv1D) │ (None, 777, 5) │ 80 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ leaky_re_lu_3 (LeakyReLU) │ (None, 777, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ average_pooling1d_1 │ (None, 388, 5) │ 0 │ │ (AveragePooling1D) │ │ │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv1d_4 (Conv1D) │ (None, 386, 5) │ 80 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ leaky_re_lu_4 (LeakyReLU) │ (None, 386, 5) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ global_average_pooling1d │ (None, 5) │ 0 │ │ (GlobalAveragePooling1D) │ │ │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense (Dense) │ (None, 1) │ 6 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 636 (2.48 KB)
Trainable params: 626 (2.45 KB)
Non-trainable params: 10 (40.00 B)
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler
gkf = GroupKFold()
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
accuracy = []
# Assuming gkf, epochs_array, epochs_labels, and groups_array are already defined
for train_index, val_index in gkf.split(epochs_array, epochs_labels, groups=groups_array[:, 0, 0]):
# Get training and validation sets
train_features, train_labels = epochs_array[train_index], epochs_labels[train_index]
val_features, val_labels = epochs_array[val_index], epochs_labels[val_index]
# Reshape data to (samples, time_steps, features)
train_features = train_features.transpose(0, 2, 1) # Adjust to (batch_size, 1250, 19)
val_features = val_features.transpose(0, 2, 1) # Adjust to (batch_size, 1250, 19)
# Normalize data
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(train_features.shape)
val_features = scaler.transform(val_features.reshape(-1, val_features.shape[-1])).reshape(val_features.shape)
# Define CNN model
model = cnnmodel()
model.compile(optimizer=Adam(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])
# Train model
model.fit(
train_features, train_labels,
epochs=20,
batch_size=1024,
validation_data=(val_features, val_labels),
callbacks=[EarlyStopping(monitor='val_loss', patience=3)]
)
# Evaluate model
acc = model.evaluate(val_features, val_labels)[1]
accuracy.append(acc)
# Break after one iteration for testing
break
# Print the validation accuracy
print(f'Validation accuracy: {accuracy}')
C:\Users\picasso\AppData\Roaming\Python\Python312\site-packages\keras\src\layers\convolutional\base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Epoch 1/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 172ms/step - accuracy: 0.5505 - loss: 0.6723 - val_accuracy: 0.5330 - val_loss: 0.6447
Epoch 2/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.5841 - loss: 0.6330 - val_accuracy: 0.7974 - val_loss: 0.5718
Epoch 3/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 91ms/step - accuracy: 0.7568 - loss: 0.5638 - val_accuracy: 0.8196 - val_loss: 0.4783
Epoch 4/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.7695 - loss: 0.5119 - val_accuracy: 0.8730 - val_loss: 0.3938
Epoch 5/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 90ms/step - accuracy: 0.8301 - loss: 0.4361 - val_accuracy: 0.8661 - val_loss: 0.4034
Epoch 6/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8227 - loss: 0.4102 - val_accuracy: 0.7245 - val_loss: 0.6450
Epoch 7/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8448 - loss: 0.4044 - val_accuracy: 0.8709 - val_loss: 0.3678
Epoch 8/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 91ms/step - accuracy: 0.8264 - loss: 0.3938 - val_accuracy: 0.6433 - val_loss: 0.8617
Epoch 9/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8475 - loss: 0.3768 - val_accuracy: 0.5857 - val_loss: 1.2216
Epoch 10/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8649 - loss: 0.3546 - val_accuracy: 0.6114 - val_loss: 0.9996
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.3651 - loss: 1.5616
Validation accuracy: [0.6113809943199158]
np.mean(accuracy)
0.6113809943199158
train_features.shape
(5760, 1250, 19)