内部計算に Java オブジェクトを使用する pyspark python で使用する UDF を作成する必要があります。
それが単純なpythonであれば、私は次のようにします:
def f(x):
return 7
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType())
次を使用して呼び出します。
df = sqlContext.range(0,5)
df2 = df.withColumn("a",fudf(df.id)).show()
ただし、必要な関数の実装は Java であり、Python ではありません。Pythonから同様の方法で呼び出すことができるように、何らかの形でラップする必要があります。
私の最初の試みは、Javaオブジェクトを実装してから、それをpysparkのpythonでラップし、それをUDFに変換することでした。シリアル化エラーで失敗しました。
Java コード:
package com.test1.test2;
public class TestClass1 {
Integer internalVal;
public TestClass1(Integer val1) {
internalVal = val1;
}
public Integer do_something(Integer val) {
return internalVal;
}
}
pyspark コード:
from py4j.java_gateway import java_import
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
audf = udf(a,IntegerType())
エラー:
---------------------------------------------------------------------------
Py4JError Traceback (most recent call last)
<ipython-input-2-9756772ab14f> in <module>()
4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
----> 6 audf = udf(a,IntegerType())
/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType)
1595 [Row(slen=5), Row(slen=3)]
1596 """
-> 1597 return UserDefinedFunction(f, returnType)
1598
1599 blacklist = ['map', 'since', 'ignore_unicode_prefix']
/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name)
1556 self.returnType = returnType
1557 self._broadcast = None
-> 1558 self._judf = self._create_judf(name)
1559
1560 def _create_judf(self, name):
/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name)
1565 command = (func, None, ser, ser)
1566 sc = SparkContext.getOrCreate()
-> 1567 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
1568 ctx = SQLContext.getOrCreate(sc)
1569 jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj)
2297 # the serialized command will be compressed by broadcast
2298 ser = CloudPickleSerializer()
-> 2299 pickled_command = ser.dumps(command)
2300 if len(pickled_command) > (1 << 20): # 1M
2301 # The broadcast will have same life cycle as created PythonRDD
/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj)
426
427 def dumps(self, obj):
--> 428 return cloudpickle.dumps(obj, 2)
429
430
/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol)
644
645 cp = CloudPickler(file,protocol)
--> 646 cp.dump(obj)
647
648 return file.getvalue()
/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj)
105 self.inject_addons()
106 try:
--> 107 return Pickler.dump(self, obj)
108 except RuntimeError as e:
109 if 'recursion' in e.args[0]:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj)
222 if self.proto >= 2:
223 self.write(PROTO + chr(self.proto))
--> 224 self.save(obj)
225 self.write(STOP)
226
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
566 write(MARK)
567 for element in obj:
--> 568 save(element)
569
570 if id(obj) in memo:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name)
191 if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None:
192 #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule)
--> 193 self.save_function_tuple(obj)
194 return
195 else:
/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func)
234 # create a skeleton function object and memoize it
235 save(_make_skel_func)
--> 236 save((code, closure, base_globals))
237 write(pickle.REDUCE)
238 self.memoize(func)
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
552 if n <= 3 and proto >= 2:
553 for element in obj:
--> 554 save(element)
555 # Subtle. Same as in the big comment below.
556 if id(obj) in memo:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj)
604
605 self.memoize(obj)
--> 606 self._batch_appends(iter(obj))
607
608 dispatch[ListType] = save_list
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items)
637 write(MARK)
638 for x in tmp:
--> 639 save(x)
640 write(APPENDS)
641 elif n:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
304 reduce = getattr(obj, "__reduce_ex__", None)
305 if reduce:
--> 306 rv = reduce(self.proto)
307 else:
308 reduce = getattr(obj, "__reduce__", None)
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
811 answer = self.gateway_client.send_command(command)
812 return_value = get_return_value(
--> 813 answer, self.gateway_client, self.target_id, self.name)
814
815 for temp_arg in temp_args:
/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
43 def deco(*a, **kw):
44 try:
---> 45 return f(*a, **kw)
46 except py4j.protocol.Py4JJavaError as e:
47 s = e.java_exception.toString()
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
310 raise Py4JError(
311 "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
--> 312 format(target_id, ".", name, value))
313 else:
314 raise Py4JError(
Py4JError: An error occurred while calling o18.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344)
at py4j.Gateway.invoke(Gateway.java:252)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:209)
at java.lang.Thread.run(Thread.java:745)
EDIT : Java クラスをシリアライズ可能にしようとしましたが、役に立ちませんでした。
私の2番目の試みは、最初にJavaでUDFを定義することでしたが、正しくラップする方法がわからないため失敗しました:
Java コード: パッケージ com.test1.test2;
import org.apache.spark.sql.api.java.UDF1;
public class TestClassUdf implements UDF1<Integer, Integer> {
Integer retval;
public TestClassUdf(Integer val) {
retval = val;
}
@Override
public Integer call(Integer arg0) throws Exception {
return retval;
}
}
しかし、どのように使用しますか?私は試した:
from py4j.java_gateway import java_import
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf")
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
dfint = sqlContext.range(0,15)
df = dfint.withColumn("a",a(dfint.id))
しかし、私は得る:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-5-514811090b5f> in <module>()
3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
4 dfint = sqlContext.range(0,15)
----> 5 df = dfint.withColumn("a",a(dfint.id))
TypeError: 'JavaObject' object is not callable
そして、次の代わりに a.call を使用しようとしました:
df = dfint.withColumn("a",a.call(dfint.id))
しかし得た: ----------------------------------------------- ---------------------------- TypeError トレースバック (最新の呼び出しが最後) in () 3 a = sc._gateway.jvm.com. test1.test2.TestClassUdf(7) 4 dfint = sqlContext.range(0,15) ----> 5 df = dfint.withColumn("a",a.call(dfint.id))
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
796 def __call__(self, *args):
797 if self.converters is not None and len(self.converters) > 0:
--> 798 (new_args, temp_args) = self._get_args(args)
799 else:
800 new_args = args
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args)
783 for converter in self.gateway_client.converters:
784 if converter.can_convert(arg):
--> 785 temp_arg = converter.convert(arg, self.gateway_client)
786 temp_args.append(temp_arg)
787 new_args.append(temp_arg)
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client)
510 HashMap = JavaClass("java.util.HashMap", gateway_client)
511 java_map = HashMap()
--> 512 for key in object.keys():
513 java_map[key] = object[key]
514 return java_map
TypeError: 'Column' object is not callable
どんな助けでも感謝します。