W3docs

Дерево решений

Как работают деревья решений, как строить деревья классификации и регрессии в Python с scikit-learn, настраивать гиперпараметры и визуализировать их.

Дерево решений — это алгоритм обучения с учителем, который делает прогнозы, изучая иерархию правил if-then-else на обучающих данных. Каждый внутренний узел проверяет признак, каждая ветвь представляет результат этой проверки, а каждый листовой узел содержит прогноз (метку класса для классификации или числовое значение для регрессии).

В этой главе рассматриваются:

  • Как деревья решений разбивают данные с помощью мер неоднородности (Gini и энтропия)
  • Построение дерева классификации и дерева регрессии в Python с scikit-learn
  • Управление глубиной дерева и предотвращение переобучения с помощью гиперпараметров
  • Визуализация и анализ обученного дерева
  • Преимущества, ограничения и случаи применения деревьев решений

Как дерево решений разбивает данные

При обучении алгоритм перебирает каждый признак и каждый возможный порог, чтобы найти разбиение, максимально снижающее неоднородность — меру того, насколько смешаны классы в узле.

В scikit-learn распространены две меры неоднородности:

Неоднородность Gini

Неоднородность Gini измеряет вероятность неправильной классификации случайно выбранного образца, если бы ему назначалась метка согласно распределению классов в узле.

Gini(node) = 1 - Σ pᵢ²

Чистый узел (все образцы принадлежат одному классу) имеет Gini = 0. Максимально смешанный узел имеет Gini, приближающееся к 0,5 для бинарной классификации.

Энтропия и информационный выигрыш

Энтропия взята из теории информации. Она максимальна, когда классы распределены равномерно, и равна нулю, когда узел чистый.

Entropy(node) = -Σ pᵢ log₂(pᵢ)

Информационный выигрыш — это снижение энтропии после разбиения. Алгоритм выбирает разбиение, дающее наибольший информационный выигрыш. В scikit-learn выбор между двумя мерами осуществляется через параметр criterion (по умолчанию "gini").

Рекурсивное разбиение

Разбиение рекурсивно повторяется для каждого дочернего узла до тех пор, пока не будет выполнено условие остановки: узел чист, ни один признак не улучшает неоднородность, либо достигнут предел глубины или размера. Это создаёт структуру бинарного дерева.

Дерево классификации в Python

Датасет Iris содержит 150 образцов и 4 числовых признака. Задача — предсказать один из трёх видов цветка.

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

# Load dataset
data = load_iris()
X, y = data.data, data.target

# Split: 80 % train, 20 % test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train — limit depth to 3 to keep the tree readable
clf = DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=42)
clf.fit(X_train, y_train)

# Evaluate
y_pred = clf.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}")
print(classification_report(y_test, y_pred, target_names=data.target_names))

Ожидаемый вывод:

Accuracy: 1.00
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00         9
   virginica       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

Датасет Iris линейно разделим при глубине 3, поэтому дерево достигает идеальной точности на тестовых данных. Реальные датасеты будут более сложными.

Прогнозирование новых образцов

После обучения вызовите predict() для классификации новых наблюдений и predict_proba() для получения вероятностей классов:

import numpy as np

# A new flower: sepal length 5.1, sepal width 3.5, petal length 1.4, petal width 0.2
new_sample = np.array([[5.1, 3.5, 1.4, 0.2]])

predicted_class = clf.predict(new_sample)
predicted_proba = clf.predict_proba(new_sample)

print("Predicted class:", data.target_names[predicted_class[0]])
print("Class probabilities:", predicted_proba)

Ожидаемый вывод:

Predicted class: setosa
Class probabilities: [[1. 0. 0.]]

Дерево регрессии в Python

Деревья решений также справляются с непрерывными целевыми переменными. Используйте DecisionTreeRegressor вместо DecisionTreeClassifier.

from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np

