Интерпретация моделей машинного обучения в python: shap

МЕНЮ


Искусственный интеллект
Поиск
Регистрация на сайте
Помощь проекту

ТЕМЫ


Новости ИИРазработка ИИВнедрение ИИРабота разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика

Авторизация



RSS


RSS новости


Зачем нужна интерпретация?

Почему так важно уделять внимание не только метрикам качества, но и интерпретации полученных результатов? С одной стороны, мы хотим получать максимально точные предсказания, но с другой — было бы неплохо понимать, почему получен именно такой результат.

Нельзя сказать, хорошая модель или плохая, основываясь лишь на одной метрике. Один из примеров — эксперимент, проведенный учеными из Вашингтонского университета. Классификатор, различающей фотографии хаски и волков, достигал 90% доли правильных ответов, несмотря на то что животные очень похожи друг на друга. Как же алгоритму удалось получить такой хороший результат? Всё оказалось довольно просто: модель принимала решения на основе фона картинки, а не на основе характеристик животных. На заднем плане фотографий с волками в большинстве случаев присутствовал снег, тогда как у хаски — нет. В итоге, вместо одного классификатора, получился алгоритм, определяющий снег. Это было бы сложно заметить без интерпретации модели ("Why Should I Trust You?": Explaining the Predictions of Any Classifier).

Поскольку данная модель была обучена для эксперимента, ошибочные предсказания не привели бы к появлению серьезных проблем. Но что, если мы строим модель для её дальнейшего использования в реальной жизни? Цена ошибки может быть очень велика. Иными словами, необходимость интерпретации появляется тогда, когда присутствуют определенные риски, например, финансовые или социальные.

Что такое interpretable machine learning?

В одних случаях интерпретация — довольно простая задача, но чем сложнее алгоритм, тем труднее понять, на основе чего сделано предсказание. В этом заключается своеобразный поиск баланса между интерпретируемостью алгоритма и его сложностью. Например, построив регрессию, мы можем с легкостью понять, на какие признаки модель обращает внимание и какие из них вносят наибольший вклад. Но что делать, если мы обучили нейронную сеть или любой другой сложный алгоритм? Как убедиться, что всё правильно работает? Как объяснить, что новая модель действительно лучше предыдущей? Подобные модели со сложными алгоритмами внутри относятся к «черным ящикам» («black box»), т.к. трудно объяснить и понять, почему модель пришла к тому или иному выводу.

Interpretable machine learning (и explainable AI) — область, в которой занимаются разработкой методов, позволяющих оценить влияние различных факторов на предсказания модели. Понимая, как именно работает модель, мы можем не только объяснить полученный результат и понять почему получен тот или иной ответ, но и идентифицировать возможные направления для улучшения модели, чуть больше узнать о структуре и особенностях наших данных. Например, в случае с волками и хаски мы бы заметили, что модель реагирует на наличие снега на фото.

Интерпретировать сами результаты можно как на глобальном уровне, так и на индивидуальном, т.е. для каждого наблюдения.

  1. Глобальная интерпретация: в этом случае мы объясняем работу сразу для всей модели, смотрим, какие факторы вносят наибольший вклад в предсказания: для линейных моделей к интерпретируемым частям можно отнести веса, для деревьев - разбиения, предсказания в листьях;
  2. Локальная интерпретация: смотрим на конкретные кейсы; также можно понять, для каких групп тот или иной признак действительно был важен, а для каких имел меньшее значение.

shap

Мы рассмотрим один из алгоритмов для интерпретации результатов — shap. Данный модуль позволяет получить информацию о важности тех или иных признаков, а затем довольно красиво всё визуализировать. Как следствие, предоставить понятный результат не только для аналитиков, но и для бизнеса. Звучит отлично, но как же всё это работает?

Оценка важности признаков в shap считается используя значения Шепли. Для конкретного признака его можно рассчитать по следующей формуле:

где i — конкретный признак, n — общее число признаков, p — предсказание от модели.

Если очень кратко: считаем разницу между предсказанием модели без интересующего нас признака и предсказанием с ним. Поскольку порядок добавления признаков имеет значение при подсчете их важности и может влиять на результат, в данной формуле используется S?N/i (N — все признаки), которая подразумевает перебор всех возможных подмножеств признаков из S, исключая интересующий признак i.

Так, для того, чтобы получить важность признака:

  1. Получаем все возможные сабсеты признаков S, в которых не присутствует признак i.
  2. Считаем, как добавление признака i влияет на предсказания модели для сабсетов S.
  3. Усредняем.
  4. Получаем важность фичи!

Более детально описание работа алгоритма описана в статьях:

  • Consistent Individualized Feature Attribution for Tree Ensembles
  • Consistent feature attribution for tree ensembles
  • A Unified Approach to Interpreting Model Predictions

Реализация shap в python

Теперь посмотрим, как работает shap в python!

Сначала импортируем библиотеку, не забыв строчку shap.initjs(), чтобы в дальнейшем получить интерактивные графики.

Note: Если вы работаете в Google Colab, то её нужно добавлять в каждой ячейке с таким графиком.

import shap shap.initjs() 

Немного о данных: информация о 14 характеристиках домов из различных пригородов, расположенных в Бостоне. Переменные:

  • CRIM – уровень преступности
  • ZN – доля земель под жилую застройку
  • INDUS – доля акров, не связанных с розничной торговлей в расчете на город
  • CHAS – 1, если граничит с рекой
  • NOX – концентрация оксида азота
  • RM – среднее количество комнат на дом
  • AGE – доля жилых домов, построенных до 1940 года
  • DIS – взвешенные расстояния до пяти бостонских центров занятости
  • RAD – индекс доступности магистралей
  • TAX – полная ставка налога на имущество
  • PTRATIO – соотношение учеников и учителей по городам
  • B – доля лиц афроамериканского происхождения
  • LSTAT – процент населения с более низким статусом
  • y – медианная стоимость домов (в тыс. долларов)
# импортируем данные X, y = shap.datasets.boston()  

Используем градиентный бустинг (но можно и другие алгоритмы):

# создаем и обучаем model = xgb.train(params = {"learning_rate": 0.01}, dtrain = xgb.DMatrix(X, label=y), num_boost_round = 100)  # обучаем модель 

Наконец, переходим к интерпретации! Создаем тот самый объясняющий объект explainer, передав ему уже созданную и обученную ранее модель model:

explainer = shap.TreeExplainer(model) 

Считаем значения Шепли либо для всех доступных наблюдений (датасет – X), либо для их части, т.к. при большом числе строк вычисления могут занять довольно много времени.

shap_values = explainer.shap_values(X) 

force plot

А теперь визуализируем! В shap есть несколько возможных вариантов. Для того, чтобы посмотреть на конкретные наблюдения, т.е. получить локальные объяснения, нужно воспользоваться функцией shap.force_plot(), передав ей на вход базовое значение и значение Шепли для интересующего нас кейса. Например, для первого (0) наблюдения в датасете:

shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:]) 

Полученный график показывает, как разные признаки влияют на итоговое предсказание модели.

  • Base value (базовое значение) – среднее значение, полученное при обучении;
  • Жирным выделено полученное значение;

В случае классификации, какие-то переменные сдвигают его к классу 0, а какие-то к 1. Так, если значение Шепли положительное (выделено розовым цветом), то оно смещает предсказание в сторону положительного класса (1, вправо), если негативное (выделено голубым) - отрицательного (0, влево). Для регрессии – либо увеличивают, либо уменьшают.

Можно оценить всё немного более глобально и построить такой график для каждого наблюдения, передав сразу все shap_values. Т.е. мы буквально переворачиваем график, полученный выше, на 90 градусов, и повторяем операцию для каждого наблюдения, а затем соединяем всё в один. Это позволяет увидеть интерпретации сразу для каждого наблюдения в датасете. Интерактивность также дает возможность посмотреть на эффект конкретного признака. Здорово, не правда ли?

# для всех: shap.force_plot(explainer.expected_value, shap_values, features=X)  

Поскольку построение графика для больших датасетов занимает много времени, можно построить его для части наблюдений:

# для 1000 наблюдений shap.force_plot(explainer.expected_value, shap_values[:1000,:], X.iloc[:1000,:]) 

summary plot

Посмотреть на всё и сразу можно с помощью shap.summary_plot().

  • Каждая точка – отдельное наблюдение;
  • Цветом обозначены значения соответствующего признака: высокие – красным , низкие – синим;
  • Признаки расположены на оси y по мере уменьшения их важности;
  • По оси x находятся значения Шепли, которые влияют на отнесение к классу 1 (в случае классификации) либо положительно, либо отрицательно .

Доступно несколько видов графиков: scatter (по умолчанию), violin, bar, которые можно выбрать используя параметр plot_type="выбранный_тип_графика". Размер графика устанавливается автоматически, поэтому чтобы его изменить, необходимо указать plot_size=(20,9) (ширину, высоту).

Согласно полученному графику, наиболее важный признак – LSTAT, и чем выше его значение, тем ниже будет предсказанная цена дома.

# для совсем всех, еще изменяем размер графика shap.summary_plot(shap_values, X, plot_size=(20,9)) 

В виде барчарта:

shap.summary_plot(shap_values, X, plot_type="bar", plot_size=(10,6)) 

dependence plot

Еще один вариант графика — dependence plot (график зависимости), показывающий как выбранные признаки влияют на shap values. Опять же, если значения положительные - модель будет предсказывать положительный класс, и наоборот.

  • y – значение Шепли
  • x – значение признака для каждого наблюдения в датасете
  • Цвет — значение по другому признаку

Например, график ниже показывает то, как изменяется цена дома в зависимости от значений среднего количества комнат на дом (RM). Для того, чтобы увидеть взаимодействия между признаками, автоматически выбирается еще одна переменная из датасета, но можно выбрать и вручную, указав параметр interaction_index.

Благодаря этому мы можем заметить, что среднее число комнат в доме имеет чуть меньшее влияние на цену дома при высоких значениях процента населения с более низким статусом (LSTAT).

shap.dependence_plot("RM", shap_values, X, interaction_index='LSTAT') 

А построить графики сразу для всех фичей можно с помощью цикла:

for feature_name in X.columns:     shap.dependence_plot(feature_name, shap_values, X, display_features=X) 

Что почитать?

Напоследок немного материалов, где можно подробнее почитать и послушать о том, какие еще существуют алгоритмы, как они работают, с какими проблемами сталкиваются, и какие новые решения появляются:

  • Что такое интерпретируемая модель?
  • Материалы с воркшопа Human Interpretability in Machine Learning (WHI 2018) с конференции ICML 2018
  • Книга об интерпретации моделей
  • Interpretable ML Symposium, NIPS 2017

Репозитории с подборками ресурсов по интерпретации моделей:

  • Awesome interpretable machine learning
  • Awesome machine learning interpretability
  • Awesome production machine learning — репозиторий с материалами по разным темам, в т.ч. с ссылками на пакеты для интерпретации моделей

Пакеты для интерпретации и статьи:

  • LIME + статья
  • Статья о Grad-CAM
  • Interpreting Convolutional Neural Network (CNN) Results
  • Grad-CAM++
  • https://arxiv.org/pdf/1806.05337.pdf — тоже про нейронные сети

Источник: m.vk.com

Комментарии: