Skip to content
Snippets Groups Projects
Commit bb68f477 authored by Sandeep's avatar Sandeep Committed by Reynold Xin
Browse files

[Fix #79] Replace Breakable For Loops By While Loops

Author: Sandeep <sandeep@techaddict.me>

Closes #503 from techaddict/fix-79 and squashes the following commits:

e3f6746 [Sandeep] Style changes
07a4f6b [Sandeep] for loop to While loop
0a6d8e9 [Sandeep] Breakable for loop to While loop
parent 6ab75780
No related branches found
No related tags found
No related merge requests found
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.mllib.tree package org.apache.spark.mllib.tree
import scala.util.control.Breaks._
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkContext} import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
...@@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo ...@@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* still survived the filters of the parent nodes. * still survived the filters of the parent nodes.
*/ */
// TODO: Convert for loop to while loop var level = 0
breakable { var break = false
for (level <- 0 until maxDepth) { while (level < maxDepth && !break) {
logDebug("#####################################") logDebug("#####################################")
logDebug("level = " + level) logDebug("level = " + level)
logDebug("#####################################") logDebug("#####################################")
// Find best split for all nodes at a level. // Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
level, filters, splits, bins) level, filters, splits, bins)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level. // Extract info for nodes at the current level.
extractNodeInfo(nodeSplitStats, level, index, nodes) extractNodeInfo(nodeSplitStats, level, index, nodes)
// Extract info for nodes at the next lower level. // Extract info for nodes at the next lower level.
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
filters) filters)
logDebug("final best split = " + nodeSplitStats._1) logDebug("final best split = " + nodeSplitStats._1)
} }
require(scala.math.pow(2, level) == splitsStatsForLevel.length) require(scala.math.pow(2, level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves. // Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf) logDebug("all leaf = " + allLeaf)
if (allLeaf) break // no more tree construction if (allLeaf) {
break = true // no more tree construction
} else {
level += 1
} }
} }
...@@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo ...@@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
parentImpurities: Array[Double], parentImpurities: Array[Double],
filters: Array[List[Filter]]): Unit = { filters: Array[List[Filter]]): Unit = {
// 0 corresponds to the left child node and 1 corresponds to the right child node. // 0 corresponds to the left child node and 1 corresponds to the right child node.
// TODO: Convert to while loop var i = 0
for (i <- 0 to 1) { while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level. // Calculate the index of the node from the node level and the index at the current level.
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
if (level < maxDepth - 1) { if (level < maxDepth - 1) {
...@@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo ...@@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("Filter = " + filter) logDebug("Filter = " + filter)
} }
} }
i += 1
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment