class compute_scatter_matrix(object):
def __init__(self,X,labels):
self.X = X
self.labels = labels
def compute_scatter(self):
meanAll = np.mean(self.X,axis=0)
St = np.dot((self.X-meanAll).T,(self.X-meanAll))
xclasses = {}
meanclasses = {}
Sw = np.zeros((self.X.shape[1],self.X.shape[1]))
for label in np.unique(self.labels):
xclasses[label] = [self.X[i] for i in range(len(self.labels)) if self.labels[i]==label]
meanclasses[label] = np.mean(xclasses[label],axis=0)
Sw += np.dot((xclasses[label]-meanclasses[label]).T,(xclasses[label]-meanclasses[label]))
SB = np.zeros((self.X.shape[1],self.X.shape[1]))
#第二种SB的计算方法
for label in np.unique(self.labels):
n = np.array(xclasses[label]).shape[0]
mean_vec = np.array(meanclasses[label]