関数が呼び出されるたびに、指定された一連の引数値の結果がまだメモ化されていない場合、結果をメモリ内テーブルに入れたいと思います。1 つの列は結果を格納するためのもので、他の列は引数値を格納するためのものです。
これをどのように実装するのが最善ですか?引数には、いくつかの列挙型を含むさまざまなタイプがあります。
C# では、通常 DataTable を使用します。Scalaに同等のものはありますか?
関数が呼び出されるたびに、指定された一連の引数値の結果がまだメモ化されていない場合、結果をメモリ内テーブルに入れたいと思います。1 つの列は結果を格納するためのもので、他の列は引数値を格納するためのものです。
これをどのように実装するのが最善ですか?引数には、いくつかの列挙型を含むさまざまなタイプがあります。
C# では、通常 DataTable を使用します。Scalaに同等のものはありますか?
を使用するmutable.Map[TupleN[A1, A2, ..., AN], R]
か、メモリが問題になる場合は WeakHashMap [1] を使用できます。以下の定義 ( micid のブログのメモ化コードに基づいて作成) を使用すると、複数の引数を持つ関数を簡単にメモ化できます。例えば:
import Memoize._
def reallySlowFn(i: Int, s: String): Int = {
Thread.sleep(3000)
i + s.length
}
val memoizedSlowFn = memoize(reallySlowFn _)
memoizedSlowFn(1, "abc") // returns 4 after about 3 seconds
memoizedSlowFn(1, "abc") // returns 4 almost instantly
定義:
/**
* A memoized unary function.
*
* @param f A unary function to memoize
* @param [T] the argument type
* @param [R] the return type
*/
class Memoize1[-T, +R](f: T => R) extends (T => R) {
import scala.collection.mutable
// map that stores (argument, result) pairs
private[this] val vals = mutable.Map.empty[T, R]
// Given an argument x,
// If vals contains x return vals(x).
// Otherwise, update vals so that vals(x) == f(x) and return f(x).
def apply(x: T): R = vals getOrElseUpdate (x, f(x))
}
object Memoize {
/**
* Memoize a unary (single-argument) function.
*
* @param f the unary function to memoize
*/
def memoize[T, R](f: T => R): (T => R) = new Memoize1(f)
/**
* Memoize a binary (two-argument) function.
*
* @param f the binary function to memoize
*
* This works by turning a function that takes two arguments of type
* T1 and T2 into a function that takes a single argument of type
* (T1, T2), memoizing that "tupled" function, then "untupling" the
* memoized function.
*/
def memoize[T1, T2, R](f: (T1, T2) => R): ((T1, T2) => R) =
Function.untupled(memoize(f.tupled))
/**
* Memoize a ternary (three-argument) function.
*
* @param f the ternary function to memoize
*/
def memoize[T1, T2, T3, R](f: (T1, T2, T3) => R): ((T1, T2, T3) => R) =
Function.untupled(memoize(f.tupled))
// ... more memoize methods for higher-arity functions ...
/**
* Fixed-point combinator (for memoizing recursive functions).
*/
def Y[T, R](f: (T => R) => T => R): (T => R) = {
lazy val yf: (T => R) = memoize(f(yf)(_))
yf
}
}
固定小数点コンビネータ ( Memoize.Y
) を使用すると、再帰関数をメモ化できます。
val fib: BigInt => BigInt = {
def fibRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
if (n == 0) 1
else if (n == 1) 1
else (f(n-1) + f(n-2))
}
Memoize.Y(fibRec)
}
[1] WeakHashMap はキャッシュとしてうまく機能しません。http://www.codeinstructions.com/2008/09/weakhashmap-is-not-cache-understanding.htmlおよびこの関連する質問を参照してください。
可変 Map を使用する anovstrup によって提案されたバージョンは、基本的に C# と同じであるため、使いやすいです。
ただし、必要に応じて、より機能的なスタイルも使用できます。一種のアキュムレータとして機能する不変のマップを使用します。タプル (例では Int の代わりに) をキーとして使用すると、可変の場合とまったく同じように機能します。
def fib(n:Int) = fibM(n, Map(0->1, 1->1))._1
def fibM(n:Int, m:Map[Int,Int]):(Int,Map[Int,Int]) = m.get(n) match {
case Some(f) => (f, m)
case None => val (f_1,m1) = fibM(n-1,m)
val (f_2,m2) = fibM(n-2,m1)
val f = f_1+f_2
(f, m2 + (n -> f))
}
もちろん、これはもう少し複雑ですが、知っておくと便利なテクニックです (上記のコードは速度ではなく、わかりやすくすることを目的としていることに注意してください)。
このテーマの初心者である私は、与えられた例のどれも完全には理解できませんでした (とにかく感謝したいと思います)。敬意を表して、誰かがここに来て同じレベルで同じ問題を抱えている場合に備えて、私自身の解決策を提示します。私のコードは、非常に基本的な Scala の知識を持っている人なら誰でも明確に理解できると思います。
def MyFunction(dt : DateTime, param : Int) : Double
{
val argsTuple = (dt, param)
if(Memo.contains(argsTuple)) Memo(argsTuple) else Memoize(dt, param, MyRawFunction(dt, param))
}
def MyRawFunction(dt : DateTime, param : Int) : Double
{
1.0 // A heavy calculation/querying here
}
def Memoize(dt : DateTime, param : Int, result : Double) : Double
{
Memo += (dt, param) -> result
result
}
val Memo = new scala.collection.mutable.HashMap[(DateTime, Int), Double]
完璧に動作します。私が何かを逃した場合、私は批評をいただければ幸いです。
メモ化に変更可能なマップを使用する場合、書き込みがまだ完了していないときに get を実行するなど、典型的な並行性の問題が発生することに注意してください。ただし、メモ化のスレッドセーフな試みは、そうすることを示唆しています。
次のスレッドセーフなコードは、メモ化fibonacci
された関数を作成し、それを呼び出すいくつかのスレッド ('a' から 'd' までの名前) を開始します。f(2) set
コードを (REPL で) 数回試してみてください。複数回印刷されることが簡単にわかります。これは、スレッド A が計算を開始したことを意味しますがf(2)
、スレッド B はそれをまったく認識せず、独自の計算のコピーを開始します。このような無知は、キャッシュの構築段階で非常に蔓延していelse
ます.
object ScalaMemoizationMultithread {
// do not use case class as there is a mutable member here
class Memo[-T, +R](f: T => R) extends (T => R) {
// don't even know what would happen if immutable.Map used in a multithreading context
private[this] val cache = new java.util.concurrent.ConcurrentHashMap[T, R]
def apply(x: T): R =
// no synchronized needed as there is no removal during memoization
if (cache containsKey x) {
Console.println(Thread.currentThread().getName() + ": f(" + x + ") get")
cache.get(x)
} else {
val res = f(x)
Console.println(Thread.currentThread().getName() + ": f(" + x + ") set")
cache.putIfAbsent(x, res) // atomic
res
}
}
object Memo {
def apply[T, R](f: T => R): T => R = new Memo(f)
def Y[T, R](F: (T => R) => T => R): T => R = {
lazy val yf: T => R = Memo(F(yf)(_))
yf
}
}
val fibonacci: Int => BigInt = {
def fiboF(f: Int => BigInt)(n: Int): BigInt = {
if (n <= 0) 1
else if (n == 1) 1
else f(n - 1) + f(n - 2)
}
Memo.Y(fiboF)
}
def main(args: Array[String]) = {
('a' to 'd').foreach(ch =>
new Thread(new Runnable() {
def run() {
import scala.util.Random
val rand = new Random
(1 to 2).foreach(_ => {
Thread.currentThread().setName("Thread " + ch)
fibonacci(5)
})
}
}).start)
}
}
Landeiの答えに加えて、ScalaでDPを実行するボトムアップ(非メモ化)の方法も提案したいと思います.コアアイデアはfoldLeft
(s)を使用することです.
フィボナッチ数の計算例
def fibo(n: Int) = (1 to n).foldLeft((0, 1)) {
(acc, i) => (acc._2, acc._1 + acc._2)
}._1
最長増加サブシーケンスの例
def longestIncrSubseq[T](xs: List[T])(implicit ord: Ordering[T]) = {
xs.foldLeft(List[(Int, List[T])]()) {
(memo, x) =>
if (memo.isEmpty) List((1, List(x)))
else {
val resultIfEndsAtCurr = (memo, xs).zipped map {
(tp, y) =>
val len = tp._1
val seq = tp._2
if (ord.lteq(y, x)) { // current is greater than the previous end
(len + 1, x :: seq) // reversely recorded to avoid O(n)
} else {
(1, List(x)) // start over
}
}
memo :+ resultIfEndsAtCurr.maxBy(_._1)
}
}.maxBy(_._1)._2.reverse
}