# Synthetic regression dataset
X_reg, y_reg = make_regression(
    n_samples=300, n_features=5, noise=20, random_state=42
)

X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(
    X_reg, y_reg, test_size=0.2, random_state=42
)

reg = DecisionTreeRegressor(max_depth=5, random_state=42)
reg.fit(X_train_r, y_train_r)

y_pred_r = reg.predict(X_test_r)

mse = mean_squared_error(y_test_r, y_pred_r)
r2 = r2_score(y_test_r, y_pred_r)
print(f"MSE : {mse:.2f}")
print(f"R²  : {r2:.2f}")

Дерево регрессии разбивает данные, минимизируя среднеквадратичную ошибку (MSE) в каждом узле, и предсказывает среднее значение целевой переменной для всех обучающих образцов, попавших в листовой узел.

Настройка гиперпараметров

Без ограничений дерево решений будет расти, пока каждый лист не станет чистым, полностью запоминая обучающую выборку (переобучение). Гиперпараметры управляют сложностью дерева:

ПараметрПо умолчаниюЭффект
max_depthNoneМаксимальное количество уровней. Меньше = проще дерево.
min_samples_split2Минимальное количество образцов для разбиения узла. Больше = меньше разбиений.
min_samples_leaf1Минимальное количество образцов в листе. Больше = более гладкие границы.
max_featuresNoneКоличество признаков для рассмотрения при каждом разбиении (полезно для отбора признаков).
criterion"gini"Мера неоднородности: "gini" или "entropy" для классификаторов; "squared_error" для регрессоров.

Используйте перекрёстную проверку и поиск по сетке, чтобы найти лучшую комбинацию:

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV

data = load_iris()
X, y = data.data, data.target

param_grid = {
    "max_depth": [2, 3, 4, 5, None],
    "min_samples_split": [2, 5, 10],
    "criterion": ["gini", "entropy"],
}

grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring="accuracy",
)
grid_search.fit(X, y)

print("Best params :", grid_search.best_params_)
print(f"Best CV score: {grid_search.best_score_:.3f}")

Ожидаемый вывод (значения могут незначительно варьироваться в зависимости от версии scikit-learn):

Best params : {'criterion': 'gini', 'max_depth': 3, 'min_samples_split': 2}
Best CV score: 0.973

Работа с категориальными признаками

Деревья решений scikit-learn требуют числового ввода. Перед обучением закодируйте категориальные столбцы:

  • Порядковые категории (например, размер: small < medium < large): используйте OrdinalEncoder.
  • Номинальные категории (например, цвет: red, green, blue): используйте OneHotEncoder, чтобы не вносить подразумеваемый порядок.
from sklearn.preprocessing import OrdinalEncoder
import numpy as np

# Encode only the categorical column; keep the numeric column as-is
sizes = np.array([["small"], ["large"], ["medium"], ["large"]])
weights = np.array([1.2, 3.4, 2.1, 4.0])

# Explicit category order: large=0, medium=1, small=2
enc = OrdinalEncoder(categories=[["large", "medium", "small"]])
sizes_encoded = enc.fit_transform(sizes)

X_encoded = np.column_stack([sizes_encoded, weights])
print(X_encoded)

Ожидаемый вывод:

[[2.  1.2]
 [0.  3.4]
 [1.  2.1]
 [0.  4. ]]

Подробное руководство см. в главе Категориальные данные.

Визуализация дерева решений

Анализ структуры дерева показывает, какие признаки обеспечивают наибольшее количество разбиений, и делает модель прозрачной для аудита.

Текстовое представление

from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.datasets import load_iris

data = load_iris()
clf = DecisionTreeClassifier(max_depth=2, random_state=42)
clf.fit(data.data, data.target)

print(export_text(clf, feature_names=list(data.feature_names)))

Ожидаемый вывод:

|--- petal length (cm) <= 2.45
|   |--- class: 0
|--- petal length (cm) >  2.45
|   |--- petal width (cm) <= 1.75
|   |   |--- class: 1
|   |--- petal width (cm) >  1.75
|   |   |--- class: 2

Графическое отображение

import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris

data = load_iris()
clf = DecisionTreeClassifier(max_depth=2, random_state=42)
clf.fit(data.data, data.target)

plt.figure(figsize=(10, 5))
plot_tree(
    clf,
    feature_names=data.feature_names,
    class_names=data.target_names,
    filled=True,
    rounded=True,
)
plt.title("Iris Decision Tree (max_depth=2)")
plt.tight_layout()
plt.savefig("iris_tree.png", dpi=150)
plt.show()

filled=True окрашивает каждый узел в цвет преобладающего класса; более тёмные оттенки означают более высокую чистоту класса.

Важность признаков

После обучения feature_importances_ присваивает каждому признаку оценку от 0 до 1, где более высокое значение означает, что признак в большей степени способствовал снижению неоднородности на всех разбиениях:

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
import numpy as np

data = load_iris()
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(data.data, data.target)

importances = clf.feature_importances_
for name, imp in sorted(
    zip(data.feature_names, importances), key=lambda x: x[1], reverse=True
):
    print(f"{name:30s}: {imp:.4f}")

Ожидаемый вывод:

petal length (cm)             : 0.5856
petal width (cm)              : 0.4144
sepal length (cm)             : 0.0000
sepal width (cm)              : 0.0000

Признаки с важностью 0 не использовались ни в одном разбиении и могут быть исключены для упрощения модели.

Преимущества и ограничения

Когда использовать деревья решений

  • Вам нужна интерпретируемая модель — правила можно вывести в виде обычного текста.
  • Датасет содержит смесь числовых и категориальных признаков (после кодирования).
  • Вы хотите получить быстрый базовый результат перед применением ансамблевых методов.
  • Связь между признаками и целевой переменной нелинейна или включает взаимодействия.

Ограничения

ОграничениеСпособ устранения
Легко переобучается без настройкиОграничьте max_depth, min_samples_leaf; используйте перекрёстную проверку
Высокая дисперсия (небольшие изменения данных → другое дерево)Используйте ансамблевые методы: Random Forest / Bootstrap Aggregation
Предвзятость к признакам с большим количеством уникальных значенийИспользуйте max_features или нормализуйте критерии разбиения
Плохо экстраполирует за пределы диапазона обучающих данныхДля задач экстраполяции предпочтительнее линейные модели
Только осесимметричные разбиенияНаклонные деревья существуют, но не входят в scikit-learn

Деревья решений и смежные алгоритмы

АлгоритмКлючевое отличие
Логистическая регрессияЛинейная граница; лучше подходит для линейно разделимых данных; не обрабатывает взаимодействия автоматически
K-ближайших соседейОснован на примерах; нет явной модели; требует масштабирования признаков
Дерево решенийНелинейное; масштабирование не требуется; высокая интерпретируемость
Random Forest (см. Bootstrap Aggregation)Ансамбль из множества деревьев; значительно меньшая дисперсия; менее интерпретируем

Ключевые выводы

  • Деревья решений разбивают данные, максимизируя информационный выигрыш (или минимизируя неоднородность Gini) в каждом узле; процесс повторяется рекурсивно.
  • DecisionTreeClassifier и DecisionTreeRegressor в scikit-learn используют один и тот же API и названия гиперпараметров.
  • Всегда устанавливайте max_depth или min_samples_leaf для предотвращения переобучения; настраивайте их с помощью поиска по сетке и перекрёстной проверки.
  • feature_importances_ показывает, на какие признаки дерево опирается больше всего — это полезно для отбора признаков.
  • Одиночные деревья — хороший интерпретируемый базовый уровень, но ансамблевые методы, такие как Random Forest, почти всегда превосходят их на реальных данных.
Was this page helpful?