Smallest sum of triplet products where the middle element is removed using Dynamic Programming

130 Views Asked by At

I have given a sequence of N numbers (4 ≤ N ≤ 150). One index i (0 < i < N) is picked and multiplied with the left and the right number, in other words with i-1 and i+1. Then the i-th number is removed. This is done until the sequence has only two numbers left over. The goal is to find the smallest sum of these products which obviously depends on the order in which the indices are picked.

E.g. for the sequence 44, 45, 5, 39, 15, 22, 10 the smallest sum would be 17775 using the indices in the following order: 1->3->4->5->2 which is the sum: 44*45*5 + 5*39*15 + 5*15*22 + 5*22*10 + 44*5*10 = 9900 + 2925 + 1650 + 1100 + 2200 = 17775

I have found a solution using a recursive function:

public static int smallestSum(List<Integer> values) {
    if (values.size() == 3)
        return values.get(0) * values.get(1) * values.get(2);
    else {
        int ret = Integer.MAX_VALUE;

        for (int i = 1; i < values.size() - 1; i++) {
            List<Integer> copy = new ArrayList<Integer>(values);
            copy.remove(i);

            int val = smallestSum(copy) + values.get(i - 1) * values.get(i) * values.get(i + 1);
            if (val < ret) ret = val; 
        }

        return ret;
    }
}

However, this solution is only feasible for small N but not for a bigger amount of numbers. What I am looking for is a way to do this using an iterative Dynamic Programming approach.

2

There are 2 best solutions below

3
On

The optimal substructure needed for a DP is that, given the identity of the last element removed, the elimination strategy for the elements to the left is independent of the elimination strategy for the elements to the right. Here's a new recursive function (smallestSumA, together with the version from the question and a test harness comparing the two) incorporating this observation:

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class Foo {
  public static void main(String[] args) {
    Random r = new Random();
    for (int i = 0; i < 10000; i++) {
      List<Integer> values = new ArrayList<Integer>();
      for (int n = 3 + r.nextInt(8); n > 0; n--) {
        values.add(r.nextInt(100));
      }
      int a = smallestSumA(values, 0, values.size() - 1);
      int q = smallestSumQ(values);
      if (q != a) {
        System.err.println("oops");
        System.err.println(q);
        System.err.println(a);
        System.err.println(values);
      }
    }
  }

  public static int smallestSumA(List<Integer> values, int first, int last) {
    if (first + 2 > last)
      return 0;
    int ret = Integer.MAX_VALUE;
    for (int i = first + 1; i <= last - 1; i++) {
      int val = (smallestSumA(values, first, i)
          + values.get(first) * values.get(i) * values.get(last) + smallestSumA(values, i, last));
      if (val < ret)
        ret = val;
    }
    return ret;
  }

  public static int smallestSumQ(List<Integer> values) {
    if (values.size() == 3)
      return values.get(0) * values.get(1) * values.get(2);
    else {
      int ret = Integer.MAX_VALUE;

      for (int i = 1; i < values.size() - 1; i++) {
        List<Integer> copy = new ArrayList<Integer>(values);
        copy.remove(i);

        int val = smallestSumQ(copy) + values.get(i - 1) * values.get(i) * values.get(i + 1);
        if (val < ret)
          ret = val;
      }

      return ret;
    }
  }
}

Invoke as smallestSum(values, 0, values.size() - 1).

To get the DP, observe that there are only N choose 2 different settings for first and last, and memoize. The running time is O(N^3).

0
On

If anyone is interested in a DP solution, based on David Eisenstat's recursive solution, here is an iterative one using DP (for many big numbers it's useful to replace int's with long's):

public static int smallestSum(List<Integer> values) {
    int[][] table = new int[values.size()][values.size()];

    for (int i = 2; i < values.size(); i++) {
        for (int j = 0; j + i < values.size(); j++) {
            int ret = Integer.MAX_VALUE;

            for (int k = j + 1; k <= j + i - 1; k++) {
                int val = table[j][k] + values.get(j) * values.get(k) * values.get(j + i) + table[k][j + i];
                if (val < ret) ret = val;
            }

            table[j][j + i] = ret;
        }
    }

    return table[0][values.size() - 1];
}