首頁 > 軟體

python sklearn 畫出決策樹並儲存為PDF的實現過程

2022-07-14 18:00:43

利用sklearn畫出決策樹並儲存為PDF

下載Graphviz

進入官網下載並安裝:

https://graphviz.gitlab.io/_pages/Download/Download_windows.html

並將下列路徑設定為環境變數:

  • D:softwareGraphvizbin

在cmd中測試:

  • dot -version

python程式碼

import numpy as np
import pandas as pd
from sklearn import tree
import graphviz
# x,y是sklearn中需要擬合的資料
x = np.array(exam_train)
y = np.array(classes_train)
clf = tree.DecisionTreeClassifier(criterion='entropy', class_weight='balanced', max_depth=25)
clf = clf.fit(x, y)
dot_data = tree.export_graphviz(clf, out_file=None, feature_names=None, filled=True, rounded=True)  # 重要引數可客製化
graph = graphviz.Source(dot_data)
graph.render(view=True, format="pdf", filename="decisiontree_pdf")

可以生成一張賊帥的決策樹PDF:

python sklearn 決策樹運用

資料形式(tree.csv)

age look income orderly target
older ugly low yes no
young ugly high no no
young handsome low no no
young handsome high yes yes
young handsome medium yes yes
young handsome medium no no

python原始碼:

# -*- coding:utf-8*-
# 將字典 轉化為 sklearn 用的資料形式 資料型 矩陣
from sklearn.feature_extraction import DictVectorizer
import csv
from sklearn import preprocessing
from sklearn import tree

allElectronicsData = open('c:/pic/data/tree.csv','rb')
reader = csv.reader(allElectronicsData)
header = reader.next()
# print header
## 資料預處理
featureList = []
labelList = []
for row in reader:
    # print row[-1]
    labelList.append(row[-1])
    # 下面這幾步的目的是為了讓特徵值轉化成一種字典的形式,就可以呼叫sk-learn裡面的DictVectorizer,直接將特徵的類別值轉化成0,1值
    rowDict = {}
    for i in range(1, len(row) - 1):
        rowDict[header[i]] = row[i]
    featureList.append(rowDict)

for each in featureList:
    print each

# Vectorize features
vec = DictVectorizer()
dummyX = vec.fit_transform(featureList).toarray()
print("dummyX:"+str(dummyX))
print(vec.get_feature_names())

# label的轉化,直接用preprocessing的LabelBinarizer方法
lb = preprocessing.LabelBinarizer()
dummyY = lb.fit_transform(labelList)
print("dummyY:"+str(dummyY))
print("labelList:"+str(labelList))

#criterion是選擇決策樹節點的 標準 ,這裡是按照「熵」為標準,即ID3演演算法;預設標準是gini index,即CART演演算法。
clf = tree.DecisionTreeClassifier()
clf = clf.fit(dummyX,dummyY)
print("clf:"+str(clf))
# 匯入相關函數,視覺化決策樹
# 匯出的結果是一個dot檔案(在系統預設路勁),需要安裝Graphviz才能將它住哪華為PDF或png格式
# 輸出的dot檔案可以使用graphvize軟體轉為PDF,graphvize安裝目錄中的bin目錄放入到環境變數的Path中
# 使用如下命令
#cmd
# dot -Tpdf  c:/tree.dot -o c:/tree.pdf
#下載地址:http://www.graphviz.org/Download_windows.php
#生成dot檔案
with open("c:/tree.dot",'w') as f:
    f = tree.export_graphviz(clf, feature_names= vec.get_feature_names(),out_file= f)

以上為個人經驗,希望能給大家一個參考,也希望大家多多支援it145.com。


IT145.com E-mail:sddin#qq.com