Suppose that we want to multiply four matrices, A x B x C x D, of dimensions 50 x 20, 20 x 1, 1 x 10, and 10 x 100, respectively. This will involve iteratively multiplying two matrices at a time. Matrix multiplication is not commutative but it is associative. Thus we can compute our product of four matrices in many different ways, depending on how we parenthesize it. Are some of these better than others?
Multiplying an m x n matrix by an n x p matrix takes m*n*p multiplications, to a good enough approximation. Using this formula, let's compare several different ways of evaluating A x B x C x D:
You can see, the order of multiplications makes a big difference in the final running time!
Dynamic Programming Solution:
If we want to compute A1 x A2 x ... x An, where the Ai's are matrices with dimensions M0 x M1, M1 x M2, ... ,Mn-1 x Mn respectively. The first thing to notice is that a particular parenthesization can be represented by a binary tree in which the individual matrices correspond to the leaves, the root is the final product, and interior nodes are intermediate products.
The binary trees in the above figure are suggestive: for a tree to be optimal, its sub-trees must also be optimal. What are the sub-problems corresponding to the sub-trees? They are products of the form Ai x Ai+1 x ... x Aj. Let's see if this works: for 1 <= i <= j <= n, define
The size of this sub-problem is the number of matrix multiplications, |j - i|. The smallest sub-problem is when i = j, in which case there's nothing to multiply, so C(i, i) = 0. For j > i, consider the optimal sub-tree for C(i, j). The first branch in this subtree, the one at the top, will split the product in two pieces, of the form Ai x ... x Ak and Ak+1 x ... x Aj , for some k between i and j. The cost of the subtree is then the cost of these two partial products, plus the cost of combining them: C(i, k) + C(k + 1, j) + Mi-1 * Mk * Mj and we just need to find the splitting point k for which following is the smallest:
Implementation:
int ChainMatrixMultiplication(int dims[], int size)
{
int** C = new int* [size];
for(int i = 0; i <= size; ++i)
{
C[i] = new int[size];
C[i][i] = 0;
}
for(int s = 1; s < size; ++s)
{
for(int i = 1; i < size - s; ++i)
{
int j = i + s;
C[i][j] = INT_MAX;
for(int k = i; k < j; ++k)
{
int count = C[i][k] + C[k + 1][j] + dims[i - 1] * dims[k] * dims[j];
if(count < C[i][j])
C[i][j] = count;
}
}
}
return C[1][size - 1];
}
Time Complexity: O(n3)