CS Mines Flood Fill problem - Java solution times out while Cpp and python succeeds (Performance improvement)

139 Views Asked by At

Working to solve this problem

https://mines20.kattis.com/problems/mines20.paintbucket

Java solution using TreeSet times out with time limit of less than 3 seconds

public class PaintBucketCSMinesSolution {

    public static void main(String[] args) throws IOException {
        // maintain distinct list of points that have been visited
        Set<int[]> exploredSet = new TreeSet<>(
                Comparator.comparingInt((int[] el) -> el[1]).thenComparingInt(el -> el[0]));
        int W;
        int H;
        int X;
        int Y;
        int[][] picture;
        BufferedReader r = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(r.readLine());
        W = Integer.parseInt(st.nextToken());
        H = Integer.parseInt(st.nextToken());
        X = Integer.parseInt(st.nextToken());
        Y = Integer.parseInt(st.nextToken());
        picture = new int[H][W];
        for (int i = 0; i < H; i++) {
            st = new StringTokenizer(r.readLine());
            for (int j = 0; j < W; j++) {
                picture[i][j] = Integer.parseInt(st.nextToken());
            }
        }

        int color = picture[Y][X];

        Stack<int[]> toExploreStack = new Stack<>();

        toExploreStack.add(new int[] { X, Y });

        while (!toExploreStack.isEmpty()) {
            int[] point = toExploreStack.pop();
            int px = point[0];
            int py = point[1];
            // System.out.println(exploredSet.contains(point));
            boolean execute = true;

            if (exploredSet.contains(point)) {
                execute = false;
            }

            if (execute) {
                exploredSet.add(point);

                if (px > 0 && picture[py][px - 1] == color) {
                    toExploreStack.add(new int[] { px - 1, py });
                }

                if (px < (W - 1) && picture[py][px + 1] == color) {
                    toExploreStack.add(new int[] { px + 1, py });
                }

                if (py > 0 && picture[py - 1][px] == color) {
                    toExploreStack.add(new int[] { px, py - 1 });
                }

                if (py < (H - 1) && picture[py + 1][px] == color) {
                    toExploreStack.add(new int[] { px, py + 1 });
                }
            }

        }
        // exploredSet.sort(Comparator.comparingInt((int[] el) ->
        // el[1]).thenComparingInt(el -> el[0]));
        // Sorting HashSet using List

        for (int[] v : exploredSet) {
            System.out.println(v[0] + " " + v[1]);
        }

    }
}

Java HashSet solution, which also times out

package datastructures.algorithms;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Scanner;
import java.util.Set;

public class PaintBucketCSMinesSolution {

    public static void main(String[] args) throws IOException {
        // maintain distinct list of points that have been visited
//      Set<int[]> exploredSet = new TreeSet<>(
//              Comparator.comparingInt((int[] el) -> el[1]).thenComparingInt(el -> el[0]));
        Set<List<Integer>> exploredSet = new HashSet<>();
        int W;
        int H;
        int X;
        int Y;
        int[][] picture;
        Scanner sc = new Scanner(System.in);
        W = sc.nextInt();
        H = sc.nextInt();
        X = sc.nextInt();
        Y = sc.nextInt();
        picture = new int[H][W];
        for (int i = 0; i < H; i++) {
            for (int j = 0; j < W; j++) {
                picture[i][j] = sc.nextInt();
            }
        }
        sc.close();
        int color = picture[Y][X];

        ArrayDeque<List<Integer>> toExploreStack = new ArrayDeque<>();

        toExploreStack.add(Arrays.asList(X, Y));

        while (!toExploreStack.isEmpty()) {
            List<Integer> point = toExploreStack.pop();
            int px = point.get(0);
            int py = point.get(1);
            // System.out.println(exploredSet.contains(point));
            boolean execute = true;

            if (exploredSet.contains(point)) {
                execute = false;
            }

            if (execute) {
                exploredSet.add(point);

                if (px > 0 && picture[py][px - 1] == color) {
                    toExploreStack.add(Arrays.asList(px - 1, py));
                }

                if (px < (W - 1) && picture[py][px + 1] == color) {
                    toExploreStack.add(Arrays.asList(px + 1, py));
                }

                if (py > 0 && picture[py - 1][px] == color) {
                    toExploreStack.add(Arrays.asList(px, py - 1));
                }

                if (py < (H - 1) && picture[py + 1][px] == color) {
                    toExploreStack.add(Arrays.asList(px, py + 1));
                }
            }

        }

        // Sorting HashSet using List

        // exploredSet.sort(Comparator.comparingInt((List<Integer> el) ->
        // el.get(1)).thenComparingInt(el -> el.get(0)));
        Comparator<List<Integer>> comp = Comparator.comparingInt((List<Integer> x) -> x.get(1))
                .thenComparingInt(x -> x.get(0));
        exploredSet.stream().sorted(comp).map(v -> v.get(0) + " " + v.get(1)).forEach(System.out::println);

    }
}

Java HashSet String solution also timed out

package datastructures.algorithms;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Scanner;
import java.util.Set;

public class PaintBucketStringSolution {

    public static void main(String[] args) throws IOException {
        // maintain distinct list of points that have been visited
//      Set<int[]> exploredSet = new TreeSet<>(
//              Comparator.comparingInt((int[] el) -> el[1]).thenComparingInt(el -> el[0]));
        Set<String> exploredSet = new HashSet<>();
        int W;
        int H;
        int X;
        int Y;
        int[][] picture;
        Scanner sc = new Scanner(System.in);
        W = sc.nextInt();
        H = sc.nextInt();
        X = sc.nextInt();
        Y = sc.nextInt();
        picture = new int[H][W];
        for (int i = 0; i < H; i++) {
            for (int j = 0; j < W; j++) {
                picture[i][j] = sc.nextInt();
            }
        }
        sc.close();
        int color = picture[Y][X];

        ArrayDeque<List<Integer>> toExploreStack = new ArrayDeque<>();

        toExploreStack.add(Arrays.asList(X, Y));

        while (!toExploreStack.isEmpty()) {
            List<Integer> point = toExploreStack.pop();

            int px = point.get(0);
            int py = point.get(1);
            // String.valueOf(
            String pointString = String.valueOf(px) + " " + String.valueOf(py);
            // System.out.println(exploredSet.contains(point));
            boolean execute = true;

            if (exploredSet.contains(pointString)) {
                execute = false;
            }

            if (execute) {
                exploredSet.add(pointString);

                if (px > 0 && picture[py][px - 1] == color) {
                    toExploreStack.add(Arrays.asList(px - 1, py));
                }

                if (px < (W - 1) && picture[py][px + 1] == color) {
                    toExploreStack.add(Arrays.asList(px + 1, py));
                }

                if (py > 0 && picture[py - 1][px] == color) {
                    toExploreStack.add(Arrays.asList(px, py - 1));
                }

                if (py < (H - 1) && picture[py + 1][px] == color) {
                    toExploreStack.add(Arrays.asList(px, py + 1));
                }
            }

        }

        // Sorting HashSet using List

        exploredSet.stream().sorted(Comparator.comparing(x -> {
            String[] s = x.split(" ");
            return s[1] + " " + s[0];
        })).forEach(System.out::println);

    }
}

While solution in python succeeds for all test cases in 2.65 seconds

W, H,X,Y = map(int, input().split())

picture=[]

for y in range(H):
    picture.append([int(c) for c in input().split()])

#print(picture)
# implement a stack
to_explore = [(X , Y)]
explored=set()
color = picture[Y][X]
#print("color" , color)

while len(to_explore) >0:
    px, py = to_explore.pop()

    if(px, py) in explored:
        continue
    explored.add((px, py))

    if (px > 0 and picture[py][px - 1] == color) :
       to_explore.append( (px - 1, py ))

    if (px < (W - 1) and picture[py][px + 1] == color) :
        to_explore.append((px + 1, py ))


    if (py > 0 and picture[py - 1][px] == color) :
        to_explore.append( (px, py - 1) )

    if (py < (H - 1) and picture[py + 1][px] == color) :
        to_explore.append(( px, py + 1) )
        
for x,y in sorted(explored, key=lambda x: (x[1], x[0])):
    print(x, y)    

CPP solution finishes in 2.05 seconds

#include <algorithm>
#include <array>
#include <iostream>
#include <set>
#include <stack>
#include <vector>

