!pip install graphviz
!pip install dtreeviz

import numpy as np
import pandas as pd 
from sklearn import datasets
from sklearn.model_selection import train_test_split

iris = datasets.load_iris() #Irisデータを読み込む
data, target =, #データとラベルを分ける
x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.3, random_state=0) # 学習データとテストデータへ7:3で分割

print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype) #データ型の確認
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape) #データ数の確認

float64 float64 int64 int64 (105, 4) (45, 4) (105,) (45,)


from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier() #分類問題のモデルを作成, y_train) # 学習
y_pred = tree.predict(x_test) # テストデータの予測値

print('学習時スコア:', tree.score(x_train, y_train), '検証スコア', tree.score(x_test, y_test))

{'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_impurity_split': None, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'presort': 'deprecated', 'random_state': None, 'splitter': 'best'} [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2 1 1 2 0 2 0 0] 学習時スコア: 1.0 検証スコア 0.9777777777777777


import matplotlib.pyplot as plt

x = iris.feature_names #特徴量名 ->['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
y = tree.feature_importances_ #特徴量の重要度

plt.barh(x, y)

[0.02150464 0.02150464 0.90006666 0.05692405]
<BarContainer object of 4 artists>
<Figure size 432x288 with 1 Axes>

4. Graphvizによる木構造可視化

import graphviz
from sklearn.tree import export_graphviz

dot = export_graphviz(tree) #決定木モデルのdot形式を取得
graph = graphviz.Source(dot) #DOT記法をレンダリング
# print(dot) #Raw-Dotが出力
graph #グラフを出力
import graphviz
from sklearn.tree import export_graphviz

dot = export_graphviz(tree, filled=True, rounded=True, 
                      class_names=['setosa', 'versicolor', 'virginica'],
                      feature_names=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'],

graph = graphviz.Source(dot) #DOT記法をレンダリング
graph #グラフを出力


from dtreeviz.trees import dtreeviz

viz = dtreeviz(
    tree, # 決定木モデル, #データ, #データラベル
    target_name='variety', #正解値のラベル
    feature_names=iris.feature_names, #特徴量名
    class_names=[str(i) for i in iris.target_names], #クラス名:['setosa', 'versicolor', 'virginica']

# viz.view() #ブラウザ上で表示
from dtreeviz.trees import dtreeviz

viz = dtreeviz(
    tree, # 決定木モデル, #データ, #データラベル
    target_name='variety', #正解値のラベル
    feature_names=iris.feature_names, #特徴量名
    class_names=[str(i) for i in iris.target_names], #クラス名:['setosa', 'versicolor', 'virginica'],
    X = [1,2,3,4] #適当なデータ

# viz.view() #ブラウザ上で表示

