4

これを達成するためのより効率的な方法はありますか:Aサイズの配列nと2つの正の整数aとが与えられた場合、すべてのペアで取られbた合計を求めます。ここで。floor(abs(A[i]-A[j])*a/b)(i, j)0 <= i < j < n

int A[n];
int a, b; // assigned some positive integer values
...
int total = 0;
for (int i = 0; i < n; i++) {
    for (int j = i+1; j < n; j++) {
        total += abs(A[i]-A[j])*a/b; // want integer division here
    }
}

これを少し最適化するために、配列()を並べ替えてから、関数O(nlogn)を使用しませんでした。absまた、内側のforループの前に値をキャッシュしたので、順番a[i]にデータを読み取ることができました。A私はそれを事前に計算a/bしてフロートに保存することを検討していましたが、余分なキャストはそれを遅くします(特に結果の床を取りたいので)。

私はより良い解決策を思い付くことができませんでしたO(n^2)

4

1 に答える 1

2

はい、より効率的なアルゴリズムがあります。O(n * log n)で実行できます。漸近的に高速な方法があるとは思いませんが、証明のアイデアにはほど遠いです。

アルゴリズム

まず、O(n * log n)時間で配列を並べ替えます。

それでは、用語を見てみましょう

floor((A[j]-A[i])*a/b) = floor ((A[j]*a - A[i]*a)/b)

のために0 <= i < j < n。それぞれについて0 <= k < n、で記述A[k]*a = q[k]*b + r[k]0 <= r[k] < bます。

の場合A[k] >= 0、整数除算を使用q[k] = (A[k]*a)/bします。の場合、除算を使用します。除算を使用しない場合は、とを使用します。r[k] = (A[k]*a)%bA[k] < 0q[k] = (A[k]*a)/b - 1r[k] = b + (A[k]*a)%bbA[k]*aq[k] = (A[k]*a)/br[k] = 0

次に、用語を書き直します。

floor((A[j]*a - A[i]*a)/b) = floor(q[j] - q[i] + (r[j] - r[i])/b)
                           = q[j] - q[i] + floor((r[j] - r[i])/b)

それぞれq[k]k正の符号(の場合i = 0, 1, .. , k-1)とn-1-k負の符号(の場合j = k+1, k+2, ..., n-1)で表示されるため、合計への寄与の合計は次のようになります。

(k - (n-1-k))*q[k] = (2*k+1-n)*q[k]

残りはまだ説明する必要があります。さて、以来0 <= r[k] < b

-b < r[j] - r[i] < b

。の場合はfloor((r[j]-r[i])/b)0です。それでr[j] >= r[i]-1r[j] < r[i]

                            n-1
 ∑ floor((A[j]-A[i])*a/b) =  ∑ (2*k+1-n)*q[k] - inversions(r)
i<j                         k=0

ここで、反転はと(i,j)のインデックスのペアです。0 <= i < j < nr[j] < r[i]

andの計算q[k]r[k]合計は(2*k+1-n)*q[k]、O(n)時間で行われます。

r[k]配列の反転を効率的にカウントすることは残っています。

各インデックスについて、次のような数0 <= k < n、つまり、大きい方のインデックスとして表示される反転の数とします。c(k)i < kr[k] < r[i]k

その場合、明らかに反転の数はです∑ c(k)

一方、安定ソートでc(k)後ろに移動する要素の数です(ここでは安定性が重要です)。r[k]

これらの動きを数えることで、配列の反転はマージソート中に簡単に実行できます。

したがって、反転はO(n * log n)でもカウントでき、O(n * log n)の全体的な複雑さを示します。

コード

単純な非科学的ベンチマークを使用したサンプル実装(ただし、単純な2次アルゴリズムと上記のアルゴリズムの違いは非常に大きいため、非科学的ベンチマークで十分に決定的です)。

#include <stdlib.h>
#include <stdio.h>
#include <time.h>

long long mergesort(int *arr, unsigned elems);
long long merge(int *arr, unsigned elems, int *scratch);
long long nosort(int *arr, unsigned elems, long long a, long long b);
long long withsort(int *arr, unsigned elems, long long a, long long b);

int main(int argc, char *argv[]) {
    unsigned count = (argc > 1) ? strtoul(argv[1],NULL,0) : 1000;
    srand(time(NULL)+count);
    long long a, b;
    b = 1000 + 9000.0*rand()/(RAND_MAX+1.0);
    a = b/3 + (b-b/3)*1.0*rand()/(RAND_MAX + 1.0);
    int *arr1, *arr2;
    arr1 = malloc(count*sizeof *arr1);
    arr2 = malloc(count*sizeof *arr2);
    if (!arr1 || !arr2) {
        fprintf(stderr,"Allocation failed\n");
        exit(EXIT_FAILURE);
    }
    unsigned i;
    for(i = 0; i < count; ++i) {
        arr1[i] = 20000.0*rand()/(RAND_MAX + 1.0) - 2000;
    }
    for(i = 0; i < count; ++i) {
        arr2[i] = arr1[i];
    }
    long long res1, res2;
    double start = clock();
    res1 = nosort(arr1,count,a,b);
    double stop = clock();
    printf("Naive:   %lld in %.3fs\n",res1,(stop-start)/CLOCKS_PER_SEC);
    start = clock();
    res2 = withsort(arr2,count,a,b);
    stop = clock();
    printf("Sorting: %lld in %.3fs\n",res2,(stop-start)/CLOCKS_PER_SEC);
    return EXIT_SUCCESS;
}

long long nosort(int *arr, unsigned elems, long long a, long long b) {
    long long total = 0;
    unsigned i, j;
    long long m;
    for(i = 0; i < elems-1; ++i) {
        m = arr[i];
        for(j = i+1; j < elems; ++j) {
            long long d = (arr[j] < m) ? (m-arr[j]) : (arr[j]-m);
            total += (d*a)/b;
        }
    }
    return total;
}

long long withsort(int *arr, unsigned elems, long long a, long long b) {
    long long total = 0;
    unsigned i;
    mergesort(arr,elems);
    for(i = 0; i < elems; ++i) {
        long long q, r;
        q = (arr[i]*a)/b;
        r = (arr[i]*a)%b;
        if (r < 0) {
            r += b;
            q -= 1;
        }
        total += (2*i+1LL-elems)*q;
        arr[i] = (int)r;
    }
    total -= mergesort(arr,elems);
    return total;
}

long long mergesort(int *arr, unsigned elems) {
    if (elems < 2) return 0;
    int *scratch = malloc((elems + 1)/2*sizeof *scratch);
    if (!scratch) {
        fprintf(stderr,"Alloc failure\n");
        exit(EXIT_FAILURE);
    }
    return merge(arr, elems, scratch);
}

long long merge(int *arr, unsigned elems, int *scratch) {
    if (elems < 2) return 0;
    unsigned left = (elems + 1)/2, right = elems-left, i, j, k;
    long long inversions = 0;
    inversions += merge(arr, left, scratch);
    inversions += merge(arr+left,right,scratch);
    if (arr[left] < arr[left-1]) {
        for(i = 0; i < left; ++i) {
            scratch[i] = arr[i];
        }
        i = 0; j = 0; k = 0;
        int *lptr = scratch, *rptr = arr+left;
        while(i < left && j < right) {
            if (rptr[j] < lptr[i]) {
                arr[k++] = rptr[j++];
                inversions += (left-i);
            } else {
                arr[k++] = lptr[i++];
            }
        }
        while(i < left) arr[k++] = lptr[i++];
    }
    return inversions;
}
于 2012-04-27T12:52:54.083 に答える