int main() {
    // comparison to sort the set naturally on inserting elements
    const auto compare = [](const std::array<int, 2> &lhs,
                            const std::array<int, 2> &rhs) {
        return lhs[1] < rhs[1] || (lhs[1] == rhs[1] && lhs[0] < rhs[0]);
    };

    // set to maintain unique elements as it is used for lookup to check if we have visited the node already
    std::set<std::array<int, 2>, decltype(compare)> exploredSet(compare);
    // std::set<std::array<int, 2>> exploredSet;
    int W;
    int H;
    int X;
    int Y;
    // std::vector<std::vector<int>> picture;
    std::cin >> W;
    std::cin >> H;
    std::cin >> X;
    std::cin >> Y;
    int picture[W][H];
    for (int i = 0; i < H; i++) {
        for (int j = 0; j < W; j++) {
            std::cin >> picture[i][j];
        }
    }
    int color = picture[Y][X];

    std::stack<std::array<int, 2>> toExploreStack;
    toExploreStack.push({X, Y});

    while (!toExploreStack.empty()) {
        std::array<int, 2> point = toExploreStack.top();
        toExploreStack.pop();
        int px = point[0];
        int py = point[1];
        bool execute = true;

        // plain for loop was slow, this method is fast
        if (exploredSet.count(point) == 1) {
            execute = false;
        }

        if (execute) {
            exploredSet.insert(point);

            if (px > 0 && picture[py][px - 1] == color) {
                toExploreStack.push({px - 1, py});
            }

            if (px < (W - 1) && picture[py][px + 1] == color) {
                toExploreStack.push({px + 1, py});
            }

            if (py > 0 && picture[py - 1][px] == color) {
                toExploreStack.push({px, py - 1});
            }

            if (py < (H - 1) && picture[py + 1][px] == color) {
                toExploreStack.push({px, py + 1});
            }
        }
    }

    for (std::array<int, 2> v : exploredSet) {
        std::cout << v[0] << " " << v[1] << std::endl;
    }
}

Please help in optimizing Java solution. I have run out of ideas.

Already tried

  • HashMap and sort it in end
  • HashSet and sort it in the end.

Also it is surprising python in this case is able to finish faster than Java implementation.

1

There are 1 best solutions below

6
IWilms On

TreeSet seems to have too much overhead in combination with the custom comparator. The HashSet version should work and is the same logic as the other versions. You just need to work with objects in order to have proper hashCode. You could try changing the relevant lines to:

Set<List<Integer>> exploredSet = new HashSet<>();
...
if (exploredSet.contains(Arrays.asList(px, py))) {
...
exploredSet.add(Arrays.asList(px, py));
...
Comparator<List<Integer>> comp = Comparator.comparingInt((List<Integer> x) -> x.get(1)).thenComparingInt(x -> x.get(0));
exploredSet.stream()
        .sorted(comp)
        .map(v -> v.get(0) + " " + v.get(1))
        .forEach(System.out::println);

EDIT

In general, if the Python solution is within the time limit an equivalent Java version with the same logic should be too. So why wasn't it? "Proper HashCode" - that was exactly the problem with my solution...sorry, I was blind. Java has a very simple hashCode function. For a List of coordinates X, Y it's just X + 31*Y. So there are a lot of collisions. For 0 <= X, Y <= 999 there are just 31969 distinct hashCodes. That slows down the whole process. Using Strings should be faster and hopefully within the time limit:

Set<String> exploredSet = new HashSet<>();
...
if (exploredSet.contains(px + " " + py)) {
...
exploredSet.add(px + " " + py);
...
exploredSet.stream()
        .sorted(Comparator.comparing(x -> {
            String[] s = x.split(" ");
            return s[1] + " " + s[0]; }))
        .forEach(System.out::println);

EDIT2

This is really weird. They probably pimped the python machine. Ok, one very last try with a boolean array explored instead of the Set exploredSet. That should be much faster, but it doesn't answer why the HashSet version is so slow.

boolean[] explored = new boolean[H*W]; // <- put this line after the 2 for-loops
...
if (explored[py*H+px]) {
...
explored[py*H+px] = true;
...
int index = 0;
for(int y=0; y<H; y++){
    for(int x=0; x<W; x++){
        if(explored[index]){
            System.out.println(x + " " + y);
        }
        index++;
    }
}