简介:本文深入解析决策树算法原理、核心机制及工程实践,涵盖ID3/C4.5/CART算法对比、剪枝策略优化、特征工程技巧,并通过泰坦尼克号生存预测、医疗诊断决策系统两个完整案例,提供可复用的代码实现与调优方法。
决策树是一种基于树结构进行决策的监督学习算法,其核心是通过递归划分特征空间构建树形结构。每个内部节点代表一个特征上的测试,每个分支代表测试输出,每个叶节点代表类别或值。从信息论视角看,决策树的构建本质是寻找最优特征划分,使得子节点纯度最大化。
数学上,决策树的构建可形式化为:给定训练集D={(x₁,y₁),…,(xₙ,yₙ)},特征集A={a₁,…,aₘ},算法递归选择最优特征a*∈A进行划分,使得划分后的子集Dᵥ(v为分支)的信息增益/增益率/基尼指数最优。这种递归划分直到满足停止条件(如最大深度、最小样本数等)。
决策树易过拟合,需通过剪枝控制复杂度。预剪枝在构建时提前停止(如设置最大深度、最小样本分裂数),但可能欠拟合。后剪枝(如代价复杂度剪枝)先生成完整树,再自底向上剪枝:对于每个非叶节点,计算剪枝前后的误差,若剪枝后误差增加不超过阈值则剪枝。CART算法的剪枝公式为:Cα(T)=∑(yᵢ-ŷᵀ)²+α|T|,其中α为正则化参数,|T|为叶节点数。
连续特征需离散化,常用方法包括等宽分箱、等频分箱、基于聚类的分箱。例如,年龄特征可分箱为[0,18)、[18,35)、[35,60)、[60,+∞)。分类特征若取值过多(如城市名),可合并低频类别或使用目标编码(用类别对应的目标均值替换)。
特征重要性评估可通过信息增益/基尼指数贡献,或使用permutation importance(打乱特征值后模型性能下降程度)。Scikit-learn的DecisionTreeClassifier提供了featureimportances属性直接获取。
关键参数包括:
评估指标需根据任务选择:分类任务用准确率、F1、AUC;回归任务用MSE、MAE。交叉验证(如5折)可更稳健评估模型性能。
import pandas as pdfrom sklearn.tree import DecisionTreeClassifierfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import accuracy_score# 加载数据data = pd.read_csv('titanic.csv')# 特征工程data['Age'].fillna(data['Age'].median(), inplace=True) # 填充缺失值data['FamilySize'] = data['SibSp'] + data['Parch'] + 1 # 家庭规模data['IsAlone'] = (data['FamilySize'] == 1).astype(int) # 是否独自data['FareBin'] = pd.qcut(data['Fare'], 4, labels=False) # 票价分箱# 选择特征features = ['Pclass', 'Sex', 'Age', 'FareBin', 'IsAlone']X = pd.get_dummies(data[features], columns=['Sex']) # 独热编码y = data['Survived']
# 划分数据集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 初始模型dt = DecisionTreeClassifier(random_state=42)dt.fit(X_train, y_train)print("初始准确率:", accuracy_score(y_test, dt.predict(X_test)))# 参数调优from sklearn.model_selection import GridSearchCVparam_grid = {'max_depth': [3, 5, 7, 10],'min_samples_split': [2, 5, 10],'min_samples_leaf': [1, 2, 4]}grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5)grid_search.fit(X_train, y_train)best_dt = grid_search.best_estimator_print("调优后准确率:", accuracy_score(y_test, best_dt.predict(X_test)))print("最佳参数:", grid_search.best_params_)
from sklearn.tree import export_graphvizimport graphviz# 导出决策树dot_data = export_graphviz(best_dt, out_file=None,feature_names=X.columns,class_names=['Died', 'Survived'],filled=True, rounded=True,special_characters=True)graph = graphviz.Source(dot_data)graph.render("titanic_decision_tree") # 生成PDF文件
可视化可直观看到分裂规则,如首节点可能以“Sex_male=0”(女性)为分裂特征,说明性别是生存的最强预测因子。
医疗诊断需根据症状、检查结果预测疾病。数据可能包含连续指标(如血压、血糖)和分类指标(如疼痛类型)。需处理缺失值(如部分检查未做)和类别不平衡(如罕见病样本少)。
# 模拟医疗数据import numpy as npnp.random.seed(42)n_samples = 1000data = {'Age': np.random.randint(18, 80, n_samples),'BP_Sys': np.random.normal(120, 15, n_samples).clip(90, 180),'BP_Dia': np.random.normal(80, 10, n_samples).clip(60, 110),'Glucose': np.random.normal(100, 20, n_samples).clip(70, 200),'ChestPain': np.random.choice(['Typical', 'Atypical', 'Non-anginal', 'None'], n_samples),'ECG': np.random.choice(['Normal', 'ST-T', 'LVH'], n_samples),'Diagnosis': np.random.choice(['Healthy', 'Hypertension', 'Diabetes', 'CAD'], n_samples, p=[0.6, 0.2, 0.15, 0.05])}df = pd.DataFrame(data)# 特征工程df['BP_Ratio'] = df['BP_Dia'] / df['BP_Sys'] # 血压比df['Glucose_High'] = (df['Glucose'] > 125).astype(int) # 血糖是否高features = ['Age', 'BP_Sys', 'Glucose_High', 'ChestPain', 'ECG']X = pd.get_dummies(df[features], columns=['ChestPain', 'ECG'])y = df['Diagnosis']
from sklearn.preprocessing import LabelEncoderle = LabelEncoder()y_encoded = le.fit_transform(y)# 训练多分类决策树dt_multi = DecisionTreeClassifier(criterion='gini', max_depth=5, random_state=42)dt_multi.fit(X, y_encoded)# 预测新样本new_sample = pd.DataFrame({'Age': [50],'BP_Sys': [140],'Glucose_High': [1],'ChestPain': ['Typical'],'ECG': ['ST-T']})new_sample = pd.get_dummies(new_sample, columns=['ChestPain', 'ECG'])# 补全缺失列(与训练数据一致)for col in X.columns:if col not in new_sample.columns:new_sample[col] = 0pred = dt_multi.predict(new_sample)print("预测疾病:", le.inverse_transform(pred)[0])
决策树的可解释性在医疗领域至关重要。可通过feature_importances_分析关键特征:
importance = pd.DataFrame({'Feature': X.columns,'Importance': dt_multi.feature_importances_}).sort_values('Importance', ascending=False)print(importance)
若“Glucose_High”重要性最高,可建议医生重点关注血糖指标。业务落地时,可将决策树规则提取为IF-THEN规则库,集成到电子病历系统中实现实时诊断辅助。
joblib.dump(dt, 'model.pkl'))决策树因其可解释性强、训练速度快、无需特征缩放等优点,成为机器学习入门的首选算法。通过合理调参和特征工程,可在保持可解释性的同时提升模型性能。实际业务中,建议从简单决策树开始,逐步尝试集成方法,平衡模型复杂度与可维护性。