简介:通过详细解释和代码实现,带你重新理解统计学习方法中的 K 近邻法。从理论到实践,全面掌握这一经典机器学习算法。
在《统计学习方法》一书的第3章中,作者详细介绍了 K 近邻法(K-Nearest Neighbor,KNN)这一经典机器学习算法。KNN 是一种基于实例的学习,它不需要明确的训练阶段和测试阶段。在分类问题中,KNN 通过计算待分类样本与已知类别样本之间的距离,选择距离最近的 K 个样本,根据这 K 个样本的类别进行投票,以多数投票结果作为待分类样本的类别。
一、KNN 算法的基本思想
KNN 算法的基本思想是:在特征空间中,如果一个样本的大部分近邻都属于某个类别,则该样本也属于这个类别。算法的核心在于计算待分类样本与已知类别样本之间的距离或相似度。常用的距离度量方式有欧氏距离、曼哈顿距离等。
二、KNN 算法的优缺点
优点:
缺点:
三、KNN 算法的实践应用
下面是一个简单的 Python 代码实现,用于演示 KNN 算法在分类问题中的应用:
import numpy as npfrom sklearn.neighbors import KNeighborsClassifierfrom sklearn.datasets import make_classificationfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_report, confusion_matrix# 生成模拟数据集X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建 KNN 分类器对象,指定邻居数量为 3knn = KNeighborsClassifier(n_neighbors=3)# 使用训练数据训练分类器对象knn.fit(X_train, y_train)# 使用测试数据进行预测y_pred = knn.predict(X_test)# 输出分类结果评估报告和混淆矩阵print(classification_report(y_test, y_pred))print(confusion_matrix(y_test, y_pred))
在上面的代码中,我们首先使用 make_classification 函数生成一个模拟数据集,然后将其划分为训练集和测试集。接下来,我们创建一个 KNeighborsClassifier 对象,并指定邻居数量为 3。然后使用训练数据训练分类器对象,并使用测试数据进行预测。最后,我们输出分类结果评估报告和混淆矩阵,以评估分类器的性能。在实际应用中,我们还可以通过调整参数、交叉验证等方法来优化分类器的性能。