from sklearn.datasets import load_iris
iris = load_iris()
data = iris.data
target = iris.target
from sklearn.model_selection import train_test_split
train_x, test_x, train_y, test_y = train_test_split(data,target,random_state=1)
from sklearn import tree
model = tree.DecisionTreeClassifier(max_depth=5)
model.fit(train_x, train_y)
model.predict(test_x)
model.score(test_x,test_y)
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 10))
plot_tree(
model.fit(train_x, train_y),
filled=True,
rounded=True,
feature_names=iris.feature_names,
class_names=iris.target_names
)
plt.show()