(Not logged on) | Log On
View  History

  

Recursive Matrix Multiplication

2/15/2008 4:41 PM
You can subscribe to this wiki article using an RSS feed reader.

Up to jsr166y.forkjoin Examples

This involves a number of multiplication implementations all of which implement the same interface:

public interface MatrixOperations {
 /** Add product of two matrices to a third.
  * c += a*b;
  * @param a
  * @param b
  * @param c
  */
 void multiplyAccumulate(MatrixDouble2D a, MatrixDouble2D b, MatrixDouble2D c);
}

We implement multiplyAccumulate rather than a simple multiply for the convenience of the recursive implementations. The matrix class is MatrixDouble2D.

Basic Multiply

This is the trivial implementation:

 public void multiplyAccumulate(MatrixDouble2D a, MatrixDouble2D b, MatrixDouble2D c) {
    for (int i=0; i<a.rowSize(); i++) {
       for (int j=0; j<b.columnSize(); j++) {
          double s=c.get(i,j);
          for (int k=0; k<a.columnSize(); k++)
              s += a.get(i,k)*b.get(k,j);
          c.set(i,j, s);
       }
    }
 }

Local Optimised multiply

Here we copy the columns from matrix b so that they are contiguous (helps keep it in the cache), and unroll the loops. The unrolling allows the CPU to keep its pipelines busier. Originally I unrolled the innermost loop as well, but this provided no improvement. The use of the VectorDouble class instead of direct use of a double[] makes no difference to performance (at least with Java6u4 server).

public class MatrixOps2b implements MatrixOperations {
    public static final MatrixOps2b INSTANCE = new MatrixOps2b();

    public void multiplyAccumulate(MatrixDouble2D a, MatrixDouble2D b, MatrixDouble2D c) {
        VectorDouble column = new VectorDouble(b.rowSize());
        int L = a.rowSize();
        int M = a.columnSize();
        int N = b.columnSize();
        for (int j = 0; j < N; j++) {
            b.copyColumn(j, column);
            int i;
            for (i = 0; i < L - 1; i += 2) {
                VectorDouble r0 = a.getRow(i);
                VectorDouble r1 = a.getRow(i + 1);
                double s0 = c.get(i, j);
                double s1 = c.get(i + 1, j);
                for (int k = 0; k < M; k++) {
                    s0 += r0.get(k) * column.get(k);
                    s1 += r1.get(k) * column.get(k);
                }
                c.set(i, j, s0);
                c.set(i + 1, j, s1);
            }
            for (;i < L; i++) {
                VectorDouble row = a.getRow(i);
                double s = c.get(i, j);
                for (int k = 0; k < M; k++) {
                    s += row.get(k) * column.get(k);
                }
                c.set(i, j, s);
            }
        }
    }
}

Recursive Sequential Multiply

This implemention delegates multiplication of small matrices to another instance of MatrixOperations. Each stage slices each matrix approximately in half on each axis unless that axis is already small. The slice point is rounded up to avoid splitting cache lines (at least that is the theory, I haven't tested its effectiveness).

public class MatrixOps5 implements MatrixOperations {
    private final MatrixOperations inner;
    private final int innerSize;

    public static final MatrixOps5 INSTANCE = new MatrixOps5(MatrixOps4.INSTANCE, 400);

    public MatrixOps5(MatrixOperations inner, int innerSize) {
        this.inner = inner;
        this.innerSize = innerSize;
    }

    private int slice(int n) {
        return ((n >> 1) + 7) & ~7;
    }

    public void multiplyAccumulate(MatrixDouble2D a, MatrixDouble2D b, MatrixDouble2D c) {
        if (a.rowSize() <= innerSize && a.columnSize() <= innerSize && b.columnSize() <= innerSize) {
            inner.multiplyAccumulate(a, b, c);
            return;
        }
        // try to force pieces towards square?
        MatrixDouble2D[] aa;
        if (a.rowSize() <= innerSize)
            aa = new MatrixDouble2D[]{a};
        else {
            int n = slice(a.rowSize());
            aa = new MatrixDouble2D[]{a.rowSlice(0, n), a.rowSlice(n, a.rowSize() - n)};
        }
        MatrixDouble2D[] bb;
        if (b.columnSize() <= innerSize)
            bb = new MatrixDouble2D[]{b};
        else {
            int n = slice(b.columnSize());
            bb = new MatrixDouble2D[]{b.columnSlice(0, n), b.columnSlice(n, b.columnSize() - n)};
        }
        for (MatrixDouble2D aRows : aa) {
            int firstRow = aRows.firstRowIndex() - a.firstRowIndex();
            int rowSize = aRows.rowSize();
            for (MatrixDouble2D bCols : bb) {
                MatrixDouble2D cc = c.submatrix(firstRow, rowSize, bCols.firstColumnIndex() - b.firstColumnIndex(), bCols.columnSize());
                if (aRows.columnSize() <= innerSize) {
                    multiplyAccumulate(aRows, bCols, cc);
                }
                else {
                    int n = slice(aRows.columnSize());
                    multiplyAccumulate(aRows.columnSlice(0, n), bCols.rowSlice(0, n), cc);
                    int r = aRows.columnSize() - n;
                    multiplyAccumulate(aRows.columnSlice(n, r), bCols.rowSlice(n, r), cc);
                }
            }
        }

    }
}

Recursive Concurrent Multiply

This algorithm is identical to the recursive sequential multiply except that the subproblems are executed via the ForkJoin framework.

public class ConcurrentMatrixOps implements MatrixOperations {
    private static ConcurrentMatrixOps instance;
    private final ForkJoinPool pool;
    private final int innerSize;
    private final MatrixOperations inner;

    public ConcurrentMatrixOps(MatrixOperations inner, int innerSize) {
        pool = new ForkJoinPool();
        this.inner = inner;
        this.innerSize = innerSize;
    }

    int innerSize() {
        return innerSize;
    }

    MatrixOperations inner() {
        return inner;
    }

    public void multiplyAccumulate(MatrixDouble2D a, MatrixDouble2D b, MatrixDouble2D c) {
        if (a.rowSize() <= innerSize && a.columnSize() <= innerSize && b.columnSize() <= innerSize)
            inner.multiplyAccumulate(a, b, c);
        else
            pool.invoke(new ConcurrentMultiplyAccumulate(this, a, b, c));
    }
}

class ConcurrentMultiplyAccumulate extends RecursiveAction {
    private ConcurrentMatrixOps ops;
    private MatrixDouble2D a, b, c;

    ConcurrentMultiplyAccumulate(ConcurrentMatrixOps ops, MatrixDouble2D a, MatrixDouble2D b, MatrixDouble2D c) {
        this.ops = ops;
        this.a = a;
        this.b = b;
        this.c = c;
    }

    private int innerSize() {
        return ops.innerSize();
    }

    private int slice(int n) {
        return ((n >> 1) + 7) & ~7;
    }

    protected void compute() {
        if (a.rowSize() <= innerSize() && a.columnSize() <= innerSize() && b.columnSize() <= innerSize()) {
            ops.inner().multiplyAccumulate(a, b, c);
            return;
        }
        MatrixDouble2D[] aa;
        if (a.rowSize() <= innerSize())
            aa = new MatrixDouble2D[]{a};
        else {
            int n = slice(a.rowSize());
            aa = new MatrixDouble2D[]{a.rowSlice(0, n), a.rowSlice(n, a.rowSize() - n)};
        }
        MatrixDouble2D[] bb;
        if (b.columnSize() <= innerSize())
            bb = new MatrixDouble2D[]{b};
        else {
            int n = slice(b.columnSize());
            bb = new MatrixDouble2D[]{b.columnSlice(0, n), b.columnSlice(n, b.columnSize() - n)};
        }
        RecursiveAction[] subtasks = new RecursiveAction[aa.length * bb.length];
        int index = 0;
        for (MatrixDouble2D aRows : aa) {
            int firstRow = aRows.firstRowIndex() - a.firstRowIndex();
            int rowSize = aRows.rowSize();
            for (MatrixDouble2D bCols : bb) {
                MatrixDouble2D cc = c.submatrix(firstRow, rowSize, bCols.firstColumnIndex() - b.firstColumnIndex(), bCols.columnSize());
                if (aRows.columnSize() <= innerSize()) {
                    subtasks[index++] = new ConcurrentMultiplyAccumulate(ops, aRows, bCols, cc);
                }
                else {
                    int n = slice(aRows.columnSize());
                    int r = aRows.columnSize() - n;
                    subtasks[index++] = new Seq(new ConcurrentMultiplyAccumulate(ops, aRows.columnSlice(0, n), bCols.rowSlice(0, n), cc),
                            new ConcurrentMultiplyAccumulate(ops, aRows.columnSlice(n, r), bCols.rowSlice(n, r), cc));
                }
            }
        }
        forkJoin(subtasks);
    }

    private static class Seq extends RecursiveAction {
        private RecursiveAction a, b;

        Seq(RecursiveAction a, RecursiveAction b) {
            this.a = a;
            this.b = b;
        }

        protected void compute() {
            a.forkJoin();
            b.forkJoin();
        }
    }
}

Alternative subdivision strategy

This time select small matrices on the basis of the total number of elements. Dimensions which are already small relative to the largest dimension are not split. This forces the submatrices towards square.

    protected void compute() {
        if (a.columnSize() * (a.rowSize() + b.columnSize()) <= innerSize()) {
            ops.inner().multiplyAccumulate(a, b, c);
            return;
        }
        int minimumSplit = Math.max(Math.max(a.rowSize(), a.columnSize()), b.columnSize()) / 2;
        MatrixDouble2D[] aa;
        if (a.rowSize() <= minimumSplit)
            aa = new MatrixDouble2D[]{a};
        else {
            int n = slice(a.rowSize());
            aa = new MatrixDouble2D[]{a.rowSlice(0, n), a.rowSlice(n, a.rowSize() - n)};
        }
        MatrixDouble2D[] bb;
        if (b.columnSize() <= minimumSplit)
            bb = new MatrixDouble2D[]{b};
        else {
            int n = slice(b.columnSize());
            bb = new MatrixDouble2D[]{b.columnSlice(0, n), b.columnSlice(n, b.columnSize() - n)};
        }
        RecursiveAction[] subtasks = new RecursiveAction[aa.length * bb.length];
        int index = 0;
        int aOffset = 0;
        for (MatrixDouble2D aRows : aa) {
            int rowSize = aRows.rowSize();
            int bOffset = 0;
            for (MatrixDouble2D bCols : bb) {
                MatrixDouble2D cc = c.submatrix(aOffset, rowSize, bOffset, bCols.columnSize());
                if (aRows.columnSize() <= minimumSplit) {
                    subtasks[index++] = new ConcurrentMultiplyAccumulate2(ops, aRows, bCols, cc);
                }
                else {
                    int n = slice(aRows.columnSize());
                    int r = aRows.columnSize() - n;
                    subtasks[index++] = new ConcurrentMultiplyAccumulate2.Seq(new ConcurrentMultiplyAccumulate2(ops, aRows.columnSlice(0, n), bCols.rowSlice(0, n), cc),
                            new ConcurrentMultiplyAccumulate2(ops, aRows.columnSlice(n, r), bCols.rowSlice(n, r), cc));
                }
                bOffset += bCols.columnSize();
            }
            aOffset += rowSize;
        }
        forkJoin(subtasks);
    }

The innerSize parameter should now be something like 20000, otherwise the rest of the concurrent multiply doesn't change.

Results

Using Java 6u4 server on my quad core Intel Q6600, and with matrices of around 1200 x 1200 (which exceeds the size of the level 2 cache), the basic algorithm achieves around 90MFLOPS, the recursive sequential code 1600MFLOPS, and the concurrent code manages 6000MFLOPS. The subdividing strategy used above can be improved. For example basing it on the number of elements rather than assessing each side gives a slightly more consistent result.

Author: Mark Thornton (mthornton@optrak.co.uk).