栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

c++实现Strassen算法 与朴素算法时间复杂度对比及优化(含源码)

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

c++实现Strassen算法 与朴素算法时间复杂度对比及优化(含源码)

c++实现Strassen算法 与朴素算法对比及优化
  1. 编程实现普通的矩阵乘法;
  2. 编程实现 Strassen’s algorithm;
  3. 在不同数据规模情况下(数据规模)下,比较两种算法的运行时间各是多少;
  4. 修改 Strassen’s algorithm,使之适应矩阵规模 N 不是 2 的幂的情况;
  5. 改进后的算法与 2 中的算法在相同数据规模下进行比较。
设计实验

将给出的实验数据写入 datas.txt 文件,为控制变量,两种算法分别从文件中读取数据。记录从数据读入到计算完成所用时间,取 5 次实验得到的平均值,利用数据作出折线图。

一、普通矩阵的代码分析
#include
#include
#include
#include
#include

using namespace std;

int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
    int N, M;
    cin >> N >> M;
    int a[M][M];
    int b[M][M];

    ifstream fin("datas.txt");

    auto start = chrono::high_resolution_clock::now();  //计时开始

    while (N--)
    {
        for (int i = 0; i < M; i++)
        {
            for (int j = 0; j < M;j++)
            {
                fin >> a[i][j];
            }
        }
        for (int i = 0; i < M; i++)
        {
            for (int j = 0; j < M;j++)
            {
                fin >> b[i][j];
            }
        }
        for (int i = 0; i < M; i++)
        {
            for (int j = 0; j < M; j++)
            {
                int cij = 0;
                for (int k = 0; k < M; k++)
                {
                    cij += a[i][k] * b[k][j];
                }
                cout << cij << " n"[j == M - 1];
            }
        }
    }
    auto end = chrono::high_resolution_clock::now();    //计时结束
    chrono::duration diff = end - start;
    cout << fixed << setprecision(10) << diff.count() << endl;
    return 0;
}
二、Strassen 算法的代码分析
#include 
#include
#include
#include
#include

using namespace std;

void minusm(int l, int **m, int **n, int **ans) //两矩阵减法
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = m[i][j] - n[i][j];
        }
    }
}

void addm(int l, int **m, int **n, int **ans) //两矩阵加法
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = m[i][j] + n[i][j];
        }
    }
}

void multim(int l, int **m, int **n, int **ans) //两矩阵乘法
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = 0;
            for (int k = 0; k < l; k++)
            {
                ans[i][j] += m[i][k] * n[k][j];
            }
        }
    }
}

void Strassen(int N, int **A, int **B, int **C) //strassen算法
{
    if(N<=4)
    {
        multim(N, A, B, C);
    }else
    {
        int **A11 = new int *[N / 2];
    int **A12 = new int *[N / 2];
    int **A21 = new int *[N / 2];
    int **A22 = new int *[N / 2];

    int **B11 = new int *[N / 2];
    int **B12 = new int *[N / 2];
    int **B21 = new int *[N / 2];
    int **B22 = new int *[N / 2];

    int **C11 = new int *[N / 2];
    int **C12 = new int *[N / 2];
    int **C21 = new int *[N / 2];
    int **C22 = new int *[N / 2];

    int **P1 = new int *[N / 2];
    int **P2 = new int *[N / 2];
    int **P3 = new int *[N / 2];
    int **P4 = new int *[N / 2];
    int **P5 = new int *[N / 2];
    int **P6 = new int *[N / 2];
    int **P7 = new int *[N / 2];

    int **AR = new int *[N / 2];
    int **BR = new int *[N / 2];

    for (int i = 0; i < N / 2; i++)
    {
        A11[i] = new int[N / 2];
        A12[i] = new int[N / 2];
        A21[i] = new int[N / 2];
        A22[i] = new int[N / 2];
        B11[i] = new int[N / 2];
        B12[i] = new int[N / 2];
        B21[i] = new int[N / 2];
        B22[i] = new int[N / 2];
        C11[i] = new int[N / 2];
        C12[i] = new int[N / 2];
        C21[i] = new int[N / 2];
        C22[i] = new int[N / 2];

        P1[i] = new int[N / 2];
        P2[i] = new int[N / 2];
        P3[i] = new int[N / 2];
        P4[i] = new int[N / 2];
        P5[i] = new int[N / 2];
        P6[i] = new int[N / 2];
        P7[i] = new int[N / 2];

        AR[i] = new int[N / 2];
        BR[i] = new int[N / 2];
    }

    for (int i = 0; i < N / 2; i++)
    {
        for (int j = 0; j < N / 2; j++)
        {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][j + N / 2];
            A21[i][j] = A[i + N / 2][j];
            A22[i][j] = A[i + N / 2][j + N / 2];
            B11[i][j] = B[i][j];
            B12[i][j] = B[i][j + N / 2];
            B21[i][j] = B[i + N / 2][j];
            B22[i][j] = B[i + N / 2][j + N / 2];
        }
    }

    addm(N / 2, A11, A22, AR);
    addm(N / 2, B11, B22, BR);
    Strassen(N / 2, AR, BR, P1);

    addm(N / 2, A21, A22, AR);
    Strassen(N / 2, AR, B11, P2);

    minusm(N / 2, B12, B22, BR);
    Strassen(N / 2, A11, BR, P3);

    minusm(N / 2, B21, B11, BR);
    Strassen(N / 2, A22, BR, P4);

    addm(N / 2, A11, A12, AR);
    Strassen(N / 2, AR, B22, P5);

    minusm(N / 2, A21, A11, AR);
    addm(N / 2, B11, B12, BR);
    Strassen(N / 2, AR, BR, P6);

    minusm(N / 2, A12, A22, AR);
    addm(N / 2, B21, B22, BR);
    Strassen(N / 2, AR, BR, P7);

    addm(N / 2, P1, P4, AR);
    minusm(N / 2, P7, P5, BR);
    addm(N / 2, AR, BR, C11);

    addm(N / 2, P3, P5, C12);

    addm(N / 2, P2, P4, C21);

    addm(N / 2, P1, P3, AR);
    minusm(N / 2, P6, P2, BR);
    addm(N / 2, AR, BR, C22);

    for (int i = 0; i < N / 2; i++)
    {
        for (int j = 0; j < N / 2; j++)
        {
            C[i][j] = C11[i][j];
            C[i][j + N / 2] = C12[i][j];
            C[i + N / 2][j] = C21[i][j];
            C[i + N / 2][j + N / 2] = C22[i][j];
        }
    }

    for (int i = 0; i < N/2; i++)
			{
				delete[] A11[i];delete[] A12[i];delete[] A21[i];
				delete[] A22[i];
 
				delete[] B11[i];delete[] B12[i];delete[] B21[i];
				delete[] B22[i];
				delete[] C11[i];delete[] C12[i];delete[] C21[i];
				delete[] C22[i];
				delete[] P1[i];delete[] P2[i];delete[] P3[i];delete[] P4[i];
				delete[] P5[i];delete[] P6[i];delete[] P7[i];
				delete[] AR[i];delete[] BR[i] ;
			}
				delete[] A11;delete[] A12;delete[] A21;delete[] A22;
				delete[] B11;delete[] B12;delete[] B21;delete[] B22;
				delete[] C11;delete[] C12;delete[] C21;delete[] C22;
				delete[] P1;delete[] P2;delete[] P3;delete[] P4;delete[] P5;
				delete[] P6;delete[] P7;
				delete[] AR;
				delete[] BR ;
 

    }
    
    
}

int main() 
{
    std::ios::sync_with_stdio(false);   
    std::cin.tie(0);

    int N, M;
    cin >> N >> M;
    int **A = new int *[M]; //分配空间
    int **B = new int *[M];
    int **C = new int *[M];

    for (int i = 0; i < M; i++)
    {
        A[i] = new int[M];
        B[i] = new int[M];
        C[i] = new int[M];
    }

    ifstream fin("datas.txt");

    auto start = chrono::high_resolution_clock::now();  //计时开始

    while (N--) //N对矩阵
    {
        for (int i = 0; i < M; i++) //输入A矩阵
        {
            for (int j = 0; j < M; j++)
                fin >> A[i][j];
        }
        for (int i = 0; i < M; i++) //输入B矩阵
        {
            for (int j = 0; j < M; j++)
                fin >> B[i][j];
        }
        Strassen(M, A, B, C);       //使用strassen算法取得矩阵C
        for (int i = 0; i < M; i++) //输出C矩阵
        {
            for (int j = 0; j < M; j++)
            {
                cout << C[i][j] << " n"[j == M - 1];
            }
        }
    }
    auto end = chrono::high_resolution_clock::now();    //ji'shi'jie's
    chrono::duration diff = end - start;

    cout << fixed << setprecision(10) << diff.count() << endl;
    return 0;
}

三、两种算法的优化细节

在代码上交水杉的过程中,最初多次超时。经过求助助教与同学,学会了几处优化的细节,最终顺利通过。

cout << cij << " n"[j == M - 1];

利用 " n"[j == M - 1]字符串数组判断是否为末行,比 if else 语句速度快。

std::ios::sync_with_stdio(false);
std::cin.tie(0);

cin,cout 之所以效率低,是因为先把要输出的东西存入缓冲区,再输出,导致效率降低,而这段语句可以来打消 iostream 的输入 输出缓存,可以节省许多时间.

四、不同数据规模的情况下,两种算法的运行时间

如下运行并记录每一次数据(M 为 2 次幂时)


数据规模普通算法strassen 算法
2 3 2^3 230.00058200.0003488
2 5 2^5 250.00466450.0048068
2 7 2^7 270.06588370.1096698
2 9 2^9 291.46563823.5054571
五、修改 Strassen’s algorithm,使之适应矩阵规模 N 不是 2 的幂的情况

为了继续使用 Strassen 算法计算,若 M 为偶数则无需更改,若 M 为奇数则+1 补零,相当于将矩阵放在左上角,右侧边与下侧边补零,仍可使用 Strassen 算法得到正确结果。
(考虑到 Strassen 算法递归时间过长,只分割了一次,若要使用递归计算 N 不是 2 的幂的情况,只需右下角补 0,使规模为不小于 N 的最小 2 的幂,即可递归,思路相同)

#include 
using namespace std;

void subm(int l, int **m, int **n, int **ans)
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = m[i][j] - n[i][j];
        }
    }
}

void addm(int l, int **m, int **n, int **ans) //两矩阵加法
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = m[i][j] + n[i][j];
        }
    }
}

void multim(int l, int **m, int **n, int **ans)
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = 0;
            for (int k = 0; k < l; k++)
            {
                ans[i][j] += m[i][k] * n[k][j];
            }
        }
    }
}

void Strassen(int M, int **A, int **B, int **C)
{
    int len = M / 2;
    int **A11 = new int *[len];
    int **A12 = new int *[len];
    int **A21 = new int *[len];
    int **A22 = new int *[len];
    int **B11 = new int *[len];
    int **B12 = new int *[len];
    int **B21 = new int *[len];
    int **B22 = new int *[len];
    int **C11 = new int *[len];
    int **C12 = new int *[len];
    int **C21 = new int *[len];
    int **C22 = new int *[len];

    int **P1 = new int *[len];
    int **P2 = new int *[len];
    int **P3 = new int *[len];
    int **P4 = new int *[len];
    int **P5 = new int *[len];
    int **P6 = new int *[len];
    int **P7 = new int *[len];

    int **AR = new int *[len];
    int **BR = new int *[len];

    for (int i = 0; i < len; i++)
    {
        A11[i] = new int[len];
        A12[i] = new int[len];
        A21[i] = new int[len];
        A22[i] = new int[len];
        B11[i] = new int[len];
        B12[i] = new int[len];
        B21[i] = new int[len];
        B22[i] = new int[len];
        C11[i] = new int[len];
        C12[i] = new int[len];
        C21[i] = new int[len];
        C22[i] = new int[len];
        P1[i] = new int[len];
        P2[i] = new int[len];
        P3[i] = new int[len];
        P4[i] = new int[len];
        P5[i] = new int[len];
        P6[i] = new int[len];
        P7[i] = new int[len];
        AR[i] = new int[len];
        BR[i] = new int[len];
    }

    for (int i = 0; i < len; i++)
    {
        for (int j = 0; j < len; j++)
        {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][j + len];
            A21[i][j] = A[i + len][j];
            A22[i][j] = A[i + len][j + len];

            B11[i][j] = B[i][j];
            B12[i][j] = B[i][j + len];
            B21[i][j] = B[i + len][j];
            B22[i][j] = B[i + len][j + len];
        }
    }
    addm(len, A11, A22, AR);
    addm(len, B11, B22, BR);
    multim(len, AR, BR, P1);

    addm(len, A21, A22, AR);
    multim(len, AR, B11, P2);

    subm(len, B12, B22, BR);
    multim(len, A11, BR, P3);

    subm(len, B21, B11, BR);
    multim(len, A22, BR, P4);

    addm(len, A11, A12, AR);
    multim(len, AR, B22, P5);

    subm(len, A21, A11, AR);
    addm(len, B11, B12, BR);
    multim(len, AR, BR, P6);

    subm(len, A12, A22, AR);
    addm(len, B21, B22, BR);
    multim(len, AR, BR, P7);

    addm(len, P1, P4, AR);
    subm(len, P7, P5, BR);
    addm(len, AR, BR, C11);

    addm(len, P3, P5, C12);

    addm(len, P2, P4, C21);

    addm(len, P1, P3, AR);
    subm(len, P6, P2, BR);
    addm(len, AR, BR, C22);

    for (int i = 0; i < len; i++)
    {
        for (int j = 0; j < len; j++)
        {
            C[i][j] = C11[i][j];
            C[i][j + len] = C12[i][j];
            C[i + len][j] = C21[i][j];
            C[i + len][j + len] = C22[i][j];
        }
    }
}

int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
    int N, M;
    cin >> N >> M;

    int length = M;

    if (M % 2 != 0) //若M为奇数,则补零
    {
        length++;
    }

    int **A = new int *[length];
    int **B = new int *[length];
    int **C = new int *[length];

    for (int i = 0; i < length; i++)
    {
        A[i] = new int[length];
        B[i] = new int[length];
        C[i] = new int[length];
    }

    while (N--)
    {
        for (int i = 0; i < M; i++)
        {
            for (int j = 0; j < M; j++)
                cin >> A[i][j];
        }
        for (int i = 0; i < M; i++)
        {
            for (int j = 0; j < M; j++)
            {
                cin >> B[i][j];
            }
        }

        if (length > M)
        {
            for (int i = 0; i < length; i++)
            {
                A[i][M] = 0;
                A[M][i] = 0;
                B[i][M] = 0;
                B[M][i] = 0;
            }
        }
        Strassen(length, A, B, C);
        for (int i = 0; i < M; i++)
        {
            for (int j = 0; j < M; j++)
            {
                cout << C[i][j] << " n"[j == M - 1];
            }
        }
    }
    return 0;
}
六、改进后的算法与 2 中的算法在相同数据规模下进行比较

采用 Strassen 算法作递归运算,需要创建大量的动态二维数组,分配内存空间将占用大量计算时间,改进后设定一个界限。当 n<界限 32 时,使用普通法计算矩阵,而不继续分治递归,n>32 时,再使用 Strassen 递归。

改进代码如下,利用递归:

#include 
#include
#include
#include
#include

using namespace std;

void minusm(int l, int **m, int **n, int **ans) //两矩阵减法
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = m[i][j] - n[i][j];
        }
    }
}

void addm(int l, int **m, int **n, int **ans) //两矩阵加法
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = m[i][j] + n[i][j];
        }
    }
}

void multim(int l, int **m, int **n, int **ans) //两矩阵乘法
{
    for (int i = 0; i < l; i++)
    {
        for (int j = 0; j < l; j++)
        {
            ans[i][j] = 0;
            for (int k = 0; k < l; k++)
            {
                ans[i][j] += m[i][k] * n[k][j];
            }
        }
    }
}

void Strassen(int N, int **A, int **B, int **C) //strassen算法
{
    if(N<=32)       //设立界限
    {
        multim(N, A, B, C);
    }else
    {
        int **A11 = new int *[N / 2];
    int **A12 = new int *[N / 2];
    int **A21 = new int *[N / 2];
    int **A22 = new int *[N / 2];

    int **B11 = new int *[N / 2];
    int **B12 = new int *[N / 2];
    int **B21 = new int *[N / 2];
    int **B22 = new int *[N / 2];

    int **C11 = new int *[N / 2];
    int **C12 = new int *[N / 2];
    int **C21 = new int *[N / 2];
    int **C22 = new int *[N / 2];

    int **P1 = new int *[N / 2];
    int **P2 = new int *[N / 2];
    int **P3 = new int *[N / 2];
    int **P4 = new int *[N / 2];
    int **P5 = new int *[N / 2];
    int **P6 = new int *[N / 2];
    int **P7 = new int *[N / 2];

    int **AR = new int *[N / 2];
    int **BR = new int *[N / 2];

    for (int i = 0; i < N / 2; i++)
    {
        A11[i] = new int[N / 2];
        A12[i] = new int[N / 2];
        A21[i] = new int[N / 2];
        A22[i] = new int[N / 2];
        B11[i] = new int[N / 2];
        B12[i] = new int[N / 2];
        B21[i] = new int[N / 2];
        B22[i] = new int[N / 2];
        C11[i] = new int[N / 2];
        C12[i] = new int[N / 2];
        C21[i] = new int[N / 2];
        C22[i] = new int[N / 2];

        P1[i] = new int[N / 2];
        P2[i] = new int[N / 2];
        P3[i] = new int[N / 2];
        P4[i] = new int[N / 2];
        P5[i] = new int[N / 2];
        P6[i] = new int[N / 2];
        P7[i] = new int[N / 2];

        AR[i] = new int[N / 2];
        BR[i] = new int[N / 2];
    }

    for (int i = 0; i < N / 2; i++)
    {
        for (int j = 0; j < N / 2; j++)
        {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][j + N / 2];
            A21[i][j] = A[i + N / 2][j];
            A22[i][j] = A[i + N / 2][j + N / 2];
            B11[i][j] = B[i][j];
            B12[i][j] = B[i][j + N / 2];
            B21[i][j] = B[i + N / 2][j];
            B22[i][j] = B[i + N / 2][j + N / 2];
        }
    }

    addm(N / 2, A11, A22, AR);
    addm(N / 2, B11, B22, BR);
    Strassen(N / 2, AR, BR, P1);        //递归

    addm(N / 2, A21, A22, AR);
    Strassen(N / 2, AR, B11, P2);       //递归

    minusm(N / 2, B12, B22, BR);
    Strassen(N / 2, A11, BR, P3);       //递归

    minusm(N / 2, B21, B11, BR);
    Strassen(N / 2, A22, BR, P4);       //递归

    addm(N / 2, A11, A12, AR);
    Strassen(N / 2, AR, B22, P5);       //递归

    minusm(N / 2, A21, A11, AR);
    addm(N / 2, B11, B12, BR);
    Strassen(N / 2, AR, BR, P6);        //递归

    minusm(N / 2, A12, A22, AR);
    addm(N / 2, B21, B22, BR);
    Strassen(N / 2, AR, BR, P7);        //递归

    addm(N / 2, P1, P4, AR);
    minusm(N / 2, P7, P5, BR);
    addm(N / 2, AR, BR, C11);

    addm(N / 2, P3, P5, C12);

    addm(N / 2, P2, P4, C21);

    addm(N / 2, P1, P3, AR);
    minusm(N / 2, P6, P2, BR);
    addm(N / 2, AR, BR, C22);

    for (int i = 0; i < N / 2; i++)
    {
        for (int j = 0; j < N / 2; j++)
        {
            C[i][j] = C11[i][j];
            C[i][j + N / 2] = C12[i][j];
            C[i + N / 2][j] = C21[i][j];
            C[i + N / 2][j + N / 2] = C22[i][j];
        }
    }

    for (int i = 0; i < N/2; i++)
   {
        delete[] A11[i];delete[] A12[i];delete[] A21[i];
        delete[] A22[i];

        delete[] B11[i];delete[] B12[i];delete[] B21[i];
        delete[] B22[i];
        delete[] C11[i];delete[] C12[i];delete[] C21[i];
        delete[] C22[i];
        delete[] P1[i];delete[] P2[i];delete[] P3[i];delete[] P4[i];
        delete[] P5[i];delete[] P6[i];delete[] P7[i];
        delete[] AR[i];delete[] BR[i] ;
   }
        delete[] A11;delete[] A12;delete[] A21;delete[] A22;
        delete[] B11;delete[] B12;delete[] B21;delete[] B22;
        delete[] C11;delete[] C12;delete[] C21;delete[] C22;
        delete[] P1;delete[] P2;delete[] P3;delete[] P4;delete[] P5;
        delete[] P6;delete[] P7;
        delete[] AR;
        delete[] BR ;
    }

}

