栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

java实现任意矩阵Strassen算法

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

java实现任意矩阵Strassen算法

本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了Strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l, 使得l = 2 ^ k, k为整数并且l < m。边长为l的方形矩阵则采用Strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
StrassenMethodTest.java

package matrixalgorithm;
 
import java.util.Scanner;
 
public class StrassenMethodTest {
 
  private StrassenMethod strassenMultiply;
   
   StrassenMethodTest(){
    strassenMultiply = new StrassenMethod();
  }//end cons
 
   public static void main(String[] args){
    Scanner input = new Scanner(System.in);
    System.out.println("Input row size of the first matrix: ");
    int arow = input.nextInt();
    System.out.println("Input column size of the first matrix: ");
    int acol = input.nextInt();
    System.out.println("Input row size of the second matrix: ");
    int brow = input.nextInt();
    System.out.println("Input column size of the second matrix: ");
    int bcol = input.nextInt();
 
    double[][] A = new double[arow][acol];
    double[][] B = new double[brow][bcol];
    double[][] C = new double[arow][bcol];
    System.out.println("Input data for matrix A: ");
     
    
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < acol; c++) {
 System.out.printf("Data of A[%d][%d]: ", r, c);
 A[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop
 
    System.out.println("Input data for matrix B: ");
    for (int r = 0; r < brow; r++) {
      for (int c = 0; c < bcol; c++) {
 System.out.printf("Data of A[%d][%d]: ", r, c);
 B[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop
 
    StrassenMethodTest algorithm = new StrassenMethodTest();
    C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol);
 
    //Display the calculation result:
    System.out.println("Result from matrix C: ");
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < bcol; c++) {
 System.out.printf("Data of C[%d][%d]: %fn", r, c, C[r][c]);
      }//end inner loop
    }//end outter loop
   }//end main
   
  //Deal with matrices that are not square:
  public double[][] multiplyRectMatrix(double[][] A, double[][] B,
      int arow, int acol, int brow, int bcol) {
    if (arow != bcol) //Invalid multiplicatio
      return new double[][]{{0}};
    
    double[][] C = new double[arow][bcol];
 
    if (arow < acol) {

      double[][] newA = new double[acol][acol];
      double[][] newB = new double[brow][brow];
 
      int n = acol;

      for (int r = 0; r < acol; r++) 
 for (int c = 0; c < acol; c++) 
   newA[r][c] = 0.0;
 
      for (int r = 0; r < brow; r++) 
 for (int c = 0; c < brow; c++) 
   newB[r][c] = 0.0;
 
      for (int r = 0; r < arow; r++) 
 for (int c = 0; c < acol; c++) 
   newA[r][c] = A[r][c];
 
      for (int r = 0; r < brow; r++) 
 for (int c = 0; c < bcol; c++) 
   newB[r][c] = B[r][c];

      double[][] C2 = multiplySquareMatrix(newA, newB, n);
      for(int r = 0; r < arow; r++)
 for(int c = 0; c < bcol; c++)
   C[r][c] = C2[r][c];
    }//end if
     
    else if(arow == acol)
      C = multiplySquareMatrix(A, B, arow);

    else {
      int n = arow;
      double[][] newA = new double[arow][arow];
      double[][] newB = new double[bcol][bcol];
 
      for (int r = 0; r < arow; r++) 
 for (int c = 0; c < arow; c++) 
   newA[r][c] = 0.0;
 
      for (int r = 0; r < bcol; r++) 
 for (int c = 0; c < bcol; c++) 
   newB[r][c] = 0.0;
 
 
      for (int r = 0; r < arow; r++) 
 for (int c = 0; c < acol; c++) 
   newA[r][c] = A[r][c];
 
      for (int r = 0; r < brow; r++)
 for (int c = 0; c < bcol; c++) 
   newB[r][c] = B[r][c];
 
      double[][] C2 = multiplySquareMatrix(newA, newB, n);
      for(int r = 0; r < arow; r++)
 for(int c = 0; c < bcol; c++)
   C[r][c] = C2[r][c];
    }//end else

     return C;
   }//end method
   
  //Deal with matrices that are square matrices. 
   public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){
      
     double[][] C2 = new double[n][n];
    
     for(int r = 0; r < n; r++)
for(int c = 0; c < n; c++)
  C2[r][c] = 0;
     
     if(n == 1){
      C2[0][0] = A2[0][0] * B2[0][0];
      return C2;
     }//end if
   
     int exp2k = 2;
     
     while(exp2k <= (n / 2) ){
exp2k *= 2;
     }//end loop
     
     if(exp2k == n){
C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n);
return C2;
     }//end else
     
     //The "biggest" strassen matrix:
     double[][][] A = new double[6][exp2k][exp2k];
     double[][][] B = new double[6][exp2k][exp2k];
     double[][][] C = new double[6][exp2k][exp2k];
     
     for(int r = 0; r < exp2k; r++){
for(int c = 0; c < exp2k; c++){
  A[0][r][c] = A2[r][c];
  B[0][r][c] = B2[r][c];
}//end inner loop
     }//end outter loop
     
