Назад
225

Дистилляция диффузии. Часть 2

225

Введение

В этом обзоре мы изучим способы дистилляции диффузионных моделей в один или несколько шагов. Перед этим советуем вам изучить наш предыдущий обзор по дистилляции диффузии, чтобы лучше понимать контекст 🙂

Диффузионные модели сейчас — одни из самых распространённых генеративных моделей, моделирующих изображения, видно, звук и даже 3D-модели.

Однако важным их недостатком является скорость: необходимо прогонять большую тяжеловесную модель много раз (за несколько шагов) для генерации одного сэмпла.

Чтобы нивелировать данную проблему, обычно предлагают использовать дистилляцию диффузионной модели «учителя» в «ученика» — одношаговый (или в несколько шагов) генератор, архитектура которого часто совпадает с архитектурой «учителя».

Для лучшего понимания материала рекомендуем прочитать наши посты:

Adversarial Diffusion Distillation

Итак, давайте перейдём к более продвинутым работам для дистилляции моделей Stability AI.

В статье нам уже встречались картинки, сгенерированные с помощью Stable Diffusion XL Turbo. Но что это за модель? Как авторы ускорили SDXL для получения SDXL Turbo? Разберёмся со всем по порядку 🙂

В техническом отчёте на модель авторы говорят об использовании Adversarial Diffusion Distillation (ADD). Поэтому в данном разделе мы наченём обсуждение именно с неё.

Описание метода

Adversarial Diffusion Distillation (ADD) — дистилляция, где учителем выступает обычная диффузионная модель. Однако авторы предлагают использовать дискриминатор, который умеет отличать реальные картинки от сгенерированных нашим учеником.

Рисунок 1. Пайплайн adversarial distillation [источник]

Модели

Обучение включает использование трёх моделей на реальных изображениях:

  • ADD-student — обучаемая модель, которая инициализируется из предобученной диффузионной модели UNet-DM;
  • обучаемый дискриминатор из предобученной нейросети для извлечения фичей \( F \) (авторы предлагают, например, ViT) и «голов» \( \mathcal{D}{\phi, k} \), для улучшения работы мы можем кондишнить его на различные фичи (текст, эмбеддинги, картинки). Отметим, что “голова” дискриминатора \( \mathcal{D}{\phi, k} \) применяется к соответствующему выходу \( F_k \). Дополнительной особенностью такого дискриминатора является то, что он может в качестве кондишининга брать исходное изображение, например, в случае, если было шагов зашумления <1000. Для этого авторы используют довольнительную нейронную сеть для извлечения эмбеддингов \( c_{img} \);
  • diffusion model (DM) — учитель, который не обучается.

При обучении студент генерирует сэмплы из зашумленных сэмплов реального датасета \( \hat{x}_{\theta}(x_s, s) \), где \( x_s = \alpha_s x_0 + \sigma_s \epsilon \), \( x_0 \) — его элемент.

Функции потерь

Есть две важные идеи обучения:

  • дистилляция и соответствующая ей функция потерь;
  • использование дискриминатора и adversarial loss.

Обсудим каждый из пунктов подробнее.

Adversarial Loss

Идея похожа на GANs: мы хотим добавить дискриминатор, который будет различать сгенерированное ( \( \hat{x}_{\theta} \) ) и реальное ( \( x_0 \) ) изображения.

Функция потерь генератора выглядит так:

\( \mathcal{L}{adv}^G(\hat{x}{\theta}(x_s, s), \phi) = -\mathbb{E}{s, \epsilon, x_0} \Big[\sum_k \mathcal{D}{\phi, k}\big(F_k(\hat{x}_{\theta}(x_s,s))\big)\Big] \)

Дискриминатор же обучается минимизировать:

\( \mathcal{L}{adv}^D(\hat{x}{\theta}(x_s, s), \phi) = -\mathbb{E}{x_0}\Big[\sum_k \text{max}(0, 1 — \mathcal{D}{\phi, k}(F_k(x_0))) + \gamma R1(\phi) \Big] + \mathbb{E}{\hat{x}{\theta}}\Big[\sum_k \text{max}(0, 1+ \mathcal{D}{\phi, k}(F_k(\hat{x}{\theta})))\Big] \)

Где R1 — gradient penalty.

Score Distillation Loss

Обозначим предсказание ADD-student как \( \hat{x}{\theta}(x_s,s) \). Добавим к результату немного шума и прогоним это через учителя, получим предсказание \( \hat{x}{\psi}(\hat{x}_{\theta, t}, t) \). Важная особенность здесь — добавление шума, поскольку сэмпл нашего студента оказывается вне распределения диффузионной модели, а прямое применение не приводит ни к чему хорошему 😟.

Тогда функция потерь должна минимизировать разницу между предсказанием ученика и учителя: \( \mathcal{L}{distill}(\hat{x}{\theta}(x_s, s), \psi) = \mathbb{E}{t, \epsilon’}\Big[c(t)d(\hat{x}{\theta}, \hat{x}{\psi}(sg(\hat{x}{\theta, t});t))\Big] \)

Где sg — оператор стоп-градиента.

В качестве distance-функции авторы используют \( d(x,y) — \|x — y\|^2_2 \), а в качестве \( c(t) = \alpha_t \) (выше шум — меньше вклад) или SDS weighting.

Важный факт: SDS функция потерь — одна из основных функций потерь для генерации в 3D по текстовому промпту или картинке с помощью генеративных праеров от 2D-диффузии. Авторы показывают, что предлагаемая ими функция потерь при определённом подборе \( c(t) \)вырождается в SDS loss. Подробнее об этом можно почитать в приложении статьи.

Результаты в картинках

Давайте посмотрим, удалось ли достичь высокого качества генерации при таком виде дистилляции.

Изображения ниже сгенерированы при разрешении 512×512 за один шаг.

Рисунок 2. Генерация изображений размером 512×512 с помощью ADD-XL модели [источник]

Авторы сравнивают свои результаты с результатами следующих работ:

На картинках мы видим, что предложенная модель значительно превосходит GANs и InstaFlow.

Рисунок 3. Сравнение качества генерации различных text2image моделей [источник]

Они показывают, что генерация в 1 шаг уже даёт неплохой результат, но можно увеличить их количество и получить качество ещё лучше.

Рисунок 4. Сравнение сгенерированных в 1, 2 и 4 шага изображений по заданному текстовому промпту. Сиды по столбцам фиксированы [источник]

Результаты в метриках

Помимо основных метрик (FID или CLIP-score) авторы анализировали выбор опрашиваемых людей. Они проверяли качество изображения, а также его соответствие текстовому промпту.

Для проверки промпта задавали следующий вопрос:

«Which image looks more representative of the text shown above and faithfully follows it?» («Какое изображение лучше отображает представленный текст?»)

Для проверки изображения:

«Which image is of higher quality and aesthetically more pleasing?» («Какое изображение лучшего качества и эстетически более приятное?»)

Рисунок 5. User study для разных моделей по качеству изображений и соответствию промпту [источник]
Рисунок 6. User study для разных моделей по качеству изображений и соответствию промпту [источник]

Также авторы посчитали метрики FID и CLIP-score. На таблице ниже можно заметить, как Progressive Distillation и InstaFlow проигрывают данному методу.

Рисунок 7. Сравнение метрик FID и CLIP-score [источник]

Кроме того, они сравнили скорость генерации и ELO Scores. Мы не будем сейчас подробно останавливаться на этой метрике, вы можете подробнее почитать о ней в оригинальной статье. Однако стоит отметить: чем больше её показатель, тем лучше.

Рисунок 8. Сравнение ELO Scores в зависимости от скорости генерации [источник]

Вывод

Таким образом, авторы предложили совместить две идеи: дистилляцию и использование adversarial функции потерь.

У такого подхода может возникнуть проблема, схожая с GANs, — недостаточно разнообразные сэмплы за счёт adversarial loss. Есть вопросы и к SDS функции потерь, поскольку обычно она требует высокого значения guidance strength. Это, в свою очередь, уменьшает разнообразие генерируемых сэмплов.

Однако в целом результаты хорошие и быстрые, поиграться с моделью можно тут.

Fast High-Resolution Image Synthesis with Latent Adversarial Diffusion Distillation

В этом году Stability AI представили миру Stable Diffusion 3, которая заметно улучшила качество генерации. И почти сразу появилась работа о способе дистилляции, помогающем уменьшить количество шагов. Идея базируется на статье, рассмотренной выше, однако имеет ряд значительных изменений.

Описание метода

Давайте разберёмся, чем же отличается Latent Adversarial Diffusion Distillation (LADD) от Adversarial Diffusion Distillation (ADD).

Рисунок 9. Сравнение пайплайнов ADD и LADD [источник]

ADD: пусть у нас есть картинка из подготовленного датасета, которую мы зашумляем и генерируем с помощью «‎ученика» и изображения «учителя». Далее применяем дистилляционный лосс и adversarial loss.

Основными отличиями LADD являются:

  • использование синтетических данных для обучения;
  • объединение дискриминатора и «учителя».

Unifying teacher and discriminator

Во-первых, дискриминатор работает в латентном пространстве, а не в пиксельном. Для этого мы снова зашумляем латентные вектора до уровня \( \hat{t} \), засэмплированного из логитнормального распределения (авторы SD3 сравнивают различные распределения t, см. рисунок ниже). Затем используем модель «учителя» для расшумления и получаем последовательность токенов после каждого attention-блока. Далее ко всем последовательностям применяем независимую «голову»‎ дискриминатора. Дополнительно каждый дискриминатор («голова» дискриминатора) кондишнится на уровень зашумления и pooled CLIP embeddings.

Рисунок 10. Сравнение качества генерации при разных распределениях сэмплирования по времени для «учителя» [источник]

Важная особенность здесь — в качестве feature extruction net для получения adversarial loss является генеративная модель «учителя». Авторы отдельно отмечают использование не дискриминативного дискриминатора, а генеративного.

Преимущества подхода:

  • эффективность: без работы в пиксельном пространстве мы уменьшаем потребление ресурсов, а ещё используем фичи с разных слоёв сетки;
  • учёт уровня шума: с помощью такого дискриминатора учитываем варьирование фичей на разных уровнях зашумления;
  • Multi-Aspect Ratio (MAR): работаем не только с квадратными картинками;
  • Alignment with Human Perception: texture bias дискриминативных моделей — приоритет текстуры над формой (а люди обычно обращают внимание на форму), поэтому в работе отмечается, что генеративные модели ближе к человеческому восприятию, значит, используем их в качестве дискриминатора.

Использование синтетических данных

В отличие от ADD, где дистилляция проводится на данных из датасета, авторы предлагают работать с синтетическими данными, сгенерированными «учителем».

Text-alignment значительно варьируется в зависимости от датасетов, которые могут отличаться по метрике CLIP-score. Например, в датасете COCO CLIP она достигает всего 0.29, однако современные диффузионные модели добиваются более высокого качества (Stable Diffusion 3 генерирует картинки по промтам из COCO c 0.35). Поэтому авторы предлагают не работать с датасетами, а генерировать изображения через «учителя».

На рисунке ниже представлено влияние синтетических данных на соответствие тексту:

Рисунок 11. Влияние синтетических данных на image-text alignment [источник]

Direct preference optimization

Чтобы лучше соответствовать ощущениям людей, авторы дополнительно файнтьюнят модель с помощью Diffusion DPO — специальный метод оптимизации, который подстраивает генерацию к человеческому восприятию. Для этого используется Low-Rank Adaptation (LoRA).

Рисунок 12. Результат генерации при добавлении DPO-LoRA к обучаемому студенту. Видно, как появляется больше деталей, улучшаются руки и фиксируются дублированные объекты, например, колёса машин [источник]

Результаты в картинках

Итак, что мы имеем: решение за 1 шаг значительно лучше Latent Consistency Models даже за 4 шага.

Рисунок 13. Сравнение LCM и LADD [источник]

Качество генерации за 1 и 4 шага для обученной модели, по моему мнению, примерно одинаковое:

Рисунок 14. Сравнение качества генерации за 1 и 4 шага [источник]

Авторы также проверили работу модели на разных задачах, в частности на Image editing.

Стоит отметить: «‎ученик» SD3-Turbo несильно отстаёт от «учителя» SD3 и значительно превосходит другие модели, предложенные для editing: InstructPix2Pix, MagicBrush, Hive.

Рисунок 15. Сравнение Image Editing [источник]

Авторы проверили работу модели на задачах inpainting, сравнив её с LaMa, SD1.5-inpainting, SD3-inpainting.

Рисунок 16. Сравнение Image Editing [источник]

Результаты в метриках

Исследователи сравнили результаты с различными бейзлайнами при постановке User Preference studyприложении они привели текстовые промпты для экспериментов!). Как можно заметить, LADD везде выигрывают (но проигрывают себе же там, где большее количество шагов генерации).

Рисунок 17. User Preference study [источник]

Авторы также сравнили генерацию в 4 шага с генерацией различных моделей в несколько шагов:

Рисунок 18. User Preference study [источник]

Вывод

Таким образом, при ускорении у нас остаётся лосс от дискриминатора и убирается любой другой дистилляционный лосс. При этом результаты выходят достаточно качественными. Очень ждём релизнутую модельку 🙂

One-step Diffusion with Distribution Matching Distillation

Статья 2023 года предлагает другой подход — Distribution Matching Distillation (DMD).

Distribution Matching Distillation (DMD) — метод, который позволяет получить генератор изображения в 1 шаг из предобученной диффузионной модели.

Основная идея подхода: минимизация KL-дивергенции междура спределением генератора и реальным распределением.

Обратите внимание: KL-дивергенция не симметрична. Ранее мы говорили про минимизацию KL(teacher|student), однако в этой работе предлагается минимизировать именно KL(student|teacher).

Её градиент выражается разностью двух score-функций: одной из целевого (реального) распределения, другой — из синтетического, которое получается из одношагового генератора. Давайте разбираться с этим подробнее 🙂

Описание метода

Итак, основная задача — дистиллировать предобученную модель \( \mu_{base} \) в одношаговый генератор \( G_{\theta} \).

Авторы предлагают инициализировать его с помощью той же архитектуры, как у базовой модели — с теми же весами, но без кондишнинга на шаг по времени. То есть мы хотим генерировать сэмплы из того же распределения, но не обязательно точно такие же.

Как и в GANs, имеем здесь fake samples — сэмплы от генератора \( G_{\theta} \), а также real samples — сэмплы из обучающей выборки, сгенерированные предобученной \( \mu_{base} \).

Рисунок 19. Основной пайплайн DMD [источник]

Оптимизация проводится за счёт минимизации двух функций потерь:

  • distribution matching loss;
  • regression loss.

Distribution Matching Loss

Для генерации изображений, похожих на оригинальные, мы можем минимизировтаь KL-дивергенцию между распределением обучающей выборки и распределением сгенерированных изображений:

\( D_{KL}(p_{fake} \| p_{real}) = \mathbb{E}{x \sim p{fake}} \Big( \log \frac{p_{fake}(x)}{p_{real}(x)}\Big) = \mathbb{E}{z \sim \mathcal{N}(0;\mathbf{I}), x=G{\theta}(z)} (\log p_{fake}(x) — \log p_{real}(x)) \)

Мы не посчитаем плотность вероятности для оценки функции потерь, но нам это и не нужно — важно определить её градиент:

\( \nabla_{\theta}D_{KL} = \mathbb{E}{z \sim \mathcal{N}(0;\mathbf{I}), x=G{\theta}(z)} \Big[-\big(s_{real}(x) — s_{fake}(x)\big)\nabla_{\theta}G_{\theta}(z)\Big] \)

где \( s_{real} = \nabla_x \log p_{real}(x) \), \( s_{fake} = \nabla_x \log p_{fake}(x) \)

\( s_{real} \) сдвигает в сторону мод реального распределения, а \( s_{fake} \) — от них.

Рисунок 20. Оптимизация только real score (a), real + fake scores (b), real + fake scores c регрессионным лоссом (с) [источник]

Недостатки оценки score-функции:

  • \( p_{real} \) имеет слишком маленькие значения в окрестности сгенерированных (фейковых) распределений, что даёт очень затухающий градиент;
  • для оценки используется диффузионная модель, которая не работает с незашумлёнными распределениями.

Для решения этих проблем предлагается добавить гауссовский шум в реальное и фейковое распределения, чтобы получить размытые и пересекающиеся распределения:

Рисунок 21. Без дополнительного шума реальное и фейковое распределения не пересекаются, поэтому для реальных данных валиден градиент real scores, а для фейковых — fake scores (a); с добавлением шума градиент правильно определён во всех случаях (b) [источник]

Для моделирования score-функции реального распределения берётся предобученная диффузионная модель \( \mu_{base} \):

\( s_{real}(x_t,t) = -\frac{x_t — \alpha_t\mu_{base}(x_t, t)}{\sigma_t^2} \)

Аналогично определяется \( s_{fake}(x_t, t) = — \frac{x_t — \alpha_t \mu_{fake}^{\phi}(x_t, t)}{\sigma_t^2} \)

Так как в процессе генерации у нас меняется распределение сгенерированных изображений мы добавляем \( \mu_{fake}^{\phi} \) — фейковую диффузионную модель для трекинга этих изменений.

Эта диффузионка инициализируется из предобученной диффузионной модели \( \mu_{base} \) и обновляется в процессе обучения через минимизацию функции потерь: \( \mathcal{L}{\text{denoise}}^{\phi} = \|\mu{fake}^{\phi}(x_t, t) — x_0 \|^2_2 \), где \( x_0 \) — изображение, сгенерированное нашим одношаговым генератором (см. подробнее рис. 22-24 с псевдокодом алгоритма).

Distribution matching gradient

Итоговый градиент KL-дивергенции со score-функцией выглядит так:

\( \nabla_{\theta}D_{KL} \approx \mathbb{E}{z,t,x,x_t} \Big[w_t \big(s{fake}(x_t, t) — s_{real}(x_t,t)\big)\frac{d x_t}{d \theta}\Big] = \mathbb{E}{z,t,x,x_t} \Big[w_t \big(s{fake}(x_t, t) — s_{real}(x_t,t)\big)\frac{d x_t}{dG_{\theta}(z)}\frac{dG_{\theta}(z)}{d \theta}\Big] = \\ \mathbb{E}{z,t,x,x_t} \Big[w_t \big(s{fake}(x_t, t) — s_{real}(x_t,t)\big)\frac{d x_t}{dx}\frac{dG_{\theta}(z)}{d \theta}\Big] = \mathbb{E}{z,t,x,x_t} \Big[w_t \big(s{fake}(x_t, t) — s_{real}(x_t,t)\big)\alpha_t\nabla_{\theta}G_{\theta}\Big] \)

где \( z \sim \mathcal{N}(0;\mathbf{I}), x = G_{\theta}(z), t \sim \mathcal{U}(T_{min}, T_{max}) \)

Авторы акцентируют внимание и на выбор \( w_t \):

\( w_t = \frac{\sigma_t^2}{\alpha_t}\frac{CS}{\|\mu_{base}(x_t,t) — x\|_1} \), где S — пространственная размерность, C — количество каналов.

Regression loss

Авторы утверждают, что Distribution matching loss хорошо определяется в области большого зашумления, но при малом добавлении шума возникает проблема с валидностью градиентов и неустойчивостью обучения.

Для её решения предлагается использовать дополнительный регрессионный лосс.

Комментарий: на самом деле регрессионный лосс — это по сути дистилляционный лосс, какой использовался в первых работах про дистилляцию.

Чтобы построить такую функцию потерь, авторы создают парный датасет \( D = \{z,y\} \) , где \( z \) — шум, \( y \) — изображения, сгенерированные с помощью предобученной диффузионной модели и сэмплера DDIM. Регрессионная функция потерь сравнивает результат генерации нашего генератора в 1 шаг и картинки, сгенерированные «учителем»:

\( \mathcal{L}{reg} = \mathbb{E}{(z,y) \sim D}l(G_{\theta}(z),y) \)

В качестве функции \( l \) применяется LPIPS loss.

Итоговая функция потерь для оптимизации выглядит следующим образом:

\( \mathcal{L}{final} = D{KL} + \lambda_{reg} \mathcal{L}_{reg} \)

Для подсчёта градиента KL-дивергенции используется данная формула:

\( \nabla_{\theta}D_{KL} \approx \mathbb{E}{z,t,x,x_t} \Big[w_t \big(s{fake}(x_t, t) — s_{real}(x_t,t)\big)\alpha_t\nabla_{\theta}G_{\theta}\Big] \)

Градиент регрессионного лосса вычисляется путём автоматического дифференцирования в Pytorch.

Таким образом, финальный алгоритм обучения можно записать следующим образом:

Рисунок 22. Алгоритм обучения [источник]
Рисунок 23. Алгоритм подсчёта distribution marching loss [источник]
Рисунок 24. Алгоритм подсчёта denoising loss [источник]

Результаты в картинках

Авторы провели несколько экспериментов с различными моделями. Рассмотрим самые интересные для нас результаты — результаты со Stable Diffusion.

Ниже представлено сравнение качества генерации Consistency модели, InstaFlow и DMD для генерации за меньшее число шагов. Стоит отметить: предложенная модель не уступает по качеству обычной диффузионной модели, которая генерирует изображения за 50 шагов, но при этом значительно выигрывает по скорости.

Рисунок 25. Сравнение работы DMD и других методов ускорения за 1 и 4 шага [источник]

Рассмотрим ещё немного картинок 🙂

Авторы добиваются качества, сопоставимого с качеством SD, однако их модель работает примерно в 30 раз быстрее.

Рисунок 26. Результаты генерации предложенной модели [источник]

Результаты в метриках

Также авторы приводят результаты генерации не только диффузионных моделей, но и GANs для сравнения качества и скорости различных разрешений изображений:

Рисунок 27. Сравнение DMD и других генераторов на разных разрешениях по метрике FID [источник]
Рисунок 28. Сравнение DMD и других способов дистилляции генератора SD1.5 за несколько шагов [источник]

Вывод

Итак, закрепим:

Основная задача — дистиллировать знания из предобученной модели в одношаговый генератор.

Для этого используем две функции потерь:

  • LPIPS — объект, который генерируется одношаговым генератором и становится близким объекту, генерируемому оригинальной функцией потерь;
  • Distribution Matching Loss — идея, аналогичная ProlificDreamer, где минимизируется KL-дивергенция между реальным и фейковым распределениями. Её невозможно расписать честно, но вполне реально аппроксимировать градиент через score-функции диффузионных моделей.

Также важная особенность работы — отсутствие реального датасета изображений для обучения.

Improved Distribution Matching Distillation for Fast Image Synthesis

У статьи выше недавно вышло продолжение, которое предлагает открытый код (в отличие от предшественника). Давайте узнаем, какие были внесены изменения, и получилось ли добиться лучшего результата.

Описание метода

Хотя регрессионный лосс из предыдущей работы помогает достигнуть некоторой стабильности обучения, он в то же время сильно ограничивает качество генерации качеством «учителя».

Но если его убрать — качество генерации значительно ухудшается:

Рисунок 29. Качество генерации с регрессионным лоссом и без него [источник]

Также авторы заметили, что средняя яркость генерируемых сэмплов сильно флуктуирует и не сходится в процессе обучения:

Рисунок 30. Флуктуации яркости при обучении [источник]

Авторы предполагают, что проблема кроется в \( \mu_{fake} \), которая не может качественно предсказывать fake scores. Это даёт ошибку в подсчёте градиента для оптимизации генератора.

Для её решения они обращаются к обучению в стиле GANs: обновляем генератор каждые 5 шагов оптимизации \( \mu_{fake} \). Ниже представлены результаты FID для такого обучения с различным количеством шагов при обновлении генератора:

Рисунок 31. Сравнение метрики FID при обучении с регрессионным лоссом и без него, а также с различным количеством шагов обновления генератора [источник]

Несмотря на улучшение стабильности обучения, генератор не достигает качества «учителя». Возможно, это связано с тем, что он не видел реальных изображений. Для нивелирования проблемы авторы предлагают использовать adversarial loss, который разделит реальные изображения и изображения, сгенерированные нашим генератором.

Чтобы построить такой дискриминатор, исследователи добавляют классификационную ветвь поверх fake diffusion denoiser \( \mu_{fake} \):

Рисунок 32. Основной пайплайн обучения [источник]

Он обучается с помощью функции потерь GANs:

\( \mathcal{L}{GAN} = \mathbb{E}{x \sim p_{real}, t\sim [0,T]}[\log D(F(x, t))] + \mathbb{E}{z \sim p{noise}, t\sim [0,T]} [-\log(D(F(G_{\theta}(z),t)))] \)

Так авторы достигли хорошего качества на ImageNet и COCO датасетах. Однако дистилляция больших текстовых моделей типа SDXL в одношаговый генератор не даёт похожих результатов. Поэтому авторы обращаются к multi-step сэмплированию.

Multi-step generator

Они фиксируют шаги генератора, которые будут использоваться во время обучения и инференса \( \{t_1 t_2,…t_n\} \). В частности для модели в 4 шага применялись 999, 749, 499 и 249 шаги «учителя», обучавшегося на 1000 шагов.

Предыдущие Multi-step генераторы часто обучались расшумлять реальные изображения. Тем не менее без первого шага, чистого шума, на вход генератора приходит прошлый шаг сэмплирования. Таким образом, появляется разница между обучением и инференсом, что приводит к ухудшению генерации на инференсе.

Для решения проблемы авторы заменили при обучении реальные зашумлённые изображения на синтетические зашумлённые картинки, полученные нашим обучаемым генератором (похожая идея предложена в работе Imagine Flash).

Рисунок 33. Отличие процессов обучения при зашумлении реальных и сгенерированных изображений [источник]

Результаты в картинках

Ниже представлено качество генерации нескольких диффузионных моделей, ускоренных за 4 шага. Стоит отметить: модель, предложенная авторами, меньше всего страдает от различных артефактов, особенно при работе с людьми и сложными длинными промптами.

Рисунок 34. Сравнение генерации изображений моделью и другими методами ускорения, а также оригинальной диффузионной моделью «учителя» [источник]

Авторы также проводят ablation study, сравнивая качество генерации с использованием distribution matching’а, GANs loss, backward simulation и без них:

Рисунок 35. Сравнение генерации изображений моделью без её различных составляющих [источник]

Результаты в метриках

Авторы везде выигрывают модель SDXL по качеству генерации и соответствию промпту в User study:

Рисунок 36. User study [источник]

Также они сравнивают работу модели на задаче class-conditional generation и text2image генерации по метрикам FID и CLIP (в случае t2i задачи):

Рисунок 37. Метрики в задачах class-conditional и text2image [источник]

Вывод

Таким образом, авторы добиваются высокого качества генерации, однако для больших текстовых моделей их метод всё ещё работает за несколько шагов (4 шага). Также во время обучения всегда применяется один и тот же guidance scale, а при использовании различных guidance scale возможно улучшение качества генерации.

Заключение

В нашей статье мы разобрали дистилляцию диффузии, направленную, в первую очередь, на уменьшение количества шагов генерации. Познакомились с рядом работ в порядке их появления и развития. Рассмотрели SOTA-решения, которые включают различные подходы: например, матчинг распределений и adversarial обучение.

Ещё год назад попытки уменьшить количество шагов базировались в основном на изменении сэмплеров генерации. Однако сейчас существует множество других способов, позволяющих достичь хорошего качества моделей. К сожалению, не ко всем методам прилагается код в открытом доступе, тем не менее всегда есть надежда на аналогичный open source 🙂

Полезные ссылки

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!