plot_tree의 시각화

tree
Author

강신성

Published

November 17, 2023

의사결정나무의 plot_tree를 시각화해보자

해당 포스트는 전북대학교 통계학과 최규빈 교수님의 강의내용을 토대로 재구성되었음을 알립니다.

1. 라이브러리 imports

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn.tree
import graphviz

#-#
import warnings
warnings.filterwarnings('ignore')

2. 데이터 적합

먼저 데이터를 트리로 적합해놓은 뒤 해당 데이터를 통해서 시각화를 해보자.

df_train = pd.read_csv('https://raw.githubusercontent.com/guebin/MP2023/main/posts/insurance.csv')
df_train
age sex bmi children smoker region charges
0 19 female 27.900 0 yes southwest 16884.92400
1 18 male 33.770 1 no southeast 1725.55230
2 28 male 33.000 3 no southeast 4449.46200
3 33 male 22.705 0 no northwest 21984.47061
4 32 male 28.880 0 no northwest 3866.85520
... ... ... ... ... ... ... ...
1333 50 male 30.970 3 no northwest 10600.54830
1334 18 female 31.920 0 no northeast 2205.98080
1335 18 female 36.850 0 no southeast 1629.83350
1336 21 female 25.800 0 no southwest 2007.94500
1337 61 female 29.070 0 yes northwest 29141.36030

1338 rows × 7 columns

## step 1
X = pd.get_dummies(df_train.drop('charges', axis = 1))
y = df_train['charges']

## step 2
predictr = sklearn.tree.DecisionTreeRegressor(max_depth = 3)

## step 3
predictr.fit(X, y)

## step 4 -- pass
DecisionTreeRegressor(max_depth=3)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

3. matplotlib 기반 시각화

A. plot_tree 기본 시각화


sklearn.tree.plot_tree(predictr);  ## ;을 통해 계산과정 제거 가능

잘 안보여…

### B. max_depth 조정

sklearn.tree.plot_tree(
    predictr,
    max_depth = 0
);

일단 보이기는 하는데, 위에 하나만 보이겠지…

C. 변수이름 추가 | feature_names = list


sklearn.tree.plot_tree(
    predictr,
    max_depth = 0,
    feature_names = X.columns.to_list()  ## 교수님은 잘 되시던데 나는 왜 리스트로 넣어야만 할까...
);

x[5]같은 식으로 순서만 표시되던 게, 이름이 표기되었다.

### D. fig 오브젝트

- plt.gcf()를 통해 fig오브젝트로 추출

sklearn.tree.plot_tree(
    predictr,
    max_depth = 1,
    feature_names = X.columns.to_list()
);

fig = plt.gcf()

이제 이녀석은 matplotlib에서 다룰 수 있다.

fig.suptitle("Can we setting title?")
Text(0.5, 0.98, 'Can we setting title?')
fig

- dpi(해상도) 조정

fig.set_dpi(250)
fig

아마 해상도를 무쟈게 올리면 되기야 하겠지… 근데 그럼 사진이 엄청나게 커지겠지…

E. matplotlibax에 그리기


- tree로 적합한 값의 차이 정도와 plot_tree를 위아래로 표기하기

fig = plt.figure()
ax = fig.subplots(2,1)
ax[0].plot(y, y, '--')
ax[0].plot(y, predictr.predict(X), 'o', alpha = 0.1)
sklearn.tree.plot_tree(
    predictr,
    max_depth = 1,
    feature_names = X.columns.to_list(),
    ax = ax[1]  ## 해당 옵션으로 ax에 삽입이 가능하다.
);

4. GraphViz

딱봐도 보기 불편한데, 뭔가 개선을 해놨지 않았을까???

그래서 준비했습니다!!

g = sklearn.tree.export_graphviz(
    predictr,
    feature_names = X.columns.to_list()
)
graphviz.Source(g)

위아래로 타이트하게 나오면서 스크롤로 볼 수 있게 된 모습, 애초에 sklearn에서 해당 사항을 우려해서 이렇게 만들어놓았다.

- 파일로 추출해서 저장하려면?

graphviz.Source(g).render('tree', format = 'pdf')  ## 파일명, 포맷
'tree.pdf'

작업공간에 파일이 추가된 것을 볼 수 있다.