Decision_Tree_Regression

Decision Tree Regression (using DecisionTrees.jl)

Adapted from http://scikit-learn.org/stable/auto_examples/tree/plot_tree_regression.html

A 1D regression with decision tree.

The decision trees is used to fit a sine curve with addition noisy observation. As a result, it learns local linear regressions approximating the sine curve.

We can see that if the maximum depth of the tree (controlled by the max_depth parameter) is set too high, the decision trees learn too fine details of the training data and learn from the noise, i.e. they overfit.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
using DecisionTree
using ScikitLearn
using PyPlot

# Create a random dataset
srand(42)
X = sort(5 * rand(80))
XX = reshape(X, 80, 1)
y = sin(X)
y[1:5:end] += 3 * (0.5 - rand(16))

# Fit regression model
regr_1 = DecisionTreeRegressor()
regr_2 = DecisionTreeRegressor(pruning_purity_threshold=0.05)
regr_3 = RandomForestRegressor(ntrees=20)
fit!(regr_1, XX, y)
fit!(regr_2, XX, y)
fit!(regr_3, XX, y)

# Predict
X_test = 0:0.01:5.0
y_1 = predict(regr_1, hcat(X_test))
y_2 = predict(regr_2, hcat(X_test))
y_3 = predict(regr_3, hcat(X_test))

# Plot the results
scatter(X, y, c="k", label="data")
plot(X_test, y_1, c="g", label="no pruning", linewidth=2)
plot(X_test, y_2, c="r", label="pruning_purity_threshold=0.05", linewidth=2)
plot(X_test, y_3, c="b", label="RandomForestClassifier", linewidth=2)
xlabel("data")
ylabel("target")
title("Decision Tree Regression")
legend(prop=Dict("size"=>10));

png

文章作者: Monad Kai
文章链接: onlookerliu.github.io/2017/12/29/Decision-Tree-Regression/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Code@浮生记
支付宝打赏
微信打赏