apache-flink KMeans operation on UnsortedGrouping

240 Views Asked by At

I have a flink DataSet (read from a file) that contains sensor readings from many different sensors. I use flinks groupBy() method to organize the data as an UnsortedGrouping per sensor. Next, I would like to run the KMeans algorithm on every UnsortedGrouping in my DataSet in a distributed way.

My question is, how to efficiently implement this functionality using flink. Below is my current implementation: I have written my own groupReduce() method that applies the flink KMeans algorithm to every UnsortedGrouping. This code works, but seems very slow and uses high amounts of memory.

I think this has to do with the amount of data reorganization I have to do. Multiple data conversions have to be performed to make the code run, because I don't know how to do it more efficiently:

  • UnsortedGrouping to Iterable (start of groupReduce() method)
  • Iterable to LinkedList (need this to use the fromCollection() method)
  • LinkedList to DataSet (required as input to KMeans)
  • resulting KMeans DataSet to LinkedList (to be able to iterate for Collector)

Surely, there must be a more efficient and performant way to implement this? Can anybody show me how to implement this in a clean and efficient flink way?

// *************************************************************************
// VARIABLES
// *************************************************************************

static int numberClusters = 10;
static int maxIterations = 10;
static int sensorCount = 117;
static ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

// *************************************************************************
// PROGRAM
// *************************************************************************

public static void main(String[] args) throws Exception {

    final long startTime = System.currentTimeMillis();

    String fileName = "C:/tmp/data.nt";
    DataSet<String> text = env.readTextFile(fileName);

    // filter relevant DataSet from text file input
    UnsortedGrouping<Tuple2<Integer,Point>> points = text
            .filter(x -> x.contains("Value") && x.contains("valueLiteral")).filter(x -> !x.contains("#string"))
            .map(x -> new Tuple2<Integer, Point>(
                    Integer.parseInt(x.substring(x.indexOf("_") + 1, x.indexOf(">"))) % sensorCount,
                    new Point(Double.parseDouble(x.split("\"")[1]))))
            .filter(x -> x.f0 < 10)
            .groupBy(0);

    DataSet<Tuple2<Integer, Point>> output = points.reduceGroup(new DistinctReduce());
    output.print();

    // print the execution time
    final long endTime = System.currentTimeMillis();
    System.out.println("Total execution time: " + (endTime - startTime) + "ms");
}

public static class DistinctReduce implements GroupReduceFunction<Tuple2<Integer, Point>, Tuple2<Integer, Point>> {

    private static final long serialVersionUID = 1L;

    @Override public void reduce(Iterable<Tuple2<Integer, Point>> in, Collector<Tuple2<Integer, Point>> out) throws Exception {

        AtomicInteger counter = new AtomicInteger(0);
        List<Point> pointsList = new LinkedList<Point>();

        for (Tuple2<Integer, Point> t : in) {
            pointsList.add(new Point(t.f1.x));
        }
        DataSet<Point> points = env.fromCollection(pointsList);

        DataSet<Centroid> centroids = points
                .distinct()
                .first(numberClusters)
                .map(x -> new Centroid(counter.incrementAndGet(), x));
        //DataSet<String> test = centroids.map(x -> String.format("Centroid %s", x)); //test.print();

        IterativeDataSet<Centroid> loop = centroids.iterate(maxIterations); 
        DataSet<Centroid> newCentroids = points // compute closest centroid for each point
                .map(new SelectNearestCenter()).withBroadcastSet(loop,"centroids") // count and sum point coordinates for each centroid
                .map(new CountAppender())
                .groupBy(0)
                .reduce(new CentroidAccumulator()) // compute new centroids from point counts and coordinate sums
                .map(new CentroidAverager());

        // feed new centroids back into next iteration
        DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);

        DataSet<Tuple2<Integer, Point>> clusteredPoints = points // assign points to final clusters
                .map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");

        // emit result System.out.println("Results from the KMeans algorithm:");
        clusteredPoints.print();

        // emit all unique strings.
        List<Tuple2<Integer, Point>> clusteredPointsList = clusteredPoints.collect();
        for(Tuple2<Integer, Point> t : clusteredPointsList) {
            out.collect(t);
        }
    }
}
1

There are 1 best solutions below

0
On

You have to group the data points and the centroids first. Then you iterate over the centroids and co groups them with the data points. For each point in a group you assign it to the closest centroid. Then you group on the initial group index and the centroid index to reduce all points assigned to the same centroid. That will be the result of one iteration.

The code could look the following way:

DataSet<Tuple2<Integer, Point>> groupedPoints = ...

DataSet<Tuple2<Integer, Centroid>> groupCentroids = ...

IterativeDataSet<Tuple2<Integer, Centroid>> groupLoop = groupCentroids.iterate(10);

DataSet<Tuple2<Integer, Centroid>> newGroupCentroids = groupLoop
    .coGroup(groupedPoints).where(0).equalTo(0).with(new CoGroupFunction<Tuple2<Integer,Centroid>, Tuple2<Integer,Point>, Tuple4<Integer, Integer, Point, Integer>>() {
    @Override
    public void coGroup(Iterable<Tuple2<Integer, Centroid>> centroidsIterable, Iterable<Tuple2<Integer, Point>> points, Collector<Tuple4<Integer, Integer, Point, Integer>> out) throws Exception {
        // cache centroids
        List<Tuple2<Integer, Centroid>> centroids = new ArrayList<>();
        Iterator<Tuple2<Integer, Centroid>> centroidIterator = centroidsIterable.iterator();

        for (Tuple2<Integer, Point> pointTuple : points) {
            double minDistance = Double.MAX_VALUE;
            int minIndex = -1;
            Point point = pointTuple.f1;

            while (centroidIterator.hasNext()) {
                centroids.add(centroidIterator.next());
            }

            for (Tuple2<Integer, Centroid> centroidTuple : centroids) {
                Centroid centroid = centroidTuple.f1;
                double distance = point.euclideanDistance(centroid);

                if (distance < minDistance) {
                    minDistance = distance;
                    minIndex = centroid.id;
                }
            }

            out.collect(Tuple4.of(minIndex, pointTuple.f0, point, 1));
        }
    }})
    .groupBy(0, 1).reduce(new ReduceFunction<Tuple4<Integer, Integer, Point, Integer>>() {
        @Override
        public Tuple4<Integer, Integer, Point, Integer> reduce(Tuple4<Integer, Integer, Point, Integer> value1, Tuple4<Integer, Integer, Point, Integer> value2) throws Exception {
            return Tuple4.of(value1.f0, value1.f1, value1.f2.add(value2.f2), value1.f3 + value2.f3);
        }
    }).map(new MapFunction<Tuple4<Integer,Integer,Point,Integer>, Tuple2<Integer, Centroid>>() {
        @Override
        public Tuple2<Integer, Centroid> map(Tuple4<Integer, Integer, Point, Integer> value) throws Exception {
            return Tuple2.of(value.f1, new Centroid(value.f0, value.f2.div(value.f3)));
        }
    });

DataSet<Tuple2<Integer, Centroid>> result = groupLoop.closeWith(newGroupCentroids);