Java实现矩阵相乘

时间:2022-12-26 14:09:26

事情原委:
大一刚学Java的时候,当时也学了线性代数这门课,在方程组基础向矩阵转换的思想方法上,感觉矩阵计算比较麻烦,所以当时一直想写一个小程序来通过输入两个矩阵输出相应的相乘之后的结果矩阵。鉴于当时刚刚接手Java,对于基本的输入输出都没有搞清楚,对于矩阵也运用不灵活,方法也不怎么会声明,出现一个数组越界或者空指针错误就写不下去了,后来MATLAB的应用也把它抛在一边,现在花点小时间来实现一下当时的小愿望。

M*N 矩阵 与 N * P 矩阵相乘,得到 M * P 矩阵。

M*N 矩阵是指矩阵有M行,N列,相应的矩阵Matrix用二维数组来实现。
Matrix.length =M,Matrix[0].length=N 。

为了便于操作,for循环的嵌套,在计算的时候把第二个矩阵转置,在操作上会简单很多。

闲话少说了,直接上代码:

import java.util.Scanner;

public class Matrix {
    public int[][]inputMatrix(int m,int n){
        int matrix1[][]=new int[m][n];
        Scanner sc= new Scanner(System.in);
        for(int i=0;i<m;i++){
            for(int j=0;j<n;j++){
                matrix1[i][j]=sc.nextInt();
            }
        }
        return matrix1;
    }
    public int [][]transMatirx(int [][]matrix){
        if(matrix.length==0)
            return null;
        int [][]result= new int[matrix[0].length][matrix.length];
        for(int i=0;i<matrix.length;i++){
            for(int j=0;j<matrix[0].length;j++){
                result[j][i]=matrix[i][j];
            }
        }
        return result;
    }

    public void print(int [][]matrix){
        if(matrix.length!=0){
            for(int i=0;i<matrix.length;i++){
                for(int j=0;j<matrix[0].length;j++){
                    System.out.print(matrix[i][j]+" ");
                }
                System.out.println();
            }
        }
    }
    public int [][]mulMatrix(int [][]matrix1,int[][]matrix2){
        if(matrix1.length==0||matrix2.length==0||matrix1[0].length!=matrix2.length){
            return null;
        }

        int m=matrix1.length;       //matrix1 行数
        int n=matrix1[0].length;//matrix1列数(每一行的元素数目)
        int [][]matrix3= transMatirx(matrix2);//将矩阵matrix2转置
        int p=matrix3.length;
        int [][]result= new int [m][p];


        for(int t=0;t<m;t++){
            for(int s=0;s<p;s++){
                int result1=0;
                for(int q=0;q<n;q++){
                    result1+=matrix1[t][q]*matrix3[s][q];
                }
                result[t][s]=result1;
            }
        }
        return result;
    }

}

下边是main方法:

public static void main(String []args){
        Matrix mm= new Matrix();
        int m,n,p;
        Scanner scan =new Scanner(System.in);
        System.out.println("input the size of the matrix1:");
        m=scan.nextInt();
        System.out.println("input the col of the matrix1:");
        n=scan.nextInt();
        System.out.println("input the size of the matrix2:");
        p=scan.nextInt();

        int matrix1[][]=new int[m][n];//m*n Matrix
        System.out.println("input the matrix1:");
        matrix1=mm.inputMatrix(m, n);

        int matrix2[][]=new int[n][p];// n*p Matrix
        System.out.println("input the matrix2:");
        matrix2=mm.inputMatrix(n, p);

        int result[][]=mm.mulMatrix(matrix1, matrix2);
        if(result==null){
            System.out.println("the two matrix can't multply with each other.");
        }
        else{
            System.out.println("the result is:");
            mm.print(result);
        }           
    }