Если вы интересуетесь машинным обучением, то наверняка слышали про BERT и трансформеры.
BERT — это языковая модель от Google, показавшая state-of-the-art результаты с большим отрывом на целом ряде задач. BERT, и вообще трансформеры, стали совершенно новым шагом развития алгоритмов обработки естественного языка (NLP). Статью о них и «турнирную таблицу» по разным бенчмаркам можно найти на сайте Papers With Code.
С BERT есть одна проблема: её проблематично использовать в промышленных системах. BERT-base содержит 110М параметров, BERT-large — 340М. Из-за такого большого числа параметров эту модель сложно загружать на устройства с ограниченными ресурсами, например, мобильные телефоны. К тому же, большое время инференса делает эту модель непригодной там, где скорость ответа критична. Поэтому поиск путей ускорения BERT является очень горячей темой.
Нам в Авито часто приходится решать задачи текстовой классификации. Это типичная задача прикладного машинного обучения, которая хорошо изучена. Но всегда есть соблазн попробовать что-то новое. Эта статья родилась из попытки применить BERT в повседневных задачах машинного обучения. В ней я покажу, как можно значительно улучшить качество существующей модели с помощью BERT, не добавляя новых данных и не усложняя модель.
Knowledge distillation как метод ускорения нейронных сетей
Существует несколько способов ускорения/облегчения нейронных сетей. Самый подробный их обзор, который я встречал, опубликован в блоге Intento на Медиуме.
Способы можно грубо разделить на три группы:
- Изменение архитектуры сети.
- Сжатие модели (quantization, pruning).
- Knowledge distillation.
Если первые два способа сравнительно известны и понятны, то третий менее распространён. Впервые идею дистилляции предложил Рич Каруана в статье “Model Compression”. Её суть проста: можно обучить легковесную модель, которая будет имитировать поведение модели-учителя или даже ансамбля моделей. В нашем случае учителем будет BERT, учеником — любая легкая модель.
Задача
Давайте разберём дистилляцию на примере бинарной классификации. Возьмём открытый датасет SST-2 из стандартного набора задач, на которых тестируют модели для NLP.
Этот датасет представляет собой набор обзоров фильмов с IMDb с разбивкой на эмоциональный окрас — позитивный или негативный. В качестве метрики на этом датасете используют accuracy.
Обучение BERT-based модели или «учителя»
Прежде всего необходимо обучить «большую» BERT-based модель, которая станет учителем. Самый простой способ это сделать — взять эмбеддинги из BERT и обучить классификатор поверх них, добавив один слой в сеть.
Благодаря библиотеке tranformers сделать это довольно легко, потому что там есть готовый класс модели BertForSequenceClassification. На мой взгляд, самое подробное и понятное руководство по обучению этой модели опубликовали Towards Data Science.
Давайте представим, что мы получили обученную модель BertForSequenceClassification. В нашем случае num_labels=2, так как у нас бинарная классификация. Эту модель мы будем использовать в качестве «учителя».
Обучение «ученика»
В качестве ученика можно взять любую архитектуру: нейронную сеть, линейную модель, дерево решений. Давайте для большей наглядности попробуем обучить BiLSTM. Для начала обучим BiLSTM без BERT.
Чтобы подавать на вход нейронной сети текст, нужно представить его в виде вектора. Один из самых простых способов — это сопоставить каждому слову его индекс в словаре. Словарь будет состоять из топ-n самых популярных слов в нашем датасете плюс два служебных слова: “pad” — «слово-пустышка», чтобы все последовательности были одной длины, и “unk” — для слов за пределами словаря. Построим словарь с помощью стандартного набора инструментов из torchtext. Для простоты я не стал использовать предобученные эмбеддинги слов.
import torch from torchtext import data def get_vocab(X): X_split = [t.split() for t in X] text_field = data.Field() text_field.build_vocab(X_split, max_size=10000) return text_field def pad(seq, max_len): if len(seq) < max_len: seq = seq + ['<pad>'] * (max_len - len(seq)) return seq[0:max_len] def to_indexes(vocab, words): return [vocab.stoi[w] for w in words] def to_dataset(x, y, y_real): torch_x = torch.tensor(x, dtype=torch.long) torch_y = torch.tensor(y, dtype=torch.float) torch_real_y = torch.tensor(y_real, dtype=torch.long) return TensorDataset(torch_x, torch_y, torch_real_y)
Модель BiLSTM
Код для модели будет выглядеть так:
import torch from torch import nn from torch.autograd import Variable class SimpleLSTM(nn.Module): def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, batch_size, device=None): super(SimpleLSTM, self).__init__() self.batch_size = batch_size self.hidden_dim = hidden_dim self.n_layers = n_layers self.embedding = nn.Embedding(input_dim, embedding_dim) self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout) self.fc = nn.Linear(hidden_dim * 2, output_dim) self.dropout = nn.Dropout(dropout) self.device = self.init_device(device) self.hidden = self.init_hidden() @staticmethod def init_device(device): if device is None: return torch.device('cuda') return device def init_hidden(self): return (Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device)), Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device))) def forward(self, text, text_lengths=None): self.hidden = self.init_hidden() x = self.embedding(text) x, self.hidden = self.rnn(x, self.hidden) hidden, cell = self.hidden hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)) x = self.fc(hidden) return x
Обучение
Для этой модели размерность выходного вектора будет (batch_size, output_dim). При обучении будем использовать обычный logloss. В PyTorch есть класс BCEWithLogitsLoss, который комбинирует сигмоиду и кросс-энтропию. То, что надо.
def loss(self, output, bert_prob, real_label): criterion = torch.nn.BCEWithLogitsLoss() return criterion(output, real_label.float())
Код для одной эпохи обучения:
def get_optimizer(model): optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.9) return optimizer, scheduler def epoch_train_func(model, dataset, loss_func, batch_size): train_loss = 0 train_sampler = RandomSampler(dataset) data_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, drop_last=True) model.train() optimizer, scheduler = get_optimizer(model) for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Train')): text, bert_prob, real_label = to_device(text, bert_prob, real_label) model.zero_grad() output = model(text.t(), None).squeeze(1) loss = loss_func(output, bert_prob, real_label) loss.backward() optimizer.step() train_loss += loss.item() scheduler.step() return train_loss / len(data_loader)
Код для проверки после эпохи:
def epoch_evaluate_func(model, eval_dataset, loss_func, batch_size): eval_sampler = SequentialSampler(eval_dataset) data_loader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size, drop_last=True) eval_loss = 0.0 model.eval() for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Val')): text, bert_prob, real_label = to_device(text, bert_prob, real_label) output = model(text.t(), None).squeeze(1) loss = loss_func(output, bert_prob, real_label) eval_loss += loss.item() return eval_loss / len(data_loader)
Если это всё собрать воедино, то получится такой код для обучения модели:
import os import torch from torch.utils.data import (TensorDataset, random_split, RandomSampler, DataLoader, SequentialSampler) from torchtext import data from tqdm import tqdm def device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def to_device(text, bert_prob, real_label): text = text.to(device()) bert_prob = bert_prob.to(device()) real_label = real_label.to(device()) return text, bert_prob, real_label class LSTMBaseline(object): vocab_name = 'text_vocab.pt' weights_name = 'simple_lstm.pt' def __init__(self, settings): self.settings = settings self.criterion = torch.nn.BCEWithLogitsLoss().to(device()) def loss(self, output, bert_prob, real_label): return self.criterion(output, real_label.float()) def model(self, text_field): model = SimpleLSTM( input_dim=len(text_field.vocab), embedding_dim=64, hidden_dim=128, output_dim=1, n_layers=1, bidirectional=True, dropout=0.5, batch_size=self.settings['train_batch_size']) return model def train(self, X, y, y_real, output_dir): max_len = self.settings['max_seq_length'] text_field = get_vocab(X) X_split = [t.split() for t in X] X_pad = [pad(s, max_len) for s in tqdm(X_split, desc='pad')] X_index = [to_indexes(text_field.vocab, s) for s in tqdm(X_pad, desc='to index')] dataset = to_dataset(X_index, y, y_real) val_len = int(len(dataset) * 0.1) train_dataset, val_dataset = random_split(dataset, (len(dataset) - val_len, val_len)) model = self.model(text_field) model.to(device()) self.full_train(model, train_dataset, val_dataset, output_dir) torch.save(text_field, os.path.join(output_dir, self.vocab_name)) def full_train(self, model, train_dataset, val_dataset, output_dir): train_settings = self.settings num_train_epochs = train_settings['num_train_epochs'] best_eval_loss = 100000 for epoch in range(num_train_epochs): train_loss = epoch_train_func(model, train_dataset, self.loss, self.settings['train_batch_size']) eval_loss = epoch_evaluate_func(model, val_dataset, self.loss, self.settings['eval_batch_size']) if eval_loss < best_eval_loss: best_eval_loss = eval_loss torch.save(model.state_dict(), os.path.join(output_dir, self.weights_name))
Дистилляция
Идея этого способа дистилляции взята из статьи исследователей из Университета Ватерлоо. Как я говорил выше, «ученик» должен научиться имитировать поведение «учителя». Что именно является поведением? В нашем случае это предсказания модели-учителя на обучающей выборке. Причём ключевая идея — использовать выход сети до применения функции активации. Предполагается, что так модель сможет лучше выучить внутреннее представление, чем в случае с финальными вероятностями.
В оригинальной статье предлагается в функцию потерь добавить слагаемое, которое будет отвечать за ошибку «подражания» — MSE между логитами моделей.
Для этих целей сделаем два небольших изменения: изменим количество выходов сети с 1 до 2 и поправим функцию потерь.
def loss(self, output, bert_prob, real_label): a = 0.5 criterion_mse = torch.nn.MSELoss() criterion_ce = torch.nn.CrossEntropyLoss() return a*criterion_ce(output, real_label) + (1-a)*criterion_mse(output, bert_prob)
Можно переиспользовать весь код, который мы написали, переопределив только модель и loss:
class LSTMDistilled(LSTMBaseline): vocab_name = 'distil_text_vocab.pt' weights_name = 'distil_lstm.pt' def __init__(self, settings): super(LSTMDistilled, self).__init__(settings) self.criterion_mse = torch.nn.MSELoss() self.criterion_ce = torch.nn.CrossEntropyLoss() self.a = 0.5 def loss(self, output, bert_prob, real_label): return self.a * self.criterion_ce(output, real_label) + (1 - self.a) * self.criterion_mse(output, bert_prob) def model(self, text_field): model = SimpleLSTM( input_dim=len(text_field.vocab), embedding_dim=64, hidden_dim=128, output_dim=2, n_layers=1, bidirectional=True, dropout=0.5, batch_size=self.settings['train_batch_size']) return model
Вот и всё, теперь наша модель учится «подражать».
Сравнение моделей
В оригинальной статье наилучшие результаты классификации на SST-2 получаются при a=0, когда модель учится только подражать, не учитывая реальные лейблы. Accuracy всё ещё меньше, чем у BERT, но значительно лучше обычной BiLSTM.
Я старался повторить результаты из статьи, но в моих экспериментах лучший результат получался при a=0,5.
Так выглядят графики loss и accuracy при обучении LSTM обычным способом. Судя по поведению loss, модель быстро обучилась, а где-то после шестой эпохи пошло переобучение.
Графики при дистилляции:
Дистиллированная BiLSTM стабильно лучше обычной. Важно, что по архитектуре они абсолютно идентичны, разница только в способе обучения. Полный код обучения я выложил на ГитХаб.
Заключение
В этом руководстве я постарался объяснить базовую идею подхода дистилляции. Конкретная архитектура ученика будет зависеть от решаемой задачи. Но в целом этот подход применим в любой практической задаче. За счёт усложнения на этапе обучения модели, можно получить значительный прирост её качества, сохранив изначальную простоту архитектуры.