Перейти к содержанию

Самая простая ML модель — решающее дерево

Обложка

Представьте себе идеальную логику машины. Алгоритм, который просчитывает все возможные ходы и создает систему четких правил для решения любой задачи. При этом его теория состоит всего из двух формул.

\[ G_{L,R} = 1 - p^2 - q^2, \]
\[ G = \frac{m_L}{m} G_L + \frac{m_R}{m} G_R. \]

А одну из версий этого алгоритма хоть раз использовал каждый Data Scientist. Я говорю о решающих деревьях. Сегодня разберемся в их работе и поговорим о том, как использовать их на практике.

Найти все отличия

Старые банкоматы определяли номинал монеты по небольшим отличиям в диаметре, толщине, весе и магнитных свойствах. Правила для классификации придумывали эксперты: они знали состав сплава, стандарты и допуски.

Теперь представьте, что систему сортировки монет нужно придумать вам, но у вас есть только мешок монет, весы и штангенциркуль. Вы будете измерять монеты, фиксировать отличия и на их основе придумывать правила, а потом проверять эти правила на всех монетах в мешке.

Так работает решающее дерево: оно находит все возможные отличия и делит целое на группы, пока не получится порядок.

Разберем каждый шаг алгоритма на примере. Нужно предсказать результаты зачета по количеству посещенных лекций и семинаров. Результаты одной группы студентов мы знаем — используем их как обучающую выборку:

Первая группа студентов

Решающее дерево работает с признаками по очереди и сравнивает каждое значение признака со всеми остальными. Начнем с количества лекций и пройдем по всем строкам. Первый студент — Нурлан, он был на 7 лекциях. Сформируем первое правило: отнесем всех студентов, которые посетили больше 7 лекций вправо, а Нурлана и остальных оставим слева:

Деление по правилу 7 лекций

Справа все студенты сдали, но слева порядка не получилось. Продолжим идти по строкам: следующее значение — 8 лекций. Никто не был больше чем на 8 лекциях, поэтому в группу справа отнести будет некого. Пропускаем эту строку и попробуем следующее значение — 3 лекции:

Деление по правилу 3 лекций

Это деление снова приводит к беспорядку. Следующие значения повторяются, их можно пропустить. Так будет до тех пор, пока мы не дойдем до Виктории. Отнесем студентов, посетили больше 4 лекций вправо, остальных — влево:

Деление по правилу 4 лекций

Это правило приводит к полному порядку — достаточно было посетить больше 4 лекций, чтобы сдать зачет:

Решающее дерево для первой группы студентов

Мы получили первое решающее дерево. Правило деления на группы называют корнем, а группы слева и справа — листьями. Их заменяют общим результатом группы — зачет или незачет. После обучения дерева данные отдельных студентов уже неважны.

Если бы мы не нашли подходящее правило в первом признаке, пришлось бы перебирать значения второго. Сделаем это, с помощью небольшой программы, чтобы закрепить материал перед следующими шагами.

Создадим список с количеством семинаров и второй список с результатами зачета:

seminars = [8, 5, 3, 7, 3, 5, 3, 8, 7, 7, 3, 5]
results = [1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1]

Значения идут в том же порядке, что и в таблице. Например, второе значение относится к Тимуру — он был на 5 семинарах и сдал зачет.

Пройдем циклом по уникальным значениям количества семинаров. Разделим студентов на две группы — с меньшим и с большим количеством посещений. Если обе группы оказались не пустыми, выведем количество сдавших и не сдавших студентов в каждой из них:

seminars = [8, 5, 3, 7, 3, 5, 3, 8, 7, 7, 3, 5]
results = [1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1]

for threshold in set(seminars):
    group1 = [x for x, y in zip(results, seminars) if y <= threshold]
    group2 = [x for x, y in zip(results, seminars) if y > threshold]
    if not group1 or not group2:
        continue
    print(f"<={threshold}: {sum(group1)}/{len(group1)} зачет")
    print(f" >{threshold}: {sum(group2)}/{len(group2)} зачет")
<=3: 1/4 зачет
 >3: 6/8 зачет
<=5: 3/7 зачет
 >5: 4/5 зачет
<=7: 5/10 зачет
 >7: 2/2 зачет

Все группы оказались перемешанными, единственным удачным остается деление по количеству лекций.

Взгляд сверху

В нашем случае те же действия можно было выполнить графически. Отложим количество лекций по горизонтали, а количество семинаров — по вертикали. Обозначим зачет синим, а незачет — красным:

Данные на плоскости

На этой плоскости наши правила выглядят как вертикальные и горизонтальные линии. Вертикальные линии разделяют студентов по количеству лекций, а горизонтальные — по количеству семинаров:

Правила на плоскости

Точки на линии, слева или снизу от нее попадают в одну группу, а остальные — в другую. Линии можно провести только в точках — другие линии не дадут новых результатов.

По распределению точек можно понять, что разделить студентов горизонтальной линией не получится. А среди вертикальных линий можно увидеть наше решение: оно проходит через 4 лекции и идеально отделяет зачеты от незачетов:

Решение на плоскости

Измерить беспорядок

На курсе газодинамики студентам обычно показывают эксперимент Рейнольдса. Поток жидкости в трубе подкрашивают струйкой чернил. Если жидкость течет медленно и стабильно, чернила остаются в центре потока. Но если ускорить поток или ударить по трубе, ровная струйка чернил распадется на хаотичные вихри.

Хаос есть в любом процессе. Его вносят неучтенные факторы, погрешности измерений и случайные события. Хаотичность мешает увидеть в данных зависимости и добиться полного порядка.

Возьмем другую группу студентов:

Вторая группа студентов

Попробуйте построить для них дерево графически или с помощью программы. Теперь это уже не получится, хотя зависимости в данных все еще очевидны.

В случае, когда беспорядок нельзя убрать полностью, решающее дерево ищет путь к его максимальному снижению. Для этого беспорядок нужно измерить. Здесь нам понадобятся две формулы, которые мы видели в начале.

Первая формула — критерий Джини. Его используют как меру беспорядка:

\[ G = 1 - p^2 - q^2. \]

Буквами \(p\) и \(q\) обозначают доли студентов, которые сдали и не сдали зачет.

Посчитаем критерий Джини всего набора данных. В нем 12 студентов, из них 4 сдали зачет, 8 не сдали:

\[ G = 1 - \left(\frac{8}{12}\right)^2 - \left(\frac{4}{12}\right)^2 \approx 0.44. \]

Это значение близко к максимальному значению критерия Джини — 0.5. Оно будет у группы, где поровну сдавших и не сдавших студентов.

Теперь посчитаем беспорядок после разделения студентов на две группы. Начнем с правила 4 лекций — оно оказалось самым удачным в прошлый раз:

Деление по правилу 4 лекций

Посчитаем критерий Джини группы слева и группы справа:

\[ \begin{align} G_L =& 1 - \left(\frac{6}{6}\right)^2 - \left(\frac{0}{6}\right)^2 = 0,\\ G_R =& 1 - \left(\frac{2}{6}\right)^2 - \left(\frac{4}{6}\right)^2 \approx 0.44. \end{align} \]

Критерий Джини обеих групп равен среднему критериев слева и справа с учетом размера каждой группы:

\[ G = \frac{m_L}{m} G_L + \frac{m_R}{m} G_R = \frac{6}{12} \cdot 0 + \frac{6}{12} 0.44 = 0.22. \]

Здесь буквами \(m\) обозначено количество студентов в каждой группе. После разделения критерий Джини снизился в два раза — с 0.44 до 0.22.

Проверим все оставшиеся способы разделения с помощью программы. Создадим списки для количества лекций, семинаров и результатов зачета:

lectures = [8, 6, 8, 3, 3, 8, 4, 8, 4, 6, 4, 3]
seminars = [3, 3, 8, 3, 8, 6, 8, 4, 6, 8, 3, 4]
results = [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0]

Пройдем циклом по всем уникальным значениям количества лекций и семинаров. Как и в прошлый раз, выделим результаты группы студентов слева и справа и пропустим пустые группы. Для остальных посчитаем беспорядок с помощью функции get_impurity и выведем его значение на экран:

lectures = [8, 6, 8, 3, 3, 8, 4, 8, 4, 6, 4, 3]
seminars = [3, 3, 8, 3, 8, 6, 8, 4, 6, 8, 3, 4]
results = [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0]

for name, feature in {"лекции(й)": lectures, "семинара(ов)": seminars}.items():
    for threshold in set(feature):
        group1 = [y for x, y in zip(feature, results) if x <= threshold]
        group2 = [y for x, y in zip(feature, results) if x > threshold]
        if not group1 or not group2:
            continue
        impurity = get_impurity(group1, group2)
        print(f"{threshold} {name}: {impurity:.2f}")

Теперь напишем эту функцию: она должна ожидать результаты группы слева и группы справа и возвращать их средний беспорядок. Вычисляем его как взвешенное среднее критериев Джини каждой группы:

def get_impurity(group1: list[int], group2: list[int]) -> float:
    m1 = len(group1)
    m2 = len(group2)
    m = m1 + m2
    return m1 / m * get_gini_impurity(group1) + m2 / m * get_gini_impurity(group2)

lectures = [8, 6, 8, 3, 3, 8, 4, 8, 4, 6, 4, 3]
seminars = [3, 3, 8, 3, 8, 6, 8, 4, 6, 8, 3, 4]
results = [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0]

for name, feature in {"лекции(й)": lectures, "семинара(ов)": seminars}.items():
    for threshold in set(feature):
        group1 = [y for x, y in zip(feature, results) if x <= threshold]
        group2 = [y for x, y in zip(feature, results) if x > threshold]
        if not group1 or not group2:
            continue
        impurity = get_impurity(group1, group2)
        print(f"{threshold} {name}: {impurity:.2f}")

Критерий Джини равен единице минус квадрат доли сдавших студентов и не сдавших студентов. Запускаем поиск:

def get_gini_impurity(group: list[int]) -> float:
    m = len(group)
    p = sum(x == 0 for x in group) / m
    q = sum(x == 1 for x in group) / m
    return 1 - p ** 2 - q ** 2

def get_impurity(group1: list[int], group2: list[int]) -> float:
    m1 = len(group1)
    m2 = len(group2)
    m = m1 + m2
    return m1 / m * get_gini_impurity(group1) + m2 / m * get_gini_impurity(group2)

lectures = [8, 6, 8, 3, 3, 8, 4, 8, 4, 6, 4, 3]
seminars = [3, 3, 8, 3, 8, 6, 8, 4, 6, 8, 3, 4]
results = [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0]

for name, feature in {"лекции(й)": lectures, "семинара(ов)": seminars}.items():
    for threshold in set(feature):
        group1 = [y for x, y in zip(feature, results) if x <= threshold]
        group2 = [y for x, y in zip(feature, results) if x > threshold]
        if not group1 or not group2:
            continue
        impurity = get_impurity(group1, group2)
        print(f"{threshold} {name}: {impurity:.2f}")
3 лекции(й): 0.37
4 лекции(й): 0.22
6 лекции(й): 0.27
3 семинара(ов): 0.33
4 семинара(ов): 0.39
6 семинара(ов): 0.42

Правило 4 лекций оказалось самым сильным, и мы смогли доказать это математически.

Правый лист решающего дерева в этом случае не сможет привести к четкому ответу. Его предсказание будет равно доле сдавших студентов:

Решающее дерево для второй группы студентов

То есть, если студент посетил 4 лекции или меньше, его шансы сдать зачет равны 0%, иначе — 67%.

Вырастить дерево

Игра "Жизнь" Джона Конвея состоит всего из двух правил. Но эти правила создают устойчивые структуры, бесконечные вихри и плавающие скопления клеток. Сам Конвей не был уверен, существует ли в игре генератор бесконечной жизни, пока его не нашел математик Билл Госпер — спустя несколько месяцев после публикации правил игры. Это пример, когда формулировка алгоритма проще, чем его поведение.

Решающие деревья обладают тем же свойством. До сих пор мы рассматривали простейшую версию дерева с двумя листьями. Но мы можем использовать алгоритм отдельно для левого и правого листа, как если бы мы начинали с начала.

Вернемся к дереву, на котором мы остановились. Левый лист уже содержит только не сдавших студентов, поэтому делить его дальше нет смысла. А правый лист можно попробовать разделить:

Данные правого листа после деления по правилу 4 лекций

Внесем данные оставшихся студентов в программу:

def get_gini_impurity(group: list[int]) -> float:
    m = len(group)
    p = sum(x == 0 for x in group) / m
    q = sum(x == 1 for x in group) / m
    return 1 - p ** 2 - q ** 2

def get_impurity(group1: list[int], group2: list[int]) -> float:
    m1 = len(group1)
    m2 = len(group2)
    m = m1 + m2
    return m1 / m * get_gini_impurity(group1) + m2 / m * get_gini_impurity(group2)

lectures = [8, 6, 8, 8, 8, 6]
seminars = [3, 3, 8, 6, 4, 8]
results = [0, 0, 1, 1, 1, 1]

for name, feature in {"лекции(й)": lectures, "семинара(ов)": seminars}.items():
    for threshold in set(feature):
        group1 = [y for x, y in zip(feature, results) if x <= threshold]
        group2 = [y for x, y in zip(feature, results) if x > threshold]
        if not group1 or not group2:
            continue
        impurity = get_impurity(group1, group2)
        print(f"{threshold} {name}: {impurity:.2f}")
6 лекции(й): 0.42
3 семинара(ов): 0.00
4 семинара(ов): 0.22
6 семинара(ов): 0.33

Правило 3 семинаров для этой группы снижает критерий Джини до нуля. Используем его для продолжения дерева — теперь оно состоит из 2 вершин и 3 листьев:

Решающее дерево для второй группы студентов, итоговый вариант

Правила можно прочитать так: зачеты получили только студенты, которые посетили больше 4 лекций и 3 семинаров.

Построим то же дерево графически:

Правила решающего дерева для второй группы студентов на плоскости

Сначала мы разделили плоскость вертикальной чертой в точке 4 лекций. Левая часть оказалась упорядоченной, ее мы оставили как есть, а правую разделили горизонтальной чертой в точке 3 семинаров. Левую часть это правило не затрагивает, поэтому линия проходит только справа.

Мы разобрали каждый шаг вручную. Теперь посмотрим, как обучают деревья в реальных проектах. Для этого установим зависимости: Pandas — для работы с данными и Scikit-learn — для работы с деревьями:

uv add scikit-learn pandas

Данные нужно записать в файл. Создадим файл train_data.csv и перепечатаем в него данные, разделяя значения запятыми:

train_data.csv
Имя,Лекции,Семинары,Зачет
Влад Г.,8,3,0
Артем К.,6,3,0
Валерия М.,8,8,1
Павел Н.,3,3,0
Елизавета Р.,3,8,0
Кира С.,8,6,1
Амир Т.,4,8,0
Айгуль Х.,8,4,1
Лилия Х.,4,6,0
Илья Ч.,6,8,1
Мария Ш.,4,3,0
Алишер Э.,3,4,0

CSV — это простой формат для хранения таблиц. С небольшими CSV файлами можно работать прямо в редакторе кода. Многие редакторы подсвечивают столбцы разными цветами или отображают содержимое как таблицу.

Прочитаем этот файл с помощью Pandas. Вызовем функцию read_csv и передадим ей путь к файлу:

import pandas as pd

train_data = pd.read_csv("train_data.csv")
print(train_data)
             Имя  Лекции  Семинары  Зачет
0        Влад Г.       8         3      0
1       Артем К.       6         3      0
2     Валерия М.       8         8      1
3       Павел Н.       3         3      0
4   Елизавета Р.       3         8      0
5        Кира С.       8         6      1
6        Амир Т.       4         8      0
7      Айгуль Х.       8         4      1
8       Лилия Х.       4         6      0
9        Илья Ч.       6         8      1
10      Мария Ш.       4         3      0
11     Алишер Э.       3         4      0

