1

過去数時間、行列の乗算に Strassen のアルゴリズムを実装しようとしてきましたが、正しい積を取得するのに苦労しました。ヘルパー関数 (helpSub、createProd、helpProduct) の 1 つが、strass2 関数の問題または形式 (コマンドの順序など) である可能性があると思います。私は完全に困惑しているので、どんなヒントでも大歓迎です。テスター マトリックスとして 2 つの 4 x 4 マトリックスを使用しています。インターネットで見た p1-p7 と c1-c4 のバリエーションをたくさん試しましたが、どれもうまくいかないようです。以下は私が作成したクラスです。

 /* @author williamnewman

public class strassen2 {

//Main Strassen multiplication function
//BASE CASE:
int [][] strass2(int[][] x, int[][]y){
    if(x.length == 1 && y.length == 1){
        System.out.println("Donezo");
        int [][] nu = new int[1][1];
        nu[0][0] = x[0][0] * y[0][0];
        return nu;

    }
    else{
   int[][] a,b,c,d,e,f,g,h;
   int dim = x.length/2;

//Dividing two matrices into 8 sub matrices
  System.out.println("A<B<C");
   a = helpSub(0,0,x);
   C(a);
   b = helpSub(0,dim,x);

   C(b);
   c = helpSub(dim,0,x);
   C(c);
   d = helpSub(dim,dim,x);
   C(d);
   e = helpSub(0,0,y);
   C(e);
   f = helpSub(0,dim,y);
   C(f);
   g = helpSub(dim,0,y);
   C(g);
   h = helpSub(dim,dim,y);
   C(h);

   int[][] p1,p2,p3,p4,p5,p6,p7;


//Creating p1-p7
   /
   p1 = strass2(a,subtract(f,h));
   p2 = strass2(h, add(a,b));
   p3 = strass2(e,add(c,d));
   p4 = strass2(d,subtract(g,e));
   p5 = strass2(add(a,d),add(e,h));
   p6 = strass2(subtract(b,d),add(g,h));
   p7 = strass2(subtract(a,c),add(e,f));
   int [][] prod;
   int [][] c1,c2,c3,c4;

//Creating c1-c4
   c1 = subtract(add(p6,p5),subtract(p4,p2));
   c2 = add(p1,p2);
   c3 = add(p3,p4);
   c4 = subtract(add(p1,p5),subtract(p3,p7));
   C(c1);
   System.out.println("C1::");
   C(c2);
   System.out.println("C2::");
   C(c3);
   System.out.println("C3::");
   C(c4);
   System.out.println("C4::");
//CREATES PRODUCT MATRIX
   prod = createProd(c1,c2,c3,c4);
   return prod;

    }




}

//Creates product matrix from c1-c4
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
    int[][] product = new int[c1.length*2][c1.length*2];
    int mid = c1.length;
    int fin = c1.length * 2;
    helpProduct(0,0,mid,mid,product,c1);
    helpProduct(0,mid,mid,fin,product,c2);
    helpProduct(mid,0,fin,mid,product,c3);
    helpProduct(mid,mid,fin,fin,product,c4);

     System.out.println();
    System.out.println("PRODUCT::!:");
    C(product);
    return product;



}

    //Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
    int indR = 0;
    int indC = 0;
    for(int i = x; i < z1; i++){
        indC = 0;
        for(int j = y; j < z2; j++){
            product[i][j] = a1[indR][indC];
            indC++;
        }
        indR++;
    }
}


    int[][] helpSub(int x, int y, int[][] mat){
    int[][] sub = new int[mat.length/2][mat.length/2];
    for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
    for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
    {
            sub[i1][j1] = mat[i2][j2];
                           // System.out.println(sub[i1][j1]);
    }
    return sub;
}



//Normal Matrix Multiplication Function
int[][] multiply(int[][]a,int[][]b){
    MM nu = new MM(a,b);
    return nu.product;
}

    //Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
    int [][] nu = new int[a.length][a[0].length];
    for(int i = 0; i < a.length; i++){
        for(int j = 0; j < a[i].length;j++){
            nu[i][j] = a[i][j] + b[i][j];
        }
    }
    return nu;
}

//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
    int [][] sub = new int[a.length][a.length];
    //System.out.println("made it");
    for(int i = 0; i < a.length; i++){
        for(int j = 0; j < a[i].length;j++){
            sub[i][j] = a[i][j] - b[i][j];
        }
    }
    return sub;
}
//Prints the matrix
 void C(int[][] product){
    for(int i = 0; i <product.length; i++){
        for(int j = 0; j < product[i].length; j++){
            System.out.print(product[j][i]  + " ");

        }
        System.out.println();
    }
}
}

わかりにくい点があればお知らせください。質問を更新します。

主な機能は次のとおりです::

      public static void main(String[] args) {
        int [][]a = {{1,2,3,4},
            {4,3,2,1},
            {1,2,3,4},
            {4,3,2,1}};

        int [][]b = {{3,4,5,6},
            {3,4,5,6},
            {5,4,3,2},
            {5,4,3,2}
        };
        MM a1 = new MM(a,b);
        a1.C();
        int[][] prod;
        System.out.println("----");
        strassen2 a2 = new strassen2();
        prod = a2.strass2(a,b);
        a2.C(prod);
    }

}

これまでの結果は次のとおりです (期待される結果は最初に表示された 4x4 マトリックスであり、実際の結果は最後に表示された 4x4 マトリックスです)。

EXPECTED:

44 40 36 32 
36 40 44 48 
44 40 36 32 
36 40 44 48 
----


ACTUAL::
70 78 50 42 
86 86 34 34 
30 38 30 38 
38 54 38 54 

私の helpSub() 関数は、修正された ah を生成したため、機能すると確信しています。ただし、strass2 の再帰呼び出しで使用するパラメーターに問題がある可能性があります。それが十分に具体的でない場合は申し訳ありませんが、私は少し燃え尽きており、誰かが明白な問題を見た場合に興味がありました.

4

1 に答える 1