SHORTCUT MODELS: метод обучение диффузионных моделей генерации в 1 шаг

МЕНЮ


Главная страница
Поиск
Регистрация на сайте
Помощь проекту
Архив новостей

ТЕМЫ


Новости ИИРазработка ИИВнедрение ИИРабота разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика

Авторизация



RSS


RSS новости


Shortcut models (https://arxiv.org/pdf/2410.12557) - метод обучения диффузионных моделей, который позволяет генерировать изображения высокого качества за один или несколько шагов.

В основе shortcut models - идея обучать сеть с учетом не только текущего уровня шума, но и желаемого размера шага. Это позволяет модели "перепрыгивать" через этапы генерации.

Ключевым преимуществом данного подхода является его простота: shortcut models обучаются за один этап, используя одну сеть, в отличие от других методов ускорения выборки, которые полагаются на сложные схемы обучения с несколькими фазами, сетями или точной настройкой шедулера.

В процессе обучения shortcut models используются два типа целей loss function:

flow-matching при малом размере шага (d ? 0), аналогично стандартным диффузионным моделям.

self-consistency при больших размерах шага (d > 0), где цель формируется путем конкатенации последовательности из двух шагов размером d/2.

Совместная оптимизация этих целей дает возможность модели научиться создавать изображения, сохраняя согласованность при любом размере шага, включая генерацию за один шаг.

Метод применим к flow-matching и transformer-based типам моделей и RNN/LSTM-сетям.

Эксперименты, проведенные с DiT на наборах данных CelebA-HQ и ImageNet-256, подтверждают эффективность метода.

Shortcut models превосходят методы "end-to-end" обучения одношаговых генеративных моделей и конкурируют с двухэтапными методами дистилляции.

Практическая реализация shortcut models (https://github.com/kvfrans/shortcut-models) написана на JAX. Для локального запуска следует установить зависимости conda из файлов environment.yml и requirements.txt репозитория.

Код поддерживает --model.sharding fsdp для полностью сегментированного параллелизма данных, если обучение проводится на multi-GPU или TPU.

Чекпоинты и FID для тестовых датасетов CelebA и Imagenet доступны на Google-диске (https://drive.google.com/drive/folders/1g665i0vMxm8qqqcp5mAiexnL919-gMwW?usp=sharing).

Пример запуска обучения на DiT-B с датасетом CelebA :

python train.py —model.hidden_size 768 —model.patch_size 2 —model.depth 12 —model.num_heads 12 —model.mlp_ratio 4   

--dataset_name celebahq256 —fid_stats data/celeba256_fidstats_ours.npz —model.cfg_scale 0 —model.class_dropout_prob 1 —model.num_classes 1 —batch_size 64 —max_steps 410_000 —model.train_type shortcut

Страница проекта (https://kvfrans.com/shortcut-models/)

Arxiv (https://arxiv.org/pdf/2410.12557)

GitHub (https://github.com/kvfrans/shortcut-models)


Источник: github.com

Комментарии: