Swift Algorithm Club: Strassen’s Algorithm

In this tutorial, you’ll learn how to implement Strassen’s Matrix Multiplication in Swift. This was the first matrix multiplication algorithm to beat the naive O(n³) implementation, and is a fantastic example of the Divide and Conquer coding paradigm — a favorite topic in coding interviews. By Richard Ash.

Leave a rating/review
Download materials
Save for later
Share
You are currently viewing page 2 of 4 of this article. Click here to view the first page.

Naive Matrix Multiplication

Start by downloading the materials using the Download Materials button found at the top and bottom of this tutorial. Open StrassenAlgorithm-starter.playground.

Instead of dealing with implementing the specifics of a matrix in Swift, the starter project includes helper methods and a Matrix class to help you focus on learning matrix multiplication and Strassen’s algorithm.

Subscript Methods

  • subscript(row:column:): Returns the element at a specified row and column.
  • subscript(row:): Returns the row at a specified index.
  • subscript(column:): Returns the column at a specified index.

Term-by-Term Matrix Math

  • * (lhs:rhs:): Multiplies two matrices term-by-term.
  • + (lhs:rhs:): Adds two matrices term-by-term.
  • - (lhs:rhs:): Subtracts two matrices term-by-term.

Array

  • dot(_:): Computes the dot product with a specified array and returns the result.

Functions

  • printMatrix(:name:): Prettily prints a specified matrix to the console.

For more details, look at Matrix.swift and Array+Extension.swift, both located under Sources. The file contains the implementation and documentation of all the methods and functions above!

Implementing matrixMultiply

To begin implementing matrixMultiply, add a new extension to your playground with the following method:

extension Matrix {
  // 1
  public func matrixMultiply(by other: Matrix) -> Matrix {
    // 2
    precondition(columnCount == other.rowCount,
                 """
                 Two matrices can only be matrix multiplied if the first \
                 column's count is equal to the second's row count.
                 """)
  }
}

Reviewing your work:

  1. You created an extension of Matrix with a public method named matrixMultiply(by:).
  2. Check to see if the current matrix’s column count matches the other matrix’s row count. Recall that this is a requirement, as highlighted when you learned about matrix multiplication.

Next, add the following to matrixMultiply(by:) below precondition:

// 1
var result = Matrix(rows: rowCount, columns: other.columnCount)
// 2
for index in result.indices {
  // 3
  let ithRow = self[row: index.row]
  let jthColumn = other[column: index.column]
  // 4
  result[index] = ithRow.dot(jthColumn)
}

return result

Going over what you just added, you:

  1. Initialized the result matrix to have dimensions be the first matrix’s row count by the second matrix’s column count.
  2. Looped through the indices of the matrix. Matrix is a MutableCollection, so this is exactly the same as looping through the indices of a regular Array.
  3. Initialized constants for both the ith row from the first matrix and the jth column from the second matrix.
  4. Set the element at result[index] to be the dot product between the ith row of the first matrix and jth column of the second matrix.

Analyzing Time Complexity

Next, you’ll analyze the time complexity of this implementation. The algorithm above runs in O(n³) time.

But how… ? You only used one for loop!

There are actually three loops in this implementation. By using the Collection iteration and Array method dot, you’ve hidden two. Take a closer look at the following line:

for index in result.indices {

This line deceptively contain TWO for loops! The matrix has n rows and n columns, iterating over each row and column is an O(n²) operation.

Next, look at the following:

result[index] = ithRow.dot(jthColumn)

The dot product takes O(n) time because you need to loop over an n length row. Since it’s embedded in a for loop, you need to multiply the two time complexities together, resulting in O(n³).

Trying It Out

Add the following outside of the extension at the bottom of your playground:

var A = Matrix<Int>(rows: 2, columns: 4)
A[row: 0] = [2, 1, -1, 0]
A[row: 1] = [0, 10, 0, 0]
printMatrix(A, name: "A")

Here, you initialize a 2×4 matrix named A, and update its rows using subscript(row:). You then print the result matrix to the console with printMatrix(:name:). You should see the following output on the console:

Matrix A: 
2 1 -1 0
0 10 0 0

Challenge 1

Initialize a 4×2 matrix named B that prints the following to the console:

Matrix B:
3 4
2 1
-1 2
2 7

This time, use subscript(column:) to update the matrix’s elements.

[spoiler title=”Solution”]

var B = Matrix<Int>(rows: 4, columns: 2)
B[column: 0] = [3, 2, -1, 2]
B[column: 1] = [4, 1, 2, 7]
printMatrix(B, name: "B")

[/spoiler]

Next, add the following below printMatrix(B, name: "B"):

let C = A.matrixMultiply(by: B)
printMatrix(C, name: "C")

You should see the following output:

Matrix C: 
9 7
20 10

Cool! Now how do you know that this is the correct matrix? You could write it out yourself (if you want to practice), but that approach becomes quite challenging as the number of rows and columns increases. Fortunately, you can check your answer on WolframAlpha and indeed Matrix C is correct.

Challenge 2

1. Initialize the following 3×3 matrices, D and E.
2. Compute their matrix multiplication, F.
3. Print F to the console.

Matrix D:
1 2 3
3 2 1
1 2 3

Matrix E:
4 5 6
6 5 4
4 5 6

[spoiler title=”Solution”]

var D = Matrix<Int>(rows: 3, columns: 3)
D[row: 0] = [1, 2, 3]
D[row: 1] = [3, 2, 1]
D[row: 2] = [1, 2, 3]

var E = Matrix<Int>(rows: 3, columns: 3)
E[row: 0] = [4, 5, 6]
E[row: 1] = [6, 5, 4]
E[row: 2] = [4, 5, 6]

let F = D.matrixMultiply(by: E)
printMatrix(F, name: "F")

You should see the following output:

Matrix F: 
28 30 32
28 30 32
28 30 32

[/spoiler]

On to Strassen’s algorithm!

Strassen’s Matrix Multiplication

Good job! You made it this far. Now for the fun bit: You’ll dive into Strassen’s algorithm. The basic idea behind Strassen’s algorithm is to split the two matrices, A and B, into eight submatrices and then recursively compute the submatrices of C. This strategy is called divide and conquer.

Consider the following:

There are eight recursive calls:

  1. a * e
  2. b * g
  3. a * f
  4. b * h
  5. c * e
  6. d * g
  7. c * f
  8. d * h

These combine to form the four quadrants of C.

This step alone, however, doesn’t improve the complexity. Using the Master Theorem with T(n) = 8T(n/2) + O(n²) you still get a time of O(n³).

Strassen’s insight was that you don’t actually need eight recursive calls to complete this process. You can finish the operation with seven recursive calls with a little bit of addition and subtraction.

Strassen’s seven calls are as follows:

  1. a * (f – h)
  2. (a + b) * h
  3. (c + d) * e
  4. d * (g – e)
  5. (a + d) * (e + h)
  6. (b – d) * (g + h)
  7. (a – c) * (e + f)

Now, you can compute matrix C’s new quadrants:

A great reaction right now would be !!??!?!?!!?! How does this even work?

Next, you’ll prove it!

1. First submatrix:

p5+p4-p2+p6 = (a+d)*(e+h) + d*(g-e) - (a+b)*h + (b-d)*(g+h)
            = (ae+de+ah+dh) + (dg-de) - (ah+bh) + (bg-dg+bh-dh)
            = ae+bg ✅

Exactly what you got the first time!

Now, on to proving the others.

2. Second submatrix:

p1+p2 = a*(f-h) + (a+b)*h
      = (af-ah) + (ah+bh)
      = af+bh ✅

3. Third submatrix:

p3+p4 = (c+d)*e + d*(g-e)
      = (ce+de) + (dg-de)
      = ce+dg ✅

4. Fourth submatrix:

p1+p5-p3-p7 = a*(f-h) + (a+d)*(e+h) - (c+d)*e - (a-c)*(e+f)
            = (af-ah) + (ae+de+ah+dh) -(ce+de) - (ae-ce+af-cf)
            = cf+dh ✅

Great! The math checks out!