    C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k);
     
    for(int r = 0; r < exp2k; r++)
      for(int c = 0; c < exp2k; c++)
 C2[r][c] = C[0][r][c];
     
    int middle = exp2k / 2;
     
    for(int r = 0; r < middle; r++){
      for(int c = exp2k; c < n; c++){
 A[1][r][c - exp2k] = A2[r][c];
 B[3][r][c - exp2k] = B2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = exp2k; r < n; r++){
      for(int c = 0; c < middle; c++){
 A[3][r - exp2k][c] = A2[r][c];
 B[1][r - exp2k][c] = B2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = middle; r < exp2k; r++){
      for(int c = exp2k; c < n; c++){
 A[2][r - middle][c - exp2k] = A2[r][c];
 B[4][r - middle][c - exp2k] = B2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = exp2k; r < n; r++){
      for(int c = middle; c < n - exp2k + 1; c++){
 A[4][r - exp2k][c - middle] = A2[r][c];
 B[2][r - exp2k][c - middle] = B2[r][c];     
      }//end inner loop     
    }//end outter loop
    
    for(int i = 1; i <= 4; i++)
      C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle);
     
    
    for (int row = 0; row < exp2k; row++) {
for (int col = 0; col < exp2k; col++) {
  for (int k = exp2k; k < n; k++) {
    C2[row][col] += A2[row][k] * B2[k][col];
  }//end loop
}//end inner loop
     }//end outter loop
     
    //Use brute force to solve the rest, will be improved later:
    for(int col = exp2k; col < n; col++){
      for(int row = 0; row < n; row++){
 for(int k = 0; k < n; k++)
   C2[row][col] += A2[row][k] * B2[k][row];
      }//end inner loop
    }//end outter loop
     
    for(int row = exp2k; row < n; row++){
      for(int col = 0; col < exp2k; col++){
 for(int k = 0; k < n; k++)
   C2[row][col] += A2[row][k] * B2[k][row];
      }//end inner loop
    }//end outter loop   
     
    return C2;
   }//end method
   
}//end class

StrassenMethod.java

package matrixalgorithm;
 
import java.util.Scanner;
 
public class StrassenMethod {
 
  private double[][][][] A = new double[2][2][][];
  private double[][][][] B = new double[2][2][][];
  private double[][][][] C = new double[2][2][][];
 
  
   
   public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){
    double[][] C2 = new double[n][n];
    //Initialize the matrix:
    for(int rowIndex = 0; rowIndex < n; rowIndex++)
      for(int colIndex = 0; colIndex < n; colIndex++)
 C2[rowIndex][colIndex] = 0.0;
 
    if(n == 1)
      C2[0][0] = A2[0][0] * B2[0][0];
    //"Slice matrices into 2 * 2 parts: 
    else{
      double[][][][] A = new double[2][2][n / 2][n / 2];
      double[][][][] B = new double[2][2][n / 2][n / 2];
      double[][][][] C = new double[2][2][n / 2][n / 2];

      for(int r = 0; r < n / 2; r++){
 for(int c = 0; c < n / 2; c++){   
   A[0][0][r][c] = A2[r][c];
   A[0][1][r][c] = A2[r][n / 2 + c];
   A[1][0][r][c] = A2[n / 2 + r][c];
   A[1][1][r][c] = A2[n / 2 + r][n / 2 + c];
    
   B[0][0][r][c] = B2[r][c];
   B[0][1][r][c] = B2[r][n / 2 + c];
   B[1][0][r][c] = B2[n / 2 + r][c];
   B[1][1][r][c] = B2[n / 2 + r][n / 2 + c];
 }//end loop
      }//end loop

      n = n / 2;

      double[][][] S = new double[10][n][n];
      S[0] = minusMatrix(B[0][1], B[1][1], n);
      S[1] = addMatrix(A[0][0], A[0][1], n);
      S[2] = addMatrix(A[1][0], A[1][1], n);
      S[3] = minusMatrix(B[1][0], B[0][0], n);
      S[4] = addMatrix(A[0][0], A[1][1], n);
      S[5] = addMatrix(B[0][0], B[1][1], n);
      S[6] = minusMatrix(A[0][1], A[1][1], n);
      S[7] = addMatrix(B[1][0], B[1][1], n);
      S[8] = minusMatrix(A[0][0], A[1][0], n);
      S[9] = addMatrix(B[0][0], B[0][1], n);

      double[][][] P = new double[7][n][n];
      P[0] = strassenMultiplyMatrix(A[0][0], S[0], n);
      P[1] = strassenMultiplyMatrix(S[1], B[1][1], n);
      P[2] = strassenMultiplyMatrix(S[2], B[0][0], n);
      P[3] = strassenMultiplyMatrix(A[1][1], S[3], n);
      P[4] = strassenMultiplyMatrix(S[4], S[5], n);
      P[5] = strassenMultiplyMatrix(S[6], S[7], n);
      P[6] = strassenMultiplyMatrix(S[8], S[9], n);

      C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n);
      C[0][1] = addMatrix(P[0], P[1], n);
      C[1][0] = addMatrix(P[2], P[3], n);
      C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n);

      n *= 2;

for(int r = 0; r < n / 2; r++){
 for(int c = 0; c < n / 2; c++){
   C2[r][c] = C[0][0][r][c];
   C2[r][n / 2 + c] = C[0][1][r][c];
   C2[n / 2 + r][c] = C[1][0][r][c];
   C2[n / 2 + r][n / 2 + c] = C[1][1][r][c];
 }//end inner loop
      }//end outter loop
    }//end else     
 
    return C2;
  }//end method
   
   //Add two matrices according to matrix addition.
   private double[][] addMatrix(double[][] A, double[][] B, int n){
    double C[][] = new double[n][n];
     
    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
 C[r][c] = A[r][c] + B[r][c];
     
    return C;
  }//end method 
   
   //Substract two matrices according to matrix addition.
   private double[][] minusMatrix(double[][] A, double[][] B, int n){
    double C[][] = new double[n][n];
     
    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
 C[r][c] = A[r][c] - B[r][c];
     
    return C;
  }//end method
   
}//end class

希望本文所述对大家学习java程序设计有所帮助。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/150576.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号