不変のアルゴリズム
Taylor Leeseによって示される最初のアルゴリズム は二次ですが、線形平均を持っています。ただし、それはピボットの選択によって異なります。そこで、ここでは、プラグ可能なピボット選択と、ランダムピボットと中央値ピボットの中央値(線形時間を保証する)の両方を備えたバージョンを提供します。
import scala.annotation.tailrec
@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
val a = choosePivot(arr)
val (s, b) = arr partition (a >)
if (s.size == k) a
// The following test is used to avoid infinite repetition
else if (s.isEmpty) {
val (s, b) = arr partition (a ==)
if (s.size > k) a
else findKMedian(b, k - s.size)
} else if (s.size < k) findKMedian(b, k - s.size)
else findKMedian(s, k)
}
def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)
ランダムピボット(二次、線形平均)、不変
これはランダムなピボット選択です。ランダムな要因を使用したアルゴリズムの分析は、主に確率と統計を扱うため、通常よりも注意が必要です。
def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))
中央値の中央値(線形)、不変
上記のアルゴリズムで使用した場合に線形時間を保証する中央値の中央値法。まず、中央値アルゴリズムの中央値の基礎となる、最大5つの数値の中央値を計算するアルゴリズム。これは、この回答でRexKerrによって提供されました。アルゴリズムは速度に大きく依存します。
def medianUpTo5(five: Array[Double]): Double = {
def order2(a: Array[Double], i: Int, j: Int) = {
if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
}
def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
if (a(i)<a(k)) { order2(a,j,k); a(j) }
else { order2(a,i,l); a(i) }
}
if (five.length < 2) return five(0)
order2(five,0,1)
if (five.length < 4) return (
if (five.length==2 || five(2) < five(0)) five(0)
else if (five(2) > five(1)) five(1)
else five(2)
)
order2(five,2,3)
if (five.length < 5) pairs(five,0,1,2,3)
else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
else { order2(five,3,4); pairs(five,0,1,3,4) }
}
そして、中央値アルゴリズム自体の中央値。基本的に、選択されたピボットがリストの他の30%よりも大きく、小さくなることが保証されます。これは、前のアルゴリズムの線形性を保証するのに十分です。詳細については、別の回答で提供されているウィキペディアのリンクを参照してください。
def medianOfMedians(arr: Array[Double]): Double = {
val medians = arr grouped 5 map medianUpTo5 toArray;
if (medians.size <= 5) medianUpTo5 (medians)
else medianOfMedians(medians)
}
インプレースアルゴリズム
それで、これがアルゴリズムのインプレースバージョンです。アルゴリズムへの変更が最小限になるように、バッキング配列を使用してパーティションをインプレースで実装するクラスを使用しています。
case class ArrayView(arr: Array[Double], from: Int, until: Int) {
def apply(n: Int) =
if (from + n < until) arr(from + n)
else throw new ArrayIndexOutOfBoundsException(n)
def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
var upper = until - 1
var lower = from
while (lower < upper) {
while (lower < until && p(arr(lower))) lower += 1
while (upper >= from && !p(arr(upper))) upper -= 1
if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
}
(copy(until = lower), copy(from = lower))
}
def size = until - from
def isEmpty = size <= 0
override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}
@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
val a = choosePivot(arr)
val (s, b) = arr partitionInPlace (a >)
if (s.size == k) a
// The following test is used to avoid infinite repetition
else if (s.isEmpty) {
val (s, b) = arr partitionInPlace (a ==)
if (s.size > k) a
else findKMedianInPlace(b, k - s.size)
} else if (s.size < k) findKMedianInPlace(b, k - s.size)
else findKMedianInPlace(s, k)
}
def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)
ランダムピボット、インプレース
ArrayView
中央値の中央値は、私が定義したクラスによって現在提供されているものよりも多くのサポートを必要とするため、インプレースアルゴリズムのラドムピボットのみを実装しています。
def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))
ヒストグラムアルゴリズム(O(log(n))メモリ)、不変
だから、ストリームについて。O(n)
文字列の長さがわからない限り、一度しかトラバースできないストリームに対してメモリ以外のことを行うことは不可能です(その場合、私の本ではストリームではなくなります)。
バケットの使用も少し問題がありますが、バケットを複数回トラバースできる場合は、バケットのサイズ、最大値、最小値を把握し、そこから作業できます。例えば:
def findMedianHistogram(s: Traversable[Double]) = {
def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
// The buckets
def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
val buckets = new Array[Int](numberOfBuckets)
// The upper limit of each bucket
val max = s.max
val min = s.min
val increment = (max - min) / numberOfBuckets
val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)
// Return the bucket a number is supposed to be in
def bucketIndex(d: Double) = indices indexWhere (d <=)
// Compute how many in each bucket
s foreach { d => buckets(bucketIndex(d)) += 1 }
// Now make the buckets cumulative
val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)
// The bucket where our target is at
val medianBucket = partialTotals indexWhere (medianIndex <)
// Keep track of how many numbers there are that are less
// than the median bucket
val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)
// Test whether a number is in the median bucket
def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket
// Get a view of the target bucket
val view = s.view filter insideMedianBucket
// If all numbers in the bucket are equal, return that
if (view forall (view.head ==)) view.head
// Otherwise, recurse on that bucket
else medianHistogram(view, newDiscarded, medianIndex)
}
medianHistogram(s, 0, (s.size - 1) / 2)
}
テストとベンチマーク
アルゴリズムをテストするために、私はScalacheckを使用しており、各アルゴリズムの出力を、並べ替えを使用した簡単な実装の出力と比較しています。もちろん、これはソートバージョンが正しいことを前提としています。
上記の各アルゴリズムのベンチマークを、提供されているすべてのピボット選択に加えて、固定ピボット選択(配列の途中、切り捨て)でベンチマークしています。各アルゴリズムは、3つの異なる入力配列サイズで、それぞれに対して3回テストされます。
テストコードは次のとおりです。
import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._
def test(algorithm: Array[Double] => Double,
reference: Array[Double] => Double): String = {
def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
val resultEqualsReference = forAll { (arr: Array[Double]) =>
arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
}
Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}
import java.lang.System.currentTimeMillis
def bench[A](n: Int)(body: => A): Long = {
val start = currentTimeMillis()
1 to n foreach { _ => body }
currentTimeMillis() - start
}
import scala.util.Random.nextDouble
def benchmark(algorithm: Array[Double] => Double,
arraySizes: List[Int]): List[Iterable[Long]] =
for (size <- arraySizes)
yield for (iteration <- 1 to 3)
yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))
def testAndBenchmark: String = {
val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
"Random Pivot" -> chooseRandomPivot,
"Median of Medians" -> medianOfMedians,
"Midpoint" -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
)
val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
"Random Pivot (in-place)" -> chooseRandomPivotInPlace,
"Midpoint (in-place)" -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
)
val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
yield name -> (findMedian(_: Array[Double])(pivotSelection))
val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms
val formattingString = "%%-%ds %%s" format (algorithms map (_._1.length) max)
// Tests
val testResults = for ((name, algorithm) <- algorithms)
yield formattingString format (name, test(algorithm, sortingAlgorithm._2))
// Benchmarks
val arraySizes = List(100, 500, 1000)
def formatResults(results: List[Long]) = results map ("%8d" format _) mkString
val benchmarkResults: List[String] = for {
(name, algorithm) <- algorithms
results <- benchmark(algorithm, arraySizes).transpose
} yield formattingString format (name, formatResults(results))
val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))
"Tests" :: "*****" :: testResults :::
("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
}
結果
テスト:
Tests
*****
Sorting OK, passed 100 tests.
Histogram OK, passed 100 tests.
Random Pivot OK, passed 100 tests.
Median of Medians OK, passed 100 tests.
Midpoint OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place) OK, passed 100 tests.
ベンチマーク:
Benchmark
*********
Algorithm 100 500 1000
Sorting 1038 6230 14034
Sorting 1037 6223 13777
Sorting 1039 6220 13785
Histogram 2918 11065 21590
Histogram 2596 11046 21486
Histogram 2592 11044 21606
Random Pivot 904 4330 8622
Random Pivot 902 4323 8815
Random Pivot 896 4348 8767
Median of Medians 3591 16857 33307
Median of Medians 3530 16872 33321
Median of Medians 3517 16793 33358
Midpoint 1003 4672 9236
Midpoint 1010 4755 9157
Midpoint 1017 4663 9166
Random Pivot (in-place) 392 1746 3430
Random Pivot (in-place) 386 1747 3424
Random Pivot (in-place) 386 1751 3431
Midpoint (in-place) 378 1735 3405
Midpoint (in-place) 377 1740 3408
Midpoint (in-place) 375 1736 3408
分析
すべてのアルゴリズム(ソートバージョンを除く)には、平均線形時間計算量と互換性のある結果があります。
中央値の中央値は、最悪の場合に線形時間計算量を保証し、ランダムピボットよりもはるかに遅くなります。
固定ピボットの選択はランダムピボットよりもわずかに劣りますが、非ランダム入力ではパフォーマンスが大幅に低下する可能性があります。
インプレースバージョンは約230%〜250%高速ですが、さらにテストを行うと(図には示されていません)、この利点はアレイのサイズとともに大きくなることが示されているようです。
ヒストグラムアルゴリズムにはとても驚きました。線形の時間計算量の平均を表示し、中央値の中央値よりも33%高速です。ただし、入力はランダムです。最悪のケースは2次式です。コードのデバッグ中にいくつかの例を見ました。