Дистилляция диффузии. Часть 1
- Введение
- Диффузионные модели
- Прямой процесс (forward process)
- Обратный процесс (backward process)
- Ускорение
- Результаты в картинках
- Результаты в метриках
- Вывод
- Результаты в картинках
- Результаты в метриках
- Вывод
- Classifier-free guidance
- 1 ступень обучения
- 2 ступень обучения
- Результаты в картинках
- Результаты в метриках
- Вывод
- Заключение
- Полезные ссылки
Введение
За последние пару лет диффузионные модели прочно закрепились в мире генеративных (и не только) задач.
С их помощью мы получаем качественный и разнообразный материал при генерации:
- изображений;
- видео;
- музыки.
[Рисунок 4. Генерация музыки [источник]](https://prod-files-secure.s3.us-west-2.amazonaws.com/fb875fd6-d46b-4f75-8a43-7beca7a54a5e/21dcdec8-56e4-4da7-a935-05d26a57f505/wavegen_19.wav)
И это ещё не весь список 🙂
К сожалению, такой процесс остаётся достаточно медленным. Для генерации одного сэмпла нам зачастую нужно много запусков одной немаленькой нейронной сети, что занимает много времени в сравнении с теми же GANs.
Именно поэтому в данной статье мы с вами рассмотрим способ ускорения диффузионных моделей, но для начала давайте вспомним, а что же вообще представляют из себя диффузионки?
Диффузионные модели
Основная идея диффузионных моделей — генерация элементов распределения из шума путём постепенного расшумления.
Диффузионные модели можно рассматривать со стороны score matching’а, стохастических дифференциальных уравнений (SDE) или как модели со скрытыми латентными переменными (идейно близко к VAE). Подробнее об этом можно почитать в нашем посте.
Диффузионный процесс состоит из двух частей:
- forward process — постепенное добавление шума к распределению объектов до тех пор, пока итоговое распределение не получит латентное распределение шума (чаще всего работают с гауссовским шумом, но также была представлена работа о других типах шумов). Очень важный фактор — марковский процесс (каждый следующий этап зашумления зависит только от предыдущего).
- backward process — сэмплирование объекта из латентного распределения и получение объекта из исходного распределения путём пошагового расшумления.
Прямой процесс (forward process)
Пусть нас есть сэмпл
Тогда зашумлённое распределение на шаге \( t \) можно записать следующим образом:
\( q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbf{I}) \)
\( q(x_{1:T}|x_0) = \prod_{t=1}^{T}q(x_t|x_{t-1}) \)
Благодаря трюку репараметризации мы можем сэмплировать из получившегося распределения:
\( x_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t}\epsilon_{t-1} \)
Обратный процесс (backward process)
Если мы сможем обернуть данный процесс — получим генеративную модель, которая берёт сэмпл из гауссовского шума с нулевым средним и единичной дисперсией \( x \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \) и восстанавливает элемент из исходного распределения.
Важно отметить: если \( \beta_t \) — маленькое значение, то \( q(x_{t-1}|x_t) \) тоже будет гауссовским. Но мы не можем просто оценивать такое распределение: для этого нужен весь исходный датасет. Следовательно, нашей задачей становится аппроксимация с помощью какой-либо модели \( p_{\theta} \):
\( p_{\theta}(x_{0:T}) = p(x_T) \prod_{t=1}^{T}p_{\theta}(x_{t-1}|x_{t}) \)
\( p_{\theta}(x_{t-1}|x_{t}) = \mathcal{N}(x_{t-1}; \mathbf{\mu}{\theta}(x_t, t), \mathbf{\sum}{\theta}(x_t,t)) \)
Ускорение
Как мы отмечали выше, диффузионные модели — не самые быстрые генераторы. Отсюда возникает вопрос: как нам добиться высокой скорости?
Ускорять их можно разными способами. Первый метод: снижать время выполнения одного шага генерации, например, переходить в латентное пространство (а не работать с пикселями), как это было предложено в LDM.
Второй метод: уменьшать количество шагов. Один из таких способов — дистилляция «учителя», предобученной диффузионной модели, в «ученика», способного генерировать за несколько шагов. То есть мы передаём знания, зашитые в нашу большую и хорошо работающую модель, во что-то более компактное и быстрее работающее.
В этом обзоре мы разберём несколько ключевых работ по дистилляции диффузионных моделей с помощью различных подходов.
Diffusion Distillation
📄 Paper
📎 GitHub
Важное условие дистилляции моделей — обучаемая функция должна быть детерминистической. Но стандартная диффузионная модель DDPM является стохастической и не может быть «учителем» при таком сценарии.
Однако появилась идея Denoising Diffusion implicit models (DDIM) — здесь процесс генерации детерминированный, следовательно, модели могут использоваться как «учителя».
Как же нам применить её для ускорения сэмплирования?
Ключевая идея работы — обучить «ученика», генератор, способный генерировать изображение из шума в один шаг. То есть из пространства шума мы хотим сразу оказаться в пространстве изображений. В качестве «учителя» выбирается предобученная диффузионная DDIM-модель, которая умеет генерировать качественные сэмплы, но за большее количество шагов.
То есть мы минимизируем KL-дивергенцию между конечным распределением, которое предсказывает «учитель», и распределением обучаемого одношагового генератора:
\( \mathcal{L}{student} = \mathbb{E}{x_T} \Big[D_{KL}(p_{teacher}(x_0|x_T) \| p_{student}(x_0|x_T))\Big] \)
Параметризуем нашего студента как простую гауссовскую модель с обучаемым средним \( \mathcal{F}{student} \) и единичной дисперсией. В качестве «учителя» выступает DDIM-модель \( \mathcal{F}{teacher} \), что приводит нас к функции потерь:
\( \mathcal{L}_ {student} = \mathbb{E}{x_T} \Big[\|\mathcal{F}{student}(x_T) — \mathcal{F}{teacher}(x_T)\|^2_2\Big] + C \)
То есть мы сэмплируем \( x_T \) из шума, рассчитываем \( \mathcal{F}_{teacher}(x_T) \) с помощью DDIM и минимизируем записанный выше лосс.
Результаты в картинках
Важно отметить: эксперименты проводились на датасетах с картинками очень невысокого разрешения.
Результаты в метриках
Отметим, что результат генерации хуже, чем у диффузионных моделей с большим количеством шагов и одношаговых генеративных (как GANs).
Вывод
Таким образом, дистилляция модели в один шаг не даёт достаточно качественных результатов генерации. Поэтому мы переходим к следующей идее 😉
Progressive Distillation
📄 Paper
📎 GitHub
Ключевая идея работы — постепенная дистилляция с уменьшением количества шагов вдвое.
В качестве «учителя» также выступает DDIM, и весь алгоритм обучения выглядит следующим образом:
В работе также вводится новый метод параметризации, который используется сейчас в разных диффузионных моделях, например, в StableDiffusion 2.0.
До этого многие применяют \( \epsilon- \)параметризацию, где именно шум \( \epsilon \) предсказывается с помощью нейронной сети \( \hat{\epsilon}_{\theta}(z_t) \). Для неё функция потерь записывается следующим образом:
\( \mathcal{L}{\theta} = \|\epsilon — \hat{\epsilon}{\theta}(z_t)\|^2_2 \)
Поскольку сэмпл распределения вычисляется через этот шум \( \hat{x}{\theta}(z_t) = \frac{1}{\alpha_t}(z_t — \sigma_t \hat{\epsilon}{\theta}(z_t)) \), функцию потерь можно переписать так:
\( \mathcal{L}{\theta} = \|\epsilon — \hat{\epsilon}{\theta}(z_t)\|^2_2 = \|\frac{1}{\sigma_t}(z_t — \alpha_tx) — \frac{1}{\sigma_t}(z_t — \alpha_t \hat{x}_{\theta}(z_t))\|2^2 = \frac{\alpha_t^2}{\sigma_t^2}\|x — \hat{x}{\theta}(z_t)\|^2_2 \)
Почему такая параметризация не подходит для дистилляции? Отметим, что изначальная модель работает с широким спектром отношения сигнала к шуму (signal-to-noise ratios SNR) — \( \alpha_t^2 / \sigma_t^2 \). Но при постепенной дистилляции он всё больше уменьшается. Когда SNR приближается к нулю, маленькие изменения выхода нейронной сети \( \hat{\epsilon}{\theta}(z_t) \) дают существенные трансформации самого изображения, поскольку \( \hat{x}{\theta}(z_t) = \frac{1}{\alpha_t}(z_t — \sigma_t \hat{\epsilon}_{\theta}(z_t)) \) (\( \alpha_t \rightarrow 0 \)). Если у нас много шагов генерации — ошибки начальных шагов постепенно корректируются. Если шагов мало — такая параметризация становится проблемой, особенно при задаче генерации в один шаг.
Таким образом, для качественной дистилляции важна устойчивость модели к изменению SNR.
Поэтому авторы предлагают использовать альтернативные способы параметризации:
- предсказывать \( x \);
- предсказывать \( x \) и \( \epsilon \);
- предсказывать v = \( \alpha_t \epsilon — \sigma_t x \) что даёт сэмпл из распределения \( \hat{x} = \alpha_tz_t — \sigma_t \hat{v}_{\theta}(z_t) \).
Для новой параметризации также нужно перестроить функцию потерь. Авторы представляют два варианта:
- \( \mathcal{L}_{\theta} = \text{max}(\|x — \hat{x}_t\|^2_2, \|\epsilon — \hat{\epsilon}_t\|^2_2) = max(\frac{\alpha_t^2}{\sigma_t^2}, 1)\|x — \hat{x}_t\|_2^2 \) (truncated SNR weighting);
- \( \mathcal{L}_{\theta} = \|v_t — \hat{v}_t\|^2_2 = (1 + \frac{\alpha_t^2}{\sigma_t^2})\|x — \hat{x}_t\|^2_2 \) (SNR+1 weighting).
Они сравнивают различные параметризации и веса для функции потерь, а также два сэмплера: DDIM (Denoising Diffusion Implicit Models, детерминистический сэмплер) и DDPM (Denoising Diffusion Probabilistic Model, стохастический сэмплер), считая метрики FID и IS для датасета CIFAR10.
Отметим, что результаты несильно отличаются для разных параметризаций и весов функции потерь. Таким образом, различные параметризации могут использоваться в зависимости от поставленной задачи.
Результаты в картинках
На изображениях ниже видно, как сильно падает качество за счёт уменьшения количества шагов (особенно до 1 шага). Собака не везде действительно похожа на собаку 🙂
При unconditional генерации результаты лучше даже на небольшом количестве шагов.
Результаты в метриках
Авторы протестировали дистиллированную модель на следующих датасетах:
- CIFAR-10;
- ImageNet;
- LSUN Bedrooms;
- LSUN Church-Outdoor.
Как видно на рисунке ниже, качество дистиллированной модели сильно падает при сокращении шагов до 4-х. При этом у DDIM и DDPM отмечается значительное ухудшение качества на 128 шагах (что заметно больше).
На большем количестве шагов все модели дают приблизительно одинаковое качество генерации.
Вывод
Хотя метод показал улучшение в сравнении с обычной дистилляцией, он всё ещё не сократил шаги генерации до небольшого количества (до 10, и уж тем более до 1-2). Значит, нужно придумывать что-то ещё 🙂
On Distillation of Guided Diffusion Models
Предложенные выше работы дистиллировали диффузионные модели, неспособные кондишниться (эксперименты на unconditional генерации) — это генерация различных сэмплов из шума. Однако сейчас чаще используются модели, которые выдают результат на основе текста, картинки, маски или чего-то другого.
Для решения этого вопроса авторы предложили двухэтапный метод дистилляции — classifier-free guided diffusion models.
Classifier-free guidance
Прежде чем погружаться в идею conditional генерации, советуем вам прочитать пост про диффузию и её связь со score matching.
Основная идея conditional генерации — генерация объектов не просто из распределения \( p_0(x) \), а из условного распределения \( p_0(x|y) \).
Что такое условное распределение?
\( p_0(x|y) = \frac{p(y|x)p_0(x)}{p(y)} \) (из теоремы Байеса).
Рассмотрим его score-функцию для времени \( t=0 \).
\( \frac{\partial \log p_0(x|y)}{\partial x} = \frac{\partial}{\partial x} \log p(y|x) + \frac{\partial}{\partial x} \log p_0(x) — \underbrace{\frac{\partial}{\partial x} \log p(y)}_{=0} \)
И для произвольного момента времени:
\( \frac{\partial \log p_t(x|y)}{\partial x} = \underbrace{\frac{\partial}{\partial x} \log p(y|x_t)}_{\text{Classifier guidance}} + \frac{\partial}{\partial x} \log p_t(x) \)
Слагаемое \( s_{\theta}(x_t, t) \approx \frac{\partial}{\partial x} \log p_t(x) = -\frac{\epsilon_{\theta}(x,t)}{\sqrt{1 — \bar{\alpha}t}} \) аппроксимируется score-функцией, которую мы обучали. То есть \( \epsilon{\theta}(x,t) \approx — \sqrt{1 — \bar{\alpha}t}s{\theta}(x_t, t) \).
Если мы знаем первое слагаемое — мы можем использовать score-функцию для генерации условного распределения.
То есть нам нужно обучить модель классификатора, которая умеет работать на различных уровнях шума.
Тогда новое classifier-guided предсказание будет выглядеть так:
\( \bar{\epsilon}{\theta}(x_t,t) = \epsilon{\theta}(x_t,t) — \sqrt{1 — \bar{\alpha}t}w\nabla{x_t} \log p_{\phi}(y|x_t) \)
где \( w \) — весовой коэффициент (guidance strength), контролирующий силу нашего classifier-guidence.
Ключевая особенность classifier-free guidance — возможность trade-off между качеством и вариативностью сэмплов. Это контролируется параметром guidance strength.
1 ступень обучения
На первом этапе вводится модель студента \( \hat{x}_{\nu_1}(z_t, w) \) c обучаемыми параметрами \( \nu_1 \), которая учится предсказывать распределения учителя на любом шаге \( t \in [0, 1] \).
Но мы используем диффузионные модели с classifier-free guidance и выбираем различную guidance strength для предсказания в зависимости от желания, значит, получаем в результате разные предсказания. Как быть?
Пусть у нас есть отрезок возможных guidance strength \( [w_{min}, w_{max}] \) .
В этом случае авторы предлагают использовать следующую функцию потерь:
Для \( w \sim p_w, t \sim U[0,1], x \sim p_{data}(x) \)
\( \mathbb{E}{w,t,x} \Big[\omega(\lambda_t)\|\hat{x}{\nu_1}(z_t, w) — \hat{x}_{\theta}^w(z_t)\|^2_2\Big] \)
где
- \( \hat{x}{\theta}(z_t) = (1 + w)\hat{x}{c,\theta}(z_t) — w\hat{x}_{\theta}(z_t) \),
- \( z_t \sim q(z_t|x) \),
- \( p_w(w) = U[w_{min}, w_{max}] \)
То есть мы добавляем в функцию потерь результаты, полученные при разных guidance strength, и усредняем их.
Таким образом, весь алгоритм выглядит следующим образом:
2 ступень обучения
На втором этапе мы используем идею progressive distillation и дистиллируем нашу модель, выученную на первом шаге \( \hat{x}{\nu_1}(z_t, w) \), в новую \( \hat{x}{\nu_2}(z_t, w), \) которая умеет генерировать результат за несколько шагов.
Как мы говорили ранее, для дистилляции используют детерминистическую модель DDIM, поэтому весь алгоритм можно записать следующим образом:
Cтохастическое сэмплирование
До этого момента мы обсуждали дистилляцию в контексте применения DDIM, или детерминистической модели с инициализацией \( z_1^w \), где \( w \in [w_{min}, w_{max}] \) — заданная guidance strenght.
Но авторы утверждают: можно проводить также N-шаговое стохастическое сэмплирование. Схематически оно выглядит таким образом:
Его можно записать в виде следующего алгоритма:
Таким образом, авторы предложили дистиллировать диффузионные модели, которые чем-то гайдятся дополнительно (текстом, картинкой, меткой класса). Для этого используется двухэтапный процесс дистилляции: на первой ступени мы дистиллируем предобученную диффузию с различными guidance strength в студента, на второй — используем progressive distillation для обучения генератора за несколько шагов.
Также они предложили дополнительную технику сэмплирования с использованием стохастики.
И в итоге решили проблему гайданса 🙂.
Результаты в картинках
Авторы экспериментировали с разными моделями и исследовали:
- кондишнинг на класс на ImageNet 256×256;
- text2image generation диффузию на LAION-5B 512×512;
- text2image translation c использованием SDEdit;
- inpainting.
Давайте рассмотрим результаты каждой задачи.
text2image generation
Стоит отметить: при уменьшении количества шагов до 2-4 качество значительно ухудшается, и генерация становится не такой точной.
text2image translation
Ниже представлены примеры решения задачи image-to-image translation, в частности смены стиля.
Inpainting
Задачу inpainting также можно решать с помощью дистиллированной диффузионной модели.
Авторы сравнивают результаты генерации дистиллированной модели (предложенный пайплайн) и DDIM модели с сокращением количества шагов. Стоит отметить: наивное уменьшение шагов не даёт хорошего качества, изображения получаются очень размытыми и с заметными артефактами.
Результаты в метриках
Для проверки качества генерации авторы провели эксперимент на датасете LAION 512×512 для дистиллированной модели и моделей DPM и DPM++ (улучшенные солверы для быстрого сэмплирования). Стоит отметить: ускорение с помощью дистилляции значительно выигрывает при уменьшении количества шагов до 2-4.
На датасете ImageNet 64×64 авторы сравнили диффузионные модели, работающие в пиксельном пространстве. Они были дистиллированны на:
- различных giudance strength с помощью алгоритма выше;
- одном giudance strength (single-w);
- модели DDIM.
Стоит отметить: несмотря на хорошее качество на 8-м и дальнейших шагах, генерация за меньшее количество шагов даёт плохие метрики.
Также авторы построили графики зависимости метрик FID и IS для дистиллированной модели (со стохастическим и детерминистическим сэмплированием), модели DDPM, DDIM. Стоит отметить: предложенная авторами модель значительно превосходит DDPM и DDIM и достигает качества своего учителя даже на 8-ми шагах.
Вывод
В этой работе основное внимание уделяется моделям, которые на что-то кондишнятся — именно они сейчас широко используются на практике. Статья очень важна, поскольку помогла улучшить качество генерации. Однако несмотря на хороший результат на 8-м и дальнейших шагах, их уменьшение до 1-2 даёт очень плохое качество.
Заключение
Таким образом, мы начали погружение в мир дистилляции диффузионных моделей, чтобы научиться их ускорять. В данном обзоре мы узнали, с чего всё начиналось и как затем развивалось. Однако это были одни из первых работ, которые, к сожалению, не достигли хорошего качества генерации за несколько шагов.
В следующей части мы рассмотрим более современные работы в этой сфере, которые как раз применяются для ускорения больших и качественных генеративных моделей, таких как, например, SDXL. Продолжение следует 🙂