使用python实现EEG信号提取与机器学习

网课视频:https://www.bilibili.com/video/BV1xq4y127JR/?p=4&vd_source=5be7027f4aa2ef056c31c89ffc9464f8
学习记录
数据集地址:https://repod.icm.edu.pl/dataset.xhtml?persistentId=doi:10.18150/repod.0107441

1. read and process data

!pip install mne
Defaulting to user installation because normal site-packages is not writeable
Collecting mne
  Downloading mne-1.8.0-py3-none-any.whl.metadata (21 kB)
Requirement already satisfied: decorator in d:\anaconda\anaconda3\lib\site-packages (from mne) (5.1.1)
Requirement already satisfied: jinja2 in d:\anaconda\anaconda3\lib\site-packages (from mne) (3.1.4)
Requirement already satisfied: lazy-loader>=0.3 in d:\anaconda\anaconda3\lib\site-packages (from mne) (0.4)
Requirement already satisfied: matplotlib>=3.6 in d:\anaconda\anaconda3\lib\site-packages (from mne) (3.8.4)
Requirement already satisfied: numpy<3,>=1.23 in d:\anaconda\anaconda3\lib\site-packages (from mne) (1.26.4)
Requirement already satisfied: packaging in d:\anaconda\anaconda3\lib\site-packages (from mne) (23.2)
Collecting pooch>=1.5 (from mne)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: scipy>=1.9 in d:\anaconda\anaconda3\lib\site-packages (from mne) (1.13.1)
Requirement already satisfied: tqdm in d:\anaconda\anaconda3\lib\site-packages (from mne) (4.66.4)
Requirement already satisfied: contourpy>=1.0.1 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (1.2.0)
Requirement already satisfied: cycler>=0.10 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (1.4.4)
Requirement already satisfied: pillow>=8 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (10.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (3.0.9)
Requirement already satisfied: python-dateutil>=2.7 in d:\anaconda\anaconda3\lib\site-packages (from matplotlib>=3.6->mne) (2.9.0.post0)
Requirement already satisfied: platformdirs>=2.5.0 in d:\anaconda\anaconda3\lib\site-packages (from pooch>=1.5->mne) (3.10.0)
Requirement already satisfied: requests>=2.19.0 in d:\anaconda\anaconda3\lib\site-packages (from pooch>=1.5->mne) (2.32.2)
Requirement already satisfied: MarkupSafe>=2.0 in d:\anaconda\anaconda3\lib\site-packages (from jinja2->mne) (2.1.3)
Requirement already satisfied: colorama in d:\anaconda\anaconda3\lib\site-packages (from tqdm->mne) (0.4.6)
Requirement already satisfied: six>=1.5 in d:\anaconda\anaconda3\lib\site-packages (from python-dateutil>=2.7->matplotlib>=3.6->mne) (1.16.0)
Requirement already satisfied: charset-normalizer<4,>=2 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda\anaconda3\lib\site-packages (from requests>=2.19.0->pooch>=1.5->mne) (2024.6.2)
Downloading mne-1.8.0-py3-none-any.whl (7.4 MB)
   ---------------------------------------- 0.0/7.4 MB ? eta -:--:--
   ---------------------------------------- 0.0/7.4 MB 2.0 MB/s eta 0:00:04
    --------------------------------------- 0.1/7.4 MB 1.2 MB/s eta 0:00:07
   - -------------------------------------- 0.2/7.4 MB 1.6 MB/s eta 0:00:05
   -- ------------------------------------- 0.4/7.4 MB 2.3 MB/s eta 0:00:04
   -- ------------------------------------- 0.5/7.4 MB 2.6 MB/s eta 0:00:03
   ---- ----------------------------------- 0.8/7.4 MB 3.1 MB/s eta 0:00:03
   ----- ---------------------------------- 1.1/7.4 MB 3.8 MB/s eta 0:00:02
   --------- ------------------------------ 1.7/7.4 MB 4.7 MB/s eta 0:00:02
   ------------ --------------------------- 2.3/7.4 MB 5.6 MB/s eta 0:00:01
   -------------- ------------------------- 2.6/7.4 MB 5.8 MB/s eta 0:00:01
   ------------------ --------------------- 3.4/7.4 MB 6.7 MB/s eta 0:00:01
   ---------------------- ----------------- 4.1/7.4 MB 7.7 MB/s eta 0:00:01
   ------------------------- -------------- 4.7/7.4 MB 7.9 MB/s eta 0:00:01
   ----------------------------- ---------- 5.4/7.4 MB 8.7 MB/s eta 0:00:01
   --------------------------------- ------ 6.1/7.4 MB 8.9 MB/s eta 0:00:01
   ------------------------------------- -- 6.8/7.4 MB 9.3 MB/s eta 0:00:01
   ---------------------------------------  7.4/7.4 MB 9.6 MB/s eta 0:00:01
   ---------------------------------------- 7.4/7.4 MB 9.1 MB/s eta 0:00:00
Downloading pooch-1.8.2-py3-none-any.whl (64 kB)
   ---------------------------------------- 0.0/64.6 kB ? eta -:--:--
   ---------------------------------------- 64.6/64.6 kB 1.7 MB/s eta 0:00:00
Installing collected packages: pooch, mne
Successfully installed mne-1.8.0 pooch-1.8.2


  WARNING: The script mne.exe is installed in 'C:\Users\picasso\AppData\Roaming\Python\Python312\Scripts' which is not on PATH.
  Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
from glob import glob
import os
import mne
import numpy as np
import pandas
import matplotlib.pyplot as plt
all_file_path=glob("D:/lab/eegtest1/dataverse_files/*.edf")
print(len(all_file_path))
28
all_file_path[0]
'D:/lab/eegtest1/dataverse_files\\h01.edf'
healthy_file_path = [i for i in all_file_path if 'h' in i.split('\\')[1]]
patient_file_path = [i for i in all_file_path if 's' in i.split('\\')[1]]
print(len(healthy_file_path),len(patient_file_path))
14 14
def read_data(file_path):
    data = mne.io.read_raw_edf(file_path,preload=True)
    data.set_eeg_reference()
    data.filter(l_freq=0.5,h_freq=45)
    epochs=mne.make_fixed_length_epochs(data,duration=5,overlap=1)
    array= epochs.get_data()
    return array
sample_data = read_data(healthy_file_path[0])
Extracting EDF parameters from D:\lab\eegtest1\dataverse_files\h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (6.604 s)

Not setting metadata
231 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 231 events and 1250 original time points ...
0 bad epochs dropped


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
sample_data.shape   #number of epochs,channels,length nof signal
(231, 19, 1250)
%%capture
control_epochs_array=[read_data(i) for i in healthy_file_path]
patient_epochs_array=[read_data(i) for i in patient_file_path]
control_epochs_array[0].shape,control_epochs_array[1].shape
((231, 19, 1250), (227, 19, 1250))
control_epochs_labels = [len(i)*[0] for i in control_epochs_array]
patient_epochs_labels = [len(i)*[1] for i in patient_epochs_array]
len(control_epochs_labels),len(patient_epochs_labels)
(14, 14)
data_list = control_epochs_array +patient_epochs_array
label_list = control_epochs_labels+patient_epochs_labels 

epochs_array = control_epochs_array +patient_epochs_array
epochs_labels = control_epochs_labels+patient_epochs_labels 
groups = [[i]*len(j) for i,j in enumerate(data_list)]
group_list = [[i]*len(j) for i,j in enumerate(data_list)]
len(group_list)
28
data_array = np.vstack(data_list)
label_array = np.hstack(label_list)
group_array = np.hstack(group_list)
print(data_array.shape,label_array.shape,group_array.shape)
(7201, 19, 1250) (7201,) (7201,)
data_array = np.vstack(data_list)
label_array = np.hstack(label_list)
group_array = np.hstack(group_list)
print(data_array.shape,label_array.shape,group_array.shape)

2.ML

from scipy import stats
def mean(x):
    return np.mean(x,axis = -1)
def std(x):
    return np.std(x,axis = -1)
def ptp(x):
    return np.ptp(x,axis = -1)
def var(x):
    return np.var(x,axis = -1)
def minim(x):
    return np.min(x,axis = -1)
def maxim(x):
    return np.max(x,axis = -1)
    
def argminim(x):
    return np.argmin(x,axis = -1)
def argmaxim(x):
    return np.argmax(x,axis = -1)

def rms(x):
    return np.sqrt(np.mean(x**2,axis=-1))

def abs_diff_signal(x):
    return np.sum(np.abs(np.diff(x,axis=-1)),axis=-1)

def skewness(x):
    return stats.skew(x,axis = -1)
def kurtosos(x):
    return stats.kurtosis(x,axis = -1)

def concatenate_features(x):
    return np.concatenate((mean(x),std(x),ptp(x),var(x),minim(x),maxim(x),argminim(x),argmaxim(x),rms(x),
                           abs_diff_signal(x),skewness(x),kurtosos(x)),axis=-1)

features=[]
for d in data_array:
    features.append(concatenate_features(d))
features_array=np.array(features)
features_array.shape
(7201, 228)
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold,GridSearchCV
clf = LogisticRegression()
gkf = GroupKFold(5)
pipe = Pipeline([('scaler',StandardScaler()),('clf',clf)])
param_grid={'clf__C':[0.1,0.5,0.7,1,3,5,7]}
gscv=GridSearchCV(pipe,param_grid,cv=gkf,n_jobs=12)
gscv.fit(features_array,label_array,groups=group_array)

D:\Anaconda\Anaconda3\Lib\site-packages\sklearn\linear_model\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
GridSearchCV(cv=GroupKFold(n_splits=5),
         estimator=Pipeline(steps=[(&#x27;scaler&#x27;, StandardScaler()),
                                   (&#x27;clf&#x27;, LogisticRegression())]),
         n_jobs=12, param_grid={&#x27;clf__C&#x27;: [0.1, 0.5, 0.7, 1, 3, 5, 7]})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-1" type="checkbox" ><label for="sk-estimator-id-1" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;&nbsp;GridSearchCV<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.model_selection.GridSearchCV.html">?<span>Documentation for GridSearchCV</span></a><span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span></label><div class="sk-toggleable__content fitted"><pre>GridSearchCV(cv=GroupKFold(n_splits=5),
         estimator=Pipeline(steps=[(&#x27;scaler&#x27;, StandardScaler()),
                                   (&#x27;clf&#x27;, LogisticRegression())]),
         n_jobs=12, param_grid={&#x27;clf__C&#x27;: [0.1, 0.5, 0.7, 1, 3, 5, 7]})</pre></div> </div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-2" type="checkbox" ><label for="sk-estimator-id-2" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">estimator: Pipeline</label><div class="sk-toggleable__content fitted"><pre>Pipeline(steps=[(&#x27;scaler&#x27;, StandardScaler()), (&#x27;clf&#x27;, LogisticRegression())])</pre></div> </div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-3" type="checkbox" ><label for="sk-estimator-id-3" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;StandardScaler<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.preprocessing.StandardScaler.html">?<span>Documentation for StandardScaler</span></a></label><div class="sk-toggleable__content fitted"><pre>StandardScaler()</pre></div> </div></div><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-4" type="checkbox" ><label for="sk-estimator-id-4" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;LogisticRegression<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.linear_model.LogisticRegression.html">?<span>Documentation for LogisticRegression</span></a></label><div class="sk-toggleable__content fitted"><pre>LogisticRegression()</pre></div> </div></div></div></div></div></div></div></div></div></div></div>
gscv.best_score_
0.6636027871434461

3.deep learning CNN

epochs_array = np.vstack(epochs_array)
epochs_labels = np.hstack(epochs_labels)
groups_array = np.hstack(groups)
epochs_array.shape,epochs_labels.shape,groups_array.shape
((136819, 1250), (7201,), (7201,))
groups_array = np.moveaxis(epochs_array ,1,2)
groups_array.shape
(7201, 1250, 19)
from tensorflow.keras.layers import Conv1D,BatchNormalization,LeakyReLU,MaxPool1D,\
GlobalAveragePooling1D,Dense,Dropout,AveragePooling1D
from tensorflow.keras.models import Sequential
from tensorflow.keras.backend import clear_session
def cnnmodel():
    clear_session()
    model=Sequential()
    model.add(Conv1D(filters=5,kernel_size=3,strides=1,input_shape=(6250,19)))#1
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2,strides=2))#2
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#3
    model.add(LeakyReLU())
    model.add(MaxPool1D(pool_size=2,strides=2))#4
    model.add(Dropout(0.5))
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#5
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2,strides=2))#6
    model.add(Dropout(0.5))
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#7
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2,strides=2))#8
    model.add(Conv1D(filters=5,kernel_size=3,strides=1))#9
    model.add(LeakyReLU())
    model.add(GlobalAveragePooling1D())#10
    model.add(Dense(1,activation='sigmoid'))#11
    model.compile('adam',loss='binary_crossentropy',metrics=['accuracy'])
    return model
    
model=cnnmodel()
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ conv1d (Conv1D)                      │ (None, 6248, 5)             │             290 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ batch_normalization                  │ (None, 6248, 5)             │              20 │
│ (BatchNormalization)                 │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ leaky_re_lu (LeakyReLU)              │ (None, 6248, 5)             │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ max_pooling1d (MaxPooling1D)         │ (None, 3124, 5)             │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ conv1d_1 (Conv1D)                    │ (None, 3122, 5)             │              80 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ leaky_re_lu_1 (LeakyReLU)            │ (None, 3122, 5)             │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ max_pooling1d_1 (MaxPooling1D)       │ (None, 1561, 5)             │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout (Dropout)                    │ (None, 1561, 5)             │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ conv1d_2 (Conv1D)                    │ (None, 1559, 5)             │              80 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ leaky_re_lu_2 (LeakyReLU)            │ (None, 1559, 5)             │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ average_pooling1d (AveragePooling1D) │ (None, 779, 5)              │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout_1 (Dropout)                  │ (None, 779, 5)              │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ conv1d_3 (Conv1D)                    │ (None, 777, 5)              │              80 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ leaky_re_lu_3 (LeakyReLU)            │ (None, 777, 5)              │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ average_pooling1d_1                  │ (None, 388, 5)              │               0 │
│ (AveragePooling1D)                   │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ conv1d_4 (Conv1D)                    │ (None, 386, 5)              │              80 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ leaky_re_lu_4 (LeakyReLU)            │ (None, 386, 5)              │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ global_average_pooling1d             │ (None, 5)                   │               0 │
│ (GlobalAveragePooling1D)             │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense (Dense)                        │ (None, 1)                   │               6 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 636 (2.48 KB)
 Trainable params: 626 (2.45 KB)
 Non-trainable params: 10 (40.00 B)
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler
gkf = GroupKFold()
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping


accuracy = []

# Assuming gkf, epochs_array, epochs_labels, and groups_array are already defined
for train_index, val_index in gkf.split(epochs_array, epochs_labels, groups=groups_array[:, 0, 0]):
    # Get training and validation sets
    train_features, train_labels = epochs_array[train_index], epochs_labels[train_index]
    val_features, val_labels = epochs_array[val_index], epochs_labels[val_index]
    
    # Reshape data to (samples, time_steps, features)
    train_features = train_features.transpose(0, 2, 1)  # Adjust to (batch_size, 1250, 19)
    val_features = val_features.transpose(0, 2, 1)  # Adjust to (batch_size, 1250, 19)
    
    # Normalize data
    scaler = StandardScaler()
    train_features = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(train_features.shape)
    val_features = scaler.transform(val_features.reshape(-1, val_features.shape[-1])).reshape(val_features.shape)

    # Define CNN model
    model = cnnmodel()
    model.compile(optimizer=Adam(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])
    
    # Train model
    model.fit(
        train_features, train_labels, 
        epochs=20, 
        batch_size=1024, 
        validation_data=(val_features, val_labels),
        callbacks=[EarlyStopping(monitor='val_loss', patience=3)]
    )
    
    # Evaluate model
    acc = model.evaluate(val_features, val_labels)[1]
    accuracy.append(acc)
    
    # Break after one iteration for testing
    break

# Print the validation accuracy
print(f'Validation accuracy: {accuracy}')

C:\Users\picasso\AppData\Roaming\Python\Python312\site-packages\keras\src\layers\convolutional\base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 172ms/step - accuracy: 0.5505 - loss: 0.6723 - val_accuracy: 0.5330 - val_loss: 0.6447
Epoch 2/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.5841 - loss: 0.6330 - val_accuracy: 0.7974 - val_loss: 0.5718
Epoch 3/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 91ms/step - accuracy: 0.7568 - loss: 0.5638 - val_accuracy: 0.8196 - val_loss: 0.4783
Epoch 4/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.7695 - loss: 0.5119 - val_accuracy: 0.8730 - val_loss: 0.3938
Epoch 5/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 90ms/step - accuracy: 0.8301 - loss: 0.4361 - val_accuracy: 0.8661 - val_loss: 0.4034
Epoch 6/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8227 - loss: 0.4102 - val_accuracy: 0.7245 - val_loss: 0.6450
Epoch 7/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8448 - loss: 0.4044 - val_accuracy: 0.8709 - val_loss: 0.3678
Epoch 8/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 91ms/step - accuracy: 0.8264 - loss: 0.3938 - val_accuracy: 0.6433 - val_loss: 0.8617
Epoch 9/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8475 - loss: 0.3768 - val_accuracy: 0.5857 - val_loss: 1.2216
Epoch 10/20
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 89ms/step - accuracy: 0.8649 - loss: 0.3546 - val_accuracy: 0.6114 - val_loss: 0.9996
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.3651 - loss: 1.5616
Validation accuracy: [0.6113809943199158]
np.mean(accuracy)
0.6113809943199158
train_features.shape
(5760, 1250, 19)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LEFT&Picasso

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值