1{
2 "cells": [
3 {
4 "cell_type": "markdown",
5 "metadata": {},
6 "source": [
7 "# 06. 결정 트리 (Decision Tree)\n",
8 "\n",
9 "## 학습 목표\n",
10 "- 결정 트리의 작동 원리 이해\n",
11 "- 정보 이득, 지니 불순도 개념\n",
12 "- 과적합 방지 (가지치기)"
13 ]
14 },
15 {
16 "cell_type": "code",
17 "execution_count": null,
18 "metadata": {},
19 "outputs": [],
20 "source": [
21 "import numpy as np\n",
22 "import pandas as pd\n",
23 "import matplotlib.pyplot as plt\n",
24 "from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree\n",
25 "from sklearn.model_selection import train_test_split, cross_val_score\n",
26 "from sklearn.metrics import accuracy_score, classification_report, mean_squared_error\n",
27 "from sklearn.datasets import load_iris, make_classification\n",
28 "import seaborn as sns\n",
29 "\n",
30 "plt.rcParams['font.family'] = 'DejaVu Sans'"
31 ]
32 },
33 {
34 "cell_type": "markdown",
35 "metadata": {},
36 "source": [
37 "## 1. 결정 트리 분류"
38 ]
39 },
40 {
41 "cell_type": "code",
42 "execution_count": null,
43 "metadata": {},
44 "outputs": [],
45 "source": [
46 "# Iris 데이터셋\n",
47 "iris = load_iris()\n",
48 "X, y = iris.data, iris.target\n",
49 "\n",
50 "# 2개 특성만 사용 (시각화 용이)\n",
51 "X_2d = X[:, [2, 3]] # petal length, petal width\n",
52 "\n",
53 "X_train, X_test, y_train, y_test = train_test_split(\n",
54 " X_2d, y, test_size=0.3, random_state=42\n",
55 ")"
56 ]
57 },
58 {
59 "cell_type": "code",
60 "execution_count": null,
61 "metadata": {},
62 "outputs": [],
63 "source": [
64 "# 결정 트리 학습\n",
65 "tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)\n",
66 "tree_clf.fit(X_train, y_train)\n",
67 "\n",
68 "y_pred = tree_clf.predict(X_test)\n",
69 "print(f\"Accuracy: {accuracy_score(y_test, y_pred):.4f}\")\n",
70 "print(f\"\\nClassification Report:\")\n",
71 "print(classification_report(y_test, y_pred, target_names=iris.target_names))"
72 ]
73 },
74 {
75 "cell_type": "markdown",
76 "metadata": {},
77 "source": [
78 "## 2. 트리 시각화"
79 ]
80 },
81 {
82 "cell_type": "code",
83 "execution_count": null,
84 "metadata": {},
85 "outputs": [],
86 "source": [
87 "# 결정 트리 시각화\n",
88 "plt.figure(figsize=(20, 10))\n",
89 "plot_tree(tree_clf, \n",
90 " feature_names=['petal length', 'petal width'],\n",
91 " class_names=iris.target_names,\n",
92 " filled=True,\n",
93 " rounded=True,\n",
94 " fontsize=12)\n",
95 "plt.title('Decision Tree - Iris Dataset')\n",
96 "plt.tight_layout()\n",
97 "plt.show()"
98 ]
99 },
100 {
101 "cell_type": "code",
102 "execution_count": null,
103 "metadata": {},
104 "outputs": [],
105 "source": [
106 "# 결정 경계 시각화\n",
107 "def plot_decision_boundary_tree(model, X, y, feature_names, class_names):\n",
108 " x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n",
109 " y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n",
110 " xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),\n",
111 " np.linspace(y_min, y_max, 200))\n",
112 " \n",
113 " Z = model.predict(np.c_[xx.ravel(), yy.ravel()])\n",
114 " Z = Z.reshape(xx.shape)\n",
115 " \n",
116 " plt.figure(figsize=(10, 8))\n",
117 " plt.contourf(xx, yy, Z, alpha=0.3, cmap='RdYlBu')\n",
118 " \n",
119 " colors = ['blue', 'green', 'red']\n",
120 " for i, (color, name) in enumerate(zip(colors, class_names)):\n",
121 " idx = y == i\n",
122 " plt.scatter(X[idx, 0], X[idx, 1], c=color, label=name, \n",
123 " edgecolors='black', alpha=0.7)\n",
124 " \n",
125 " plt.xlabel(feature_names[0])\n",
126 " plt.ylabel(feature_names[1])\n",
127 " plt.title('Decision Tree Decision Boundary')\n",
128 " plt.legend()\n",
129 " plt.show()\n",
130 "\n",
131 "plot_decision_boundary_tree(tree_clf, X_2d, y, \n",
132 " ['petal length', 'petal width'],\n",
133 " iris.target_names)"
134 ]
135 },
136 {
137 "cell_type": "markdown",
138 "metadata": {},
139 "source": [
140 "## 3. 과적합 분석"
141 ]
142 },
143 {
144 "cell_type": "code",
145 "execution_count": null,
146 "metadata": {},
147 "outputs": [],
148 "source": [
149 "# 깊이에 따른 성능 비교\n",
150 "depths = range(1, 20)\n",
151 "train_scores = []\n",
152 "test_scores = []\n",
153 "\n",
154 "for depth in depths:\n",
155 " tree = DecisionTreeClassifier(max_depth=depth, random_state=42)\n",
156 " tree.fit(X_train, y_train)\n",
157 " train_scores.append(tree.score(X_train, y_train))\n",
158 " test_scores.append(tree.score(X_test, y_test))\n",
159 "\n",
160 "plt.figure(figsize=(10, 6))\n",
161 "plt.plot(depths, train_scores, 'b-o', label='Train Score')\n",
162 "plt.plot(depths, test_scores, 'r-o', label='Test Score')\n",
163 "plt.xlabel('Max Depth')\n",
164 "plt.ylabel('Accuracy')\n",
165 "plt.title('Decision Tree: Train vs Test Score by Depth')\n",
166 "plt.legend()\n",
167 "plt.grid(True, alpha=0.3)\n",
168 "plt.show()"
169 ]
170 },
171 {
172 "cell_type": "markdown",
173 "metadata": {},
174 "source": [
175 "## 4. 특성 중요도"
176 ]
177 },
178 {
179 "cell_type": "code",
180 "execution_count": null,
181 "metadata": {},
182 "outputs": [],
183 "source": [
184 "# 전체 특성 사용 모델\n",
185 "X_train_full, X_test_full, y_train_full, y_test_full = train_test_split(\n",
186 " X, y, test_size=0.3, random_state=42\n",
187 ")\n",
188 "\n",
189 "tree_full = DecisionTreeClassifier(max_depth=4, random_state=42)\n",
190 "tree_full.fit(X_train_full, y_train_full)\n",
191 "\n",
192 "# 특성 중요도\n",
193 "importance = pd.DataFrame({\n",
194 " 'Feature': iris.feature_names,\n",
195 " 'Importance': tree_full.feature_importances_\n",
196 "}).sort_values('Importance', ascending=True)\n",
197 "\n",
198 "plt.figure(figsize=(10, 6))\n",
199 "plt.barh(importance['Feature'], importance['Importance'])\n",
200 "plt.xlabel('Feature Importance')\n",
201 "plt.title('Decision Tree Feature Importance')\n",
202 "plt.grid(True, alpha=0.3)\n",
203 "plt.tight_layout()\n",
204 "plt.show()"
205 ]
206 },
207 {
208 "cell_type": "markdown",
209 "metadata": {},
210 "source": [
211 "## 5. 하이퍼파라미터 튜닝"
212 ]
213 },
214 {
215 "cell_type": "code",
216 "execution_count": null,
217 "metadata": {},
218 "outputs": [],
219 "source": [
220 "from sklearn.model_selection import GridSearchCV\n",
221 "\n",
222 "param_grid = {\n",
223 " 'max_depth': [2, 3, 4, 5, 6, 7, 8],\n",
224 " 'min_samples_split': [2, 5, 10, 20],\n",
225 " 'min_samples_leaf': [1, 2, 4, 8]\n",
226 "}\n",
227 "\n",
228 "grid_search = GridSearchCV(\n",
229 " DecisionTreeClassifier(random_state=42),\n",
230 " param_grid,\n",
231 " cv=5,\n",
232 " scoring='accuracy',\n",
233 " return_train_score=True\n",
234 ")\n",
235 "\n",
236 "grid_search.fit(X_train_full, y_train_full)\n",
237 "\n",
238 "print(f\"Best Parameters: {grid_search.best_params_}\")\n",
239 "print(f\"Best CV Score: {grid_search.best_score_:.4f}\")\n",
240 "print(f\"Test Score: {grid_search.score(X_test_full, y_test_full):.4f}\")"
241 ]
242 },
243 {
244 "cell_type": "markdown",
245 "metadata": {},
246 "source": [
247 "## 6. 결정 트리 회귀"
248 ]
249 },
250 {
251 "cell_type": "code",
252 "execution_count": null,
253 "metadata": {},
254 "outputs": [],
255 "source": [
256 "# 회귀 데이터 생성\n",
257 "np.random.seed(42)\n",
258 "X_reg = np.sort(5 * np.random.rand(200, 1), axis=0)\n",
259 "y_reg = np.sin(X_reg).ravel() + np.random.randn(200) * 0.1\n",
260 "\n",
261 "# 모델 학습\n",
262 "tree_reg = DecisionTreeRegressor(max_depth=4, random_state=42)\n",
263 "tree_reg.fit(X_reg, y_reg)\n",
264 "\n",
265 "# 예측\n",
266 "X_test_reg = np.linspace(0, 5, 500).reshape(-1, 1)\n",
267 "y_pred_reg = tree_reg.predict(X_test_reg)\n",
268 "\n",
269 "plt.figure(figsize=(12, 5))\n",
270 "plt.scatter(X_reg, y_reg, alpha=0.5, label='Data')\n",
271 "plt.plot(X_test_reg, y_pred_reg, 'r-', linewidth=2, label='Prediction')\n",
272 "plt.xlabel('X')\n",
273 "plt.ylabel('y')\n",
274 "plt.title('Decision Tree Regression')\n",
275 "plt.legend()\n",
276 "plt.grid(True, alpha=0.3)\n",
277 "plt.show()"
278 ]
279 },
280 {
281 "cell_type": "markdown",
282 "metadata": {},
283 "source": [
284 "## 정리\n",
285 "\n",
286 "### 핵심 개념\n",
287 "- **분할 기준**: 정보 이득 (엔트로피) 또는 지니 불순도\n",
288 "- **가지치기**: max_depth, min_samples_split, min_samples_leaf\n",
289 "- **장점**: 해석 용이, 전처리 불필요\n",
290 "- **단점**: 과적합 경향, 작은 변화에 민감\n",
291 "\n",
292 "### 다음 단계\n",
293 "- 앙상블 학습 (Random Forest, Gradient Boosting)\n",
294 "- 교차 검증을 통한 하이퍼파라미터 최적화"
295 ]
296 }
297 ],
298 "metadata": {
299 "kernelspec": {
300 "display_name": "Python 3",
301 "language": "python",
302 "name": "python3"
303 },
304 "language_info": {
305 "name": "python",
306 "version": "3.10.0"
307 }
308 },
309 "nbformat": 4,
310 "nbformat_minor": 4
311}