Тренируем генеративно-состязательную нейросеть раскрашивать эскизы персонажей аниме. Пошагово объясняем алгоритм Sketch2Color, пишем код на Python и реализуем проект внутри фреймворка глубокого обучения TensorFlow.
Ранее с помощью нейросетей мы генерировали лица персонажей манги и аниме. Теперь научим генеративно-состязательную сеть раскрашивать черно-белые контурные наброски.
Введение
Генеративно-состязательные сети (GAN) представляют собой результат переноса идей парадигмы генеративного моделирования на методы глубокого обучения.
Генеративное моделирование представляет пример задачи машинного обучения «без учителя», Изучение шаблонов во входных данных происходит таким образом, что модель может создавать новые примеры, схожие по характеристикам с экземплярами оригинального датасета.
Архитектура модели GAN включает две подмодели:
Генератор, создающий новые примеры.
Дискриминатор для классификации того, являются ли сгенерированные примеры «реальными» или поддельными, созданными генератором (фактически задача состоит в том, чтобы обучить генератор создавать примеры, которые дискриминатор не сможет отличить от оригинальных).
В терминах теории игр идея генеративно-состязательной нейросети базируется на идее игры с нулевой суммой: если один выигрывает, соперник проигрывает. В теории игр модель GAN сходится, когда дискриминатор и генератор достигают равновесия Нэша. Это оптимальная точка для следующего минимаксного уравнения:
Генеративно-состязательные сети используются в задачах перевода изображения в изображение (летние фотографии в зимние, день в ночь) и для создания неотличимых от реальности изображений никогда не существовавших объектов, сцен и людей. В этой статье мы тоже преобразуем одно изображение, черно-белое, в цветное, как если бы набросок раскрашивал художник.
1. Получение и предварительная обработка данных
Набор данных для раскрашивания аниме-эскизов, используемый для обучения генеративно-состязательной нейросети, можно загрузить с Kaggle. После загрузки и распаковки набора данных его нужно предварительно обработать, так как эскиз и цветное изображение были на одном изображении.
После сохранения эскизов и цветных изображений в отдельных каталогах мы нормализуем их так, чтобы все значения, что находились в диапазоне [0, 255] перешли в диапазон [- 1, 1].
Нормализация цветовой шкалы необходима для эффективной работы функции активации выходного слоя. Пока лучшей функцией активации для GAN считается гиперболический тангенс.
2. Архитектура генератора
Архитектура генератора, который используется для раскраски эскиза представляет вариант архитектуры U-Net – полносвязной сверточной нейросети, разработанной в 2015 году для сегментации биомедицинских изображений.
Вместо использования полносвязных слоев в кодирующих-декодирующих блоках здесь чтобы не потерять информации используется свертка и деконволюция. В сравнении с другими задачами трансформации изображения Sketch2Color критически важно сохранить информацию о ребрах, образующих набросок. U-net архитектура используется для объединения слоев и декодера.
Как показано на изображении выше, на каждом слое декодирования (синие блоки) соответствующие слои кодера (желтые блоки) объединяются с текущим слоем для декодирования следующего слоя.
3. Архитектура дискриминатора
В отличие от генератора дискриминатор имеет только блоки кодера. Дискриминатор предназначен для классификации того, является ли входная пара «эскиз-цветное изображение» «реальными» или «поддельными». То есть получено ли цветное изображение из фактических данных или от генератора.
На вход дискриминатора поступает либо пара эскиза (на рисунке выше обозначен желтым цветом) и реального целевого изображения (красный), либо пара эскиза (желтый) и сгенерированного изображение (синий). Сеть дискриминатора обучена максимизировать точность классификации.
Выходные данные дискриминатора представляют собой матрицу вероятностей формы 30x30x1. В этой матрице каждый элемент соответствует вероятности того, что пара изображений это реальная пара эскиза и цветного аниме-изображения.
Здесь мы также избегаем использования полносвязных слоев здесь, чтобы избежать потери информации. Для получения единственного значения используем агрегацию посредством пулинга по среднему значению (global average pooling). Сверточные слои между входом и выходом извлекают высокоуровневые характеристики.
4. Функции потерь генератора и дискриминатора
Создать цветное изображение из черно-белого эскиза сложнее, чем из изображения в градациях серого, ведь такой рисунок содержит меньше полезной информации. Для создания дополнительных ограничений мы будем использовать условные порождающие состязательные сети (conditional GAN). Функция потерь для условных GAN в общем случае выглядит следующим образом:
В этом выражении x – входной эскиз, y – цель (цветное мультипликационное изображение), G(x, z) – сгенерированное цветное изображение.
Условные GAN изучают отображение из вектора случайных чисел z в выходное изображение y по условиям, которые заданы эскизами x. В то время как генератор пытается минимизировать потери, дискриминатор пытается их максимизировать. Действуя совместно, они достигают равновесия.
Генератор минимизирует потери во время обучения, так что получаются правдоподобные цветные изображения. Вот его функция потерь:
Одновременное обучение дискриминатора стимулирует вариативность в генерации цветных изображений. Однако получить реалистичные картинки удается, если сочетать потери GAN с традиционными функциями потерь.
Первая из таких функций потерь – PixelLevel, т.е. L1-расстояние между каждым пикселем целевого цветного изображения и сгенерированного аналога:
Вторая функция потерь FeatureLevel – L2-расстояние между активацией ?j 4-го слоя 16-слойной VGG-сети, предварительно обученной на наборе данных ImageNet. Предобучение используется для сохранения высокоуровневых функций, таких как цвет и форма объектов. Соответствующая функция потерь:
Последняя из используемых функций потерь – TotalVariation. Эта функция необходима для того, чтобы GAN использовала палитры цветов, аналогичные данным обучающей выборки. Это действует как форма регуляризации и способствует плавности контуров изображений.
Функция потерь GAN представляет собой взвешенную комбинацию всех вышеуказанных потерь:
Веса Wp, Wf, Wg и Wtv учитывают важность каждого из видов потерь. Минимизируя функцию потерь L, GAN находит лучшие образцы пар эскиза и цветного изображения.
5. Обучение генератора и дискриминатора
Генеративно-состязательная сеть была обучена за 43 эпохи с размером батча 8. В процессе обучения использовалось сглаживание меток, т.е. более «мягкие» метки (0,9 вместо 1).
Дискриминатор обучался меньше в случае четных батчей, так как когда генератор узнает больше о реальных и сгенерированных цветных изображениях, он начинает доминировать над генератором. Тогда генератор становится слабым и не обучается, лишь фиксируя распределение реальных целевых изображений.
Для обучения использовался оптимизатор Adam с learning rate = 0.0002 и beta_1 = 0.5.
Дискриминатор и генератор обучались поочередно в цикле: сначала дискриминатор, затем генератор, затем снова дискриминатор и т. д.
6. Работа с TensorBoard
В каждом батче после обучения рассчитывались потери дискриминатора (или генератора) на парах изображений. Средние потери дискриминатора и генератора регистрировались после каждой эпохи, чтобы отслеживать их прогресс с помощью обратных вызовов TensorBoard.
В соответствии с ожиданиями потери дискриминатора колеблются между 0.23 и 0.35, а потери генератора стабильно уменьшаются. То есть генератор фиксирует распределение реальных целевых изображений.
7. Результаты обучения
После каждой эпохи генератор проверялся на прогнозирование цветов для фиксированных эскизов, чтобы проверить, узнает ли он корреляции между эскизом и связанным с ним цветным изображением.
Ниже представлены примеры результатов для 30-й, 40-й и 43-й эпох обучения.
По мере обучения цвета становятся более правдоподобными, не выходят за границы и распределяются немонотонно.
8. Вывод результатов
Наконец, обучение завершено. Сделаем вывод генератора для эскизов образцов.
Генератор выдает разумные цвета для заданных простых набросков. Генеративно-состязательная сеть обучилась на 13 тысячах пар эскизов и цветных изображений. Результат можно улучшить, собирая и добавляя дополнительные пары изображений в обучающие данные.