Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-46175

can't return prediction that has different length than ml input

    XMLWordPrintableJSON

Details

    • Improvement
    • Status: Open
    • Major
    • Resolution: Unresolved
    • 3.4.1
    • None
    • MLlib, PySpark
    • None
    • I'm on spark 3.4

    Description

      I'm using

       
      from pyspark.ml.functions import predict_batch_udf
      to construct a pandas udf that runs a computer vision model to predict classification labels for images. The model takes a 4D array as input and returns a 4D array as output (Batch, Channels, Height, Width)
       
      However I'd like to run some additional processing in the pandas_udf to convert the 4D output array (floats) to text labels since this is a more useful output and we are trying to register pandas_udfs ahead of time for spark.sql users.
       
       
      When I set the return type to a StringType though I get an error
       
      ```
      23/11/29 02:43:04 WARN TaskSetManager: Lost task 0.0 in stage 8.0 (TID 16) (172.18.0.2 executor 0): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/opt/spark/python/pyspark/ml/functions.py", line 809, in predict yield _validate_and_transform_prediction_result( File "/opt/spark/python/lib/pyspark.zip/pyspark/ml/functions.py", line 331, in _validate_and_transform_prediction_result raise ValueError("Prediction results must have same length as input data.") ValueError: Prediction results must have same length as input data.
      ```
       
      This seems like an unnecessary limitation, since it is common for ML models, especially in computer vision, to take input shapes different from output shapes.
       
      I think the issue is here: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/functions.html
       
      Can this check that enforces same shape be removed?
       
       
      to illustrate the problem, here are my StructTypes. The Raw one works but the Str one does not
       
      ```
      from collections import namedtuple
      from pyspark.sql.types import ArrayType, IntegerType, StringType,StructType, StructField, FloatType

      Task = namedtuple('TaskSchema', ['inference_input', 'inference_result'])

      SingleLabelClassificationRaw = Task(
      inference_input=StructType([
      StructField("array", ArrayType(IntegerType()), nullable=False),
      StructField("shape", ArrayType(IntegerType()), nullable=False)
      ]),
      inference_result=ArrayType(FloatType())
      )

      SingleLabelClassificationStr = Task(
      inference_input=StructType([
      StructField("array", ArrayType(IntegerType()), nullable=False),
      StructField("shape", ArrayType(IntegerType()), nullable=False)
      ]),
      inference_result=StringType()
      )
      ```

      Attachments

        Activity

          People

            Unassigned Unassigned
            rbavery Ryan Avery
            Votes:
            0 Vote for this issue
            Watchers:
            1 Start watching this issue

            Dates

              Created:
              Updated: