Назад
239

Дистилляция диффузии. Часть 1

239

Введение

За последние пару лет диффузионные модели прочно закрепились в мире генеративных (и не только) задач.

Рисунок 1. Изображение, сгенерированное через SDXL-Turbo

С их помощью мы получаем качественный и разнообразный материал при генерации:

  • изображений;
Рисунок 2. Генерация изображений с помощью модели Midjourney
  • видео;
Рисунок 3. Генерация видео с помощью модели SORA
  • музыки.

[Рисунок 4. Генерация музыки [источник]](https://prod-files-secure.s3.us-west-2.amazonaws.com/fb875fd6-d46b-4f75-8a43-7beca7a54a5e/21dcdec8-56e4-4da7-a935-05d26a57f505/wavegen_19.wav)

И это ещё не весь список 🙂

К сожалению, такой процесс остаётся достаточно медленным. Для генерации одного сэмпла нам зачастую нужно много запусков одной немаленькой нейронной сети, что занимает много времени в сравнении с теми же GANs.

Именно поэтому в данной статье мы с вами рассмотрим способ ускорения диффузионных моделей, но для начала давайте вспомним, а что же вообще представляют из себя диффузионки?

Диффузионные модели

Основная идея диффузионных моделей — генерация элементов распределения из шума путём постепенного расшумления.

Рисунок 5. Генерация изображений из шума [источник]

Диффузионные модели можно рассматривать со стороны score matching’а, стохастических дифференциальных уравнений (SDE) или как модели со скрытыми латентными переменными (идейно близко к VAE). Подробнее об этом можно почитать в нашем посте.

Диффузионный процесс состоит из двух частей:

  • forward process — постепенное добавление шума к распределению объектов до тех пор, пока итоговое распределение не получит латентное распределение шума (чаще всего работают с гауссовским шумом, но также была представлена работа о других типах шумов). Очень важный фактор — марковский процесс (каждый следующий этап зашумления зависит только от предыдущего).
  • backward process — сэмплирование объекта из латентного распределения и получение объекта из исходного распределения путём пошагового расшумления.

Прямой процесс (forward process)

Пусть нас есть сэмпл \( x_0 \sim q(x) \). Тогда прямой процесс диффузии — постепенное добавление небольшого количества шума, чаще всего гауссовского (далее работаем с ним), в течение \( T \) шагов. В результате имеем последовательность зашумленных сэмплов \( x_1, x_2 …, x_T \) и гауссовский шум. Количество добавляемого шума регулируется с помощью variance scheduler \( \{\beta_t \in (0,1)\}^T_t=1 \).

Тогда зашумлённое распределение на шаге \( t \) можно записать следующим образом:

\( q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbf{I}) \)

\( q(x_{1:T}|x_0) = \prod_{t=1}^{T}q(x_t|x_{t-1}) \)

Рисунок 6. Зашумление изображения с помощью библиотеки diffusers. Само изображение сгенерировано через SDXL Turbo

Благодаря трюку репараметризации мы можем сэмплировать из получившегося распределения:

\( x_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t}\epsilon_{t-1} \)

Рисунок 7. Схема работы диффузионной модели [источник]

Обратный процесс (backward process)

Если мы сможем обернуть данный процесс — получим генеративную модель, которая берёт сэмпл из гауссовского шума с нулевым средним и единичной дисперсией \( x \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \) и восстанавливает элемент из исходного распределения.

Важно отметить: если \( \beta_t \) — маленькое значение, то \( q(x_{t-1}|x_t) \) тоже будет гауссовским. Но мы не можем просто оценивать такое распределение: для этого нужен весь исходный датасет. Следовательно, нашей задачей становится аппроксимация с помощью какой-либо модели \( p_{\theta} \):

\( p_{\theta}(x_{0:T}) = p(x_T) \prod_{t=1}^{T}p_{\theta}(x_{t-1}|x_{t}) \)

\( p_{\theta}(x_{t-1}|x_{t}) = \mathcal{N}(x_{t-1}; \mathbf{\mu}{\theta}(x_t, t), \mathbf{\sum}{\theta}(x_t,t)) \)

Рисунок 8. Пример обучения диффузионной модели для 2D-данных [источник]

Ускорение

Как мы отмечали выше, диффузионные модели — не самые быстрые генераторы. Отсюда возникает вопрос: как нам добиться высокой скорости?

Ускорять их можно разными способами. Первый метод: снижать время выполнения одного шага генерации, например, переходить в латентное пространство (а не работать с пикселями), как это было предложено в LDM.

Второй метод: уменьшать количество шагов. Один из таких способов — дистилляция «‎учителя»‎, предобученной диффузионной модели, в «‎ученика», способного генерировать за несколько шагов. То есть мы передаём знания, зашитые в нашу большую и хорошо работающую модель, во что-то более компактное и быстрее работающее.

В этом обзоре мы разберём несколько ключевых работ по дистилляции диффузионных моделей с помощью различных подходов.

Diffusion Distillation

📄 Paper

📎 GitHub

Важное условие дистилляции моделей — обучаемая функция должна быть детерминистической. Но стандартная диффузионная модель DDPM является стохастической и не может быть «‎учителем» при таком сценарии.

Однако появилась идея Denoising Diffusion implicit models (DDIM) — здесь процесс генерации детерминированный, следовательно, модели могут использоваться как «‎учителя».

Как же нам применить её для ускорения сэмплирования?

Ключевая идея работы — обучить «‎ученика», генератор, способный генерировать изображение из шума в один шаг. То есть из пространства шума мы хотим сразу оказаться в пространстве изображений. В качестве «‎учителя» выбирается предобученная диффузионная DDIM-модель, которая умеет генерировать качественные сэмплы, но за большее количество шагов.

Рисунок 9. Схема дистилляции диффузионной модели [источник]

То есть мы минимизируем KL-дивергенцию между конечным распределением, которое предсказывает «‎учитель», и распределением обучаемого одношагового генератора:

\( \mathcal{L}{student} = \mathbb{E}{x_T} \Big[D_{KL}(p_{teacher}(x_0|x_T) \| p_{student}(x_0|x_T))\Big] \)

Параметризуем нашего студента как простую гауссовскую модель с обучаемым средним \( \mathcal{F}{student} \) и единичной дисперсией. В качестве «‎учителя» выступает DDIM-модель \( \mathcal{F}{teacher} \), что приводит нас к функции потерь:

\( \mathcal{L}_ {student} = \mathbb{E}{x_T} \Big[\|\mathcal{F}{student}(x_T) — \mathcal{F}{teacher}(x_T)\|^2_2\Big] + C \)

То есть мы сэмплируем \( x_T \) из шума, рассчитываем \( \mathcal{F}_{teacher}(x_T) \) с помощью DDIM и минимизируем записанный выше лосс.

Результаты в картинках

Важно отметить: эксперименты проводились на датасетах с картинками очень невысокого разрешения.

Рисунок 10. Unconditional генерация CIFAR10 и CelebA [источник]
Рисунок 11. Unconditional генерация LSUN (Bedroom, Church-Outdoor) [источник]

Результаты в метриках

Отметим, что результат генерации хуже, чем у диффузионных моделей с большим количеством шагов и одношаговых генеративных (как GANs).

Рисунок 12. Метрики FID и IS на датасете CIFAR10 [источник]

Вывод

Таким образом, дистилляция модели в один шаг не даёт достаточно качественных результатов генерации. Поэтому мы переходим к следующей идее 😉

Progressive Distillation

📄 Paper

📎 GitHub

Ключевая идея работыпостепенная дистилляция с уменьшением количества шагов вдвое.

Рисунок 13. Схема алгоритма Progressive Distillation [источник]

В качестве «‎учителя» также выступает DDIM, и весь алгоритм обучения выглядит следующим образом:

Рисунок 14. Алгоритм Progressive Distillation [источник]

В работе также вводится новый метод параметризации, который используется сейчас в разных диффузионных моделях, например, в StableDiffusion 2.0.

До этого многие применяют \( \epsilon- \)параметризацию, где именно шум \( \epsilon \) предсказывается с помощью нейронной сети \( \hat{\epsilon}_{\theta}(z_t) \). Для неё функция потерь записывается следующим образом:

\( \mathcal{L}{\theta} = \|\epsilon — \hat{\epsilon}{\theta}(z_t)\|^2_2 \)

Поскольку сэмпл распределения вычисляется через этот шум \( \hat{x}{\theta}(z_t) = \frac{1}{\alpha_t}(z_t — \sigma_t \hat{\epsilon}{\theta}(z_t)) \), функцию потерь можно переписать так:

\( \mathcal{L}{\theta} = \|\epsilon — \hat{\epsilon}{\theta}(z_t)\|^2_2 = \|\frac{1}{\sigma_t}(z_t — \alpha_tx) — \frac{1}{\sigma_t}(z_t — \alpha_t \hat{x}_{\theta}(z_t))\|2^2 = \frac{\alpha_t^2}{\sigma_t^2}\|x — \hat{x}{\theta}(z_t)\|^2_2 \)

Почему такая параметризация не подходит для дистилляции? Отметим, что изначальная модель работает с широким спектром отношения сигнала к шуму (signal-to-noise ratios SNR) — \( \alpha_t^2 / \sigma_t^2 \). Но при постепенной дистилляции он всё больше уменьшается. Когда SNR приближается к нулю, маленькие изменения выхода нейронной сети \( \hat{\epsilon}{\theta}(z_t) \) дают существенные трансформации самого изображения, поскольку \( \hat{x}{\theta}(z_t) = \frac{1}{\alpha_t}(z_t — \sigma_t \hat{\epsilon}_{\theta}(z_t)) \) (\( \alpha_t \rightarrow 0 \)). Если у нас много шагов генерации — ошибки начальных шагов постепенно корректируются. Если шагов мало — такая параметризация становится проблемой, особенно при задаче генерации в один шаг.

Таким образом, для качественной дистилляции важна устойчивость модели к изменению SNR.

Поэтому авторы предлагают использовать альтернативные способы параметризации:

  • предсказывать \( x \);
  • предсказывать \( x \) и \( \epsilon \);
  • предсказывать v = \( \alpha_t \epsilon — \sigma_t x \) что даёт сэмпл из распределения \( \hat{x} = \alpha_tz_t — \sigma_t \hat{v}_{\theta}(z_t) \).

Для новой параметризации также нужно перестроить функцию потерь. Авторы представляют два варианта:

  • \( \mathcal{L}_{\theta} = \text{max}(\|x — \hat{x}_t\|^2_2, \|\epsilon — \hat{\epsilon}_t\|^2_2) = max(\frac{\alpha_t^2}{\sigma_t^2}, 1)\|x — \hat{x}_t\|_2^2 \) (truncated SNR weighting);
  • \( \mathcal{L}_{\theta} = \|v_t — \hat{v}_t\|^2_2 = (1 + \frac{\alpha_t^2}{\sigma_t^2})\|x — \hat{x}_t\|^2_2 \) (SNR+1 weighting).

Они сравнивают различные параметризации и веса для функции потерь, а также два сэмплера: DDIM (Denoising Diffusion Implicit Models, детерминистический сэмплер) и DDPM (Denoising Diffusion Probabilistic Model, стохастический сэмплер), считая метрики FID и IS для датасета CIFAR10.

Рисунок 15. Сравнение FID и IS метрик для различных параметризаций и weighting функций на unconditional генерации CIFAR-10 [источник]

Отметим, что результаты несильно отличаются для разных параметризаций и весов функции потерь. Таким образом, различные параметризации могут использоваться в зависимости от поставленной задачи.

Результаты в картинках

На изображениях ниже видно, как сильно падает качество за счёт уменьшения количества шагов (особенно до 1 шага). Собака не везде действительно похожа на собаку 🙂

Рисунок 16. Class-conditional генерация изображения с кондишнингом на ‘malamute’ класс в зависимости от количества шагов [источник]

При unconditional генерации результаты лучше даже на небольшом количестве шагов.

Рисунок 17. Unconditional генерация на датасете LSUN Church-Outdoor [источник]
Рисунок 18. Unconditional генерация на датасете LSUN Bedrooms [источник]

Результаты в метриках

Авторы протестировали дистиллированную модель на следующих датасетах:

  • CIFAR-10;
  • ImageNet;
  • LSUN Bedrooms;
  • LSUN Church-Outdoor.

Как видно на рисунке ниже, качество дистиллированной модели сильно падает при сокращении шагов до 4-х. При этом у DDIM и DDPM отмечается значительное ухудшение качества на 128 шагах (что заметно больше).

На большем количестве шагов все модели дают приблизительно одинаковое качество генерации.

Рисунок 19. Сравнение FID-метрик для DDIM, DDPM и Distilled модели на различных датасетах в зависимости от количества шагов генерации [источник]
Рисунок 20. Сравнение FID-метрики на CIFAR-10 диффузионных моделей с указанием количества шагов генерации [источник]

Вывод

Хотя метод показал улучшение в сравнении с обычной дистилляцией, он всё ещё не сократил шаги генерации до небольшого количества (до 10, и уж тем более до 1-2). Значит, нужно придумывать что-то ещё 🙂

On Distillation of Guided Diffusion Models

Предложенные выше работы дистиллировали диффузионные модели, неспособные кондишниться (эксперименты на unconditional генерации) — это генерация различных сэмплов из шума. Однако сейчас чаще используются модели, которые выдают результат на основе текста, картинки, маски или чего-то другого.

Для решения этого вопроса авторы предложили двухэтапный метод дистилляции — classifier-free guided diffusion models.

Classifier-free guidance

Прежде чем погружаться в идею conditional генерации, советуем вам прочитать пост про диффузию и её связь со score matching.

Основная идея conditional генерации — генерация объектов не просто из распределения \( p_0(x) \), а из условного распределения \( p_0(x|y) \).

Что такое условное распределение?

\( p_0(x|y) = \frac{p(y|x)p_0(x)}{p(y)} \) (из теоремы Байеса).

Рассмотрим его score-функцию для времени \( t=0 \).

\( \frac{\partial \log p_0(x|y)}{\partial x} = \frac{\partial}{\partial x} \log p(y|x) + \frac{\partial}{\partial x} \log p_0(x) — \underbrace{\frac{\partial}{\partial x} \log p(y)}_{=0} \)

И для произвольного момента времени:

\( \frac{\partial \log p_t(x|y)}{\partial x} = \underbrace{\frac{\partial}{\partial x} \log p(y|x_t)}_{\text{Classifier guidance}} + \frac{\partial}{\partial x} \log p_t(x) \)

Слагаемое \( s_{\theta}(x_t, t) \approx \frac{\partial}{\partial x} \log p_t(x) = -\frac{\epsilon_{\theta}(x,t)}{\sqrt{1 — \bar{\alpha}t}} \) аппроксимируется score-функцией, которую мы обучали. То есть \( \epsilon{\theta}(x,t) \approx — \sqrt{1 — \bar{\alpha}t}s{\theta}(x_t, t) \).

Если мы знаем первое слагаемое — мы можем использовать score-функцию для генерации условного распределения.

То есть нам нужно обучить модель классификатора, которая умеет работать на различных уровнях шума.

Тогда новое classifier-guided предсказание будет выглядеть так:

\( \bar{\epsilon}{\theta}(x_t,t) = \epsilon{\theta}(x_t,t) — \sqrt{1 — \bar{\alpha}t}w\nabla{x_t} \log p_{\phi}(y|x_t) \)

где \( w \) — весовой коэффициент (guidance strength), контролирующий силу нашего classifier-guidence.

Рисунок 21. Алгоритм classifier-guided генерации для DDPM модели [источник]

Ключевая особенность classifier-free guidance — возможность trade-off между качеством и вариативностью сэмплов. Это контролируется параметром guidance strength.

1 ступень обучения

На первом этапе вводится модель студента \( \hat{x}_{\nu_1}(z_t, w) \) c обучаемыми параметрами \( \nu_1 \), которая учится предсказывать распределения учителя на любом шаге \( t \in [0, 1] \).

Но мы используем диффузионные модели с classifier-free guidance и выбираем различную guidance strength для предсказания в зависимости от желания, значит, получаем в результате разные предсказания. Как быть?

Пусть у нас есть отрезок возможных guidance strength \( [w_{min}, w_{max}] \) .

В этом случае авторы предлагают использовать следующую функцию потерь:

Для \( w \sim p_w, t \sim U[0,1], x \sim p_{data}(x) \)

\( \mathbb{E}{w,t,x} \Big[\omega(\lambda_t)\|\hat{x}{\nu_1}(z_t, w) — \hat{x}_{\theta}^w(z_t)\|^2_2\Big] \)

где

  • \( \hat{x}{\theta}(z_t) = (1 + w)\hat{x}{c,\theta}(z_t) — w\hat{x}_{\theta}(z_t) \),
  • \( z_t \sim q(z_t|x) \),
  • \( p_w(w) = U[w_{min}, w_{max}] \)

То есть мы добавляем в функцию потерь результаты, полученные при разных guidance strength, и усредняем их.

Таким образом, весь алгоритм выглядит следующим образом:

Рисунок 22. Алгоритм первого этапа дистилляции диффузионной модели [источник]

2 ступень обучения

На втором этапе мы используем идею progressive distillation и дистиллируем нашу модель, выученную на первом шаге \( \hat{x}{\nu_1}(z_t, w) \), в новую \( \hat{x}{\nu_2}(z_t, w), \) которая умеет генерировать результат за несколько шагов.

Как мы говорили ранее, для дистилляции используют детерминистическую модель DDIM, поэтому весь алгоритм можно записать следующим образом:

Рисунок 23. Алгоритм второго этапа дистилляции диффузионной модели [источник]

Cтохастическое сэмплирование

