Details
Description
The file `mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala` on the line no. 628 has a TODO saying:
// TODO: Check other fields besides the information gain.
If, in addition to the existing check of InformationGainStats' gain value I add another check, for instance, impurity – the test fails because the values are different in the saved model and the one restored from disk.
See PR with an example.
The tests are executed with this command:
build/mvn -e -Dtest=none -DwildcardSuites=org.apache.spark.mllib.tree.DecisionTreeSuite test
Excerpts from the output of the command above:
... - model save/load *** FAILED *** checkEqual failed since the two trees were not identical. TREE A: DecisionTreeModel classifier of depth 2 with 5 nodes If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) If (feature 1 in {0.0,1.0}) Predict: 0.0 Else (feature 1 not in {0.0,1.0}) Predict: 0.0 TREE B: DecisionTreeModel classifier of depth 2 with 5 nodes If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) If (feature 1 in {0.0,1.0}) Predict: 0.0 Else (feature 1 not in {0.0,1.0}) Predict: 0.0 (DecisionTreeSuite.scala:610) ...
If I add a little debug info in the `DecisionTreeSuite.checkEqual`:
val aStats = a.stats val bStats = b.stats println(s"id ${a.id} ${b.id}") println(s"impurity ${aStats.get.impurity} ${bStats.get.impurity}") println(s"leftImpurity ${aStats.get.leftImpurity} ${bStats.get.leftImpurity}") println(s"rightImpurity ${aStats.get.rightImpurity} ${bStats.get.rightImpurity}") println(s"leftPredict ${aStats.get.leftPredict} ${bStats.get.leftPredict}") println(s"rightPredict ${aStats.get.rightPredict} ${bStats.get.rightPredict}") println(s"gain ${aStats.get.gain} ${bStats.get.gain}")
Then, in the output of the test command we can see that only values of `gain` are equal:
id 1 1 impurity 0.2 0.5 leftImpurity 0.3 0.5 rightImpurity 0.4 0.5 leftPredict 1.0 (prob = 0.4) 0.0 (prob = 1.0) rightPredict 0.0 (prob = 0.6) 0.0 (prob = 1.0) gain 0.1 0.1