重温经典:统计学习方法第3章 K近邻法

作者:谁偷走了我的奶酪2024.02.16 22:39浏览量:4

简介:通过详细解释和代码实现,带你重新理解统计学习方法中的 K 近邻法。从理论到实践,全面掌握这一经典机器学习算法。

在《统计学习方法》一书的第3章中,作者详细介绍了 K 近邻法(K-Nearest Neighbor,KNN)这一经典机器学习算法。KNN 是一种基于实例的学习,它不需要明确的训练阶段和测试阶段。在分类问题中,KNN 通过计算待分类样本与已知类别样本之间的距离,选择距离最近的 K 个样本,根据这 K 个样本的类别进行投票,以多数投票结果作为待分类样本的类别。

一、KNN 算法的基本思想

KNN 算法的基本思想是:在特征空间中,如果一个样本的大部分近邻都属于某个类别,则该样本也属于这个类别。算法的核心在于计算待分类样本与已知类别样本之间的距离或相似度。常用的距离度量方式有欧氏距离、曼哈顿距离等。

二、KNN 算法的优缺点

优点:

  1. 简单易懂,易于实现。
  2. 对异常值和噪声具有较强的鲁棒性。
  3. 当数据集较大时,KNN 算法具有较好的分类性能。

缺点:

  1. 当数据集规模较大时,计算量大,时间复杂度高。
  2. K 的选择对分类结果有影响,不同的 K 值可能导致不同的分类结果。
  3. 对于非线性问题,KNN 算法可能表现不佳。

三、KNN 算法的实践应用

下面是一个简单的 Python 代码实现,用于演示 KNN 算法在分类问题中的应用:

  1. import numpy as np
  2. from sklearn.neighbors import KNeighborsClassifier
  3. from sklearn.datasets import make_classification
  4. from sklearn.model_selection import train_test_split
  5. from sklearn.metrics import classification_report, confusion_matrix
  6. # 生成模拟数据集
  7. X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)
  8. # 划分训练集和测试集
  9. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  10. # 创建 KNN 分类器对象,指定邻居数量为 3
  11. knn = KNeighborsClassifier(n_neighbors=3)
  12. # 使用训练数据训练分类器对象
  13. knn.fit(X_train, y_train)
  14. # 使用测试数据进行预测
  15. y_pred = knn.predict(X_test)
  16. # 输出分类结果评估报告和混淆矩阵
  17. print(classification_report(y_test, y_pred))
  18. print(confusion_matrix(y_test, y_pred))

在上面的代码中,我们首先使用 make_classification 函数生成一个模拟数据集,然后将其划分为训练集和测试集。接下来,我们创建一个 KNeighborsClassifier 对象,并指定邻居数量为 3。然后使用训练数据训练分类器对象,并使用测试数据进行预测。最后,我们输出分类结果评估报告和混淆矩阵,以评估分类器的性能。在实际应用中,我们还可以通过调整参数、交叉验证等方法来优化分类器的性能。