Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-29952

Pandas UDFs do not support vectors as input

    XMLWordPrintableJSON

Details

    • Improvement
    • Status: Open
    • Minor
    • Resolution: Unresolved
    • 3.1.0
    • None
    • PySpark, SQL
    • None

    Description

      Currently, pandas udfs do not support columns of vectors as input. Only columns of arrays. This means that feature columns that contain Dense- or Sparse vectors generated by CountVectorizer for example are not supported by pandas udfs out of the box. One needs to convert vectors into arrays first. It was not documented anywhere and I had to find out by trial and error. Below is an example. 

       

      from pyspark.sql.functions import udf, pandas_udf
      import pyspark.sql.functions as F
      from pyspark.ml.linalg import DenseVector, Vectors, VectorUDT
      from pyspark.sql.types import *
      import numpy as np
      
      columns = ['features','id']
      vals = [
           (DenseVector([1, 2, 1, 3]),1),
           (DenseVector([2, 2, 1, 3]),2)
      ]
      
      sdf = spark.createDataFrame(vals,columns)
      sdf.show()
      
      +-----------------+---+
      |         features| id|
      +-----------------+---+
      |[1.0,2.0,1.0,3.0]|  1|
      |[2.0,2.0,1.0,3.0]|  2|
      +-----------------+---+
      
      @udf(returnType=ArrayType(FloatType()))
      def vector_to_array(v):
          # convert column of vectors into column of arrays
          a = v.values.tolist()
          return a
      
      sdf = sdf.withColumn('features_array',vector_to_array('features'))
      sdf.show()
      sdf.dtypes
      
      +-----------------+---+--------------------+
      |         features| id|      features_array|
      +-----------------+---+--------------------+
      |[1.0,2.0,1.0,3.0]|  1|[1.0, 2.0, 1.0, 3.0]|
      |[2.0,2.0,1.0,3.0]|  2|[2.0, 2.0, 1.0, 3.0]|
      +-----------------+---+--------------------+
      
      [('features', 'vector'), ('id', 'bigint'), ('features_array', 'array<float>')]
      
      import pandas as pd
      
      @pandas_udf(LongType())
      def _pandas_udf(v):
          res = []
          for i in v:
              res.append(i.mean())
          return pd.Series(res)
      
      sdf.select(_pandas_udf('features_array')).show()
      
      +---------------------------+
      |_pandas_udf(features_array)|
      +---------------------------+
      |                          1|
      |                          2|
      +---------------------------+
      

      But If I use the vector column I get the following error.

      sdf.select(_pandas_udf('features')).show()
      
      ---------------------------------------------------------------------------
      Py4JJavaError                             Traceback (most recent call last)
      <ipython-input-74-d93e4117f661> in <module>
           13 
           14 
      ---> 15 sdf.select(_pandas_udf('features')).show()
      
      ~/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
          376         """
          377         if isinstance(truncate, bool) and truncate:
      --> 378             print(self._jdf.showString(n, 20, vertical))
          379         else:
          380             print(self._jdf.showString(n, int(truncate), vertical))
      
      ~/.pyenv/versions/3.4.4/lib/python3.4/site-packages/pyspark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
         1255         answer = self.gateway_client.send_command(command)
         1256         return_value = get_return_value(
      -> 1257             answer, self.gateway_client, self.target_id, self.name)
         1258 
         1259         for temp_arg in temp_args:
      
      ~/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
           61     def deco(*a, **kw):
           62         try:
      ---> 63             return f(*a, **kw)
           64         except py4j.protocol.Py4JJavaError as e:
           65             s = e.java_exception.toString()
      
      ~/.pyenv/versions/3.4.4/lib/python3.4/site-packages/pyspark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
          326                 raise Py4JJavaError(
          327                     "An error occurred while calling {0}{1}{2}.\n".
      --> 328                     format(target_id, ".", name), value)
          329             else:
          330                 raise Py4JError(
      
      Py4JJavaError: An error occurred while calling o2635.showString.
      : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 156.0 failed 1 times, most recent failure: Lost task 0.0 in stage 156.0 (TID 606, localhost, executor driver): java.lang.UnsupportedOperationException: Unsupported data type: struct<type:tinyint,size:int,indices:array<int>,values:array<double>>
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:56)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:92)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:116)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:115)
      	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
      	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
      	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
      	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
      	at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
      	at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99)
      	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
      	at org.apache.spark.sql.types.StructType.map(StructType.scala:99)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:115)
      	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:71)
      	at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:345)
      	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945)
      	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:194)
      
      Driver stacktrace:
      	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
      	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
      	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
      	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
      	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
      	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
      	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
      	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
      	at scala.Option.foreach(Option.scala:257)
      	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
      	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
      	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
      	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
      	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
      	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
      	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
      	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
      	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
      	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
      	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
      	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
      	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
      	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
      	at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
      	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
      	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
      	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
      	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
      	at org.apache.spark.sql.Dataset.head(Dataset.scala:2544)
      	at org.apache.spark.sql.Dataset.take(Dataset.scala:2758)
      	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
      	at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
      	at sun.reflect.GeneratedMethodAccessor81.invoke(Unknown Source)
      	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
      	at java.lang.reflect.Method.invoke(Method.java:498)
      	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
      	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
      	at py4j.Gateway.invoke(Gateway.java:282)
      	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
      	at py4j.commands.CallCommand.execute(CallCommand.java:79)
      	at py4j.GatewayConnection.run(GatewayConnection.java:238)
      	at java.lang.Thread.run(Thread.java:748)
      Caused by: java.lang.UnsupportedOperationException: Unsupported data type: struct<type:tinyint,size:int,indices:array<int>,values:array<double>>
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:56)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:92)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:116)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:115)
      	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
      	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
      	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
      	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
      	at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
      	at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99)
      	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
      	at org.apache.spark.sql.types.StructType.map(StructType.scala:99)
      	at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:115)
      	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:71)
      	at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:345)
      	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945)
      	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:194)
      

       

       

       

      Attachments

        Activity

          People

            Unassigned Unassigned
            kobakhit koba
            Votes:
            1 Vote for this issue
            Watchers:
            6 Start watching this issue

            Dates

              Created:
              Updated: