'''
Date:2020-07-11
Auther:Deniu He
Email:hedeniu@163.com
Organization: CQUPT
该版本用于测试调用matlab代码
'''
import matlab
import matlab.engine
import pandas as pd
import numpy as np
import os
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import KFold,StratifiedKFold,StratifiedShuffleSplit,train_test_split
from sklearn.metrics import accuracy_score,f1_score
if __name__ == '__main__':
##-------------启动matlab引擎-----------------###
engine = matlab.engine.start_matlab() # 启动matlab引擎
# path = r"D:\Program Files\MATLAB\orca-master\exampledata\ERA.csv"
# data = np.loadtxt(path, dtype=float, delimiter=',')
# X = data[:, :-1]
# y = data[:, -1]
path2 = r"D:\ExperimentalData\Jain\jain.csv"
data = np.array(pd.read_csv(path2, header=None))
X = data[:, :-1]
y = data[:, -1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
y_train = np.vstack(y_train)
y_test = np.vstack(y_test)
print(y_train.shape)
train_X = matlab.double(X_train.tolist())
train_y = matlab.double(y_train.tolist())
test_X = matlab.double(X_test.tolist())
test_y = matlab.double(y_test.tolist())
acc = engine.Diaoyong(train_X,train_y,test_X,test_y)
print(acc)
在python中标签向量是ndarray格式的,形状为(1000,),它本质上是一个行向量。
需要先转换为(1000,1)的格式。