Привет, Хабр, в этой статье я расскажу про библиотеку ignite, с помощью которой можно легко обучать и тестировать нейронные сети, используя фреймворк PyTorch.
С помощью ignite можно писать циклы для обучения сети буквально в несколько строк, добавлять из коробки расчет стандартных метрик, сохранять модель и т.д. Ну, а для тех кто переехал с TF на PyTorch, можно сказать, что библиотека ignite — Keras для PyTorch.
В статье будет детально разобран пример обучения нейронной сети для задачи классификации, используя ignite
Добавим еще больше огня в PyTorch
Не буду тратить время, рассказывая о том, насколько крутой фреймворк PyTorch. Тот, кто им уже пользовался, понимает, о чём я пишу. Но, при всех его достоинствах, он все же является низкоуровневым в плане написания циклов для обучения, проверки, тестирования нейронных сетей.
Если мы посмотрим официальные примеры использования фреймворка PyTorch, то увидим в коде обучения сетки как минимум два цикла итераций по эпохам и по батчам обучающей выборки:
for epoch in range(1, epochs + 1): for batch_idx, (data, target) in enumerate(train_loader): # ...
Основная идея библиотеки ignite заключается в том, чтобы факторизовать эти циклы в единый класс, при этом позволив пользователю взаимодействовать с этими циклами с помощью обработчиков событий.
В итоге, в случае стандартных задач глубокого обучения мы можем неплохо сэкономить на количестве строк кода. Меньше строк — меньше ошибок!
К примеру, для сравнения, слева код для обучения и валидации модели, используя ignite, а справа — на чистом PyTorch:
больше не нужно писать для каждой задачи циклы for epoch in range(n_epochs) и for batch in data_loader.
позволяет лучше факторизовать код
позволяет вычислять базовые метрики из коробки
предоставляет “плюшки” типа
сохранение последней и лучших моделей (также оптимизатора и learning rate scheduler) во время обучения,
ранняя остановка обучения
итд
легко интегрируется с инструментами визуализации: tensorboardX, visdom, ...
В каком-то смысле, как уже было упомянуто, библиотеку ignite можно сравнить со всем известным Keras и его API для обучения и тестирования сетей. Также, библиотека ignite с первого взгляда очень похожа на библиотеку tnt, поскольку изначально обе библиотеки преследовали единые цели и имеют схожие идеи по их реализации.
Итак, зажигаем:
pip install pytorch-ignite
или
conda install ignite -c pytorch
Далее на конкретном примере мы ознакомимся с API библиотеки ignite.
Задача классификации с ignite
В этой части статьи рассмотрим школьный пример обучения нейронной сети для задачи классификации, используя библиотеку ignite.
Итак, возьмём простой датасет с картинками фруктов с kaggle. Задача заключается в том, чтобы каждой картинке с фруктом сопоставить соответствующий класс.
Прежде чем использовать ignite, давайте определим основные компоненты:
import torch.nn as nn from torchvision.models.squeezenet import squeezenet1_1 model = squeezenet1_1(pretrained=False, num_classes=81) model.classifier[-1] = nn.AdaptiveAvgPool2d(1) model = model.to(device)
import torch.nn as nn from torch.optim import SGD optimizer = SGD(model.parameters(), lr=0.01, momentum=0.5) criterion = nn.CrossEntropyLoss()
from ignite.engine import Engine, _prepare_batch def process_function(engine, batch): model.train() optimizer.zero_grad() x, y = _prepare_batch(batch, device=device) y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return loss.item() trainer = Engine(process_function)
Давайте разберемся, что означает этот код.
Движок Engine
Класс ignite.engine.Engine — каркас библиотеки, а объект этого класса trainer:
trainer = Engine(process_function)
определен со входной функцией process_function для обработки одного батча и служит для реализации проходов по обучающей выборке. Внутри класса ignite.engine.Engine происходит следующее:
while epoch < max_epochs: # run once on data for batch in data: output = process_function(batch)
Мы видим, что внутри функции мы, как обычно в случае обучения модели, вычисляем предсказания y_pred, рассчитываем функцию потерь loss и градиенты. Последние позволяют обновить веса модели: optimizer.step().
В общем случае, нет никаких ограничений на код функции process_function. Отметим только, что она принимает на вход два аргумента: объект Engine (в нашем случае trainer) и батч от загрузчика данных. Поэтому, например, для тестирования нейронной сети мы можем определить другой объект класса ignite.engine.Engine, в котором входная функция просто вычисляет предсказания, и реализовать проход по проверочной выборке один единственный раз. Об этом читайте далее.
Итак, выше приведенный код лишь только определяет необходимые объекты без запуска обучения. В принципе, в минимальном примере, можно вызвать метод:
trainer.run(train_loader, max_epochs=10)
и данного кода достаточно, чтобы "тихо" (без какого-либо вывода промежуточных результатов) обучить модель.
Заметка
Отметим также, что для задач такого типа в библиотеке есть удобный метод создания объекта trainer:
from ignite.engine import create_supervised_trainer trainer = create_supervised_trainer(model, optimizer, criterion, device)
Конечно, на практике вышеприведенный пример представляет мало интереса, поэтому давайте добавим следующие опции для "тренера":
вывод на экран значения функции потерь через каждые 50 итераций
запуск расчета метрик на обучающей выборке при фиксированной модели
запуск расчета метрик на проверочной выборке после каждой эпохи
сохранение параметров модели после каждой эпохи
сохранение трёх лучших моделей
изменение скорости обучения в зависимости от эпохи (learning rate scheduling)
ранняя остановка обучения (early-stopping)
События и обработчики событий
Чтобы добавить вышеперечисленные опции для "тренера" в библиотеке ignite предусмотрена система событий и запуск пользовательских обработчиков событий. Таким образом, пользователь может управлять объектом класса Engine на каждом этапе:
движок запустился/завершил запуск
эпоха началась/завершилась
батч итерация началась/завершилась
и запускать свой код на каждом событии.
Вывод на экран значения функции потерь
Для этого нужно просто определить функцию, в которой будет происходит вывод на экран, и добавить ее к "тренеру":
На самом деле есть два способа добавить обработчик событий: через add_event_handler, либо через декоратор on. Тоже самое, что и выше, можно сделать так:
Запуск расчета метрик на обучающей и тестовой выборках
Давайте будем вычислять следующие метрики: средняя точность, средняя полнота после каждой эпохи на части обучающей и всей тестовой выборках. Заметим, что мы будем вычислять метрики на части обучающей выборки после каждой эпохи обучения, а не во время обучения. Таким образом замер эффективности будет более точным, поскольку модель не изменяется во время вычисления.
Далее мы создадим два движка для оценки модели, используя ignite.engine.create_supervised_evaluator:
from ignite.engine import create_supervised_evaluator # Напомним, что device = “cuda” был определен выше train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
Мы создаем два движка для того, чтобы на один из них (val_evaluator) далее прицепить дополнительные обработчики событий для сохранения модели и ранней остановки обучения (обо всем этом далее).
Давайте также более детально рассмотрим, как определен движок для оценки модели, а именно, как определена входная функция process_function для обработки одного батча:
def create_supervised_evaluator(model, metrics={}, device=None): if device: model.to(device) def _inference(engine, batch): model.eval() with torch.no_grad(): x, y = _prepare_batch(batch, device=device) y_pred = model(x) return y_pred, y engine = Engine(_inference) for name, metric in metrics.items(): metric.attach(engine, name) return engine
Продолжаем далее. Выберем случайным образом часть обучающей выборки, на которой будем вычислять метрики:
import numpy as np from torch.utils.data.dataset import Subset indices = np.arange(len(train_dataset)) random_indices = np.random.permutation(indices)[:len(val_dataset)] train_subset = Subset(train_dataset, indices=random_indices) train_eval_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory="cuda" in device)
Далее, давайте определим в какой момент обучения мы будем запускать вычисление метрик и будем производить вывод на экран:
и, вероятно, был вопрос о типе объекта, полученного из функции train_evaluator.run(train_eval_loader), у которого есть атрибут metrics.
На самом деле, у класса Engine содержится структура под названием state (тип State) для того, чтобы была возможность передавать данные между обработчиками событий. Этот атрибут state содержит базовую информацию о текущей эпохе, итерации, о количестве эпох и т.д. Также его можно использовать для передачи любых пользовательских данных, в том числе и результатов расчета метрик.
state = train_evaluator.run(train_eval_loader) metrics = state.metrics # или просто train_evaluator.run(train_eval_loader) metrics = train_evaluator.state.metrics
Расчет метрик во время обучения
Если в задаче огромная обучающая выборка и расчет метрик после каждой эпохи обучения стоит дорого, а при этом все же хотелось бы видеть изменение некоторых метрик во время обучения, то можно использовать из коробки следующий обработчик событий RunningAverage. Например, мы хотим рассчитывать и выводить на экран точность классификатора:
Чтобы использовать функционал RunningAverage, то нужно установить ignite из исходников:
pip install git+https://github.com/pytorch/ignite
Изменение скорости обучение (learning rate scheduling)
Есть несколько способов изменять скорость обучения с помощью ignite. Далее рассмотрим самый простой способ, вызывая функцию lr_scheduler.step() в начале каждой эпохи.
from torch.optim.lr_scheduler import ExponentialLR lr_scheduler = ExponentialLR(optimizer, gamma=0.8) @trainer.on(Events.EPOCH_STARTED) def update_lr_scheduler(engine): lr_scheduler.step() # Вывод значений скорости обучения: if len(optimizer.param_groups) == 1: lr = float(optimizer.param_groups[0]['lr']) print("Learning rate: {}".format(lr)) else: for i, param_group in enumerate(optimizer.param_groups): lr = float(param_group['lr']) print("Learning rate (group {}): {}".format(i, lr))
Сохранение лучших моделей и других параметров во время обучения
Во время обучения было бы здорово записывать на диск веса лучшей модели, а также периодически сохранять веса модели, параметры оптимизатора и параметры изменения скорости обучения. Последнее может быть полезно для того, чтобы возобновить обучение из последнего сохраненного состояния.
В ignite для этого есть специальный класс ModelCheckpoint. Итак, давайте создадим обработчик событий ModelCheckpoint и будем сохранять лучшую модель по значению точности на проверочной выборке. В таком случае, определим score_function функцию, которая выдает значение точности в обработчик событий и он решает нужно ли сохранять модель или нет:
from ignite.handlers import ModelCheckpoint def score_function(engine): val_avg_accuracy = engine.state.metrics['avg_accuracy'] return val_avg_accuracy best_model_saver = ModelCheckpoint("best_models", filename_prefix="model", score_name="val_accuracy", score_function=score_function, n_saved=3, save_as_state_dict=True, create_dir=True) # "best_models" - Папка куда сохранять 1 или несколько лучших моделей # Имя файла -> {filename_prefix}_{name}_{step_number}_{score_name}={abs(score_function_result)}.pth # save_as_state_dict=True, # Сохранять как `state_dict` val_evaluator.add_event_handler(Events.COMPLETED, best_model_saver, {"best_model": model})
Теперь создадим еще один обработчик событий ModelCheckpoint для того, чтобы сохранять состояние обучения через каждые 1000 итераций:
Итак, уже почти все готово, добавим последний элемент:
Ранняя остановка обучения (early-stopping)
Давайте добавим еще один обработчик событий, который остановит обучение, если не будет происходить улучшение качества модели в течение 10 эпох. Качество модели будем снова оценивать с помощью фунцкии score_function.
Теперь проверим модели и параметры, сохраненные на диск:
ls best_models/ model_best_model_10_val_accuracy=0.8730994.pth model_best_model_8_val_accuracy=0.8712978.pth model_best_model_9_val_accuracy=0.8818188.pth
и
ls checkpoint/ checkpoint_lr_scheduler_3000.pth checkpoint_optimizer_3000.pth checkpoint_model_3000.pth
Предсказания обученной моделью
Для начала создадим загрузчик тестовых данных (для примера возьмем валидационную выборку) так, чтобы батч данных состоял из изображений и их индексов:
С помощью ignite создадим новый движок для предсказания на тестовых данных. Для этого определим функцию inference_update, которая выдает результат предсказания и индекс изображения. Для повышения точности, мы также будем использовать всем известный трюк “test time augmentation” (TTA).
import torch.nn.functional as F from ignite._utils import convert_tensor def _prepare_batch(batch): x, index = batch x = convert_tensor(x, device=device) return x, index def inference_update(engine, batch): x, indices = _prepare_batch(batch) y_pred = model(x) y_pred = F.softmax(y_pred, dim=1) return {"y_pred": convert_tensor(y_pred, device='cpu'), "indices": indices} model.eval() inferencer = Engine(inference_update)
Далее создадим обработчики событий, которые будут оповещать об этапе предсказаний и сохранять предсказания в выделенный массив:
И теперь можем посчитать еще раз точность модели по полученным предсказаниям:
from sklearn.metrics import accuracy_score y_test_true = [y for _, y in val_dataset] accuracy_score(y_test_true, y_preds) > 0.9310369676443035
Итак, в этой части мы показали, как посчитать предсказания с помощью обученной модели на валидационной выборке. На самом деле, пример очень простой, но из него должно быть понятно, каким образом изпользовать ignite для других и более сложных ситуаций.
В github репозитории библиотеки можно найти и другие примеры обучения сетей для таких задач как
fast neural transfer
reinforcement learning
dcgan
Заключение
В заключении хочу сказать, что библиотека ignite не является официальным продуктом от Facebook и в её разработке принимают участие программисты на добровольной основе (напр. автор этой статьи). На текущий момент она находится в версии 0.1.0, но основной API (Engine, State, Events, Metric, ...) будет по мере возможного оставаться без изменений и в последующих версиях. Поскольку библиотека находится в стадии активной разработки, в том числе и дополнительных модулей, то разработчики будут рады отзывам, сообщениях об ошибках и pull request-ам в репозитории github.