int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);

    int N, M;
    cin >> N >> M;
    int **A = new int *[M]; //分配空间
    int **B = new int *[M];
    int **C = new int *[M];

    for (int i = 0; i < M; i++)
    {
        A[i] = new int[M];
        B[i] = new int[M];
        C[i] = new int[M];
    }

    ifstream fin("datas.txt");

    auto start = chrono::high_resolution_clock::now();  //计时开始

    while (N--) //N对矩阵
    {
        for (int i = 0; i < M; i++) //输入A矩阵
        {
            for (int j = 0; j < M; j++)
                fin >> A[i][j];
        }
        for (int i = 0; i < M; i++) //输入B矩阵
        {
            for (int j = 0; j < M; j++)
                fin >> B[i][j];
        }
        Strassen(M, A, B, C);       //使用strassen算法取得矩阵C
        for (int i = 0; i < M; i++) //输出C矩阵
        {
            for (int j = 0; j < M; j++)
            {
                cout << C[i][j] << " n"[j == M - 1];
            }
        }
    }
    auto end = chrono::high_resolution_clock::now();    //ji'shi'jie's
    chrono::duration diff = end - start;

    cout << fixed << setprecision(10) << diff.count() << endl;
    return 0;
}

总结

根据时间复杂度分析,普通递归与暴力求解的复杂度相同,利用 Strassen 算法减少相乘次数可减小复杂度。可是在数据规模不大的情况下,由于开辟内存多,Strassen 算法反而比普通算法速度慢。
综合两种算法各自的优缺点改进代码,在数据规模小于界限时利用普通矩阵乘法,大于时再采用 Strassen 算法递归,明显减少了运行速度。
心得:各种算法各有其不同的适用性(如数据规模等),最好可以综合它们的优点进行优化,才能得到最优的算法。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/311897.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号