matlab版
clc;
clear;
close all;
I = double(imread('Houston.tif'));
[m, n, z] = size(I);
TR = double(imread('Houston_train.tif'));
TE = double(imread('Houston_test.tif'));
I2d = hyperConvert2d(I);
for i = 1 : z
I2d(i, :) = mat2gray(I2d(i, :)); %每一个维度归一化
end
TR2d = hyperConvert2d(TR);
TE2d = hyperConvert2d(TE);
TR_sample = I2d(:,TR2d>0);
TE_sample = I2d(:,TE2d>0);
TR_temp = TR2d(:,TR2d>0);
TE_temp = TE2d(:,TE2d>0);
X = [TR_sample,TE_sample];
Y = [TR_temp, TE_temp];
K = 10;
si = 1;
ALL_W = creatLap(X,K, si);
% neibor=zeros(size(ALL_W));
% for x=length(Y)
% for y=length(Y)
% for erros=[1,-1,145,-145,144,146,-144,-146]
% if x-y==erros
% neibor(x,y)=0.125;
% end
% end
% end
% end
% ALL_W=ALL_W+neibor;
ALL_W = sparse(ALL_W);
ALL_D = (sum(ALL_W, 2)).^(-1/2);
ALL_D = diag(ALL_D);
L_temp = ALL_W * ALL_D;
ALL_L = ALL_D * L_temp;
ALL_L = ALL_L + eye(size(ALL_L));
ALL_L = sparse(ALL_L);
ALL_X = X';
ALL_Y = Y';
%% Please replace the following route with your own one
save('D:\desk\高光谱\GCN-twobranch\IEEE_TGRS_GCN-master\HSI_GCN/ALL_X.mat','ALL_X');
save('D:\desk\高光谱\GCN-twobranch\IEEE_TGRS_GCN-master\HSI_GCN/ALL_Y.mat','ALL_Y');
save('D:\desk\高光谱\GCN-twobranch\IEEE_TGRS_GCN-master\HSI_GCN/ALL_L.mat','ALL_L');
save('D:\desk\高光谱\GCN-twobranch\IEEE_TGRS_GCN-master\HSI_GCN/ALL_W.mat','ALL_W');
第一版
import numpy as np
import tifffile as tiff
import os
import cv2
import scipy.io as sio
from scipy.cluster.vq import whiten
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix
import pandas as pd
NUM_CLASS =15
PATH = './data/Houston/'
SAVA_PATH = './file/'
BATCH_SIZE = 100
r = 5
upscale = 2
LiDarName = 'Houston_LiDAR.tif'
HsiName = 'Houston.tif'
gth_train = 'Houston_train.tif'
gth_test = 'Houston_test.tif'
lchn = 1
hchn = 144
if not os.path.exists(SAVA_PATH):
os.mkdir(SAVA_PATH)
def read_image(filename):
img = tiff.imread(filename)
img = np.asarray(img, dtype=np.float32)
return img
def show_image(image):
plt.imshow(image)
plt.axis('off')
plt.show()
I= read_image(PATH + HsiName)
m, n, z=I.shape
TR_map = read_image(PATH + gth_train)
TE_map = read_image(PATH + gth_test)
print ('***************原始的**************')
print("HSI.shape={}".format(I.shape))
print("TR_map.shape={} ".format(TR_map.shape))
print("TE_map.shape={} ".format(TE_map.shape))
I2d = I.reshape(m*n,z)
TR2d = TR_map.reshape(m*n)
TE2d= TE_map.reshape(m*n)
print('***************展平的**************')
print("I2d.shape={}".format(I2d.shape))
print("TR2d.shape={} ".format(TR2d.shape))
print("TE2d.shape={} ".format(TE2d.shape))
TR_sample = I2d[np.where(TR2d>0)[0],:]
TE_sample = I2d[np.where(TE2d>0)[0],:]
TR_temp = TR2d[np.where(TR2d>0)[0]]
TE_temp = TE2d[np.where(TE2d>0)[0]]
print('***************挑选后**************')
print("TR_sample.shape={} ".format(TR_sample.shape))
print("TE_sample.shape={} ".format(TE_sample.shape))
print("TR_temp.shape={} ".format(TR_temp.shape))
print("TE_temp.shape={} ".format(TE_temp.shape))
X = np.append(TR_sample,TE_sample,axis=0)
Y = np.append(TR_temp, TE_temp)
print('***************合并后**************')
print("X.shape={}".format(X.shape))
print("Y.shape={}".format(Y.shape))
K = 10;
print('***************邻居数**************')
print("k={}".format(K))
A=np.corrcoef(X)
def pick():
weight=np.zeros(A.shape)
for i in range (A.shape[0]):
temp=pd.Series(A[i,:])
temp=np.where(temp>(Y.shape[0]-10))
weight[temp]=1.0
weight = coo_matrix(weight)
return weight
A= pick()
A=coo_matrix(A)
np.save('X.npy', X, allow_pickle=True)
np.save('Y.npy', Y, allow_pickle=True)
np.save('A.npy', A, allow_pickle=True)
第二版
import numpy as np
import tifffile as tiff
import os
import cv2
import scipy.io as sio
from scipy.cluster.vq import whiten
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix
import pandas as pd
NUM_CLASS =15
PATH = './data/Houston/'
SAVA_PATH = './file/'
BATCH_SIZE = 100
r = 5
upscale = 2
LiDarName = 'Houston_LiDAR.tif'
HsiName = 'Houston.tif'
gth_train = 'Houston_train.tif'
gth_test = 'Houston_test.tif'
lchn = 1
hchn = 144
if not os.path.exists(SAVA_PATH):
os.mkdir(SAVA_PATH)
def read_image(filename):
img = tiff.imread(filename)
img = np.asarray(img, dtype=np.float32)
return img
def show_image(image):
plt.imshow(image)
plt.axis('off')
plt.show()
I= read_image(PATH + HsiName)
m, n, z=I.shape
TR_map = read_image(PATH + gth_train)
TE_map = read_image(PATH + gth_test)
print ('***************原始的**************')
print("HSI.shape={}".format(I.shape))
print("TR_map.shape={} ".format(TR_map.shape))
print("TE_map.shape={} ".format(TE_map.shape))
I2d = I.reshape(m*n,z)
TR2d = TR_map.reshape(m*n)
TE2d= TE_map.reshape(m*n)
print('***************展平的**************')
print("I2d.shape={}".format(I2d.shape))
print("TR2d.shape={} ".format(TR2d.shape))
print("TE2d.shape={} ".format(TE2d.shape))
TR_sample = I2d[np.where(TR2d>0)[0],:]
TE_sample = I2d[np.where(TE2d>0)[0],:]
TR_temp = TR2d[np.where(TR2d>0)[0]]
TE_temp = TE2d[np.where(TE2d>0)[0]]
print('***************挑选后**************')
print("TR_sample.shape={} ".format(TR_sample.shape))
print("TE_sample.shape={} ".format(TE_sample.shape))
print("TR_temp.shape={} ".format(TR_temp.shape))
print("TE_temp.shape={} ".format(TE_temp.shape))
X = np.append(TR_sample,TE_sample,axis=0)
Y = np.append(TR_temp, TE_temp)
print('***************合并后**************')
print("X.shape={}".format(X.shape))
print("Y.shape={}".format(Y.shape))
K = 10;
print('***************邻居数**************')
print("k={}".format(K))
A=np.corrcoef(X)
def pick():
weight=np.zeros(A.shape)
for i in range (A.shape[0]):
temp=pd.Series(A[i,:])
temp=np.where(temp>(Y.shape[0]-10))
weight[temp]=1.0
weight = coo_matrix(weight)
return weight
A= pick()
A=coo_matrix(A)
per = 0.89
def creat(validation=False):
Xh = []
Y = []
for c in range(1, NUM_CLASS + 1):
idx, idy = np.where( TR_map ==c)
if not validation:
idx = idx[:int(per * len(idx))]
idy = idy[:int(per * len(idy))]
else:
idx = idx[int(per * len(idx)):]
idy = idy[int(per * len(idy)):]
np.random.seed(820)
ID = np.random.permutation(len(idx))
idx = idx[ID]
idy = idy[ID]
for i in range(len(idx)):
tmph = I[idx[i], idy[i]]
tmpy = TR_map[idx[i], idy[i]] - 1
Xh.append(tmph)
Y.append(tmpy)
Xh = np.asarray(Xh, dtype=np.float32)
Y = np.asarray(Y, dtype=np.int8)
if not validation:
np.save(SAVA_PATH + 'train_Xh.npy', Xh, allow_pickle=True)
np.save(SAVA_PATH +'train_y.npy', Y, allow_pickle=True)
else:
np.save(SAVA_PATH +'validation_Xh.npy', Xh, allow_pickle=True)
np.save(SAVA_PATH +'validation_y', Y, allow_pickle=True)
creat(validation=False)
creat(validation=True)
X = np.asarray(X, dtype=np.float32)
Y = np.asarray(Y, dtype=np.int8)
TE_sample = np.asarray(TE_sample, dtype=np.float32)
TE_temp = np.asarray(TE_temp, dtype=np.int8)
np.save(SAVA_PATH +'X.npy', X, allow_pickle=True)
np.save(SAVA_PATH +'Y.npy', Y, allow_pickle=True)
np.save(SAVA_PATH +'A.npy', A, allow_pickle=True)
np.save(SAVA_PATH +'test_Xh.npy', TE_sample, allow_pickle=True)
np.save(SAVA_PATH +'test_y.npy', TE_temp, allow_pickle=True)
第三版 改成onehot标签
import numpy as np
import tifffile as tiff
import os
import cv2
import scipy.io as sio
from scipy.cluster.vq import whiten
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix
import pandas as pd
NUM_CLASS =15
PATH = './data/Houston/'
SAVA_PATH = './file/'
BATCH_SIZE = 100
r = 5
upscale = 2
LiDarName = 'Houston_LiDAR.tif'
HsiName = 'Houston.tif'
gth_train = 'Houston_train.tif'
gth_test = 'Houston_test.tif'
lchn = 1
hchn = 144
if not os.path.exists(SAVA_PATH):
os.mkdir(SAVA_PATH)
def read_image(filename):
img = tiff.imread(filename)
img = np.asarray(img, dtype=np.float32)
return img
def show_image(image):
plt.imshow(image)
plt.axis('off')
plt.show()
I= read_image(PATH + HsiName)
m, n, z=I.shape
TR_map = read_image(PATH + gth_train)
TE_map = read_image(PATH + gth_test)
print ('***************原始的**************')
print("HSI.shape={}".format(I.shape))
print("TR_map.shape={} ".format(TR_map.shape))
print("TE_map.shape={} ".format(TE_map.shape))
I2d = I.reshape(m*n,z)
TR2d = TR_map.reshape(m*n)
TE2d= TE_map.reshape(m*n)
print('***************展平的**************')
print("I2d.shape={}".format(I2d.shape))
print("TR2d.shape={} ".format(TR2d.shape))
print("TE2d.shape={} ".format(TE2d.shape))
TR_sample = I2d[np.where(TR2d>0)[0],:]
TE_sample = I2d[np.where(TE2d>0)[0],:]
TR_temp = TR2d[np.where(TR2d>0)[0]]
TE_temp = TE2d[np.where(TE2d>0)[0]]
print('***************挑选后**************')
print("TR_sample.shape={} ".format(TR_sample.shape))
print("TE_sample.shape={} ".format(TE_sample.shape))
print("TR_temp.shape={} ".format(TR_temp.shape))
print("TE_temp.shape={} ".format(TE_temp.shape))
X = np.append(TR_sample,TE_sample,axis=0)
Y = np.append(TR_temp, TE_temp)
print('***************合并后**************')
print("X.shape={}".format(X.shape))
print("Y.shape={}".format(Y.shape))
K = 10;
print('***************邻居数**************')
print("k={}".format(K))
A=np.corrcoef(X)
def pick():
weight=np.zeros(A.shape)
for i in range (A.shape[0]):
temp=pd.Series(A[i,:])
temp=np.where(temp>(Y.shape[0]-10))
weight[temp]=1.0
weight = coo_matrix(weight)
return weight
A= pick()
A=coo_matrix(A)
per = 0.89
def creat(validation=False):
Xh = []
Y = []
for c in range(1, NUM_CLASS + 1):
idx, idy = np.where( TR_map ==c)
if not validation:
idx = idx[:int(per * len(idx))]
idy = idy[:int(per * len(idy))]
else:
idx = idx[int(per * len(idx)):]
idy = idy[int(per * len(idy)):]
np.random.seed(820)
ID = np.random.permutation(len(idx))
idx = idx[ID]
idy = idy[ID]
for i in range(len(idx)):
tmph = I[idx[i], idy[i]]
tmpy = TR_map[idx[i], idy[i]] - 1
Xh.append(tmph)
Y.append(tmpy)
Xh = np.asarray(Xh, dtype=np.float32)
Y = np.asarray(Y, dtype=np.int8)
if not validation:
np.save(SAVA_PATH + 'train_Xh.npy', Xh, allow_pickle=True)
np.save(SAVA_PATH +'train_y.npy', Y, allow_pickle=True)
else:
np.save(SAVA_PATH +'validation_Xh.npy', Xh, allow_pickle=True)
np.save(SAVA_PATH +'validation_y', Y, allow_pickle=True)
creat(validation=False)
creat(validation=True)
def encode_onehot(labels):
classes = set(labels)
classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
return labels_onehot
Y=encode_onehot(Y)
print('***************onehot后**************')
print("Y.shape={}".format(Y.shape))
X = np.asarray(X, dtype=np.float32)
Y = np.asarray(Y, dtype=np.int8)
TE_sample = np.asarray(TE_sample, dtype=np.float32)
TE_temp = np.asarray(TE_temp, dtype=np.int8)
np.save(SAVA_PATH +'X.npy', X, allow_pickle=True)
np.save(SAVA_PATH +'Y.npy', Y, allow_pickle=True)
np.save(SAVA_PATH +'A.npy', A, allow_pickle=True)
np.save(SAVA_PATH +'test_Xh.npy', TE_sample, allow_pickle=True)
np.save(SAVA_PATH +'test_y.npy', TE_temp, allow_pickle=True)
第四版 mask_tr 等等
import numpy as np
import tifffile as tiff
import os
import cv2
import scipy.io as sio
from scipy.cluster.vq import whiten
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix
import pandas as pd
NUM_CLASS =15
PATH = './data/Houston/'
SAVA_PATH = './file/'
BATCH_SIZE = 100
r = 5
upscale = 2
LiDarName = 'Houston_LiDAR.tif'
HsiName = 'Houston.tif'
gth_train = 'Houston_train.tif'
gth_test = 'Houston_test.tif'
lchn = 1
hchn = 144
if not os.path.exists(SAVA_PATH):
os.mkdir(SAVA_PATH)
def read_image(filename):
img = tiff.imread(filename)
img = np.asarray(img, dtype=np.float32)
return img
def show_image(image):
plt.imshow(image)
plt.axis('off')
plt.show()
I= read_image(PATH + HsiName)
m, n, z=I.shape
TR_map = read_image(PATH + gth_train)
TE_map = read_image(PATH + gth_test)
print ('***************原始的**************')
print("HSI.shape={}".format(I.shape))
print("TR_map.shape={} ".format(TR_map.shape))
print("TE_map.shape={} ".format(TE_map.shape))
I2d = I.reshape(m*n,z)
TR2d = TR_map.reshape(m*n)
TE2d= TE_map.reshape(m*n)
print('***************展平的**************')
print("I2d.shape={}".format(I2d.shape))
print("TR2d.shape={} ".format(TR2d.shape))
print("TE2d.shape={} ".format(TE2d.shape))
TR_sample = I2d[np.where(TR2d>0)[0],:]
TE_sample = I2d[np.where(TE2d>0)[0],:]
TR_temp = TR2d[np.where(TR2d>0)[0]]
TE_temp = TE2d[np.where(TE2d>0)[0]]
print('***************挑选后**************')
print("TR_sample.shape={} ".format(TR_sample.shape))
print("TE_sample.shape={} ".format(TE_sample.shape))
print("TR_temp.shape={} ".format(TR_temp.shape))
print("TE_temp.shape={} ".format(TE_temp.shape))
X = TR_sample
Y = TR_temp
print('***************合并后**************')
print("X.shape={}".format(X.shape))
print("Y.shape={}".format(Y.shape))
K = 10;
print('***************邻居数**************')
print("k={}".format(K))
A=np.corrcoef(X)
def pick():
weight=np.zeros(A.shape)
for i in range (A.shape[0]):
temp=pd.Series(A[i,:])
temp=np.where(temp>(Y.shape[0]-10))
weight[temp]=1.0
weight = coo_matrix(weight)
return weight
A= pick()
A=coo_matrix(A)
per = 0.89
def creat(validation=False):
Xh = []
Yh = []
idx_trval=[]
for c in range(1, NUM_CLASS + 1):
idx = np.where(TR_temp==c)
idx = idx[0]
if not validation:
idx = idx[: int(per * len(idx))]
else:
idx = idx[int(per * len(idx)):]
np.random.seed(820)
ID = np.random.permutation(len(idx))
for i in range(len(idx)):
tmph = X[idx[i]]
tmpy = Y[idx[i]] - 1
Xh.append(tmph)
Yh.append(tmpy)
idx_trval.append(idx[i])
mask_trval=np.zeros(Y.shape,dtype = bool)
mask_trval[idx_trval]=True
Xh = np.asarray(Xh, dtype=np.float32)
Yh = np.asarray(Yh, dtype=np.int8)
if not validation:
print('***********************************')
print("mask_tr.shape={0}, TRUE :{1}".format(len(mask_trval), len(np.where(mask_trval==True)[0])))
print("train_Xh.shape={}".format(Xh.shape))
print("train_y.shape={}".format(Y.shape))
np.save(SAVA_PATH + 'mask_tr.npy', mask_trval, allow_pickle=True)
np.save(SAVA_PATH + 'train_Xh.npy', Xh, allow_pickle=True)
np.save(SAVA_PATH +'train_y.npy', Yh, allow_pickle=True)
else:
print('***********************************')
print("mask_val.shape={0}, TRUE :{1}".format(len(mask_trval), len(np.where(mask_trval==True)[0])))
print("validation_Xh.shape={}".format(Xh.shape))
print("validation_y.shape={}".format(Y.shape))
np.save(SAVA_PATH + 'mask_val.npy', mask_trval, allow_pickle=True)
np.save(SAVA_PATH +'validation_Xh.npy', Xh, allow_pickle=True)
np.save(SAVA_PATH +'validation_y', Y, allow_pickle=True)
creat(validation=False)
creat(validation=True)
def encode_onehot(labels):
classes = set(labels)
classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
return labels_onehot
Y=encode_onehot(Y)
print('***************onehot后**************')
print("Y.shape={}".format(Y.shape))
X = np.asarray(X, dtype=np.float32)
Y = np.asarray(Y, dtype=np.int8)
TE_sample = np.asarray(TE_sample, dtype=np.float32)
TE_temp = np.asarray(TE_temp, dtype=np.int8)
np.save(SAVA_PATH +'X.npy', X, allow_pickle=True)
np.save(SAVA_PATH +'Y.npy', Y, allow_pickle=True)
np.save(SAVA_PATH +'A.npy', A, allow_pickle=True)
np.save(SAVA_PATH +'test_Xh.npy', TE_sample, allow_pickle=True)
np.save(SAVA_PATH +'test_y.npy', TE_temp, allow_pickle=True)
第五版 基本完善
import numpy as np
import tifffile as tiff
import os
import cv2
import scipy.io as sio
from scipy.cluster.vq import whiten
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix
import pandas as pd
NUM_CLASS =15
PATH = './data/Houston/'
SAVA_PATH = './file/'
BATCH_SIZE = 100
r = 5
upscale = 2
LiDarName = 'Houston_LiDAR.tif'
HsiName = 'Houston.tif'
gth_train = 'Houston_train.tif'
gth_test = 'Houston_test.tif'
lchn = 1
hchn = 144
if not os.path.exists(SAVA_PATH):
os.mkdir(SAVA_PATH)
def read_image(filename):
img = tiff.imread(filename)
img = np.asarray(img, dtype=np.float32)
return img
def show_image(image):
plt.imshow(image)
plt.axis('off')
plt.show()
I= read_image(PATH + HsiName)
m, n, z=I.shape
TR_map = read_image(PATH + gth_train)
TE_map = read_image(PATH + gth_test)
print ('***************原始的**************')
print("HSI.shape={}".format(I.shape))
print("TR_map.shape={} ".format(TR_map.shape))
print("TE_map.shape={} ".format(TE_map.shape))
I2d = I.reshape(m*n,z)
TR2d = TR_map.reshape(m*n)
TE2d= TE_map.reshape(m*n)
print('***************展平的**************')
print("I2d.shape={}".format(I2d.shape))
print("TR2d.shape={} ".format(TR2d.shape))
print("TE2d.shape={} ".format(TE2d.shape))
TR_sample = I2d[np.where(TR2d>0)[0],:]
TE_sample = I2d[np.where(TE2d>0)[0],:]
TR_temp = TR2d[np.where(TR2d>0)[0]]
TE_temp = TE2d[np.where(TE2d>0)[0]]
print('***************挑选后**************')
print("TR_sample.shape={} ".format(TR_sample.shape))
print("TE_sample.shape={} ".format(TE_sample.shape))
print("TR_temp.shape={} ".format(TR_temp.shape))
print("TE_temp.shape={} ".format(TE_temp.shape))
X = np.append(TR_sample,TE_sample,axis=0)
Y = np.append(TR_temp, TE_temp)
print('***************合并后**************')
print("X.shape={}".format(X.shape))
print("Y.shape={}".format(Y.shape))
K = 10;
print('***************邻居数**************')
print("k={}".format(K))
A=np.corrcoef(X)
def pick(num):
weight=np.zeros(A.shape)
for i in range (A.shape[0]):
temp=pd.Series(A[i,:])
temp=np.where(temp>(Y.shape[0]-num))
weight[temp]=1.0
weight = coo_matrix(weight)
return weight
A= A*pick(10)
per = 0.89
def creat_trval(validation=False):
Xh = []
Yh = []
idx_trval=[]
for c in range(1, NUM_CLASS + 1):
value_temp=TR_temp-c*np.ones(TR_temp.shape)
idx = np.where(value_temp == 0)
idx = idx[0]
if not validation:
idx = idx[: int(per * len(idx))]
else:
idx = idx[int(per * len(idx)):]
np.random.seed(820)
ID = np.random.permutation(len(idx))
for i in range(len(idx)):
tmph = X[idx[i]]
tmpy = TR_temp[idx[i]] - 1
Xh.append(tmph)
Yh.append(tmpy)
idx_trval.append(idx[i])
mask_trval=np.zeros(Y.shape,dtype = bool)
mask_trval[idx_trval]=True
Xh = np.asarray(Xh, dtype=np.float32)
Yh = np.asarray(Yh, dtype=np.int8)
if not validation:
print('***********************************')
print("mask_tr.shape={0}, TRUE :{1}".format(len(mask_trval), len(np.where(mask_trval==True)[0])))
print("train_Xh.shape={}".format(Xh.shape))
print("train_y.shape={}".format(Y.shape))
np.save(SAVA_PATH + 'mask_tr.npy', mask_trval, allow_pickle=True)
else:
print('***********************************')
print("mask_val.shape={0}, TRUE :{1}".format(len(mask_trval), len(np.where(mask_trval==True)[0])))
print("validation_Xh.shape={}".format(Xh.shape))
print("validation_y.shape={}".format(Y.shape))
np.save(SAVA_PATH + 'mask_val.npy', mask_trval, allow_pickle=True)
def creat_test():
length_trval=TR_temp.shape[0]
mask_test= np.zeros(Y.shape, dtype=bool)
mask_test[length_trval :] = True
print('***********************************')
print("mask_test.shape={0}, TRUE :{1}".format(len(mask_test), len(np.where(mask_test == True)[0])))
np.save(SAVA_PATH + 'mask_test.npy', mask_test, allow_pickle=True)
creat_trval(validation=False)
creat_trval(validation=True)
creat_test()
def encode_onehot(labels):
classes = set(labels)
classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
return labels_onehot
Y=encode_onehot(Y)
print('***************onehot后**************')
print("Y.shape={}".format(Y.shape))
X = np.asarray(X, dtype=np.float32)
A = np.asarray(A, dtype=np.float32)
Y = np.asarray(Y, dtype=np.int8)
np.save(SAVA_PATH +'X.npy', X, allow_pickle=True)
np.save(SAVA_PATH +'Y.npy', Y, allow_pickle=True)
np.save(SAVA_PATH +'A.npy', A, allow_pickle=True)