Данные легко читаются, слева проставлены номера строк, а к колонкам можно обращаться как к ключам в словаре.

Создадим второй CSV файл с тестовыми данными. Назовем его test_data.csv и добавим несколько студентов с разным уровнем подготовки:

test_data.csv
Имя,Лекции,Семинары
Андрей В.,1,2
Дарья П.,1,7
Михаил С.,7,1
Елена Ш.,7,6

На них мы проверим дерево после обучения. Колонка с результатами здесь не нужна.

Прочитаем файл test_data.csv и перейдем к обучению. Импортируем решающее дерево из Scikit-learn: это класс DecisionTreeClassifier из пакета sklearn.tree. Для обучения дерева вызываем метод fit и передаем ему два аргумента: данные двух колонок с признаками и данные колонки с результатами зачета. Предсказания дерева на тестовой выборке записываем в колонку result. Для получения предсказаний вызываем метод predict:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier

train_data = pd.read_csv("train_data.csv")
test_data = pd.read_csv("test_data.csv")

model = DecisionTreeClassifier()
model.fit(train_data[["Лекции", "Семинары"]], train_data["Зачет"])
test_data["Зачет"] = model.predict(test_data[["Лекции", "Семинары"]])
print(test_data)
         Имя  Лекции  Семинары  Зачет
0  Андрей В.       1         2      0
1   Дарья П.       1         7      0
2  Михаил С.       7         1      0
3   Елена Ш.       7         6      1

В тестовой группе зачет сдала только Елена: только у нее достаточное количество посещений.

Вовремя остановиться

Сколько историй вы знаете про сошедший с ума искусственный интеллект? HAL9000 из "Космической Одиссеи", SKYNET из "Терминатора", Дэвид из "Прометея", GLaDOS из "Portal" — во всех этих случаях ИИ дали свободу действий, но плохо сформулировали основную директиву. У решающего дерева директива сейчас звучит так: "Убрать весь беспорядок". В предыдущем примере это получилось сделать за два шага. Но если дерево не добьется результата так же быстро, у него нет причин останавливаться.

Посмотрим на последнюю группу студентов:

Третья группа студентов

Здесь шум практически уничтожил зависимости. Интересно, как с этой задачей справится дерево. Я повторил обучение с помощью Scikit-learn как в предыдущем примере и получил это дерево:

Решающее дерево для третьей группы студентов

Прочитать его целиком будет сложно, поэтому начнем с правила, которое ведет к одному из листьев. Для примера возьмем лист в последнем ряду справа:

(семинары <= 4) и (лекции > 3) и (лекции <= 6) и (семинары > 3) и (лекции > 4)

Сгруппируем правила по каждому признаку:

(4 < лекции <= 6) и (3 < семинары <= 4)

Если учесть, что в данных встречаются только числа 3, 4, 6 и 8, правило можно записать проще:

(лекции = 6) и (семинары = 4)

Этому правилу соответствует только один студент из выборки — Эльвира. Она сдала зачет, и лист основан только на ее результате. Соседний лист соответствует правилу:

(лекции = 4) и (семинары = 4)

Это данные Кирилла, он не сдал зачет. Остальные листья тоже опираются на данные одного или двух студентов.

Поэтому правила звучат так странно. От общих вопросов дерево переходит к слишком узким. Оно не выявляет общую закономерность, а сводит каждого студента к частному случаю и подгоняет данные под ответы. В машинном обучении это называют переобучением.

Из-за переобучения предсказания для новых студентов будут ненадежными. Они будут зависеть от того, на кого из обучающей выборки новый студент окажется похожим больше всего. Если он готовился примерно как Эльвира, модель предскажет ему результат Эльвиры.

Это можно увидеть и на плоскости:

Решающее дерево для третьей группы студентов на плоскости

Каждая точка образует вокруг себя небольшую область. Эти области не складываются в общую картину.

На схеме дерева видно, что ветвление началось с общих правил, которые сильнее всего снижали беспорядок. Но ни в одной из групп беспорядок не удалось довести до нуля, поэтому ветвление продолжилось. Позже признаки в правилах начали повторяться — дерево уточняло условия и стягивало диапазон значений к отдельным точкам.

