import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
if __name__ == '__mian__':
x1, y1 = np.random.multivariate_normal([10, 25], [[10, 0], [0, 2]], 1000).T
x2, y2 = np.random.multivariate_normal([20, 13], [[30, 0], [0, 40]], 2000).T
x3, y3 = np.random.multivariate_normal([15, -5], [[100, 20], [0, 10]], 2000).T
X1 = np.concatenate((x1.reshape(-1, 1), y1.reshape(-1, 1)), axis=1)
X2 = np.concatenate((x2.reshape(-1, 1), y2.reshape(-1, 1)), axis=1)
X3 = np.concatenate((x3.reshape(-1, 1), y3.reshape(-1, 1)), axis=1)
X = np.concatenate((X1, X2, X3, ), axis=0)
gm = GaussianMixture(n_components=3,
covariance_type='full',
verbose=10)
gm.fit(X)
y = gm.predict(X)
print(y)
print(len(y))
plt.plot(x1, y1, '.')
plt.plot(x2, y2, '.')
plt.plot(x3, y3, '.')
plt.axis('equal')
plt.title('orig')
plt.figure()
plt.plot(x1, y1, 'b.')
plt.plot(x2, y2, 'b.')
plt.plot(x3, y3, 'b.')
plt.axis('equal')
plt.title('data')
plt.figure()
x1, y1 = np.split(X[y == 0, :], 2, axis=1)
x2, y2 = np.split(X[y == 1, :], 2, axis=1)
x3, y3 = np.split(X[y == 2, :], 2, axis=1)
plt.plot(x1, y1, '.')
plt.plot(x2, y2, '.')
plt.plot(x3, y3, '.')
plt.axis('equal')
plt.title('gmm')
plt.show()
sklearn gmm demo
最新推荐文章于 2023-06-03 10:38:42 发布