37

潜在的な階乗素数 (n!+-1 の形式の数) の約数を見つけようとしましたが、最近 Skylake-X ワークステーションを購入したので、AVX512 命令を使用して速度を上げることができると考えました。

アルゴリズムは単純で、主なステップは、同じ除数に関してモジュロを繰り返し取ることです。主なことは、広い範囲の n 値をループすることです。これはcで書かれた単純なアプローチです(Pは素数の表です):

uint64_t factorial_naive(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i, residue;
for (i = 0; i < APP_BUFLEN; i++){
    residue = 2;
    for (n=3; n <= nmax; n++){
        residue *=  n;
        residue %= P[i];
        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1){
                report_factor(n, -1, P[i]);
            }
            if(residue == P[i]- 1){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

ここでの考え方は、n の広い範囲、たとえば 1,000,000 -> 10,000,000 を同じ除数のセットに対してチェックすることです。したがって、同じ除数を数百万回モジュロ尊重します。DIV の使用は非常に遅いため、計算の範囲に応じていくつかのアプローチが可能です。ここで、私の場合、n は 10^7 未満である可能性が高く、潜在的な除数 p は 10,000 G (< 10^13) 未満です。したがって、数値は 64 ビット未満であり、53 ビット未満でもあります!最大剰余 (p-1) x n が 64 ビットより大きい。したがって、64ビットより大きい数値からモジュロを取得しているため、モンゴメリー法の最も単純なバージョンは機能しないと思いました。

double を使用する場合、FMA を使用して最大 106 ビット (推測) の正確な積を取得する、power pc の古いコードを見つけました。そこで、このアプローチを AVX 512 アセンブラー (Intel Intrinsics) に変換しました。これは、FMA メソッドの単純なバージョンです。これは Dekker (1971) の作業に基づいています。Dekker プロダクトと TwoProduct の FMA バージョンは、この背後にある理論的根拠を見つけたり、グーグルで調べたりするときに役立つ言葉です。また、このアプローチについては、このフォーラム (例:ここ) で議論されています。

int64_t factorial_FMA(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i;
double prime_double, prime_double_reciprocal, quotient, residue;
double nr, n_double, prime_times_quotient_high, prime_times_quotient_low;

for (i = 0; i < APP_BUFLEN; i++){
    residue = 2.0;
    prime_double = (double)P[i];
    prime_double_reciprocal = 1.0 / prime_double;
    n_double = 3.0;
    for (n=3; n <= nmax; n++){
        nr =  n_double * residue;
        quotient = fma(nr, prime_double_reciprocal, rounding_constant);
        quotient -= rounding_constant;
        prime_times_quotient_high= prime_double * quotient;
        prime_times_quotient_low = fma(prime_double, quotient, -prime_times_quotient_high);
        residue = fma(residue, n, -prime_times_quotient_high) - prime_times_quotient_low;

        if (residue < 0.0) residue += prime_double;
        n_double += 1.0;

        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1.0){
                report_factor(n, -1, P[i]);
            }
            if(residue == prime_double - 1.0){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

ここで私は魔法の定数を使用しました

static const double rounding_constant = 6755399441055744.0; 

つまり、倍精度の 2^51 + 2^52 マジック ナンバーです。

これを AVX512 (ループごとに 32 の潜在的な除数) に変換し、IACA を使用して結果を分析しました。スループットのボトルネック: 割り当てリソースが利用できないため、バックエンドとバックエンドの割り当てが停滞していることがわかりました。私はアセンブラーの経験があまりないので、質問は、これを高速化し、このバックエンドのボトルネックを解決するためにできることはありますか?

AVX512 コードはここにあり、 githubからも見つけることができます

uint64_t factorial_AVX512_unrolled_four(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
// we are trying to find a factor for a factorial numbers : n! +-1
//nmin is minimum n we want to report and nmax is maximum. P is table of primes
// we process 32 primes in one loop.
// naive version of the algorithm is int he function factorial_naive
// and simple version of the FMA based approach in the function factorial_simpleFMA

const double one_table[8] __attribute__ ((aligned(64))) ={1.0, 1.0, 1.0,1.0,1.0,1.0,1.0,1.0};


uint64_t n;

__m512d zero, rounding_const, one, n_double;

__m512i prime1, prime2, prime3, prime4;

__m512d residue1, residue2, residue3, residue4;
__m512d prime_double_reciprocal1, prime_double_reciprocal2, prime_double_reciprocal3, prime_double_reciprocal4;
__m512d quotient1, quotient2, quotient3, quotient4;
__m512d prime_times_quotient_high1, prime_times_quotient_high2, prime_times_quotient_high3, prime_times_quotient_high4;
__m512d prime_times_quotient_low1, prime_times_quotient_low2, prime_times_quotient_low3, prime_times_quotient_low4;
__m512d nr1, nr2, nr3, nr4;
__m512d prime_double1, prime_double2, prime_double3, prime_double4;
__m512d prime_minus_one1, prime_minus_one2, prime_minus_one3, prime_minus_one4;

__mmask8 negative_reminder_mask1, negative_reminder_mask2, negative_reminder_mask3, negative_reminder_mask4;
__mmask8 found_factor_mask11, found_factor_mask12, found_factor_mask13, found_factor_mask14;
__mmask8 found_factor_mask21, found_factor_mask22, found_factor_mask23, found_factor_mask24;

// load data and initialize cariables for loop
rounding_const = _mm512_set1_pd(rounding_constant);
one = _mm512_load_pd(one_table);
zero = _mm512_setzero_pd ();

// load primes used to sieve
prime1 = _mm512_load_epi64((__m512i *) &P[0]);
prime2 = _mm512_load_epi64((__m512i *) &P[8]);
prime3 = _mm512_load_epi64((__m512i *) &P[16]);
prime4 = _mm512_load_epi64((__m512i *) &P[24]);

// convert primes to double
prime_double1 = _mm512_cvtepi64_pd (prime1); // vcvtqq2pd
prime_double2 = _mm512_cvtepi64_pd (prime2); // vcvtqq2pd
prime_double3 = _mm512_cvtepi64_pd (prime3); // vcvtqq2pd
prime_double4 = _mm512_cvtepi64_pd (prime4); // vcvtqq2pd

// calculates 1.0/ prime
prime_double_reciprocal1 = _mm512_div_pd(one, prime_double1);
prime_double_reciprocal2 = _mm512_div_pd(one, prime_double2);
prime_double_reciprocal3 = _mm512_div_pd(one, prime_double3);
prime_double_reciprocal4 = _mm512_div_pd(one, prime_double4);

// for comparison if we have found factors for n!+1
prime_minus_one1 = _mm512_sub_pd(prime_double1, one);
prime_minus_one2 = _mm512_sub_pd(prime_double2, one);
prime_minus_one3 = _mm512_sub_pd(prime_double3, one);
prime_minus_one4 = _mm512_sub_pd(prime_double4, one);

// residue init
residue1 =  _mm512_set1_pd(2.0);
residue2 =  _mm512_set1_pd(2.0);
residue3 =  _mm512_set1_pd(2.0);
residue4 =  _mm512_set1_pd(2.0);

// double counter init
n_double = _mm512_set1_pd(3.0);

// main loop starts here. typical value for nmax can be 5,000,000 -> 10,000,000

for (n=3; n<=nmax; n++) // main loop
{

    // timings for instructions:
    // _mm512_load_epi64 = vmovdqa64 : L 1, T 0.5
    // _mm512_load_pd = vmovapd : L 1, T 0.5
    // _mm512_set1_pd
    // _mm512_div_pd = vdivpd : L 23, T 16
    // _mm512_cvtepi64_pd = vcvtqq2pd : L 4, T 0,5

    // _mm512_mul_pd = vmulpd :  L 4, T 0.5
    // _mm512_fmadd_pd = vfmadd132pd, vfmadd213pd, vfmadd231pd :  L 4, T 0.5
    // _mm512_fmsub_pd = vfmsub132pd, vfmsub213pd, vfmsub231pd : L 4, T 0.5
    // _mm512_sub_pd = vsubpd : L 4, T 0.5
    // _mm512_cmplt_pd_mask = vcmppd : L ?, Y 1
    // _mm512_mask_add_pd = vaddpd :  L 4, T 0.5
    // _mm512_cmpeq_pd_mask = vcmppd L ?, Y 1
    // _mm512_kor = korw L 1, T 1

    // nr = residue *  n
    nr1 = _mm512_mul_pd (residue1, n_double);
    nr2 = _mm512_mul_pd (residue2, n_double);
    nr3 = _mm512_mul_pd (residue3, n_double);
    nr4 = _mm512_mul_pd (residue4, n_double);

    // quotient = nr * 1.0/ prime_double + rounding_constant
    quotient1 = _mm512_fmadd_pd(nr1, prime_double_reciprocal1, rounding_const);
    quotient2 = _mm512_fmadd_pd(nr2, prime_double_reciprocal2, rounding_const);
    quotient3 = _mm512_fmadd_pd(nr3, prime_double_reciprocal3, rounding_const);
    quotient4 = _mm512_fmadd_pd(nr4, prime_double_reciprocal4, rounding_const);

    // quotient -= rounding_constant, now quotient is rounded to integer
    // countient should be at maximum nmax (10,000,000)
    quotient1 = _mm512_sub_pd(quotient1, rounding_const);
    quotient2 = _mm512_sub_pd(quotient2, rounding_const);
    quotient3 = _mm512_sub_pd(quotient3, rounding_const);
    quotient4 = _mm512_sub_pd(quotient4, rounding_const);

    // now we calculate high and low for prime * quotient using decker product (FMA).
    // quotient is calculated using approximation but this is accurate for given quotient
    prime_times_quotient_high1 = _mm512_mul_pd(quotient1, prime_double1);
    prime_times_quotient_high2 = _mm512_mul_pd(quotient2, prime_double2);
    prime_times_quotient_high3 = _mm512_mul_pd(quotient3, prime_double3);
    prime_times_quotient_high4 = _mm512_mul_pd(quotient4, prime_double4);


    prime_times_quotient_low1 = _mm512_fmsub_pd(quotient1, prime_double1, prime_times_quotient_high1);
    prime_times_quotient_low2 = _mm512_fmsub_pd(quotient2, prime_double2, prime_times_quotient_high2);
    prime_times_quotient_low3 = _mm512_fmsub_pd(quotient3, prime_double3, prime_times_quotient_high3);
    prime_times_quotient_low4 = _mm512_fmsub_pd(quotient4, prime_double4, prime_times_quotient_high4);

    // now we calculate new reminder using decker product and using original values
    // we subtract above calculated prime * quotient (quotient is aproximation)

    residue1 = _mm512_fmsub_pd(residue1, n_double, prime_times_quotient_high1);
    residue2 = _mm512_fmsub_pd(residue2, n_double, prime_times_quotient_high2);
    residue3 = _mm512_fmsub_pd(residue3, n_double, prime_times_quotient_high3);
    residue4 = _mm512_fmsub_pd(residue4, n_double, prime_times_quotient_high4);

    residue1 = _mm512_sub_pd(residue1, prime_times_quotient_low1);
    residue2 = _mm512_sub_pd(residue2, prime_times_quotient_low2);
    residue3 = _mm512_sub_pd(residue3, prime_times_quotient_low3);
    residue4 = _mm512_sub_pd(residue4, prime_times_quotient_low4);

    // lets check if reminder < 0
    negative_reminder_mask1 = _mm512_cmplt_pd_mask(residue1,zero);
    negative_reminder_mask2 = _mm512_cmplt_pd_mask(residue2,zero);
    negative_reminder_mask3 = _mm512_cmplt_pd_mask(residue3,zero);
    negative_reminder_mask4 = _mm512_cmplt_pd_mask(residue4,zero);

    // we and prime back to reminder using mask if it was < 0
    residue1 = _mm512_mask_add_pd(residue1, negative_reminder_mask1, residue1, prime_double1);
    residue2 = _mm512_mask_add_pd(residue2, negative_reminder_mask2, residue2, prime_double2);
    residue3 = _mm512_mask_add_pd(residue3, negative_reminder_mask3, residue3, prime_double3);
    residue4 = _mm512_mask_add_pd(residue4, negative_reminder_mask4, residue4, prime_double4);

    n_double = _mm512_add_pd(n_double,one);

    // if we are below nmin then we continue next iteration
    if (n < nmin) continue;

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
    { // we find factor very rarely

        double *residual_list1 = (double *) &residue1;
        double *residual_list2 = (double *) &residue2;
        double *residual_list3 = (double *) &residue3;
        double *residual_list4 = (double *) &residue4;

        double *prime_list1 = (double *) &prime_double1;
        double *prime_list2 = (double *) &prime_double2;
        double *prime_list3 = (double *) &prime_double3;
        double *prime_list4 = (double *) &prime_double4;



        for (int i=0; i <8; i++){
            if( residual_list1[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list1[i]);
            }
            if( residual_list2[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list2[i]);
            }
            if( residual_list3[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list3[i]);
            }
            if( residual_list4[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list4[i]);
            }

            if(residual_list1[i] == (prime_list1[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list1[i]);
            }
            if(residual_list2[i] == (prime_list2[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list2[i]);
            }
            if(residual_list3[i] == (prime_list3[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list3[i]);
            }
            if(residual_list4[i] == (prime_list4[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list4[i]);
            }
        }
    }

}

return EXIT_SUCCESS;
}
4

1 に答える 1

1

いくつかのコメンターが示唆しているように、「バックエンド」のボトルネックは、このコードに期待されるものです。これは、あなたが物事を十分に食べさせていることを示しています。これはあなたが望んでいることです.

レポートを見ると、このセクションに機会があるはずです。

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)

IACA 分析から:

|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r11d, k0
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw eax, k1
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw ecx, k2
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw esi, k3
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw edi, k4
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r8d, k5
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r9d, k6
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r10d, k7
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, eax
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, ecx
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, esi
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, edi
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, r8d
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, r9d
|   1*     |             |      |             |             |      |      |      |      | or r11d, r10d

プロセッサは、結果の比較マスク (k0 ~ k7) を「or」演算のために通常のレジスタに移動しています。これらの動きを排除し、6ops 対 8 で「or」ロールアップを実行できるはずです。

注: found_factor_mask 型は として定義されてい__mmask8ます__mask16。これにより、コンパイラはいくつかの最適化を行うことができます。そうでない場合は、コメント投稿者が指摘したように、アセンブリにドロップしてください。

関連: この or-mask 句を起動する反復の割合は? 別のコメンターが観察したように、これは累積的な「または」操作で展開できるはずです。展開された各反復の最後 (または N 回の反復後) に蓄積された「または」値を確認し、それが「真」の場合は、値を戻して値を再実行し、どの n 値がそれをトリガーしたかを判断します。

(そして、「ロール」内でバイナリ検索を実行して、一致する n 値を見つけることができます。これにより、ある程度の利益が得られる可能性があります)。

次に、この中間ループ チェックを取り除くことができるはずです。

    // if we are below nmin then we continue next iteration, we
    if (n < nmin) continue;

ここに表示されるもの:

|   1*     |             |      |             |             |      |      |      |      | cmp r14, 0x3e8
|   0*F    |             |      |             |             |      |      |      |      | jb 0x229

予測子が (おそらく) これを (ほとんど) 正しく取得するため、大きな利益にはならないかもしれませんが、2 つの「フェーズ」に対して 2 つの異なるループを持つことで、ある程度の利益が得られるはずです。

  • n=3 ~ n=nmin-1
  • n=nmin以上

サイクルが増えたとしても、それは 3% です。そして、それは一般に、上記の大きな「or」演算に関連しているため、そこにはさらに巧妙さが見られる可能性があります。

于 2018-09-24T22:22:32.110 に答える