跟着Java源码学习稀疏矩阵
import java.util.HashMap;
import java.util.Map;
public class SparseMatrix {
// 使用嵌套Map存储稀疏矩阵的非零元素
private final Map<Integer, Map<Integer, Integer>> matrix;
private final int rows;
private final int cols;
public SparseMatrix(int rows, int cols) {
this.rows = rows;
this.cols = cols;
this.matrix = new HashMap<>();
}
// 添加或更新矩阵中的元素
public void addElement(int row, int col, int value) {
if (row >= rows || col >= cols) {
throw new IndexOutOfBoundsException("Invalid row or column index.");
}
if (value == 0) {
// 如果添加的值为0,直接删除存储的元素(若存在)
if (matrix.containsKey(row)) {
matrix.get(row).remove(col);
if (matrix.get(row).isEmpty()) {
matrix.remove(row);
}
}
} else {
matrix.computeIfAbsent(row, k -> new HashMap<>()).put(col, value);
}
}
// 获取矩阵中的元素值
public int getElement(int row, int col) {
if (row >= rows || col >= cols) {
throw new IndexOutOfBoundsException("Invalid row or column index.");
}
return matrix.getOrDefault(row, new HashMap<>()).getOrDefault(col, 0);
}
// 矩阵相加
public SparseMatrix add(SparseMatrix other) {
if (this.rows != other.rows || this.cols != other.cols) {
throw new IllegalArgumentException("Matrix dimensions must match for addition.");
}
SparseMatrix result = new SparseMatrix(rows, cols);
// 遍历当前矩阵的非零元素并相加
for (var rowEntry : matrix.entrySet()) {
int row = rowEntry.getKey();
for (var colEntry : rowEntry.getValue().entrySet()) {
int col = colEntry.getKey();
int value = colEntry.getValue() + other.getElement(row, col);
result.addElement(row, col, value);
}
}
// 遍历其他矩阵的非零元素,避免重复计算
for (var rowEntry : other.matrix.entrySet()) {
int row = rowEntry.getKey();
for (var colEntry : rowEntry.getValue().entrySet()) {
int col = colEntry.getKey();
if (!matrix.containsKey(row) || !matrix.get(row).containsKey(col)) {
int value = colEntry.getValue();
result.addElement(row, col, value);
}
}
}
return result;
}
// 打印稀疏矩阵(仅显示非零元素)
public void print() {
for (var rowEntry : matrix.entrySet()) {
int row = rowEntry.getKey();
for (var colEntry : rowEntry.getValue().entrySet()) {
int col = colEntry.getKey();
int value = colEntry.getValue();
System.out.println("Element at (" + row + ", " + col + ") = " + value);
}
}
}
}