欢迎访问我的博客首页


1. scikit-learn


  scikit-learn 也称为 sklearn,是一个开源的机器学习库。

2. 鸢尾花数据集


1. 鸢尾花数据集简介

  鸢尾花数据集是一个很简单的数据集,样本就是一些数字。鸢尾花数据集包含 3 类鸢尾花亚属的 4 个特征。亚属包括山鸢尾、变色鸢尾和维吉尼亚鸢尾。特征包括花萼长度、花萼宽度、花瓣长度、花瓣宽度,单位是厘米。
  鸢尾花数据集中包含每个亚属的 50 个样本,共 150 个样本。每个样本包括 4 个特征向量和 1 个类别标签。

2. 鸢尾花数据集的读取

from sklearn.datasets import load_iris

if __name__ == '__main__':
    iris = load_iris()
    data, target = iris.data, iris.target
    print(data.shape, target.shape)     # (150, 4) (150,)
    print(data[0], target[0])           # [5.1 3.5 1.4 0.2] 0
    print(data[50], target[50])         # [7.  3.2 4.7 1.4] 1
    print(data[100], target[100])       # [6.3 3.3 6.  2.5] 2

  代码:data 是 150 × 4 150\times4 150×4 的特征,target 是鸢尾花亚属标签。每个特征的四个维度分别表示花萼长度、花萼宽度、花瓣长度、花瓣宽度。前 50 个特征属于类别为 0 的鸢尾花亚属,中间 50 个特征属于类别为 1 的鸢尾花亚属,后 50 个特征属于类别为 2 的鸢尾花亚属。

3. 决策树


  使用 sklearn 中的决策树分类鸢尾花。

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier


def train():
    # 1.数据。
    iris = datasets.load_iris()
    data, target = iris.data, iris.target
    train_data, test_data, train_target, test_target = \
        train_test_split(data, target, test_size=0.33, random_state=42)
    # 2.训练。
    model = DecisionTreeClassifier()
    model.fit(train_data, train_target)
    # 3.评估。
    print(model.score(test_data, test_target))

4. 支持向量机


  使用 scikit-learn 中的支持向量机分类鸢尾花。

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn import svm
import joblib


def train():
    # 1.数据。
    iris = load_iris()
    data, target = iris.data, iris.target
    train_data, test_data, train_target, test_target = \
        train_test_split(data, target, train_size=0.6, test_size=0.4, random_state=1, shuffle=True, stratify=target)
    # 2.训练。
    model = svm.SVC(C=2, degree=3, kernel='rbf', gamma=10, decision_function_shape='ovo', probability=True)
    model.fit(train_data, train_target.ravel())
    # 3.评估。
    print(model.score(train_data, train_target))
    print(model.score(test_data, test_target))
    # 4.保存。
    joblib.dump(model, 'svm_model.m')


def predict():
    model = joblib.load("svm_model.m")
    data = np.array([[4.8, 3.4, 1.6, 0.2], [5.7, 2.8, 4.5, 1.3], [6.3, 2.7, 4.9, 1.8]])
    print(model.predict(data))
    print(model.predict_proba(data))
    print(model.decision_function(data))


if __name__ == '__main__':
    train()
    predict()

5. 参考


  1. 鸢尾花数据集与逻辑回归分类
  2. svm/决策树/随机森林/knn分类鸢尾花数据集
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