机器学习2:KNN决策树探究泰坦尼克号幸存者问题
KNN决策树解决泰坦尼克import pandas as pdfrom sklearn.tree import DecisionTreeClassifier, export_graphvizfrom sklearn.metrics import classification_reportimport graphviz#决策树可视化data = pd.read_csv(r"O:\泰迪云课堂\01Py
·
KNN决策树探究泰坦尼克号幸存者问题

import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import classification_report
import graphviz #决策树可视化
data = pd.read_csv(r"titanic_data.csv")
data.drop("PassengerId",axis = 1,inplace = True) #删除id这一列
data
| Survived | Pclass | Sex | Age | |
|---|---|---|---|---|
| 0 | 0 | 3 | male | 22.0 |
| 1 | 1 | 1 | female | 38.0 |
| 2 | 1 | 3 | female | 26.0 |
| 3 | 1 | 1 | female | 35.0 |
| 4 | 0 | 3 | male | 35.0 |
| ... | ... | ... | ... | ... |
| 886 | 0 | 2 | male | 27.0 |
| 887 | 1 | 1 | female | 19.0 |
| 888 | 0 | 3 | female | NaN |
| 889 | 1 | 1 | male | 26.0 |
| 890 | 0 | 3 | male | 32.0 |
891 rows × 4 columns
data.loc[data["Sex"] == "male","Sex"] = 1
data.loc[data["Sex"] == "female","Sex"] = 0
data
| Survived | Pclass | Sex | Age | |
|---|---|---|---|---|
| 0 | 0 | 3 | 1 | 22.0 |
| 1 | 1 | 1 | 0 | 38.0 |
| 2 | 1 | 3 | 0 | 26.0 |
| 3 | 1 | 1 | 0 | 35.0 |
| 4 | 0 | 3 | 1 | 35.0 |
| ... | ... | ... | ... | ... |
| 886 | 0 | 2 | 1 | 27.0 |
| 887 | 1 | 1 | 0 | 19.0 |
| 888 | 0 | 3 | 0 | NaN |
| 889 | 1 | 1 | 1 | 26.0 |
| 890 | 0 | 3 | 1 | 32.0 |
891 rows × 4 columns
data.fillna(data["Age"].mean(),inplace = True) #用均值来填充缺失值
data
| Survived | Pclass | Sex | Age | |
|---|---|---|---|---|
| 0 | 0 | 3 | 1 | 22.000000 |
| 1 | 1 | 1 | 0 | 38.000000 |
| 2 | 1 | 3 | 0 | 26.000000 |
| 3 | 1 | 1 | 0 | 35.000000 |
| 4 | 0 | 3 | 1 | 35.000000 |
| ... | ... | ... | ... | ... |
| 886 | 0 | 2 | 1 | 27.000000 |
| 887 | 1 | 1 | 0 | 19.000000 |
| 888 | 0 | 3 | 0 | 29.699118 |
| 889 | 1 | 1 | 1 | 26.000000 |
| 890 | 0 | 3 | 1 | 32.000000 |
891 rows × 4 columns
Dtc = DecisionTreeClassifier(max_depth = 5,random_state =8) #构建决策树
Dtc.fit(data.iloc[:,1:],data["Survived"]) #模型训练
pre = Dtc.predict(data.iloc[:,1:]) #模型预测
print(classification_report(pre,data["Survived"])) #混淆矩阵
precision recall f1-score support
0 0.88 0.84 0.86 573
1 0.73 0.79 0.76 318
accuracy 0.82 891
macro avg 0.81 0.82 0.81 891
weighted avg 0.83 0.82 0.82 891
pre == data["Survived"] #比较模型预测值与实际值是否一致
0 True
1 True
2 True
3 True
4 True
...
886 True
887 True
888 False
889 False
890 True
Name: Survived, Length: 891, dtype: bool
可视化
dot_data = export_graphviz(Dtc,feature_names = ["Pclass","Sex","Age"],class_names="Survive")
graph = graphviz.Source(dot_data)
graph

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)