Sign Up

Sign Up to our social questions and Answers Engine to ask questions, answer people’s questions, and connect with other people.

Have an account? Sign In

Have an account? Sign In Now

Sign In

Login to our social questions & Answers Engine to ask questions answer people’s questions & connect with other people.

Sign Up Here

Forgot Password?

Don't have account, Sign Up Here

Forgot Password

Lost your password? Please enter your email address. You will receive a link and will create a new password via email.

Have an account? Sign In Now

You must login to ask a question.

Forgot Password?

Need An Account, Sign Up Here

Please briefly explain why you feel this question should be reported.

Please briefly explain why you feel this answer should be reported.

Please briefly explain why you feel this user should be reported.

Sign InSign Up

The Archive Base

The Archive Base Logo The Archive Base Logo

The Archive Base Navigation

  • SEARCH
  • Home
  • About Us
  • Blog
  • Contact Us
Search
Ask A Question

Mobile menu

Close
Ask a Question
  • Home
  • Add group
  • Groups page
  • Feed
  • User Profile
  • Communities
  • Questions
    • New Questions
    • Trending Questions
    • Must read Questions
    • Hot Questions
  • Polls
  • Tags
  • Badges
  • Buy Points
  • Users
  • Help
  • Buy Theme
  • SEARCH
Home/ Questions/Q 8676281
In Process

The Archive Base Latest Questions

Editorial Team
  • 0
Editorial Team
Asked: June 12, 20262026-06-12T20:09:17+00:00 2026-06-12T20:09:17+00:00

I’m trying to implement Strassen Matrix multiplication in Python. I’ve got it working somewhat.

  • 0

I’m trying to implement Strassen Matrix multiplication in Python. I’ve got it working somewhat. Here’s my code:

a = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]

def new_m(p, q): # create a matrix filled with 0s
    matrix = [[0 for row in range(p)] for col in range(q)]
    return matrix

def straight(a, b): # multiply the two matrices
    if len(a[0]) != len(b): # if # of col != # of rows:
        return "Matrices are not m*n and n*p"
    else:
        p_matrix = new_m(len(a), len(b[0]))
        for i in range(len(a)):
            for j in range(len(b[0])):
                for k in range(len(b)):
                    p_matrix[i][j] += a[i][k]*b[k][j]
    return p_matrix

def split(matrix): # split matrix into quarters 
    a = matrix
    b = matrix
    c = matrix
    d = matrix
    while(len(a) > len(matrix)/2):
        a = a[:len(a)//2]
        b = b[:len(b)//2]
        c = c[len(c)//2:]
        d = d[len(d)//2:]
    while(len(a[0]) > len(matrix[0])/2):
        for i in range(len(a[0])//2):
            a[i] = a[i][:len(a[i])//2]
            b[i] = b[i][len(b[i])//2:]
            c[i] = c[i][:len(c[i])//2]
            d[i] = d[i][len(d[i])//2:]
    return a,b,c,d

def add_m(a, b):
    if type(a) == int:
        d = a + b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] + b[i][j])
            d.append(c)
    return d

def sub_m(a, b):
    if type(a) == int:
        d = a - b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] - b[i][j])
            d.append(c)
    return d


def strassen(a, b, q):
    # base case: 1x1 matrix
    if q == 1:
        d = [[0]]
        d[0][0] = a[0][0] * b[0][0]
        return d
    else:
        #split matrices into quarters
        a11, a12, a21, a22 = split(a)
        b11, b12, b21, b22 = split(b)

        # p1 = (a11+a22) * (b11+b22)
        p1 = strassen(add_m(a11,a22), add_m(b11,b22), q/2)

        # p2 = (a21+a22) * b11
        p2 = strassen(add_m(a21,a22), b11, q/2)

        # p3 = a11 * (b12-b22)
        p3 = strassen(a11, sub_m(b12,b22), q/2)

        # p4 = a22 * (b12-b11)
        p4 = strassen(a22, sub_m(b12,b11), q/2)

        # p5 = (a11+a12) * b22
        p5 = strassen(add_m(a11,a12), b22, q/2)

        # p6 = (a21-a11) * (b11+b12)
        p6 = strassen(sub_m(a21,a11), add_m(b11,b12), q/2)

        # p7 = (a12-a22) * (b21+b22)
        p7 = strassen(sub_m(a12,a22), add_m(b21,b22), q/2)


        # c11 = p1 + p4 - p5 + p7
        c11 = add_m(sub_m(add_m(p1, p4), p5), p7)

        # c12 = p3 + p5
        c12 = add_m(p3, p5)

        # c21 = p2 + p4
        c21 = add_m(p2, p4)

        # c22 = p1 + p3 - p2 + p6
        c22 = add_m(sub_m(add_m(p1, p3), p2), p6)

        c = new_m(len(c11)*2,len(c11)*2)
        for i in range(len(c11)):
            for j in range(len(c11)):
                c[i][j]                   = c11[i][j]
                c[i][j+len(c11)]          = c12[i][j]
                c[i+len(c11)][j]          = c21[i][j]
                c[i+len(c11)][j+len(c11)] = c22[i][j]

        return c

print "Strassen Outputs:"
print strassen(a, b, 4)
print "Should be:"
print straight(a, b)

I included straight matrix multiplication for reference to the proper desired output. Basically this happens:

Strassen Outputs:

[[10, 14, 22, 26], [32, 36, 48, 52], [58, 66, 70, 78], [80, 88, 96, 104]]

Should be:

[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]

I’m not sure what the source of the problem is, which means I can’t solve it!

  • 1 1 Answer
  • 0 Views
  • 0 Followers
  • 0
Share
  • Facebook
  • Report

Leave an answer
Cancel reply

You must login to add an answer.

Forgot Password?

Need An Account, Sign Up Here

1 Answer

  • Voted
  • Oldest
  • Recent
  • Random
  1. Editorial Team
    Editorial Team
    2026-06-12T20:09:19+00:00Added an answer on June 12, 2026 at 8:09 pm

    Shouldn’t this:

    # p4 = a22 * (b12-b11)
    p4 = strassen(a22, sub_m(b12,b11), q/2)
    

    be:

    # p4 = a22 * (b21-b11)
    p4 = strassen(a22, sub_m(b21,b11), q/2)
    

    instead?

    ~/coding$ python -i strass.py
    Strassen Outputs:
    [[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]
    Should be:
    [[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]
    >>> import numpy
    >>> def check():
    ...     for i in range(100):
    ...         a = numpy.random.randint(0, 10,size=(4,4)).tolist()
    ...         b = numpy.random.randint(0, 10,size=(4,4)).tolist()
    ...         assert strassen(a,b,4) == straight(a,b)
    ...         assert (numpy.array(strassen(a,b,4)) == numpy.dot(a,b)).all()
    ...     print 'hooray!'
    ... 
    >>> check()
    hooray!
    
    • 0
    • Reply
    • Share
      Share
      • Share on Facebook
      • Share on Twitter
      • Share on LinkedIn
      • Share on WhatsApp
      • Report

Sidebar

Related Questions

Basically, what I'm trying to create is a page of div tags, each has
I'm trying to decode HTML entries from here NYTimes.com and I cannot figure out
I'm trying to create an if statement in PHP that prevents a single post
I am trying to understand how to use SyndicationItem to display feed which is
link Im having trouble converting the html entites into html characters, (&# 8217;) i
I have a string like this: La Torre Eiffel paragonata all’Everest What PHP function
I've got a string that has curly quotes in it. I'd like to replace
I am trying to render a haml file in a javascript response like so:
I have this code to decode numeric html entities to the UTF8 equivalent character.
I'm parsing an RSS feed that has an ’ in it. SimpleXML turns this

Explore

  • Home
  • Add group
  • Groups page
  • Communities
  • Questions
    • New Questions
    • Trending Questions
    • Must read Questions
    • Hot Questions
  • Polls
  • Tags
  • Badges
  • Users
  • Help
  • SEARCH

Footer

© 2021 The Archive Base. All Rights Reserved
With Love by The Archive Base

Insert/edit link

Enter the destination URL

Or link to existing content

    No search term specified. Showing recent items. Search or use up and down arrow keys to select an item.