How can I plot a histogram on rotated axis in projection space (LDA)

129 Views Asked by At

I used the mean separation criterion (find parameter w that maximizes distance between class means) as well as Fisher LDA to find a nice line to separate two linearly separable classes. Then I plotted the histogram. As you can see there is a lot of overlap. Now I want to project the points on a line (parallel to the connection line between the class means) and plot a histogram on top of this line (Reproduce Fisher linear discriminant figure). Somebody did it in Matlab but I dont have any idea how to translate this into python. Does anyone know how to do that?

Thanks in advance!

enter image description here

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

# set seed
np.random.seed(8)
X, y = datasets.make_blobs(n_samples=100, centers=2, n_features=2, center_box=(0, 10))

# Mean separation vector
mu1 = np.mean(X[y == 0], axis=0)
mu2 = np.mean(X[y == 1], axis=0)
w = (mu2 - mu1) / np.linalg.norm(mu2 - mu1)

# Plot the mean separation vector with equivalent axis ticks
plt.figure(figsize=(5, 5))
plt.xlim(0, 15)
plt.ylim(0, 15)
plt.xticks(np.arange(0, 15, 1))
plt.yticks(np.arange(0, 15, 1))
plt.grid()
plt.scatter(X[:, 0][y == 0], X[:, 1][y == 0], label='Class 1')
plt.scatter(X[:, 0][y == 1], X[:, 1][y == 1], label='Class 2')
plt.plot(mu1[0], mu1[1], 'X', color='red', markersize=10, label='Mean of class 1')
plt.plot(mu2[0], mu2[1], 'X', color='red', markersize=10, label='Mean of class 2')
plt.legend()
plt.plot([mu1[0], mu2[0]], [mu1[1], mu2[1]], 'k--')
plt.arrow((mu1[0] + mu2[0]) / 2, (mu1[1] + mu2[1]) / 2, w[1], -w[0], head_width=0.3, head_length=0.3, fc='k', ec='k')
# histogram of the original data
plt.hist(X[:, 0][y == 0], bins=10, alpha=0.5, label='Class 1')
plt.hist(X[:, 0][y == 1], bins=10, alpha=0.5, label='Class 2')
plt.legend()

# Project the data on the mean separation vector
X_proj = np.dot(X, w)
0

There are 0 best solutions below