Python Kruskal's algorithm for clustering

751 Views Asked by At

I'm trying to figure out how to find the minimum distance between points in different clusters by determining the clusters with a modified Kruskal's for a series of (x, y) coordinates. I've been spending like 3-4 hours a day for the last half a week but I don't feel like I'm that close. My code is below:

#Uses python3
import sys
import math

class Point(object):
    x = ""
    y = ""
    index = ""
    set = ""
    neighbors = None

    def __init__(self, x, y, index, set):
        self.x = x
        self.y = y
        self.index = index
        self.set = set

    def setSet(self, set):
        self.set = set

    def getSet(self):
        return self.set

    def addNeighbor(self, point):
        if self.neighbors == None:
            self.neighbors = [point]
        else:
            self.neighbors.append(point)

    def getNeighbors(self):
        return self.neighbors

class Edge(object):
    point1 = None
    point2 = None
    length = None

    def __init__(self, pt1, pt2):
        length = (((pt1.x - pt2.x)**2) + ((pt1.y - pt2.y)**2))**(.5)
        # print("Pt 1: " + str(pt1.x) + ", " + str(pt1.y))
        # print("Pt 2: " + str(pt2.x) + ", " + str(pt2.y))
        # print(length)
        self.point1 = pt1
        self.point2 = pt2
        self.length = length

    def getLength(self):
        return self.length

def clustering(x, y, k):
    points = []
    numSets = len(x)
    edges = []
    #write your code here
    for i in range(len(x)):
        new_point = Point(x[i], y[i], i, i)
        points.append(new_point)
    for point in points:
        for neighbor in points:
            if not point is neighbor:
                if point.neighbors == None:
                    # print(point + ": " + neighbor)
                    edge = Edge(point, neighbor)
                    edges.append(edge)
                    point.addNeighbor(neighbor)
                    neighbor.addNeighbor(point)
                    # print (edge.getLength())
                else:
                    if neighbor not in point.neighbors:
                        edge = Edge(point, neighbor)
                        edges.append(edge)
                        point.addNeighbor(neighbor)
                        neighbor.addNeighbor(point)
    edges.sort(key=lambda edge: edge.getLength())
    if k == len(points):
        smallestEdge = edges[0]
        return smallestEdge.length
    while numSets > k:
        last_edge = 0
        for edge in edges:
            # print("edge" + edge)
            pointA = edge.point1
            pointB = edge.point2
            if pointA.set != pointB.set:
                # print("merge points")
                if pointA.set < pointB.set:
                    oldSet = pointB.set
                    pointB.setSet(pointA.set)
                    for point in points:
                        if point.set == oldSet:
                            point.setSet(pointA.set)
                    numSets = numSets - 1
                if pointB.set < pointA.set:
                    oldSet = pointA.set
                    pointA.setSet(pointB.set)
                    for point in points:
                        if point.set == oldSet:
                            point.setSet(pointB.set)
                    numSets = numSets - 1
            # print(numSets)s
            # print("Point A: (" + str(pointA.x) + ", " + str(pointA.y) + ")/Point B: (" + str(pointB.x) + ", " + str(pointB.y) + "), Length: " + str(edge.length))
            last_edge += 1
            if numSets <= k:
                break

        # print(len(sets)
    nextEdge = edges[last_edge + 1]
    pointA = nextEdge.point1
    pointB = nextEdge.point2
    # print("Point A: (" + str(pointA.x) + ", " + str(pointA.y) + ")/Point B: (" + str(pointB.x) + ", " + str(pointB.y) + "), Length: " + str(nextEdge.length))

    return nextEdge.length


if __name__ == '__main__':
    input = sys.stdin.read()
    data = list(map(int, input.split()))
    n = data[0]
    data = data[1:]
    x = data[0:2 * n:2]
    y = data[1:2 * n:2]
    data = data[2 * n:]
    k = data[0]
    print("{0:.9f}".format(clustering(x, y, k)))

The input comes in the form of:

Number of points
x1, y1
x2, y2
.
.
.
.
Number of clusters

It works on the test cases I'm provided (this is for a class), but then fails when trying to validate it due to a test case that I'm not sure what it is, so I'm really not sure what is causing this to fail and haven't been able to find out what it is. What/how many things am I doing wrong here (I'm sure there's a lot)?

0

There are 0 best solutions below