0

pyspark を使用して ALS モデル オブジェクトを作成した後。

サンプルコード例:

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.sql import Row

lines = spark.read.text("data/mllib/als/sample_movielens_ratings.txt").rdd
parts = lines.map(lambda row: row.value.split("::"))
ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]),
                                     rating=float(p[2]), timestamp=long(p[3])))
ratings = spark.createDataFrame(ratingsRDD)
(rating_data, test) = ratings.randomSplit([0.8, 0.2])

# Build the recommendation model using ALS on the training data
# Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating",
          coldStartStrategy="drop")

    als_model = als_spec.fit(rating_data)

ここでは、ALS モデルを作成し、cloudepickel を作成しています。フィットを使用している場合、変換も行う必要がありますか?

以下のコードを使用して、私の als_model オブジェクトをピッケルしようとしています:

with open(os.path.join(model_path, 'als-als-model.pkl'), 'w') as out:
                cloudpickle.dump(als_model, out)

以下のようなエラーが発生します。

  File "/usr/local/spark/python/lib/py4j-0.10.6-src.zip/py4j/protocol.py", line 324, in get_return_value
    format(target_id, ".", name, value))
Py4JError: An error occurred while calling o224.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
#011at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
#011at 

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-124-8c94f4ee0de9> in <module>()
      1 
----> 2 tree.fit(data_location)

~/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/sagemaker/estimator.py in fit(self, inputs, wait, logs, job_name)
    152         self.latest_training_job = _TrainingJob.start_new(self, inputs)
    153         if wait:
--> 154             self.latest_training_job.wait(logs=logs)
    155         else:
    156             raise NotImplemented('Asynchronous fit not available')
4

0 に答える 0