I want to find the immediate neighbours around a given point in a multidimensional space (up to 7 dimensions).
Important facts about the space:
- non-linear spacing among points within a single dimension. As shown in the screenshot below, the distance between the points is varying

- unequal spacing among different dimensions
(sample code to generate a grid of uneven spacing among dimensions)
x_values = np.linspace(-0.3, 0.3, 5)
y_values = np.linspace(-0.3, 0.3, 5)
z_values = np.linspace(1, 6, 6) # unqual spacing (large spacing in z-direction)
MWE:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.neighbors import KDTree
import numpy as np
# Define ranges for X, Y, and Z values
x_values = np.linspace(-0.3, 0.3, 5)
y_values = np.linspace(-0.3, 0.3, 5)
z_values = np.linspace(1, 6, 6) # unqual spacing (large spacing in z-direction)
# z_values = np.linspace(-0.3, 0.3, 5) # equal spacing case
# Create meshgrid to generate combinations of X, Y, and Z values
X, Y, Z = np.meshgrid(x_values, y_values, z_values)
# Reshape the meshgrid arrays to create a single array of all combinations
points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
# Create a KDTree object with the sample points
kdtree = KDTree(points, leaf_size=30, metric='euclidean')
# Query point for which nearest neighbors will be found
# query_point = np.array([[0, 0, 0]]) # test query point for equal spacing
query_point = np.array([[0, 0, 2]]) # test query point for unequal spacing
# Find the indices of the nearest neighbors and their distances
distances, indices = kdtree.query(query_point, k=27)
# Plot all points in 3D
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(points[:, 0], points[:, 1], points[:, 2], color='blue', label='All Points')
# Plot the query point in 3D
ax.scatter(query_point[:, 0], query_point[:, 1], query_point[:, 2], color='red', label='Query Point')
# Plot the nearest neighbors in 3D
nearest_neighbors = points[indices[0]] # Get nearest neighbors using indices
ax.scatter(nearest_neighbors[:, 0], nearest_neighbors[:, 1], nearest_neighbors[:, 2], color='green', label='Nearest Neighbors')
# Connect the query point with its nearest neighbors in 3D
for neighbor in nearest_neighbors:
ax.plot([query_point[0, 0], neighbor[0]], [query_point[0, 1], neighbor[1]], [query_point[0, 2], neighbor[2]], color='gray', linestyle='--')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('KD-Tree Nearest Neighbors in 3D')
ax.legend()
plt.show()
print()
Results with the above Code:
Required results: Immediate neighbors should be selected from each dimension irrespective their actual distance.


The trick is generally to rescale the axis so that each axis contribute the same way to the distances used for the nn algo. More generally you can consider the scale of each axis as an hyper-parameter that you can tune to give more or less importance to a feature.
Then you have to 'go-back' to the original space for the original distances / plotings. (go back with an inverse transform or simply get the indices and use the original data).
Here is the code:
And the result in the original space: