sklearn.datasets.load_digits()

load_digits()sklearn.datasets 提供的 手写数字数据集,用于 多分类任务,适用于 机器学习模型测试计算机视觉入门


1. load_digits() 数据集简介

属性 说明
样本数 1797
特征数 64(8×8 灰度像素)
类别数 10(数字 0-9
任务类型 多分类问题
数据类型 每个样本是 8×8 图片,像素值 0-16

2. load_digits() 代码示例

(1) 加载数据集

from sklearn.datasets import load_digits

# 加载数据
digits = load_digits()

# 获取特征矩阵和目标变量
X, y = digits.data, digits.target

print("特征矩阵形状:", X.shape)
print("目标变量形状:", y.shape)
print("类别名称:", digits.target_names)

输出

特征矩阵形状: (1797, 64)
目标变量形状: (1797,)
类别名称: [0 1 2 3 4 5 6 7 8 9]

解释

  • X.shape = (1797, 64)17978×8 图片,每张图片有 64 个像素点。
  • y.shape = (1797,)1797 个标签,对应 0-9 这 10 个数字。

(2) 数据集格式

print(type(digits))

输出

<class 'sklearn.utils._bunch.Bunch'>

解释

  • load_digits() 返回的是 Bunch 对象,类似于字典,可通过 .data.target.images 访问数据

(3) 可视化手写数字

import matplotlib.pyplot as plt

# 显示前 10 个数字
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap="gray")
    ax.set_title(f"Label: {digits.target[i]}")
    ax.axis("off")

plt.show()

解释

  • digits.images[i] 提供 8×8 图片数据
  • 使用 imshow() 可视化手写数字

3. load_digits() 数据分析

(1) 样本数据

import pandas as pd

# 转换为 DataFrame
df = pd.DataFrame(digits.data)
df["target"] = digits.target

print(df.head())

输出

   0   1   2   3   4   5   6   7  ...  56  57  58  59  60  61  62  63  target
0  0   0   5  13   9   1   0   0  ...   0   0   0   0   0   0   0   0       0
1  0   0   0   12  13   5   0   0  ...   0   0   0   0   0   0   0   0       1
2  0   0   0   4  15  12   0   0  ...   0   0   0   0   0   0   0   0       2
3  0   0   7  15   1   0   0   0  ...   0   0   0   0   0   0   0   0       3
4  0   0   0   1  11   0   0   0  ...   0   0   0   0   0   0   0   0       4

解释

  • 64 个像素值作为特征,每个样本是一个 8×8 灰度图片
  • target 是手写数字的真实标签(0-9)

(2) 类别分布

import seaborn as sns

sns.countplot(x=df["target"])
plt.title("Digits 数据集类别分布")
plt.show()

解释

  • 查看每个数字的数量,验证数据集是否均衡

4. 适用场景

  • 手写数字识别任务(计算机视觉)。
  • 分类问题(多分类任务)
  • 机器学习算法测试(如 KNNSVM随机森林

5. load_digits() vs. 其他数据集

数据集 任务类型 样本数 特征数 适用场景
load_iris() 多分类 150 4 经典分类问题
load_wine() 多分类 178 13 葡萄酒分类
load_digits() 多分类 1797 64 手写数字识别
fetch_openml("mnist_784") 多分类 70000 784 MNIST 手写数字识别

6. 结论

  • load_digits() 提供了 1797 张 8×8 手写数字图片,用于多分类任务,适用于 机器学习和计算机视觉入门
  • 可以 转换为 Pandas DataFrame 进行数据分析,也可以 使用可视化方法查看手写数字
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