Градиентный бустинг с CatBoost. (часть 1/3) |
||
МЕНЮ Главная страница Поиск Регистрация на сайте Помощь проекту Архив новостей ТЕМЫ Новости ИИ Голосовой помощник Разработка ИИГородские сумасшедшие ИИ в медицине ИИ проекты Искусственные нейросети Искусственный интеллект Слежка за людьми Угроза ИИ ИИ теория Внедрение ИИКомпьютерные науки Машинное обуч. (Ошибки) Машинное обучение Машинный перевод Нейронные сети начинающим Психология ИИ Реализация ИИ Реализация нейросетей Создание беспилотных авто Трезво про ИИ Философия ИИ Big data Работа разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика
Генетические алгоритмы Капсульные нейросети Основы нейронных сетей Распознавание лиц Распознавание образов Распознавание речи Творчество ИИ Техническое зрение Чат-боты Авторизация |
2021-12-07 07:01 CatBoost – библиотека, которая была разработана Яндексом в 2017 году, представляет разновидность семейства алгоритмов Boosting и является усовершенствованной реализацией Gradient Boosting Decision Trees (GBDT). CatBoost имеет поддержку категориальных переменных и обеспечивает высокую точность. Стоит сказать, что CatBoost решает проблему смещения градиента (Gradient Bias) и смещения предсказания (Prediction Shift), это позволяет уменьшить вероятность переобучения и повысить точность алгоритма. Открываем Jupyter Notebook и начинаем работать с CatBoost. Импортируем нужные нам библиотеки: Загружаем набор данных: Посмотрим на нашу выборку: ACTION – это метка, дали сотруднику доступ или нет. Так же в нашем наборе данных мы имеем 9 признаков, все они числовые, но на самом деле все эти строки хешированные и сравнивать их не имеет смысла. Отделим таргет и фичи. В X кладем фичи, а в y отправляется таргет. Catboost необходимо сказать, какие признаки категориальные, для этого необходимо передать массив с индексами категориальных фичей: Посмотрим на соотношение классов в нашем датасете, для этого посчитаем количество нулей и единиц: Нулей в выборке 1897, а единиц 30872, это свидетельствует о дисбалансе классов, на это надо обращать внимание. Прежде чем обучить модель, необходимо подготовить данные, данный кусок кода позволяет это сделать: Посмотрим, как записался наш датасет: Чтобы Catboost правильно считал наши данные, ему надо понимать, что и в какой колонке лежит, напишем код, который генерирует column description file, где будет описано какая колонка чем является: Посмотрим на сгенерированный файл: Перед нами три колонки: первая колонка – это индексы колонок в файле с обучающей выборкой, вторая колонка – тип, третья колонка – фичи. Теперь создадим объекты выборки, родной формат для Catboost является Pool, это такой класс, в котором содержатся данные, в конструктор он принимает разные параметры, создаем такие Pool’ы: Следующим шагом мы разобьем нашу выборку на тестовую и тренировочную, сделаем это с помощью train_test_split из библиотеки Scikit Learn: Теперь переходим к обучению, здесь у нас будет два параметра – количество итераций и скорость обучения. Напомню, градиентный бустинг – это композиция решающих деревьев, каждое дерево строится последовательно, каждое последующее дерево компенсирует ошибки предсказания предыдущего дерева. По сути, число итераций – это количество деревьев. Если выставить скорость обучения высокой, то мы получим быстро переобучение, если слишком маленькой, то будем долго идти до некого оптимума. Далее переходим к функции fit, которая запускает обучение нашей модели: Здесь мы выставили параметр Verbose равный False, чтобы в stdout не выводилась никакая информация во время обучения. После обучения вызовем метод is_fitted(), он показывает обучилась ли модели, второй параметр get_params(), он покажет нам с какими параметрами происходило обучение модели. Выполним данный блок кода: Обучим модель вновь, но на этот раз будет показан прогресс нашего обучения, здесь параметр Verbose будет иметь значение равное 35, это значит, что каждые 35 итераций будет выводиться прогресс обучения нашей модели: В выводе мы видим время обучения, лучшую итерацию и лучшее значение на тестовой выборке. Catboost имеет параметр custom_loss, чтобы посмотреть обучение модели на других метриках. Передадим список метрик в этот параметр и эти самые метрики будут считаться на каждой итерации. Надо отметить, что по умолчанию в Catboost стоит метрика LogLoss и нельзя не обратить внимание на параметр Plot, который установлен с флагом True, это встроенный визуализатор, он показывает в Real-Time ход обучения модели. Запустим данную часть кода: Визуализатор показывает ход обучения по трем метрикам:
2. AUC 3. Accuracy На графике рисуется точка, которая дает понять, в какой момент наша модель переобучилась, по сути, все то, что идет после этой точки нам не нужно, оно только ухудшает наши предсказания, все деревья, после итерации, на которой мы получили переобучение, просто отбрасываются. Бывает так, что у нас есть модель с разными наборами параметров, в данном случае будет отличаться скорость обучения и мы хотим посмотреть, в какой момент мы получим переобучение модели, это так же можно визуализировать: Сразу же видно, что модель с learning_rate в 0.7 моментально получила переобучение, а модель с learning_rate в 0.01 получила переобучение на последних итерациях, исходя из этого можно сказать, что последняя модель более качественная и будет нам давать наилучшие результаты при ее дальнейшем использовании. Чтобы отбросить ненужные деревья, в ходе обучения модели, в Catboost присутствует параметр use_best_model, по умолчанию он включен и в итоге у нас останется дерево до того момента, когда качество начнет ухудшаться, давайте посмотрим, что это действительно так: Как видим, после 125 итерации модель ловит переобучение. Чтобы посмотреть количество деревьев, которые содержаться в нашей модели, мы выполним такую строчку: Да, у нас осталось дерево ровно до того момента, пока мы не получили переобучение нашей модели. На этом закончим первую часть статьи про градиентный бустинг с использованием CatBoost. В следующей части поговорим про Cross Validation, Overfitting Detector, ROC-AUC, SnapShot и Predict. До скорого! Источник: newtechaudit.ru Комментарии: |
|