Boost then Convolve: Gradient Boosting Meets Graph Neural Networks

МЕНЮ


Главная страница
Поиск
Регистрация на сайте
Помощь проекту
Архив новостей

ТЕМЫ


Новости ИИРазработка ИИВнедрение ИИРабота разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика

Авторизация



RSS


RSS новости


Рассматривается следующая задача: пусть есть какой-то граф, каждая вершина в этом графе описывается каким-то набором характеристик (признаков), целевая переменная также определена на уровне вершины. Используя структуру графа и признаки вершин, будем пытаться предсказать целевую переменную (один из рассматриваемых примеров - предсказание возраста пользователя ВК).

Graph Neural Networks (GNN) в целом умеют решать эту задачу, вопрос только в том, насколько эффективно они смогут использовать признаки пользователя, которые представляют собой табличные данные.

Отсюда возникает идея простого бейзлайна как подружить нейронки и бустинг. Можем обучать бустинг только на признаках вершины и передавать в GNN предсказание бустинга как фичу. Такой бейзлайн не очень похож на end-to-end и примерно понятно почему он не оптимален (бустинг ничего не будет знать о структуре графа). Авторы называют такой бейзлайн Res-GNN.

End-to-end training

Идея на самом деле достаточно проста. Раз мы передаем предсказание бустинга как фичу в GNN, наверное, мы хотели бы чтобы эта фича “уменьшала лосс GNN” (вернее помогала GNN точнее предсказывать целевую переменную). Другими словами, мы хотим чтобы бустинг выдавал такие предикты, использование которых в GNN уменьшит итоговый лосс. А для того, чтобы найти такие предикты бустинга, которые уменьшат лосс, нужно шагнуть в сторону антиградиента лосса GNN по предиктам бустинга.

И здесь-то становится понятно, что по сути бустинг именно так и обучается, только в обычном случае бустинг шагает в направлении антиградиента своего лосса по предиктам, а здесь надо шагать в направлении антиградиента лосса нейронки по предиктам бустинга.

То есть по сути все сводится к модификации функции потерь для бустинга.

Так выглядит этот алгоритм обучения (Boost GNN):

Здесь стоит обратить внимание что во время “train l steps of GNN” мы оптимизируем не только параметры модели (theta), но и X’ - output бустинга. Если взять l=1, то во второй с конца строке мы получим в точности антиградиент лосса GNN по X’, умноженный на learning rate GNN. Если брать l больше - получим сумму антиградиентов с l шагов градиентного спуска.

Датасеты и результаты

Авторы рассматривают задачи регрессии и классификации на уровне вершины графа на 5 датасетах.

На датасетах с homogenious признаками вершин (например усреднение эмбеддингов) использование бустинга не дает профита по сравнению с GNN. В датасетах с heterogenious фичами (табличные данные в привычном виде с категориальными фичами) BGNN и ResGNN в среднем работают лучше.

Также BGNN и Res-GNN сравниваются с бустингом в чистом виде и в большинтве случаев побеждают такой бейзлайн.

Помимо этого, авторы экспериментируют с разными архитектурами GNN и убеждаются что каждая из них в варианте BGNN и Res-GNN работает лучше, чем в чистом виде (то есть использование бустинга улучшает метрики).

Мое мнение

Мне понравился красивый хак, с помощью которого заставили бустинг и нейронку обучаться end-to-end. В целом метод выглядит очень перспективным, однако несколько спорных моментов, на мой взгляд, все же есть:

  1. Переобучение, out-of-fold и вот это все. Кто участвовал в соревнованиях по ml на табличных данных наверняка знает, что стэкинг так не работает. Если вы хотите фитить одну модель поверх предиктов другой, обязательно делать это out-of-fold. Здесь же, по крайней мере на первой итерации, обе модели предсказывают один и тот же таргет. Однако даже если эта проблема и есть, то ее фикс может только улучшить результаты предложенного метода
  2. Выбор бейзлайнов. То, что BGNN побеждает GNN и бустинг без фичей о структуре графа не говорит о том, что BGNN победит production-решения подобных задач, основанные на бустинге с кучей графовых фичей. Это ни в коем случае не в упрек статье (делать для каждого датасета мини-каггл в рамках написания статьи это все же overkill). Если у кого-то на работе есть такие задачи и вы попробуете этот метод - будет очень интересно послушать ваш фидбэк
  3. Масштабирование. На практике обычно приходится иметь дело с достаточно большими графами (миллионы и десятки миллионов вершин). Градиентный бустинг работает с такими датасетами без проблем, а с GNN все может быть чуть сложнее

Ну и конечно было бы круто расширить предложенный метод до предсказания на уровне двух вершин. Тогда к этому методу можно будет свести не только очевидно подходящие сюда рекомендации друзей, но и вообще более-менее любые рекомендации.


Источник: m.vk.com

Комментарии: