Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

Source Code Analysis of decision Tree in Spark

2025-01-16 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >

Share

Shulou(Shulou.com)06/03 Report--

1.Example

Using the decision tree classifier API in Spark MLlib, a decision tree model is trained and developed with Python.

"" DecisionTreeClassification Example. "" from _ _ future__ import print_functionfrom pyspark import SparkContextfrom pyspark.mllib.tree import DecisionTree, DecisionTreeModelfrom pyspark.mllib.util import MLUtilsif _ _ name__ = = "_ _ main__": sc = SparkContext (appName= "PythonDecisionTreeClassificationExample") # load and parse the data file as RDD dataPath = "/ home/zhb/Desktop/work/DecisionTreeShareProject/app/sample_libsvm_data.txt" print (dataPath) data = MLUtils.loadLibSVMFile (sc DataPath) # split the dataset into the training dataset and the test dataset (trainingData,testData) = data.randomSplit ([0.7 ~ 0.3]) print ("train data count:" + str (trainingData.count () print ("test data count:" + str (testData.count () # training decision tree classifier # categoricalFeaturesInfo is empty Indicates that all features are continuous values model = DecisionTree.trainClassifier (trainingData, numClasses=2, categoricalFeaturesInfo= {}, impurity='gini', maxDepth=5) MaxBins=32) # Forecast predictions on the test dataset predictions = model.predict (testData.map (lambda x: x.features)) # package true value and predicted value labelsAndPredictions = testData.map (lambda lp: lp.label) .zip (predictions) # Statistical prediction frequency testErr = labelsAndPredictions.filter (lambda (v) P): v! = p). Count () / float (testData.count ()) print ('DecisionTree Test Error =% 5.3f% (testErr*100)) print ("DecisionTree Learned classifiction tree model:") print (model.toDebugString ()) # Save and load the trained model modelPath = "/ home/zhb/Desktop/work/DecisionTreeShareProject/app/myDecisionTreeClassificationModel" model.save (sc, modelPath) sameModel = DecisionTreeModel.load (sc, modelPath) 2. Source code analysis of decision tree

The decision tree classifier API is DecisionTree.trainClassifier and enters the source code analysis.

The path where the source file is located is spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala.

Since ("1.1.0") def trainClassifier (input: RDD [LabeledPoint], numClasses: Int, categoricalFeaturesInfo: Map [Int, Int], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = {val impurityType = Impurities.fromString (impurity) train (input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort, categoricalFeaturesInfo)}

A classifier is trained and the train method is called.

@ Since ("1.0.0") def train (input: RDD [LabeledPoint], algo: Algo, impurity: Impurity, maxDepth: Int, numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map [Int, Int]): DecisionTreeModel = {val strategy = new Strategy (algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree (strategy) .run (input)}

The train method first encapsulates the parameters such as model type (classification or regression), information gain index, decision tree depth, number of classification, maximum number of segmentation boxes as Strategy, then creates a new DecisionTree object, and calls the run method.

@ Since ("1.0.0") class DecisionTree private [spark] (privateval strategy: Strategy, privateval seed: Int) extends Serializable with Logging {/ * @ param strategy The configuration parameters for the tree algorithm which specify the type * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy Etc. * / @ Since ("1.0.0") def this (strategy: Strategy) = this (strategy, seed = 0) strategy.assertValid () / * * Method to train a decision tree model over an RDD * * @ param input Training data: RDD of `org`.`apache`.`spark`.mllib`.`regression`.`LabeledPoint`. * @ return DecisionTreeModel that can be used for prediction. * / @ Since ("1.2.0") def run (input: RDD [LabeledPoint]): DecisionTreeModel = {val rf = new RandomForest (strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed) val rfModel = rf.run (input) rfModel.trees (0)}}

In the run method, first create a new RandomForest object, set the strategy and the number of decision trees to 1, pass the subset selection strategy to the RandomForest object as "all", then call the run method in RandomForest, and finally return the first decision tree in the random forest model.

That is, the decision tree model uses the random forest model for training, sets the number of decision trees to 1, and then takes the first decision tree in the random forest model as the result and returns as the decision tree training model.

3. Random forest source code analysis

The path of the random forest source file is spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala.

Private class RandomForest (privateval strategy: Strategy, privateval numTrees: Int, featureSubsetStrategy: String, privateval seed: Int) extends Serializable with Logging {strategy.assertValid () require (numTrees > 0, s "RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") Require (RandomForest.supportedFeatureSubsetStrategies.contains (featureSubsetStrategy) | | Try (featureSubsetStrategy.toInt). Filter (_ > 0). IsSuccess | | Try (featureSubsetStrategy.toDouble). Filter (_ > 0). Filter (_ 1. Val metadata = DecisionTreeMetadata.buildMetadata (retaggedInput, strategy, numTrees,featureSubsetStrategy) # create metadata for input data-- > 2. Val splits = findSplits (retaggedInput, metadata, seed) # split the features in the metadata-- > 2.1 calculate the sampling rate Sampling the input sample-> 2.2 findSplitsBySorting (sampledInput, metadata, continuousFeatures) # segmenting the features in the sampled sample-> 2.2.1 val thresholds = findSplitsForContinuousFeature (samples, metadata, idx) # for continuous features-> 2.2.2 val categories = extractMultiClassCategories (splitIndex + 1, featureArity) # for subtype features And the features are out of order-- > 2.2.3 Array.empty [Split] # for classified features, and the features are ordered It can be constructed directly during training-- > 3. Val treeInput = TreePoint.convertToTreeRDD (retaggedInput, splits, metadata) # convert input data into tree data-> 3.1 input.map {x = > TreePoint.labeledPointToTreePoint (x, thresholds, featureArity) # convert LabeledPoint data into TreePoint data-> 3.2 arr (featureIndex) = findBin (featureIndex, labeledPoint, featureArity (featureIndex), thresholds (featureIndex)) # in (labeledPoint Feature) find a discrete value-> 4. Val baggedInput = BaggedPoint.convertToBaggedRDD (treeInput, strategy.subsamplingRate, numTrees,withReplacement, seed) # sample the input data-- > 4.1 convertToBaggedRDDSamplingWithReplacement (input, subsamplingRate, numSubsamples, seed) # have put back sampling-- > 4.2 convertToBaggedRDDWithoutSampling (input) # sample size is 1 Sampling rate is 100%-> 4.3 convertToBaggedRDDSamplingWithoutReplacement (input, subsamplingRate, numSubsamples, seed) # No return sampling-> 5. Val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit (nodeQueue, maxMemoryUsage,metadata, rng) # get all the nodes that need to be sliced in each tree-> 5.1 val featureSubset: option [Array [int]] = if (metadata.subsamplingFeatures) {Some (Range (0, metadata.numFeatures). Iterator, metadata.numFeaturesPerNode Rng.nextLong (). _ 1)} # if subsampling is needed Select feature subset-- > 5.2 val nodeMemUsage = RandomForest.aggregateSizeForNode (metadata, featureSubset) * 8L # calculate after adding this node Whether there is enough memory-- > 6. RandomForest.findBestSplits (baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) # find the optimal syncopation point-- > 6.1val (split: Split, stats: ImpurityStats) = binsToBestSplit (aggStats, splits, featuresForNode, nodes (nodeIndex)) # find the best segmentation for each node-> 7. New DecisionTreeClassificationModel (uid, rootNode.toNode, numFeatures, strategy.getNumClasses) # returns the decision tree classification model

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Internet Technology

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report