栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

矩阵的乘运算法则_矩阵连乘算法?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

矩阵的乘运算法则_矩阵连乘算法?

传统算法

  Strassen算法将2X2矩阵的乘法次数从8次减少到了7次。在介绍strassen算法之前,先用传统的算法计算一下2*2的矩阵乘法。
A = [ 1 2 3 4 ] B = [ 5 6 7 8 ] A × B = [ 1 × 5 + 2 × 7 1 × 6 + 2 × 8 3 × 5 + 4 × 7 3 × 6 + 4 × 8 ] = [ 19 22 43 50 ] A= left[ begin{matrix} 1 & 2 \ 3 & 4 end{matrix}right]\ B= left[ begin{matrix} 5 & 6 \ 7 & 8 end{matrix}right]\ Atimes B=left[ begin{matrix} 1times5+2times7 & 1times6+2times8 \ 3times5+4times7 & 3times6+4times8 end{matrix}right]=left[ begin{matrix} 19 & 22 \ 43 & 50 end{matrix}right]\ A=[13​24​]B=[57​68​]A×B=[1×5+2×73×5+4×7​1×6+2×83×6+4×8​]=[1943​2250​]
  总共使用了8次乘法和4次加法。

Strassen算法

  Strassen算法使用了7个中间变量,巧妙地用7次乘法合18次加法,减少了1次乘法操作,提高了算法的性能。其算法如下:
  设矩阵A、B为:
A = [ A 11 A 12 A 21 A 22 ] B = [ B 11 B 12 B 21 B 22 ] A= left[ begin{matrix} A_{11} & A_{12} \ A_{21} & A_{22} end{matrix}right]\ B= left[ begin{matrix} B_{11} & B_{12} \ B_{21} & B_{22} end{matrix}right]\ A=[A11​A21​​A12​A22​​]B=[B11​B21​​B12​B22​​]
  建立7个临时变量 P 1 P_1 P1​到 P 7 P_7 P7​,每个变量使用一次乘法运算。
P 1 = ( A 11 + A 22 ) ( B 11 + B 22 ) P 2 = ( A 21 + A 22 ) B 11 P 3 = A 11 ( B 12 − B 22 ) P 4 = A 22 ( B 21 − B 11 ) P 5 = ( A 11 + A 12 ) B 22 P 6 = ( A 21 − A 11 ) ( B 11 + B 12 ) P 7 = ( A 12 − A 22 ) ( B 21 + B 22 ) C 11 = P 1 + P 4 − P 5 + P 7 C 12 = P 3 + P 5 C 21 = P 2 + P 4 C 22 = P 1 − P 2 + P 3 + P 6 A × B = [ C 11 C 12 C 21 C 22 ] P_1 = (A_{11}+A_{22})(B_{11}+B_{22})\ P_2 = (A_{21}+A_{22})B_{11}\ P_3 = A_{11}(B_{12} − B_{22})\ P_4 = A_{22}(B_{21} − B_{11})\ P_5 = (A_{11} + A_{12})B_{22}\ P_6 = (A_{21} − A_{11})(B_{11} + B_{12})\ P_7 = (A_{12} − A_{22})(B_{21 }+ B_{22})\ C_{11} = P_1 + P_4 − P_5 + P_7\ C_{12} = P_3 + P_5\ C_{21} = P_2 + P_4\ C_{22} = P_1 − P_2 + P_3 + P_6\ Atimes B=left[ begin{matrix} C_{11} & C_{12} \ C_{21} & C_{22} end{matrix}right]\ P1​=(A11​+A22​)(B11​+B22​)P2​=(A21​+A22​)B11​P3​=A11​(B12​−B22​)P4​=A22​(B21​−B11​)P5​=(A11​+A12​)B22​P6​=(A21​−A11​)(B11​+B12​)P7​=(A12​−A22​)(B21​+B22​)C11​=P1​+P4​−P5​+P7​C12​=P3​+P5​C21​=P2​+P4​C22​=P1​−P2​+P3​+P6​A×B=[C11​C21​​C12​C22​​]
  公式比较复杂,总共11个公式呢,根本记不住,所以我建议,收藏我的博文,不要去记忆,当然也可以顺便关注我一波。
  需要注意的是上面11个公式中,乘法的左右顺序特别重要,因为这个公式可以适用于任何代数环。代数环就是乘法不需要符合交换律的集合、加法与乘法运算符。这意味着什么,这意味着2X2矩阵中的元素不仅可以是数字,还可以是矩阵。也就是说可以利用分块矩阵的方法,将大矩阵拆分为2X2的矩阵再使用Strassen算法。
  不过需要注意的是因为存在 A 11 + A 22 A_{11}+A_{22} A11​+A22​这样的骚操作,所以进行矩阵分块时,行数或者列数不能是奇数,所以在为奇数的时候还是要用传统的方法啊。

python实现

  跟我以往的文章不同,这次我没有把本文的算法代码和其他博文的代码混在一起。我新写了一个python文件,只做Strassen算法,而且使用了分治以处理大矩阵,代码如下:

class Matrix:
    # 矩阵
    @staticmethod
    def create_by_lines(lines):
        # 为了支持分块,设置四个属性
        return Matrix(lines, 0, len(lines), 0, len(lines[0]))

    def __init__(self, lines, row_start, row_end, column_start, column_end):
        self.__lines = lines
        # 为了支持分块,设置四个属性
        self.__column_start = column_start
        self.__column_end = column_end
        self.__row_start = row_start
        self.__row_end = row_end

    def __mul__(self, other):
        # 首先判断能不能相乘
        if self.column_len() != other.row_len():
            raise Exception("矩阵A列数%d != 矩阵B的行数%d" % (len(self.__lines[0]), len(other.__lines)))
        # 然后判断是不是2X2矩阵
        # 这里场景比较多:
        # 1 1 x n n x 1
        # 2 n x 1 1 x n
        # 3 2 x 2 2 x 2 strassen 数值运算
        # 4 其他,进行分块 strassen 矩阵运算
        if self.row_len() == 1 or self.column_len() == 1:
            return self.plain_mul(other)

        # 奇数不能分块
        if self.row_len() & 1 == 1 or self.column_len() & 1 == 1 or other.row_len() & 1 == 1:
            return self.plain_mul(other)

        # 这个时候就可以使用strassen算法了

        a11, a12, a21, a22 = self.sub()
        b11, b12, b21, b22 = other.sub()

        p1 = (a11 + a22) * (b11 + b22)
        p2 = (a21 + a22) * b11
        p3 = a11 * (b12 - b22)
        p4 = a22 * (b21 - b11)
        p5 = (a11 + a12) * b22
        p6 = (a21 - a11) * (b11 + b12)
        p7 = (a12 - a22) * (b21 + b22)

        return Matrix.create(p1 + p4 - p5 + p7, p3 + p5, p2 + p4, p1 - p2 + p3 + p6)

    def __add__(self, other):
        arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
        # 里面不能是同一个数组
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = other.__lines[other.__row_start + i]
            for j in range(0, self.column_len()):
                arr[i][j] = self_row[self.__column_start + j] + other_row[other.__column_start + j]
        return Matrix.create_by_lines(arr)

    def __sub__(self, other):
        arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
        # 里面不能是同一个数组
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = other.__lines[other.__row_start + i]
            for j in range(0, self.column_len()):
                arr[i][j] = self_row[self.__column_start + j] - other_row[other.__column_start + j]
        return Matrix.create_by_lines(arr)

    def plain_mul(self, other):
        # 弄一个m行n列的新矩阵
        m = self.row_len()
        n = other.column_len()
        p = other.row_len()

        result = [[0] * n for _ in range(0, m)]
        # i 代表 A矩阵的行
        for i in range(self.__row_start, self.__row_end):
            # j 代表 B 矩阵的列
            for j in range(other.__column_start, other.__column_end):
                # 第一个矩阵的行 与第二个矩阵列的乘积和
                # k 代表 A矩阵的列和B矩阵的行
                for k in range(0, p):
                    self_line = self.__lines[i]
                    other_line = other.__lines[other.__row_start + k]
                    a = self_line[self.__column_start + k]
                    b = other_line[j]
                    mul = a * b
                    result[i - self.__row_start][j - other.__column_start] += mul
        return Matrix.create_by_lines(result)

    def row_len(self):
        return self.__row_end - self.__row_start

    def column_len(self):
        return self.__column_end - self.__column_start

    def sub(self):
        a_middle_row = (self.__row_end + self.__row_start) // 2
        a_middle_column = (self.__column_end + self.__column_start) // 2
        a11 = Matrix(self.__lines, self.__row_start, a_middle_row, self.__column_start, a_middle_column)
        a12 = Matrix(self.__lines, self.__row_start, a_middle_row, a_middle_column, self.__column_end)
        a21 = Matrix(self.__lines, a_middle_row, self.__row_end, self.__column_start, a_middle_column)
        a22 = Matrix(self.__lines, a_middle_row, self.__row_end, a_middle_column, self.__column_end)
        return a11, a12, a21, a22

    @staticmethod
    def create(a11, a12, a21, a22):
        len_rows = a11.row_len() + a21.row_len()
        len_columns = a11.column_len() + a12.column_len()
        lines = [[0] * len_columns for _ in range(0, len_rows)]
        # 拷贝进去
        a11.copy_to(lines, 0, 0)
        a12.copy_to(lines, 0, a11.column_len())
        a21.copy_to(lines, a11.row_len(), 0)
        a22.copy_to(lines, a12.row_len(), a21.column_len())
        return Matrix.create_by_lines(lines)

    def copy_to(self, lines, row_start, column_start):
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = lines[row_start + i]
            for j in range(0, self.column_len()):
                other_row[column_start + j] = self_row[self.__column_start + j]

    @property
    def lines(self):
        return self.__lines
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/786613.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号