Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-28434

Decision Tree model isn't equal after save and load

    XMLWordPrintableJSON

    Details

    • Type: Improvement
    • Status: Resolved
    • Priority: Minor
    • Resolution: Fixed
    • Affects Version/s: 2.4.3
    • Fix Version/s: 3.0.0
    • Component/s: MLlib
    • Labels:
      None
    • Environment:

      spark from master

      Description

      The file `mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala` on the line no. 628 has a TODO saying:

       

      // TODO: Check other fields besides the information gain.
      

      If, in addition to the existing check of InformationGainStats' gain value I add another check, for instance, impurity – the test fails because the values are different in the saved model and the one restored from disk.

       

      See PR with an example.

       

      The tests are executed with this command:

       

      build/mvn -e -Dtest=none -DwildcardSuites=org.apache.spark.mllib.tree.DecisionTreeSuite test

       

      Excerpts from the output of the command above:

      ...
      
      - model save/load *** FAILED ***
      checkEqual failed since the two trees were not identical.
      TREE A:
      DecisionTreeModel classifier of depth 2 with 5 nodes
      If (feature 0 <= 0.5)
      Predict: 0.0
      Else (feature 0 > 0.5)
      If (feature 1 in {0.0,1.0})
      Predict: 0.0
      Else (feature 1 not in {0.0,1.0})
      Predict: 0.0
      
      TREE B:
      DecisionTreeModel classifier of depth 2 with 5 nodes
      If (feature 0 <= 0.5)
      Predict: 0.0
      Else (feature 0 > 0.5)
      If (feature 1 in {0.0,1.0})
      Predict: 0.0
      Else (feature 1 not in {0.0,1.0})
      Predict: 0.0 (DecisionTreeSuite.scala:610)
      
      ...

      If I add a little debug info in the `DecisionTreeSuite.checkEqual`:

       

      val aStats = a.stats
      val bStats = b.stats
      
      
      println(s"id ${a.id} ${b.id}")
      println(s"impurity ${aStats.get.impurity} ${bStats.get.impurity}")
      println(s"leftImpurity ${aStats.get.leftImpurity} ${bStats.get.leftImpurity}")
      println(s"rightImpurity ${aStats.get.rightImpurity} ${bStats.get.rightImpurity}")
      println(s"leftPredict ${aStats.get.leftPredict} ${bStats.get.leftPredict}")
      println(s"rightPredict ${aStats.get.rightPredict} ${bStats.get.rightPredict}")
      println(s"gain ${aStats.get.gain} ${bStats.get.gain}")
      

       

      Then, in the output of the test command we can see that only values of `gain` are equal:

       

      id 1 1
      impurity 0.2 0.5
      leftImpurity 0.3 0.5
      rightImpurity 0.4 0.5
      leftPredict 1.0 (prob = 0.4) 0.0 (prob = 1.0)
      rightPredict 0.0 (prob = 0.6) 0.0 (prob = 1.0)
      gain 0.1 0.1
      

        Attachments

          Activity

            People

            • Assignee:
              eugenzyx Ievgen Prokhorenko
              Reporter:
              eugenzyx Ievgen Prokhorenko
            • Votes:
              0 Vote for this issue
              Watchers:
              1 Start watching this issue

              Dates

              • Created:
                Updated:
                Resolved: