16

DataFrameのようなものがあります:

userID, category, frequency
1,cat1,1
1,cat2,3
1,cat9,5
2,cat4,6
2,cat9,2
2,cat10,1
3,cat1,5
3,cat7,16
3,cat8,2

userID個別のカテゴリの数は 10 です。それぞれの特徴ベクトルを作成し、不足しているカテゴリをゼロで埋めたいと思います。

したがって、出力は次のようになります。

userID,feature
1,[1,3,0,0,0,0,0,0,5,0]
2,[0,0,0,6,0,0,0,0,2,1]
3,[5,0,0,0,0,0,16,2,0,0]

これは単なる例です。実際には、約 200,000 の一意のユーザー ID と 300 の一意のカテゴリがあります。

機能を作成する最も効率的な方法は何DataFrameですか?

4

3 に答える 3

14

もう少しDataFrame中心的な解決策:

import org.apache.spark.ml.feature.VectorAssembler

val df = sc.parallelize(Seq(
  (1, "cat1", 1), (1, "cat2", 3), (1, "cat9", 5), (2, "cat4", 6),
  (2, "cat9", 2), (2, "cat10", 1), (3, "cat1", 5), (3, "cat7", 16),
  (3, "cat8", 2))).toDF("userID", "category", "frequency")

// Create a sorted array of categories
val categories = df
  .select($"category")
  .distinct.map(_.getString(0))
  .collect
  .sorted

// Prepare vector assemble
val assembler =  new VectorAssembler()
  .setInputCols(categories)
  .setOutputCol("features")

// Aggregation expressions
val exprs = categories.map(
   c => sum(when($"category" === c, $"frequency").otherwise(lit(0))).alias(c))

val transformed = assembler.transform(
    df.groupBy($"userID").agg(exprs.head, exprs.tail: _*))
  .select($"userID", $"features")

および UDAF の代替:

import org.apache.spark.sql.expressions.{
  MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.types.{
  StructType, ArrayType, DoubleType, IntegerType}
import scala.collection.mutable.WrappedArray

class VectorAggregate (n: Int) extends UserDefinedAggregateFunction {
    def inputSchema = new StructType()
      .add("i", IntegerType)
      .add("v", DoubleType)
    def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
    def dataType = new VectorUDT()
    def deterministic = true 

    def initialize(buffer: MutableAggregationBuffer) = {
      buffer.update(0, Array.fill(n)(0.0))
    }

    def update(buffer: MutableAggregationBuffer, input: Row) = {
      if (!input.isNullAt(0)) {
        val i = input.getInt(0)
        val v = input.getDouble(1)
        val buff = buffer.getAs[WrappedArray[Double]](0) 
        buff(i) += v
        buffer.update(0, buff)
      }
    }

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      val buff1 = buffer1.getAs[WrappedArray[Double]](0) 
      val buff2 = buffer2.getAs[WrappedArray[Double]](0) 
      for ((x, i) <- buff2.zipWithIndex) {
        buff1(i) += x
      }
      buffer1.update(0, buff1)
    }

    def evaluate(buffer: Row) =  Vectors.dense(
      buffer.getAs[Seq[Double]](0).toArray)
}

使用例:

import org.apache.spark.ml.feature.StringIndexer

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("category_idx")
  .fit(df)

val indexed = indexer.transform(df)
  .withColumn("category_idx", $"category_idx".cast("integer"))
  .withColumn("frequency", $"frequency".cast("double"))

val n = indexer.labels.size + 1

val transformed = indexed
  .groupBy($"userID")
  .agg(new VectorAggregate(n)($"category_idx", $"frequency").as("vec"))

transformed.show

// +------+--------------------+
// |userID|                 vec|
// +------+--------------------+
// |     1|[1.0,5.0,0.0,3.0,...|
// |     2|[0.0,2.0,0.0,0.0,...|
// |     3|[5.0,0.0,16.0,0.0...|
// +------+--------------------+

この場合、値の順序は次のように定義されindexer.labelsます。

indexer.labels
// Array[String] = Array(cat1, cat9, cat7, cat2, cat8, cat4, cat10)

実際には、Odomontoisによる解決策を好むので、これらは主に参照用に提供されています。

于 2015-11-23T14:41:31.647 に答える
13

仮定する:

val cs: SparkContext
val sc: SQLContext
val cats: DataFrame

どこuserIdfrequencybigintに対応する列ですscala.Long

中間マッピングを作成していますRDD:

val catMaps = cats.rdd
  .groupBy(_.getAs[Long]("userId"))
  .map { case (id, rows) => id -> rows
    .map { row => row.getAs[String]("category") -> row.getAs[Long]("frequency") }
    .toMap
  }

次に、提示されたすべてのカテゴリを辞書順に収集します

val catNames = cs.broadcast(catMaps.map(_._2.keySet).reduce(_ union _).toArray.sorted)

または手動で作成する

val catNames = cs.broadcast(1 to 10 map {n => s"cat$n"} toArray)

最後に、マップを、存在しない値の値が 0 の配列に変換します。

import sc.implicits._
val catArrays = catMaps
      .map { case (id, catMap) => id -> catNames.value.map(catMap.getOrElse(_, 0L)) }
      .toDF("userId", "feature")

catArrays.show()のようなものを印刷します

+------+--------------------+
|userId|             feature|
+------+--------------------+
|     2|[0, 1, 0, 6, 0, 0...|
|     1|[1, 0, 3, 0, 0, 0...|
|     3|[5, 0, 0, 0, 16, ...|
+------+--------------------+

私はこの領域のスパークにほとんど慣れていないため、これはデータフレームの最もエレガントなソリューションではない可能性があります。

catNames手動で作成して、欠落しているcat3cat5、 ...にゼロを追加できることに注意してください。

catMapsまた、それ以外の場合は RDD が 2 回操作されることに注意してください.persist()

于 2015-11-23T10:08:08.520 に答える