决策树算法的主要步骤是先进行特征选择,根据这个特征对数据集进行分割,将其划分为两个子集,以递归的方式重复上述操作,直到满足某个停止条件。在生成树的过程中和结束后对决策树进行适当的剪枝操作以提高其泛化性。最后用生成好的决策树模型进行预测。
决策树的构建始于选择最佳的特征来分割数据。这通常是通过计算每个特征的信息增益(或其他类似的度量)来完成的。详见决策树的构造。
一旦我们选择了最佳的特征,我们就根据该特征的每个可能的取值来分割数据。每个取值都会生成一个子节点。然后,我们在每个子节点上重复特征选择和分割的过程,直到满足某个停止条件。
我们会继续在每个子节点上重复特征选择和分割的过程,直到满足某个停止条件。常见的停止条件有:
当满足停止条件时,我们将当前节点标记为叶节点,叶节点代表一个类别。
通过以上步骤,我们就生成了一个决策树。决策树的每个内部节点代表一个特征,每个分支代表一个可能的取值,每个叶节点代表一个类别。
当我们得到一个新的样本时,我们可以使用决策树进行分类。从根节点开始,根据样本在每个节点的特征取值来决定走哪个分支,直到达到一个叶节点。叶节点的类别就是我们对样本的预测结果。
1from sklearn.datasets import load_iris
2from sklearn.model_selection import train_test_split
3from sklearn.tree import DecisionTreeClassifier
4from sklearn import tree
5import matplotlib.pyplot as plt
6
7# 加载鸢尾花数据集
8iris = load_iris()
9X = iris.data
10y = iris.target
11
12# 划分训练集和测试集
13X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
14
15# 创建决策树分类器并训练
16clf = DecisionTreeClassifier(max_depth=3)
17clf.fit(X_train, y_train)
18
19# 预测测试集
20y_pred = clf.predict(X_test)
21
22# 可视化决策树
23fig, ax = plt.subplots(figsize=(12, 12))
24tree.plot_tree(clf, filled=True)
25plt.show()
1from sklearn.datasets import load_iris
2from sklearn.model_selection import train_test_split
3from sklearn.tree import DecisionTreeClassifier
4from sklearn import tree
5import matplotlib.pyplot as plt
6
7# 加载鸢尾花数据集
8iris = load_iris()
9X = iris.data
10y = iris.target
11
12# 划分训练集和测试集
13X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
14
15# 创建决策树分类器并训练
16clf = DecisionTreeClassifier(max_depth=3)
17clf.fit(X_train, y_train)
18
19# 预测测试集
20y_pred = clf.predict(X_test)
21
22# 可视化决策树
23fig, ax = plt.subplots(figsize=(12, 12))
24tree.plot_tree(clf, filled=True)
25plt.show()
这段代码首先加载了鸢尾花数据集,然后将数据集划分为训练集和测试集。接着,创建一个决策树分类器,并使用训练集对其进行训练。然后,使用训练好的决策树分类器对测试集进行预测。最后,将训练好的决策树进行可视化展示。
plot_tree
函数生成的是决策树的图形表示。决策树是一种树形结构,其中每个内部节点表示一个特征(或属性),每个分支代表一个决策规则,每个叶节点代表一个输出(或结果)。
在plot_tree
生成的图中:
每个节点都显示了一个判断条件,用来决定下一步的方向。例如,X[3] <= 0.8
表示如果第4个特征(索引从0开始)小于或等于0.8,那么就走左边的分支,否则走右边的分支。
gini
表示该节点的不纯度,如果所有的样本都属于同一个类别,那么gini系数就是0,表示纯度最高。
samples
表示进入这个节点的样本数量。
value
表示这个节点中每个类别的样本数量。
class
表示这个节点最终的预测结果,它是基于这个节点中样本数量最多的类别来确定的。
每个节点下面的分支代表满足该节点条件后的下一步决策。
叶子节点(没有子节点的节点)表示决策结果,即预测的类别。
运行结果: