基于逻辑回归模型对鸢尾花数据集进行分类

理论知识

        不做过多赘述,相关知识有:指数分布族、GLM建模(分布函数+连接函数,对于本例来说是二项分布+sigmoid函数)、最大似然函数、交叉熵函数(评估逻辑回归模型的目标函数)。该分类问题关注的是通过已知的概率结果来推算出未知参数。对未知参数做自变量的对数似然函数求导,解得极值处的方程并代入连接函数,根据分布类型,即可推导出sigmoid函数,这便是该模型的来历。由这个函数我们便能把线性模型映射到0~1来解决概率问题。

直接上代码

import numpy as np
import seaborn as sns
from pandas import read_csv
from pandas.plotting import  scatter_matrix
from matplotlib import  pyplot
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn import metrics

#导入鸢尾花数据集
filename = 'iris.csv'
names = ['separ-length', 'separ-width','petal-length','petal-width','class']
dataset = read_csv(filename, names=names)

#查看数据
def display(data):
    # 显示数据维度
    print('数据维度:行 %s,列 %s\n'%data.shape)
    # 查看数据的前十行
    print(data.head(20))
    # 统计描述数据信息
    print(data.describe())
    # 分类分布情况
    print(data.groupby('class').size())

# 可视化统计数据
def picture(data):
    # 箱线图 箱线图用来展示属性和中位值的离散程度
    data.plot(kind='box', subplots=True,layout=(2,2), sharex=False,sharey=False)
    # 直方图 直方图显示每一个特征属性的分布情况
    data.hist()
    # 散点矩阵图 散点矩阵图用来展示每个属性之间的影响关系
    scatter_matrix(data)
    
    #绘图
    pyplot.show()

#逻辑回归模型训练与预测
def LR_train(data):
    array = data.values         #提取数值
    X = array[:, 0:4]           #提取样本特征
    Y = array[:,4]              #提取标签
    validation_size = 0.8        #八成数据训练   两成数据评估
    seed = 0                    #随机数种子为零

    # 划分训练集和测试集
    X_train, X_validation, Y_train,Y_validation = \
    train_test_split(X, Y, test_size = validation_size,random_state=seed)

    #构造逻辑回归模型并调用
    LR = LogisticRegression(max_iter=10000)
    LR.fit(X_train, Y_train)
    y_pred= LR.predict(X_validation)
    print("模型精度:{:.2f}".format(np.mean(y_pred==Y_validation)))
    print("模型精度:{:.2f}".format(LR.score(X_validation,Y_validation)))

    X_new = np.array([[5.8,3.1,5.0,1.7]])    #预测目标
    prediction = LR.predict(X_new)
    print("预测的目标类别是:{}".format(prediction))

    #查看混淆矩阵(预测值和真实值的各类情况统计矩阵)
    confusion_matrix_result=metrics.confusion_matrix(y_pred,Y_validation)
    print('The confusion matrix result:\n',confusion_matrix_result)

    #利用热力图对于结果进行可视化
    pyplot.figure(figsize=(8,6))
    sns.heatmap(confusion_matrix_result,annot=True,cmap='Blues')
    pyplot.xlabel('Predictedlabels')
    pyplot.ylabel('Truelabels')
    pyplot.show()
    
if __name__ == '__main__':
    display(dataset)
    picture(dataset)
    LR_train(dataset)

一些问题

可能是数据集只有150个的原因,当我的训练集与测试集数量二八开时,模型进行三分类精度能达到100%。当数量比变成八二开时,精确度依然可以达到90%
训练集占八成,测试集占两成
训练集占两成  测试集占八成

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