Простое руководство по дистилляции BERT

МЕНЮ


Искусственный интеллект
Поиск
Регистрация на сайте
Помощь проекту

ТЕМЫ


Новости ИИРазработка ИИВнедрение ИИРабота разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика

Авторизация



RSS


RSS новости


Если вы интересуетесь машинным обучением, то наверняка слышали про BERT и трансформеры.

BERT — это языковая модель от Google, показавшая state-of-the-art результаты с большим отрывом на целом ряде задач. BERT, и вообще трансформеры, стали совершенно новым шагом развития алгоритмов обработки естественного языка (NLP). Статью о них и «турнирную таблицу» по разным бенчмаркам можно найти на сайте Papers With Code.

С BERT есть одна проблема: её проблематично использовать в промышленных системах. BERT-base содержит 110М параметров, BERT-large — 340М. Из-за такого большого числа параметров эту модель сложно загружать на устройства с ограниченными ресурсами, например, мобильные телефоны. К тому же, большое время инференса делает эту модель непригодной там, где скорость ответа критична. Поэтому поиск путей ускорения BERT является очень горячей темой.

Нам в Авито часто приходится решать задачи текстовой классификации. Это типичная задача прикладного машинного обучения, которая хорошо изучена. Но всегда есть соблазн попробовать что-то новое. Эта статья родилась из попытки применить BERT в повседневных задачах машинного обучения. В ней я покажу, как можно значительно улучшить качество существующей модели с помощью BERT, не добавляя новых данных и не усложняя модель.

Knowledge distillation как метод ускорения нейронных сетей

Существует несколько способов ускорения/облегчения нейронных сетей. Самый подробный их обзор, который я встречал, опубликован в блоге Intento на Медиуме.

Способы можно грубо разделить на три группы:

  1. Изменение архитектуры сети.
  2. Сжатие модели (quantization, pruning).
  3. Knowledge distillation.

Если первые два способа сравнительно известны и понятны, то третий менее распространён. Впервые идею дистилляции предложил Рич Каруана в статье “Model Compression”. Её суть проста: можно обучить легковесную модель, которая будет имитировать поведение модели-учителя или даже ансамбля моделей. В нашем случае учителем будет BERT, учеником — любая легкая модель.

Задача

Давайте разберём дистилляцию на примере бинарной классификации. Возьмём открытый датасет SST-2 из стандартного набора задач, на которых тестируют модели для NLP.

Этот датасет представляет собой набор обзоров фильмов с IMDb с разбивкой на эмоциональный окрас — позитивный или негативный. В качестве метрики на этом датасете используют accuracy.

Обучение BERT-based модели или «учителя»

Прежде всего необходимо обучить «большую» BERT-based модель, которая станет учителем. Самый простой способ это сделать — взять эмбеддинги из BERT и обучить классификатор поверх них, добавив один слой в сеть.

Благодаря библиотеке tranformers сделать это довольно легко, потому что там есть готовый класс модели BertForSequenceClassification. На мой взгляд, самое подробное и понятное руководство по обучению этой модели опубликовали Towards Data Science.

Давайте представим, что мы получили обученную модель BertForSequenceClassification. В нашем случае num_labels=2, так как у нас бинарная классификация. Эту модель мы будем использовать в качестве «учителя».

Обучение «ученика»

В качестве ученика можно взять любую архитектуру: нейронную сеть, линейную модель, дерево решений. Давайте для большей наглядности попробуем обучить BiLSTM. Для начала обучим BiLSTM без BERT.

Чтобы подавать на вход нейронной сети текст, нужно представить его в виде вектора. Один из самых простых способов — это сопоставить каждому слову его индекс в словаре. Словарь будет состоять из топ-n самых популярных слов в нашем датасете плюс два служебных слова: “pad” — «слово-пустышка», чтобы все последовательности были одной длины, и “unk” — для слов за пределами словаря. Построим словарь с помощью стандартного набора инструментов из torchtext. Для простоты я не стал использовать предобученные эмбеддинги слов.  

import torch from torchtext import data  def get_vocab(X):     X_split = [t.split() for t in X]     text_field = data.Field()     text_field.build_vocab(X_split, max_size=10000)     return text_field  def pad(seq, max_len):     if len(seq) < max_len:         seq = seq + ['<pad>'] * (max_len - len(seq))     return seq[0:max_len]  def to_indexes(vocab, words):     return [vocab.stoi[w] for w in words]  def to_dataset(x, y, y_real):     torch_x = torch.tensor(x, dtype=torch.long)     torch_y = torch.tensor(y, dtype=torch.float)     torch_real_y = torch.tensor(y_real, dtype=torch.long)     return TensorDataset(torch_x, torch_y, torch_real_y)

