Details
Description
I'm using
from pyspark.ml.functions import predict_batch_udf
to construct a pandas udf that runs a computer vision model to predict classification labels for images. The model takes a 4D array as input and returns a 4D array as output (Batch, Channels, Height, Width)
However I'd like to run some additional processing in the pandas_udf to convert the 4D output array (floats) to text labels since this is a more useful output and we are trying to register pandas_udfs ahead of time for spark.sql users.
When I set the return type to a StringType though I get an error
```
23/11/29 02:43:04 WARN TaskSetManager: Lost task 0.0 in stage 8.0 (TID 16) (172.18.0.2 executor 0): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/opt/spark/python/pyspark/ml/functions.py", line 809, in predict yield _validate_and_transform_prediction_result( File "/opt/spark/python/lib/pyspark.zip/pyspark/ml/functions.py", line 331, in _validate_and_transform_prediction_result raise ValueError("Prediction results must have same length as input data.") ValueError: Prediction results must have same length as input data.
```
This seems like an unnecessary limitation, since it is common for ML models, especially in computer vision, to take input shapes different from output shapes.
I think the issue is here: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/functions.html
Can this check that enforces same shape be removed?
to illustrate the problem, here are my StructTypes. The Raw one works but the Str one does not
```
from collections import namedtuple
from pyspark.sql.types import ArrayType, IntegerType, StringType,StructType, StructField, FloatType
Task = namedtuple('TaskSchema', ['inference_input', 'inference_result'])
SingleLabelClassificationRaw = Task(
inference_input=StructType([
StructField("array", ArrayType(IntegerType()), nullable=False),
StructField("shape", ArrayType(IntegerType()), nullable=False)
]),
inference_result=ArrayType(FloatType())
)
SingleLabelClassificationStr = Task(
inference_input=StructType([
StructField("array", ArrayType(IntegerType()), nullable=False),
StructField("shape", ArrayType(IntegerType()), nullable=False)
]),
inference_result=StringType()
)
```