49

この質問は、新しい R パッケージにつながりました。 wrswoR

置換を使用しない R のデフォルトのサンプリングsample.intは、二次実行時間を必要とするようです。たとえば、一様分布から引き出された重みを使用する場合です。サンプルサイズが大きい場合、これは遅くなります。R 内から使用できるより高速な実装を知っている人はいますか? 2 つのオプションは、「置換による拒否サンプリング」( stats.sx のこの質問を参照) と、 Wong と Easton (1980)によるアルゴリズム ( StackOverflowの回答で Python 実装を使用) です。

sample.intが呼び出されたときに内部的に呼び出される C 関数をヒントしてくれた Ben Bolker に感謝しreplace=FますProbSampleNoReplace。実際、コードは 2 つの入れ子になったforループを示しています (の 420 行目以降random.c)。

実行時間を経験的に分析するコードは次のとおりです。

library(plyr)

sample.int.test <- function(n, p) {
    sample.int(2 * n, n, replace=F, prob=p); NULL }

times <- ldply(
  1:7,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    data.frame(
      n=n,
      user=system.time(sample.int.test(n, p), gcFirst=T)['user.self'])
  },
  .progress='text'
)

times

library(ggplot2)
ggplot(times, aes(x=n, y=user/n)) + geom_point() + scale_x_log10() +
  ylab('Time per unit (s)')

# Output:
       n   user
1   2048  0.008
2   4096  0.028
3   8192  0.100
4  16384  0.408
5  32768  1.645
6  65536  6.604
7 131072 26.558

プロット

編集:重み付けされていないサンプリングにはこのパフォーマンスのペナルティがないようであることを指摘してくれたArunに感謝します。

4

3 に答える 3

24

アップデート:

Efraimidis および SpirakisアルゴリズムのRcpp実装(@Hemmo、@Dinrem、@krlmlr、および@rtlgrmpfに感謝):

library(inline)
library(Rcpp)
src <- 
'
int num = as<int>(size), x = as<int>(n);
Rcpp::NumericVector vx = Rcpp::clone<Rcpp::NumericVector>(x);
Rcpp::NumericVector pr = Rcpp::clone<Rcpp::NumericVector>(prob);
Rcpp::NumericVector rnd = rexp(x) / pr;
for(int i= 0; i<vx.size(); ++i) vx[i] = i;
std::partial_sort(vx.begin(), vx.begin() + num, vx.end(), Comp(rnd));
vx = vx[seq(0, num - 1)] + 1;
return vx;
'
incl <- 
'
struct Comp{
  Comp(const Rcpp::NumericVector& v ) : _v(v) {}
  bool operator ()(int a, int b) { return _v[a] < _v[b]; }
  const Rcpp::NumericVector& _v;
};
'
funFast <- cxxfunction(signature(n = "Numeric", size = "integer", prob = "numeric"),
                       src, plugin = "Rcpp", include = incl)

# See the bottom of the answer for comparison
p <- c(995/1000, rep(1/1000, 5))
n <- 100000
system.time(print(table(replicate(funFast(6, 3, p), n = n)) / n))

      1       2       3       4       5       6 
1.00000 0.39996 0.39969 0.39973 0.40180 0.39882 
   user  system elapsed 
   3.93    0.00    3.96 
# In case of:
# Rcpp::IntegerVector vx = Rcpp::clone<Rcpp::IntegerVector>(x);
# i.e. instead of NumericVector
      1       2       3       4       5       6 
1.00000 0.40150 0.39888 0.39925 0.40057 0.39980 
   user  system elapsed 
   1.93    0.00    2.03 

古いバージョン:

いくつかの可能なアプローチを試してみましょう。

置換による簡単な不合格サンプリング。これは、@krlmlr が提供するよりもはるかに単純な関数です。sample.int.rejつまり、サンプル サイズは常に に等しくなりnます。これから見ていくように、重みの分布が均一であると仮定すると、それでも非常に高速ですが、別の状況では非常に遅くなります。

fastSampleReject <- function(all, n, w){
  out <- numeric(0)
  while(length(out) < n)
    out <- unique(c(out, sample(all, n, replace = TRUE, prob = w)))
  out[1:n]
}

Wong と Easton (1980) によるアルゴリズムこれは、このPython バージョンの実装です。安定しており、何かが欠けている可能性がありますが、他の関数に比べてはるかに遅いです。

fastSample1980 <- function(all, n, w){
  tws <- w
  for(i in (length(tws) - 1):0)
    tws[1 + i] <- sum(tws[1 + i], tws[1 + 2 * i + 1], 
                      tws[1 + 2 * i + 2], na.rm = TRUE)      
  out <- numeric(n)
  for(i in 1:n){
    gas <- tws[1] * runif(1)
    k <- 0        
    while(gas > w[1 + k]){
      gas <- gas - w[1 + k]
      k <- 2 * k + 1
      if(gas > tws[1 + k]){
        gas <- gas - tws[1 + k]
        k <- k + 1
      }
    }
    wgh <- w[1 + k]
    out[i] <- all[1 + k]        
    w[1 + k] <- 0
    while(1 + k >= 1){
      tws[1 + k] <- tws[1 + k] - wgh
      k <- floor((k - 1) / 2)
    }
  }
  out
}

Wong と Easton によるアルゴリズムの Rcpp 実装。これは私の最初の使用可能な関数であるため、さらに最適化できる可能性がありますRcppが、とにかくうまく機能します。

library(inline)
library(Rcpp)

src <-
'
Rcpp::NumericVector weights = Rcpp::clone<Rcpp::NumericVector>(w);
Rcpp::NumericVector tws = Rcpp::clone<Rcpp::NumericVector>(w);
Rcpp::NumericVector x = Rcpp::NumericVector(all);
int k, num = as<int>(n);
Rcpp::NumericVector out(num);
double gas, wgh;

if((weights.size() - 1) % 2 == 0){
  tws[((weights.size()-1)/2)] += tws[weights.size()-1] + tws[weights.size()-2];
}
else
{
  tws[floor((weights.size() - 1)/2)] += tws[weights.size() - 1];
}

for (int i = (floor((weights.size() - 1)/2) - 1); i >= 0; i--){
  tws[i] += (tws[2 * i + 1]) + (tws[2 * i + 2]);
}
for(int i = 0; i < num; i++){
  gas = as<double>(runif(1)) * tws[0];
  k = 0;
  while(gas > weights[k]){
    gas -= weights[k];
    k = 2 * k + 1;
    if(gas > tws[k]){
      gas -= tws[k];
      k += 1;
    }
  }
  wgh = weights[k];
  out[i] = x[k];
  weights[k] = 0;
  while(k > 0){
    tws[k] -= wgh;
    k = floor((k - 1) / 2);
  }
  tws[0] -= wgh;
}
return out;
'

fun <- cxxfunction(signature(all = "numeric", n = "integer", w = "numeric"),
                   src, plugin = "Rcpp")

今いくつかの結果:

times1 <- ldply(
  1:6,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n) # Uniform distribution
    p <- p/sum(p)
    data.frame(
      n=n,
      user=c(system.time(sample.int.test(n, p), gcFirst=T)['user.self'],
             system.time(weighted_Random_Sample(1:(2*n), p, n), gcFirst=T)['user.self'],
             system.time(fun(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(sample.int.rej(2*n, n, p), gcFirst=T)['user.self'],
             system.time(fastSampleReject(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(fastSample1980(1:(2*n), n, p), gcFirst=T)['user.self']),
      id=c("Base", "Reservoir", "Rcpp", "Rejection", "Rejection simple", "1980"))
  },
  .progress='text'
)


times2 <- ldply(
  1:6,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n - 1)
    p <- p/sum(p) 
    p <- c(0.999, 0.001 * p) # Special case
    data.frame(
      n=n,
      user=c(system.time(sample.int.test(n, p), gcFirst=T)['user.self'],
             system.time(weighted_Random_Sample(1:(2*n), p, n), gcFirst=T)['user.self'],
             system.time(fun(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(sample.int.rej(2*n, n, p), gcFirst=T)['user.self'],
             system.time(fastSampleReject(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(fastSample1980(1:(2*n), n, p), gcFirst=T)['user.self']),
      id=c("Base", "Reservoir", "Rcpp", "Rejection", "Rejection simple", "1980"))
  },
  .progress='text'
)

ここに画像の説明を入力

ここに画像の説明を入力

arrange(times1, id)
       n  user               id
1   2048  0.53             1980
2   4096  0.94             1980
3   8192  2.00             1980
4  16384  4.32             1980
5  32768  9.10             1980
6  65536 21.32             1980
7   2048  0.02             Base
8   4096  0.05             Base
9   8192  0.18             Base
10 16384  0.75             Base
11 32768  2.99             Base
12 65536 12.23             Base
13  2048  0.00             Rcpp
14  4096  0.01             Rcpp
15  8192  0.03             Rcpp
16 16384  0.07             Rcpp
17 32768  0.14             Rcpp
18 65536  0.31             Rcpp
19  2048  0.00        Rejection
20  4096  0.00        Rejection
21  8192  0.00        Rejection
22 16384  0.02        Rejection
23 32768  0.02        Rejection
24 65536  0.03        Rejection
25  2048  0.00 Rejection simple
26  4096  0.01 Rejection simple
27  8192  0.00 Rejection simple
28 16384  0.01 Rejection simple
29 32768  0.00 Rejection simple
30 65536  0.05 Rejection simple
31  2048  0.00        Reservoir
32  4096  0.00        Reservoir
33  8192  0.00        Reservoir
34 16384  0.02        Reservoir
35 32768  0.03        Reservoir
36 65536  0.05        Reservoir

arrange(times2, id)
       n  user               id
1   2048  0.43             1980
2   4096  0.93             1980
3   8192  2.00             1980
4  16384  4.36             1980
5  32768  9.08             1980
6  65536 19.34             1980
7   2048  0.01             Base
8   4096  0.04             Base
9   8192  0.18             Base
10 16384  0.75             Base
11 32768  3.11             Base
12 65536 12.04             Base
13  2048  0.01             Rcpp
14  4096  0.02             Rcpp
15  8192  0.03             Rcpp
16 16384  0.08             Rcpp
17 32768  0.15             Rcpp
18 65536  0.33             Rcpp
19  2048  0.00        Rejection
20  4096  0.00        Rejection
21  8192  0.02        Rejection
22 16384  0.02        Rejection
23 32768  0.05        Rejection
24 65536  0.08        Rejection
25  2048  1.43 Rejection simple
26  4096  2.87 Rejection simple
27  8192  6.17 Rejection simple
28 16384 13.68 Rejection simple
29 32768 29.74 Rejection simple
30 65536 73.32 Rejection simple
31  2048  0.00        Reservoir
32  4096  0.00        Reservoir
33  8192  0.02        Reservoir
34 16384  0.02        Reservoir
35 32768  0.02        Reservoir
36 65536  0.04        Reservoir

1980どちらの場合よりも遅いため、明らかに関数を拒否できますBaseRejection simple2 番目のケースで 0.999 の確率が 1 つしかない場合にも問題が発生します。

だから残ってRejectionいるRcpp、、Reservoir。最後のステップは、値自体が正しいかどうかを確認することです。それらについて確実にするために、ベンチマークとして使用します(また、置換なしのサンプリングのためsampleに一致する必要がない確率に関する混乱を排除するため)。p

p <- c(995/1000, rep(1/1000, 5))
n <- 100000

system.time(print(table(replicate(sample(1:6, 3, repl = FALSE, prob = p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.39992 0.39886 0.40088 0.39711 0.40323  # Benchmark
   user  system elapsed 
   1.90    0.00    2.03 

system.time(print(table(replicate(sample.int.rej(2*3, 3, p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.40007 0.40099 0.39962 0.40153 0.39779 
   user  system elapsed 
  76.02    0.03   77.49 # Slow

system.time(print(table(replicate(weighted_Random_Sample(1:6, p, 3), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.49535 0.41484 0.36432 0.36338 0.36211  # Incorrect
   user  system elapsed 
   3.64    0.01    3.67 

system.time(print(table(replicate(fun(1:6, 3, p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.39876 0.40031 0.40219 0.40039 0.39835 
   user  system elapsed 
   4.41    0.02    4.47 

ここでいくつかのことに注意してください。何らかの理由weighted_Random_Sampleで間違った値が返されます (まったく調べていませんが、一様分布を仮定すると正しく動作します)。sample.int.rej繰り返しサンプリングでは非常に遅いです。

Rcpp結論として、繰り返しサンプリングの場合は最適な選択であると思われますが、sample.int.rejそれ以外の場合は少し速く、使いやすいです。

于 2013-03-04T20:45:20.987 に答える
20

いくつかのコメントを掘り下げてみることにしましたが、Efraimidis と Spirakisの論文が魅力的であることがわかりました (参考文献を見つけてくれた @Hemmo に感謝します)。この論文の一般的な考え方は次のとおりです。一様乱数を生成し、それを各アイテムの重みで 1 乗してキーを作成します。次に、最も高いキー値をサンプルとして取得します。これは見事にうまくいきます!

weighted_Random_Sample <- function(
    .data,
    .weights,
    .n
    ){

    key <- runif(length(.data)) ^ (1 / .weights)
    return(.data[order(key, decreasing=TRUE)][1:.n])
}

'.n' を '.data' の長さ (常に '.weights' の長さである必要があります) に設定すると、これは実際には加重リザーバー順列になりますが、この方法はサンプリングと順列の両方でうまく機能します。

更新:上記の関数は、重みがゼロより大きいことを期待していることに言及する必要があります。そうしkey <- runif(length(.data)) ^ (1 / .weights)ないと、正しく注文できません。


念のため、OP のテスト シナリオを使用して両方の機能を比較しました。

set.seed(1)

times_WRS <- ldply(
1:7,
function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    n_Set <- 1:(2 * n)
    data.frame(
      n=n,
      user=system.time(weighted_Random_Sample(n_Set, p, n), gcFirst=T)['user.self'])
  },
  .progress='text'
)

sample.int.test <- function(n, p) {
sample.int(2 * n, n, replace=F, prob=p); NULL }

times_sample.int <- ldply(
  1:7,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    data.frame(
      n=n,
      user=system.time(sample.int.test(n, p), gcFirst=T)['user.self'])
  },
  .progress='text'
)

times_WRS$group <- "WRS"
times_sample.int$group <- "sample.int"
library(ggplot2)

ggplot(rbind(times_WRS, times_sample.int) , aes(x=n, y=user/n, col=group)) + geom_point() + scale_x_log10() +  ylab('Time per unit (s)')

そして、ここに時間があります:

times_WRS
#        n user
# 1   2048 0.00
# 2   4096 0.01
# 3   8192 0.00
# 4  16384 0.01
# 5  32768 0.03
# 6  65536 0.06
# 7 131072 0.16

times_sample.int
#        n  user
# 1   2048  0.02
# 2   4096  0.05
# 3   8192  0.14
# 4  16384  0.58
# 5  32768  2.33
# 6  65536  9.23
# 7 131072 37.79

性能比較

于 2013-03-04T15:30:55.250 に答える
3

replacement を使用した拒否サンプリングに基づく、より高速なアプローチの独自の実装を投入しましょう。アイデアは次のとおりです。

  • 要求されたサイズより「やや」大きい置換でサンプルを生成します

  • 重複した値を捨てる

  • 十分な値が描画されていない場合は、調整されたnsizeおよびprobパラメーターを使用して同じプロシージャを再帰的に呼び出します

  • 返されたインデックスを元のインデックスに再マップします

どのくらいの大きさのサンプルを描く必要がありますか? 一様分布を仮定すると、結果はN 個の合計値のうち x 個の一意の値を確認するための試行回数の期待値になります。2 つの調和数(H_n と H_{n - size}) の差です。最初のいくつかの高調波数が表にされます。それ以外の場合は、自然対数を使用した近似が使用されます。(これは大まかな数字であり、ここであまり正確である必要はありません。) ここで、不均一な分布の場合、描画されるアイテムの予想数はより多くなる可能性があるため、あまり多くのサンプルを描画することはありません。さらに、抽出されるサンプルの数は母集団のサイズの 2 倍に制限されています。O(n ln n) 個のアイテムをサンプリングするよりも、いくつかの再帰呼び出しを行う方が速いと思います。

このコードは、R パッケージwrswoRsample.int.rejルーチンで利用できますsample_int_rej.R。インストール:

library(devtools)
install_github('wrswoR', 'muelleki')

「十分に高速」に動作するようですが、正式な実行時テストはまだ実行されていません。また、パッケージは Ubuntu でのみテストされています。フィードバックをお待ちしております。

于 2013-03-04T15:02:43.920 に答える