Модель BiLSTM

Код для модели будет выглядеть так:

import torch from torch import nn from torch.autograd import Variable  class SimpleLSTM(nn.Module):      def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, n_layers,                  bidirectional, dropout, batch_size, device=None):         super(SimpleLSTM, self).__init__()         self.batch_size = batch_size         self.hidden_dim = hidden_dim         self.n_layers = n_layers         self.embedding = nn.Embedding(input_dim, embedding_dim)          self.rnn = nn.LSTM(embedding_dim,                            hidden_dim,                            num_layers=n_layers,                            bidirectional=bidirectional,                            dropout=dropout)          self.fc = nn.Linear(hidden_dim * 2, output_dim)         self.dropout = nn.Dropout(dropout)         self.device = self.init_device(device)         self.hidden = self.init_hidden()      @staticmethod     def init_device(device):         if device is None:             return torch.device('cuda')         return device      def init_hidden(self):         return (Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device)),                 Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device)))      def forward(self, text, text_lengths=None):         self.hidden = self.init_hidden()         x = self.embedding(text)         x, self.hidden = self.rnn(x, self.hidden)         hidden, cell = self.hidden         hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))         x = self.fc(hidden)          return x

Обучение

Для этой модели размерность выходного вектора будет (batch_size, output_dim). При обучении будем использовать обычный logloss. В PyTorch есть класс BCEWithLogitsLoss, который комбинирует сигмоиду и кросс-энтропию. То, что надо.

def loss(self, output, bert_prob, real_label):     criterion = torch.nn.BCEWithLogitsLoss()     return criterion(output, real_label.float())

Код для одной эпохи обучения:

def get_optimizer(model):     optimizer = torch.optim.Adam(model.parameters())     scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.9)     return optimizer, scheduler  def epoch_train_func(model, dataset, loss_func, batch_size):     train_loss = 0     train_sampler = RandomSampler(dataset)     data_loader = DataLoader(dataset, sampler=train_sampler,                              batch_size=batch_size,                              drop_last=True)     model.train()     optimizer, scheduler = get_optimizer(model)     for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Train')):         text, bert_prob, real_label = to_device(text, bert_prob, real_label)         model.zero_grad()         output = model(text.t(), None).squeeze(1)         loss = loss_func(output, bert_prob, real_label)         loss.backward()         optimizer.step()         train_loss += loss.item()     scheduler.step()     return train_loss / len(data_loader)

Код для проверки после эпохи:

def epoch_evaluate_func(model, eval_dataset, loss_func, batch_size):     eval_sampler = SequentialSampler(eval_dataset)     data_loader = DataLoader(eval_dataset, sampler=eval_sampler,                              batch_size=batch_size,                              drop_last=True)      eval_loss = 0.0     model.eval()     for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Val')):         text, bert_prob, real_label = to_device(text, bert_prob, real_label)         output = model(text.t(), None).squeeze(1)         loss = loss_func(output, bert_prob, real_label)         eval_loss += loss.item()      return eval_loss / len(data_loader)

Если это всё собрать воедино, то получится такой код для обучения модели:

import os import torch from torch.utils.data import (TensorDataset, random_split,                               RandomSampler, DataLoader,                               SequentialSampler) from torchtext import data from tqdm import tqdm  def device():     return torch.device("cuda" if torch.cuda.is_available() else "cpu")  def to_device(text, bert_prob, real_label):     text = text.to(device())     bert_prob = bert_prob.to(device())     real_label = real_label.to(device())     return text, bert_prob, real_label  class LSTMBaseline(object):     vocab_name = 'text_vocab.pt'     weights_name = 'simple_lstm.pt'      def __init__(self, settings):         self.settings = settings         self.criterion = torch.nn.BCEWithLogitsLoss().to(device())      def loss(self, output, bert_prob, real_label):         return self.criterion(output, real_label.float())      def model(self, text_field):         model = SimpleLSTM(             input_dim=len(text_field.vocab),             embedding_dim=64,             hidden_dim=128,             output_dim=1,             n_layers=1,             bidirectional=True,             dropout=0.5,             batch_size=self.settings['train_batch_size'])         return model      def train(self, X, y, y_real, output_dir):         max_len = self.settings['max_seq_length']         text_field = get_vocab(X)          X_split = [t.split() for t in X]         X_pad = [pad(s, max_len) for s in tqdm(X_split, desc='pad')]         X_index = [to_indexes(text_field.vocab, s) for s in tqdm(X_pad, desc='to index')]          dataset = to_dataset(X_index, y, y_real)         val_len = int(len(dataset) * 0.1)         train_dataset, val_dataset = random_split(dataset, (len(dataset) - val_len, val_len))          model = self.model(text_field)         model.to(device())          self.full_train(model, train_dataset, val_dataset, output_dir)         torch.save(text_field, os.path.join(output_dir, self.vocab_name))      def full_train(self, model, train_dataset, val_dataset, output_dir):         train_settings = self.settings         num_train_epochs = train_settings['num_train_epochs']         best_eval_loss = 100000         for epoch in range(num_train_epochs):             train_loss = epoch_train_func(model, train_dataset, self.loss, self.settings['train_batch_size'])             eval_loss = epoch_evaluate_func(model, val_dataset, self.loss, self.settings['eval_batch_size'])              if eval_loss < best_eval_loss:                 best_eval_loss = eval_loss                 torch.save(model.state_dict(), os.path.join(output_dir, self.weights_name))

Дистилляция

Идея этого способа дистилляции взята из статьи исследователей из Университета Ватерлоо. Как я говорил выше, «ученик» должен научиться имитировать поведение «учителя». Что именно является поведением? В нашем случае это предсказания модели-учителя на обучающей выборке. Причём ключевая идея — использовать выход сети до применения функции активации. Предполагается, что так модель сможет лучше выучить внутреннее представление, чем в случае с финальными вероятностями.

В оригинальной статье предлагается в функцию потерь добавить слагаемое, которое будет отвечать за ошибку «подражания» — MSE между логитами моделей.

Для этих целей сделаем два небольших изменения: изменим количество выходов сети с 1 до 2 и поправим функцию потерь.

def loss(self, output, bert_prob, real_label):     a = 0.5     criterion_mse = torch.nn.MSELoss()     criterion_ce = torch.nn.CrossEntropyLoss()     return a*criterion_ce(output, real_label) + (1-a)*criterion_mse(output, bert_prob)

Можно переиспользовать весь код, который мы написали, переопределив только модель и loss:

 class LSTMDistilled(LSTMBaseline):     vocab_name = 'distil_text_vocab.pt'     weights_name = 'distil_lstm.pt'      def __init__(self, settings):         super(LSTMDistilled, self).__init__(settings)         self.criterion_mse = torch.nn.MSELoss()         self.criterion_ce = torch.nn.CrossEntropyLoss()         self.a = 0.5      def loss(self, output, bert_prob, real_label):         return self.a * self.criterion_ce(output, real_label) + (1 - self.a) * self.criterion_mse(output, bert_prob)      def model(self, text_field):         model = SimpleLSTM(             input_dim=len(text_field.vocab),             embedding_dim=64,             hidden_dim=128,             output_dim=2,             n_layers=1,             bidirectional=True,             dropout=0.5,             batch_size=self.settings['train_batch_size'])         return model

Вот и всё, теперь наша модель учится «подражать».

Сравнение моделей

В оригинальной статье наилучшие результаты классификации на SST-2 получаются при a=0, когда модель учится только подражать, не учитывая реальные лейблы. Accuracy всё ещё меньше, чем у BERT, но значительно лучше обычной BiLSTM.

Я старался повторить результаты из статьи, но в моих экспериментах лучший результат получался при a=0,5.

Так выглядят графики loss и accuracy при обучении LSTM обычным способом. Судя по поведению loss, модель быстро обучилась, а где-то после шестой эпохи пошло переобучение.

Графики при дистилляции:

Дистиллированная BiLSTM стабильно лучше обычной. Важно, что по архитектуре они абсолютно идентичны, разница только в способе обучения. Полный код обучения я выложил на ГитХаб.

Заключение

В этом руководстве я постарался объяснить базовую идею подхода дистилляции. Конкретная архитектура ученика будет зависеть от решаемой задачи. Но в целом этот подход применим в любой практической задаче. За счёт усложнения на этапе обучения модели, можно получить значительный прирост её качества, сохранив изначальную простоту архитектуры.


Источник: habr.com

Комментарии: