16

Spark の StringIndexer は非常に便利ですが、生成されたインデックス値と元の文字列の間の対応を取得する必要があるのはよくあることであり、これを実現するための組み込みの方法が必要なようです。Spark のドキュメントにある次の簡単な例を使用して説明します。

from pyspark.ml.feature import StringIndexer

df = sqlContext.createDataFrame(
    [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
    ["id", "category"])
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
indexed_df = indexer.fit(df).transform(df)

この単純化されたケースでは、次のことがわかります。

+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
|  0|       a|          0.0|
|  1|       b|          2.0|
|  2|       c|          1.0|
|  3|       a|          0.0|
|  4|       a|          0.0|
|  5|       c|          1.0|
+---+--------+-------------+

すべてうまくいきますが、多くのユースケースで、元の文字列とインデックス ラベルの間のマッピングを知りたいと思っています。私が考えることができる最も簡単な方法は、次のようなものです。

   In [8]: indexed.select('category','categoryIndex').distinct().show()
+--------+-------------+
|category|categoryIndex|
+--------+-------------+
|       b|          2.0|
|       c|          1.0|
|       a|          0.0|
+--------+-------------+

必要に応じて、結果を辞書などに保存できます。

In [12]: mapping = {row.categoryIndex:row.category for row in
           indexed.select('category','categoryIndex').distinct().collect()}

In [13]: mapping
Out[13]: {0.0: u'a', 1.0: u'c', 2.0: u'b'}

私の質問は次のとおりです。これは非常に一般的なタスクであり、文字列インデクサーが何らかの方法でこのマッピングを保存していると推測していますが (もちろん間違っている可能性もあります)、上記のタスクをより簡単に達成する方法はありますか?

私の解決策は多かれ少なかれ簡単ですが、大規模なデータ構造の場合、(おそらく) 回避できる余分な計算が必要になります。アイデア?

4

1 に答える 1

15

ラベル マッピングは、列のメタデータから抽出できます。

meta = [
    f.metadata for f in indexed_df.schema.fields if f.name == "categoryIndex"
]
meta[0]
## {'ml_attr': {'name': 'category', 'type': 'nominal', 'vals': ['a', 'c', 'b']}}

ここでml_attr.vals、位置とラベルの間のマッピングを提供します。

dict(enumerate(meta[0]["ml_attr"]["vals"]))
## {0: 'a', 1: 'c', 2: 'b'}

スパーク 1.6+

を使用して、数値をラベルに変換できますIndexToString。これは、上記のように列のメタデータを使用します。

from pyspark.ml.feature import IndexToString

idx_to_string = IndexToString(
    inputCol="categoryIndex", outputCol="categoryValue")

idx_to_string.transform(indexed_df).drop("id").distinct().show()
## +--------+-------------+-------------+
## |category|categoryIndex|categoryValue|
## +--------+-------------+-------------+
## |       b|          2.0|            b|
## |       a|          0.0|            a|
## |       c|          1.0|            c|
## +--------+-------------+-------------+

スパーク <= 1.5

これは汚いハックですが、次のように Java インデクサーからラベルを簡単に抽出できます。

from pyspark.ml.feature import StringIndexerModel

# A simple monkey patch so we don't have to _call_java later 
def labels(self):
    return self._call_java("labels")

StringIndexerModel.labels = labels

# Fit indexer model
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex").fit(df)

# Extract mapping
mapping = dict(enumerate(indexer.labels()))
mapping
## {0: 'a', 1: 'c', 2: 'b'}
于 2015-11-24T21:10:19.070 に答える