До этого момента мы обсуждали дистилляцию в контексте применения DDIM, или детерминистической модели с инициализацией \( z_1^w \), где \( w \in [w_{min}, w_{max}] \) — заданная guidance strenght.

Но авторы утверждают: можно проводить также N-шаговое стохастическое сэмплирование. Схематически оно выглядит таким образом:

Рисунок 24. Схема стохастического сэмплирования [источник]

Его можно записать в виде следующего алгоритма:

Рисунок 25. Алгоритм стохастического сэмплирования [источник]

Таким образом, авторы предложили дистиллировать диффузионные модели, которые чем-то гайдятся дополнительно (текстом, картинкой, меткой класса). Для этого используется двухэтапный процесс дистилляции: на первой ступени мы дистиллируем предобученную диффузию с различными guidance strength в студента, на второй — используем progressive distillation для обучения генератора за несколько шагов.

Также они предложили дополнительную технику сэмплирования с использованием стохастики.

И в итоге решили проблему гайданса 🙂.

Результаты в картинках

Авторы экспериментировали с разными моделями и исследовали:

  • кондишнинг на класс на ImageNet 256×256;
  • text2image generation диффузию на LAION-5B 512×512;
  • text2image translation c использованием SDEdit;
  • inpainting.

Давайте рассмотрим результаты каждой задачи.

Рисунок 26. Генерация картинок по тексту, inpainting, image-to-image translation [источник]

text2image generation

Стоит отметить: при уменьшении количества шагов до 2-4 качество значительно ухудшается, и генерация становится не такой точной.

Рисунок 27. Генерация картинок по тексту для разного количества шагов [источник]

text2image translation

Ниже представлены примеры решения задачи image-to-image translation, в частности смены стиля.

Рисунок 28. Image-to-image генерация [источник]

Inpainting

Задачу inpainting также можно решать с помощью дистиллированной диффузионной модели.

Рисунок 29. Inpainting [источник]

Авторы сравнивают результаты генерации дистиллированной модели (предложенный пайплайн) и DDIM модели с сокращением количества шагов. Стоит отметить: наивное уменьшение шагов не даёт хорошего качества, изображения получаются очень размытыми и с заметными артефактами.

Рисунок 30. Сравнение дистиллированной модели и DDIM модели на 4-х и 8-ми шагах [источник]

Результаты в метриках

Для проверки качества генерации авторы провели эксперимент на датасете LAION 512×512 для дистиллированной модели и моделей DPM и DPM++ (улучшенные солверы для быстрого сэмплирования). Стоит отметить: ускорение с помощью дистилляции значительно выигрывает при уменьшении количества шагов до 2-4.

Рисунок 31. Метрики FID и CLIP для DPM, DPM++ и дистиллированной модели [источник]

На датасете ImageNet 64×64 авторы сравнили диффузионные модели, работающие в пиксельном пространстве. Они были дистиллированны на:

  • различных giudance strength с помощью алгоритма выше;
  • одном giudance strength (single-w);
  • модели DDIM.

Стоит отметить: несмотря на хорошее качество на 8-м и дальнейших шагах, генерация за меньшее количество шагов даёт плохие метрики.

Рисунок 32. Сравнение генерации диффузионных моделей, работающих в пиксельном пространстве на ImageNet 64×64 [источник]

Также авторы построили графики зависимости метрик FID и IS для дистиллированной модели (со стохастическим и детерминистическим сэмплированием), модели DDPM, DDIM. Стоит отметить: предложенная авторами модель значительно превосходит DDPM и DDIM и достигает качества своего учителя даже на 8-ми шагах.

Рисунок 33. Качество генерации метрик FID и IS на ImageNet 64×64 в зависимости от количества шагов для различных guidance strength [источник]

Вывод

В этой работе основное внимание уделяется моделям, которые на что-то кондишнятся — именно они сейчас широко используются на практике. Статья очень важна, поскольку помогла улучшить качество генерации. Однако несмотря на хороший результат на 8-м и дальнейших шагах, их уменьшение до 1-2 даёт очень плохое качество.

Заключение

Таким образом, мы начали погружение в мир дистилляции диффузионных моделей, чтобы научиться их ускорять. В данном обзоре мы узнали, с чего всё начиналось и как затем развивалось. Однако это были одни из первых работ, которые, к сожалению, не достигли хорошего качества генерации за несколько шагов.

В следующей части мы рассмотрим более современные работы в этой сфере, которые как раз применяются для ускорения больших и качественных генеративных моделей, таких как, например, SDXL. Продолжение следует 🙂

Полезные ссылки

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!