疎行列計算(UJMP,疎行列,Java8,ラムダ式)
今回は私的な行列計算のコードをアップします.なので,説明を書こうとは思っていないのですいません.
後にもしかしたら入れるかもしれません.
後にもしかしたら入れるかもしれません.
ラムダ式を用いた行列計算
long start = System.currentTimeMillis();// 開始時間取得 long rowline=0;// 行列の行番号の変数初期化 ListrowValue = new ArrayList ();// 任意の行にある非ゼロ要素の列番号リスト // 行列計算 for(long[] m : m1.availableCoordinates()){ if(rowline == m[0]){ rowValue.add(BigDecimal.valueOf(m1.getAsDouble(m[0], m[1]) * a.getAsDouble(m[1], 0))); } else { // リストrowValueの値を合算し,行列にセット a_tmp.setAsDouble(rowValue.parallelStream().reduce((value1, value2) -> value1.add(value2)).get().doubleValue(), rowline, 0); rowline = m[0];// 行番号を更新 rowValue.clear();// リストを初期化 rowValue.add(BigDecimal.valueOf(m1.getAsDouble(m[0], m[1]) * a.getAsDouble(m[1], 0))); } } a_tmp.setAsDouble(rowValue.parallelStream().reduce((value1, value2) -> value1.add(value2)).get().doubleValue(), rowline, 0); long end = System.currentTimeMillis();// 終了時間取得 System.out.println("time:" + (end-start) + "ms");// 実行時間表示
非ゼロ要素のみ対象にした行列計算
long start = System.currentTimeMillis();// 開始時間取得 long rowline = 0L;// 行列の行番号の変数初期化 Double sum =0.0; for(long[] m : m1.availableCoordinates()){ if(rowline == m[0]){ sum += m1.getAsDouble(m[0], m[1]) * a.getAsDouble(m[1], 0); } else { a_tmp.setAsDouble(sum, rowline, 0); rowline = m[0]; sum = m1.getAsDouble(m[0], m[1]) * a.getAsDouble(m[1], 0); } } a_tmp.setAsDouble(sum, rowline, 0); long end = System.currentTimeMillis();// 終了時間取得 System.out.println("time:" + (end-start) + "ms"); System.out.print("a: \n" + a_tmp );
行のリストからの行列計算
long start = System.currentTimeMillis();// 開始時間取得 ListrowList = m1.getColumnList();// 行のリスト取得 for(int i=0; i < rowList.size(); i++){ Matrix row = rowList.get(i); double listSum =0.0; for(int j=0; j< row.getColumnCount(); j++){ listSum += rowList.get(i).getAsDouble(0, j) * a.getAsDouble(j, 0); } a_tmp.setAsDouble(listSum, i, 0); } long end = System.currentTimeMillis();// 終了時間取得 System.out.println("time:" + (end-start) + "ms");// 実行時間表示
ラムダ式を用いた並列行列計算
public class Parallel { public static void main(String[] args) { long start = System.currentTimeMillis(); m1.getRowList().parallelStream() .map(matrix -> Parallel.getTimesValue(matrix, a)); .forEachOrdered(value -> Parallel.aList.add(value)); long end = System.currentTimeMillis(); System.out.println("time:" + (end-start) + "ms"); } static double getTimesValue(Matrix matrix, Matrix a){ double listSum =0.0; for(int i=0; i< matrix.getColumnCount(); i++){ listSum += matrix.getAsDouble(0, i) * a.getAsDouble(i, 0); } return listSum; }