11

内部計算に Java オブジェクトを使用する pyspark python で使用する UDF を作成する必要があります。

それが単純なpythonであれば、私は次のようにします:

def f(x):
    return 7
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType())

次を使用して呼び出します。

df = sqlContext.range(0,5)
df2 = df.withColumn("a",fudf(df.id)).show()

ただし、必要な関数の実装は Java であり、Python ではありません。Pythonから同様の方法で呼び出すことができるように、何らかの形でラップする必要があります。

私の最初の試みは、Javaオブジェクトを実装してから、それをpysparkのpythonでラップし、それをUDFに変換することでした。シリアル化エラーで失敗しました。

Java コード:

package com.test1.test2;

public class TestClass1 {
    Integer internalVal;
    public TestClass1(Integer val1) {
        internalVal = val1;
    }
    public Integer do_something(Integer val) {
        return internalVal;
    }    
}

pyspark コード:

from py4j.java_gateway import java_import
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
audf = udf(a,IntegerType())

エラー:

---------------------------------------------------------------------------
Py4JError                                 Traceback (most recent call last)
<ipython-input-2-9756772ab14f> in <module>()
      4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
      5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
----> 6 audf = udf(a,IntegerType())

/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType)
   1595     [Row(slen=5), Row(slen=3)]
   1596     """
-> 1597     return UserDefinedFunction(f, returnType)
   1598 
   1599 blacklist = ['map', 'since', 'ignore_unicode_prefix']

/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name)
   1556         self.returnType = returnType
   1557         self._broadcast = None
-> 1558         self._judf = self._create_judf(name)
   1559 
   1560     def _create_judf(self, name):

/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name)
   1565         command = (func, None, ser, ser)
   1566         sc = SparkContext.getOrCreate()
-> 1567         pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
   1568         ctx = SQLContext.getOrCreate(sc)
   1569         jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())

/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj)
   2297     # the serialized command will be compressed by broadcast
   2298     ser = CloudPickleSerializer()
-> 2299     pickled_command = ser.dumps(command)
   2300     if len(pickled_command) > (1 << 20):  # 1M
   2301         # The broadcast will have same life cycle as created PythonRDD

/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj)
    426 
    427     def dumps(self, obj):
--> 428         return cloudpickle.dumps(obj, 2)
    429 
    430 

/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol)
    644 
    645     cp = CloudPickler(file,protocol)
--> 646     cp.dump(obj)
    647 
    648     return file.getvalue()

/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj)
    105         self.inject_addons()
    106         try:
--> 107             return Pickler.dump(self, obj)
    108         except RuntimeError as e:
    109             if 'recursion' in e.args[0]:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj)
    222         if self.proto >= 2:
    223             self.write(PROTO + chr(self.proto))
--> 224         self.save(obj)
    225         self.write(STOP)
    226 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
    566         write(MARK)
    567         for element in obj:
--> 568             save(element)
    569 
    570         if id(obj) in memo:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name)
    191         if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None:
    192             #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule)
--> 193             self.save_function_tuple(obj)
    194             return
    195         else:

/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func)
    234         # create a skeleton function object and memoize it
    235         save(_make_skel_func)
--> 236         save((code, closure, base_globals))
    237         write(pickle.REDUCE)
    238         self.memoize(func)

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
    552         if n <= 3 and proto >= 2:
    553             for element in obj:
--> 554                 save(element)
    555             # Subtle.  Same as in the big comment below.
    556             if id(obj) in memo:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj)
    604 
    605         self.memoize(obj)
--> 606         self._batch_appends(iter(obj))
    607 
    608     dispatch[ListType] = save_list

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items)
    637                 write(MARK)
    638                 for x in tmp:
--> 639                     save(x)
    640                 write(APPENDS)
    641             elif n:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    304             reduce = getattr(obj, "__reduce_ex__", None)
    305             if reduce:
--> 306                 rv = reduce(self.proto)
    307             else:
    308                 reduce = getattr(obj, "__reduce__", None)

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
    811         answer = self.gateway_client.send_command(command)
    812         return_value = get_return_value(
--> 813             answer, self.gateway_client, self.target_id, self.name)
    814 
    815         for temp_arg in temp_args:

/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     43     def deco(*a, **kw):
     44         try:
---> 45             return f(*a, **kw)
     46         except py4j.protocol.Py4JJavaError as e:
     47             s = e.java_exception.toString()

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    310                 raise Py4JError(
    311                     "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
--> 312                     format(target_id, ".", name, value))
    313         else:
    314             raise Py4JError(

Py4JError: An error occurred while calling o18.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335)
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344)
    at py4j.Gateway.invoke(Gateway.java:252)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:209)
    at java.lang.Thread.run(Thread.java:745)

EDIT : Java クラスをシリアライズ可能にしようとしましたが、役に立ちませんでした。

私の2番目の試みは、最初にJavaでUDFを定義することでしたが、正しくラップする方法がわからないため失敗しました:

Java コード: パッケージ com.test1.test2;

import org.apache.spark.sql.api.java.UDF1;

public class TestClassUdf implements UDF1<Integer, Integer> {

    Integer retval;

    public TestClassUdf(Integer val) {
        retval = val;
    }

    @Override
    public Integer call(Integer arg0) throws Exception {
        return retval;
    }   
}

しかし、どのように使用しますか?私は試した:

from py4j.java_gateway import java_import
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf")
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
dfint = sqlContext.range(0,15)
df = dfint.withColumn("a",a(dfint.id))

しかし、私は得る:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-514811090b5f> in <module>()
      3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
      4 dfint = sqlContext.range(0,15)
----> 5 df = dfint.withColumn("a",a(dfint.id))

TypeError: 'JavaObject' object is not callable

そして、次の代わりに a.call を使用しようとしました:

df = dfint.withColumn("a",a.call(dfint.id))

しかし得た: ----------------------------------------------- ---------------------------- TypeError トレースバック (最新の呼び出しが最後) in () 3 a = sc._gateway.jvm.com. test1.test2.TestClassUdf(7) 4 dfint = sqlContext.range(0,15) ----> 5 df = dfint.withColumn("a",a.call(dfint.id))

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
    796     def __call__(self, *args):
    797         if self.converters is not None and len(self.converters) > 0:
--> 798             (new_args, temp_args) = self._get_args(args)
    799         else:
    800             new_args = args

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args)
    783                 for converter in self.gateway_client.converters:
    784                     if converter.can_convert(arg):
--> 785                         temp_arg = converter.convert(arg, self.gateway_client)
    786                         temp_args.append(temp_arg)
    787                         new_args.append(temp_arg)

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client)
    510         HashMap = JavaClass("java.util.HashMap", gateway_client)
    511         java_map = HashMap()
--> 512         for key in object.keys():
    513             java_map[key] = object[key]
    514         return java_map

TypeError: 'Column' object is not callable

どんな助けでも感謝します。

4

2 に答える 2

10

UDAFに関するあなた自身の別の質問(および回答)の助けを借りて、これを機能させました。

Spark はudf()Scala をラップする方法を提供するFunctionNので、Java 関数を Scala でラップして使用できます。Java メソッドは、静的であるか、implements Serializable.

package com.example

import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.functions.udf

class MyUdf extends Serializable {
  def getUdf: UserDefinedFunction = udf(() => MyJavaClass.MyJavaMethod())
}

PySpark での使用:

def my_udf():
    from pyspark.sql.column import Column, _to_java_column, _to_seq
    pcls = "com.example.MyUdf"
    jc = sc._jvm.java.lang.Thread.currentThread() \
        .getContextClassLoader().loadClass(pcls).newInstance().getUdf().apply
    return Column(jc(_to_seq(sc, [], _to_java_column)))

rdd1 = sc.parallelize([{'c1': 'a'}, {'c1': 'b'}, {'c1': 'c'}])
df1 = rdd1.toDF()
df2 = df1.withColumn('mycol', my_udf())

他の質問と回答の UDAF と同様に、列を渡すことができます。return Column(jc(_to_seq(sc, ["col1", "col2"], _to_java_column)))

于 2016-07-05T02:14:59.400 に答える
2

https://dzone.com/articles/pyspark-java-udf-integration-1に沿って、Java で UDF1 を定義できます。

public class AddNumber implements UDF1<Long, Long> {

@Override
public Long call(Long num) throws Exception {
      return (num + 5);
   }
}

そして、jarをpysparkに追加した後--package <your-jar>

pyspark で次のように使用できます。

from pyspark.sql import functions as F
from pyspark.sql.types import LongType


>>> df = spark.createDataFrame([float(i) for i in range(100)], FloatType()).toDF("a")
>>> spark.udf.registerJavaFunction("addNumber", "com.example.spark.AddNumber", LongType())
>>> df.withColumn("b", F.expr("addNumber(a)")).show(5)
+---+---+
|  a|  b|
+---+---+
|0.0|  5|
|1.0|  6|
|2.0|  7|
|3.0|  8|
|4.0|  8|
+---+---+
only showing top 5 rows
于 2020-05-23T13:09:45.123 に答える