Дерево действовало правильно, но не смогло вовремя остановиться — как раджа из "Золотой Антилопы". Исправить это можно с помощью критериев остановки. Их несколько.

Можно запретить дереву расти ниже определенного уровня. Дерево изображают перевернутым, поэтому максимальный уровень называют глубиной. Глубина корня равна 0, значит глубина нашего дерева равна 5. В этом примере логичнее было бы остановиться на глубине 1 или 2. Для этого конструктору дерева передают аргумент max_depth с нужным значением:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier

train_data = pd.read_csv("train_data.csv")
model = DecisionTreeClassifier(max_depth=2)
model.fit(train_data[["Лекции", "Семинары"]], train_data["Зачет"])

Дерево остановилось на втором уровне и 4 листьях:

Решающее дерево с ограничением по глубине

Другой способ — задать минимальное количество студентов на одном листе. Это не позволит дереву изолировать частные случаи: каждый ответ должен опираться на минимальное число примеров. За это ограничение отвечает аргумент min_samples_leaf:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier

train_data = pd.read_csv("train_data.csv")
model = DecisionTreeClassifier(min_samples_leaf=4)
model.fit(train_data[["Лекции", "Семинары"]], train_data["Зачет"])

При значении 4 дерево остановится уже на первом уровне и двух листьях с самыми общими правилами. Любой следующий шаг нарушил бы ограничение:

Решающее дерево с ограничением по количеству студентов на одном листе

На практике разглядывать каждый вариант дерева не получится — некоторые модели содержат сотни деревьев. Поэтому в своих проектах используйте кросс-валидацию и экспериментируйте с параметрами, чтобы добиться лучших результатов.

Ловушка жадности

В каждом примере мы сравнивали все варианты ветвления дерева. Поэтому могло показаться, что мы каждый раз приходим к лучшему варианту дерева из всех возможных. Но это не так — важно, насколько далеко мы оцениваем последствия каждого шага.

Представьте внимательного шахматиста, который перебирает все варианты своего следующего хода, но смотрит только на один ход вперед. Его легко обмануть: можно пожертвовать фигурой, чтобы через несколько ходов забрать две или поставить мат. Кроме того, он не сможет разыграть комбинацию. Его цель — нанести оппоненту максимальный урон одним ходом. Такие алгоритмы называют жадными.

Решающее дерево тоже относится к жадным алгоритмам. Посмотрите на эти точки:

Пример проблематичных данных

Подумайте, как отделить красные точки от синих только вертикальными и горизонтальными линиями. Лучший способ — разделить их крест-накрест:

Лучшее решение

Но дерево решит эту задачу по-другому. Сначала оно отделит одну из крайних точек слева или справа, потом — вторую, а центр оставит нетронутым. Центр можно разделить только в два хода, поэтому с точки зрения дерева ни один вариант ветвления не имеет смысла:

Решение дерева

Зачем использовать алгоритм, который не может посмотреть хотя бы на несколько шагов вперед? Причина в том, что жадные алгоритмы работают очень быстро и дают достаточно хорошее решение. Их используют там, где лучшее решение можно найти только полным перебором. Это как раз задачи поиска оптимального дерева, кратчайшего маршрута или лучшей комбинации предметов.

Число возможных конфигураций дерева приблизительно равно:

\[ 2^{2^n}, \]

где \(n\) — число признаков.

Предположим, что создание дерева занимает одну наносекунду. Тогда поиск оптимального дерева в наборе из 5 признаков займет чуть больше 4 секунд, а на набор из 6 признаков понадобится 600 лет:

\[ \begin{align} n=5:&\quad 2^{2^5} = 4{,}294{,}967{,}296\ \text{нс} \approx 4.3\ \text{с},\\ n=6:&\quad 2^{2^6} = 18{,}446{,}744{,}073{,}709{,}551{,}616\ \text{нс} \approx 584.9\ \text{лет}. \end{align} \]

Мы не узнаем идеальный ответ. И нам придется смириться с тем, что мы знаем хотя бы достаточно хороший.