5

大きな正方行列から非対角要素を削除する必要があるアプリケーションで、小さなパフォーマンスのボトルネックがあります。したがって、マトリックスx

17    24     1     8    15
23     5     7    14    16
 4     6    13    20    22
10    12    19    21     3
11    18    25     2     9

になる

17     0     0     0     0
 0     5     0     0     0
 0     0    13     0     0
 0     0     0    21     0
 0     0     0     0     9

質問:以下の bsxfun と diag のソリューションは、これまでのところ最速のソリューションであり、コードを Matlab に保持したまま改善できるとは思えませんが、より高速な方法はありますか?

ソリューション

ここまでで考えたこと。

単位行列で要素ごとの乗算を実行します。これが最も簡単な解決策です。

y = x .* eye(n);

bsxfunと の使用diag:

y = bsxfun(@times, diag(x), eye(n));

下/上三角行列:

y = x - tril(x, -1) - triu(x, 1);

ループを使用したさまざまなソリューション:

y = x;
for ix=1:n
    for jx=1:n
        if ix ~= jx
            y(ix, jx) = 0;
        end
    end
end

y = x;
for ix=1:n
    for jx=1:ix-1
        y(ix, jx) = 0;
    end
    for jx=ix+1:n
        y(ix, jx) = 0;
    end
end

タイミング

ソリューションは実際にはbsxfun最速です。これは私のタイミングコードです:

function timing()
clear all

n = 5000;
x = rand(n, n);

f1 = @() tf1(x, n);
f2 = @() tf2(x, n);
f3 = @() tf3(x);
f4 = @() tf4(x, n);
f5 = @() tf5(x, n);

t1 = timeit(f1);
t2 = timeit(f2);
t3 = timeit(f3);
t4 = timeit(f4);
t5 = timeit(f5);

fprintf('t1: %f s\n', t1)
fprintf('t2: %f s\n', t2)
fprintf('t3: %f s\n', t3)
fprintf('t4: %f s\n', t4)
fprintf('t5: %f s\n', t5)
end

function y = tf1(x, n)
y = x .* eye(n);
end


function y = tf2(x, n)
y = bsxfun(@times, diag(x), eye(n));
end


function y = tf3(x)
y = x - tril(x, -1) - triu(x, 1);
end


function y = tf4(x, n)
y = x;
for ix=1:n
    for jx=1:n
        if ix ~= jx
            y(ix, jx) = 0;
        end
    end
end
end


function y = tf5(x, n)
y = x;
for ix=1:n
    for jx=1:ix-1
        y(ix, jx) = 0;
    end
    for jx=ix+1:n
        y(ix, jx) = 0;
    end
end
end

返す

t1: 0.111117 s
t2: 0.078692 s
t3: 0.219582 s
t4: 1.183389 s
t5: 1.198795 s
4

2 に答える 2

9

見つけた:

diag(diag(x))

よりも高速ですbsxfun。同様に:

diag(x(1:size(x,1)+1:end))

ほぼ同じ量だけ高速です。timeitforで遊んで、私はあなたよりも〜20倍x=rand(5000)速くなりました。bsxfun

編集:

これは以下と同等diag(diag(...です:

x2(n,n)=0;
x2(1:n+1:end)=x(1:n+1:end);

事前に割り当てる方法x2が重要であることに注意してください。使用するだけではx2=zeros(n)、ソリューションが遅くなります。詳細については、このディスカッションを参照してください...

于 2013-10-04T17:56:18.533 に答える
8

さまざまなループ関数は、実装がはるかに遅いため、わざわざテストしませんでしたが、他のものと、以前に使用した別の方法をテストしました。

y = diag(diag(x));

スポイラーは次のとおりです。

c1: 193.18 milliseconds  // multiply by identity
c2: 102.16 milliseconds  // bsxfun
c3: 342.24 milliseconds  // tril and triu
c4:   6.03 milliseconds  // call diag twice

私のマシンでは、 2回の呼び出しがdiag断然最速のようです。

完全なタイミング コードは次のとおりです。ではなく、独自のベンチマーク機能を使用timeitしましたが、結果は同等である必要があります(自分で確認できます)。

>> x = randn(5000);

>> c1 = @() x .* eye(5000);
>> c2 = @() bsxfun(@times, diag(x), eye(5000));
>> c3 = @() x - tril(x,-1) - triu(x,1);
>> c4 = @() diag(diag(x));


>> benchmark.bench(c1)

Benchmarking @()x.*eye(5000)
   Mean: 193.18 milliseconds, lb 191.94 milliseconds, ub 194.25 milliseconds, ci 95%
  Stdev: 6.01 milliseconds, lb 3.27 milliseconds, ub 8.58 milliseconds, ci 95%

>> benchmark.bench(c2)

Benchmarking @()bsxfun(@times,diag(x),eye(5000))
   Mean: 102.16 milliseconds, lb 100.83 milliseconds, ub 103.44 milliseconds, ci 95%
  Stdev: 6.61 milliseconds, lb 6.04 milliseconds, ub 7.07 milliseconds, ci 95%

>> benchmark.bench(c3)

Benchmarking @()x-tril(x,-1)-triu(x,1)
   Mean: 342.24 milliseconds, lb 340.28 milliseconds, ub 344.20 milliseconds, ci 95%
  Stdev: 10.06 milliseconds, lb 8.85 milliseconds, ub 11.17 milliseconds, ci 95%

>> benchmark.bench(c4)

Benchmarking @()diag(diag(x))
   Mean: 6.03 milliseconds, lb 5.96 milliseconds, ub 6.09 milliseconds, ci 95%
  Stdev: 0.34 milliseconds, lb 0.27 milliseconds, ub 0.40 milliseconds, ci 95%
于 2013-10-04T17:57:32.430 に答える