10

Scala の非末尾再帰関数は末尾再帰関数よりも遅いと予想していたことを友人に説明していたので、検証することにしました。古き良き階乗関数を両方の方法で作成し、結果を比較しようとしました。コードは次のとおりです。

def main(args: Array[String]): Unit = {
  val N = 2000 // not too much or else stackoverflows
  var spent1: Long = 0
  var spent2: Long = 0
  for ( i <- 1 to 100 ) { // repeat to average the results
    val t0 = System.nanoTime
    factorial(N)
    val t1 = System.nanoTime
    tailRecFact(N)
    val t2 = System.nanoTime
    spent1 += t1 - t0
    spent2 += t2 - t1
  }
  println(spent1/1000000f) // get milliseconds
  println(spent2/1000000f)
}

@tailrec
def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n)

def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1)

結果は私を混乱させます、私はこの種の出力を得ます:

578.2985

870.22125

つまり、非末尾再帰関数は末尾再帰関数よりも 30% 高速であり、操作の数は同じです!

それらの結果を説明するものは何ですか?

4

2 に答える 2

10

実際には、最初に見た場所ではありません。その理由は、末尾再帰メソッドにあり、その乗算でより多くの作業を行っています。再帰呼び出しでパラメーター n と s の順序を入れ替えてみると、均等になります。

def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s)

さらに、このサンプルのほとんどの時間は、再帰呼び出しの時間を小さくする BigInt 操作に費やされています。これらを Ints (Java プリミティブにコンパイル) に切り替えると、末尾再帰 (goto) がメソッド呼び出しとどのように比較されるかがわかります。

object Test extends App {

  val N = 2000

  val t0 = System.nanoTime()
  for ( i <- 1 to 1000 ) {
    factorial(N)
  }
  val t1 = System.nanoTime
  for ( i <- 1 to 1000 ) {
    tailRecFact(N, 1)
  }
  val t2 = System.nanoTime

  println((t1 - t0) / 1000000f) // get milliseconds
  println((t2 - t1) / 1000000f)

  def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1)

  @tailrec
  final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n)
}

95.16733
3.987605

興味深いことに、逆コンパイルされた出力

  public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          13
      11: aload_2       
      12: areturn       
      13: aload_1       
      14: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      17: iconst_1      
      18: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      21: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      24: aload_1       
      25: aload_2       
      26: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      29: astore_2      
      30: astore_1      
      31: goto          0

  public scala.math.BigInt factorial(scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          21
      11: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      14: iconst_1      
      15: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      18: goto          40
      21: aload_1       
      22: aload_0       
      23: aload_1       
      24: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      27: iconst_1      
      28: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      31: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      34: invokevirtual #47                 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt;
      37: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      40: areturn   
于 2013-10-09T09:40:52.197 に答える
9

@monkjack によって示された問題 (つまり、small * big の乗算は big * small よりも高速であり、違いの大きな部分を占めています) に加えて、アルゴリズムはケースごとに異なるため、実際には比較できません。

末尾再帰バージョンでは、大きなものから小さなものへと乗算しています:

n * n-1 * n-2 * ... * 2 * 1

非末尾再帰バージョンでは、小さいものから大きいものへと乗算します:

n * (n-1 * (n-2 * (... * (2 * 1))))

末尾再帰バージョンを変更して、小さいものから大きいものへ乗算する場合:

def tailRecFact2(n: BigInt) = {
  def loop(x: BigInt, out: BigInt): BigInt =
    if (x > n) out else loop(x + 1, x * out)
  loop(1, 1)
}

この場合、末尾再帰は通常の再帰よりも約 20% 速くなります。モンクジャックの修正を行った場合のように 10% 遅くなるのではありません。これは、大きな BigInt を掛け合わせるよりも、小さな BigInt を掛け合わせる方が速いためです。

于 2013-10-09T12:34:55.477 に答える