Назад
102

Consistency models

102

Введение

Диффузионные модели — один из ключевых инструментов в области генеративных моделей. Они применяются в генерации изображений, текстур, 3D-сцен, а также видео, аудио и др.

Рисунок 1. Генерация видео с помощью модели Genmo [источник]

Главное преимущество диффузионных моделей — их способность генерировать высококачественные изображения, обучаясь восстанавливать данные из зашумлённых версий. Однако, несмотря на свою эффективность, диффузионные модели имеют и недостатки, связанные с высокой вычислительной сложностью. Инференс модели обычно занимает много времени, в частности потому что требует большого количества шагов. Это ограничивает их практическое применение в реальном времени и на устройствах с ограниченными ресурсами.

Поэтому ускорение диффузионных моделей — важная задача, которой активно занимаются многие команды.

Ранее мы познакомились с некоторыми способами ускорения диффузионных моделей:

А сегодня мы изучим ещё один тип ускорения — Consistency models и Latent Consistency models. Итак, давайте начинать 🙂

Интегратор

Диффузия и ODE

Диффузионные модели можно рассматривать с точки зрения стохастических и обыкновенных дифференциальных уравнений (SDE и ODE соответственно) при переходе от дискретного шага по времени к непрерывному.

В таком случаем мы работаем с непрерывным диффузионным процессом: \( \{x(t)\}^T_{t=0} \) , \( x(0) \sim p_0 \) и \( x(T) \sim p_T \).

Мы можем записать прямое и обратное SDE для процесса, характеризующегося функцией \( f(x,t) \) и виннеровским процессом \( w \).

Рисунок 2. SDE для прямого и обратного процессов [источник]

Прямой процесс (forward SDE): данные → шум

\( dx_t = f(x_t,t)dt + \sigma(t)dw_t \) , \( t\in[0,T], T>0, f(\cdot, \cdot) \) — некоторая drift force, а \( \{w_t\}_{t \in [0,T]} \) — Броуновское движение, взятое с весовым коэффициентом \( \sigma(t) \).

При этом мы считаем, что \( p_0(x) = p_{data}(x) \), а \( p_T \) — нормальное распределение с нулевым средним и единичной дисперсией.

Обратный процесс (reverse SDE): шум → данные

Соответственно, обратное по времени уравнение выглядит следующим образом:

\( dx_t = \Big[f(x_t,t)dt — \sigma(t)^2\nabla \log p_t (x_t) \Big] dt + \sigma(t)dw_t \)

где \( \nabla \log p_t(x) \) — score функция \( p_t(x) \).

Рисунок 3. Прямое и обратное SDE [источник]

Cтохастическому дифференциальному уравнению можно сопоставить обыкновенное дифференциальное уравнение. Его решения (сэмплированные в точках t траектории) совпадают с решениями SDE:

\( dx_t = \Big[f(x_t,t)dt — \frac{1}{2}\sigma(t)^2\nabla \log p_t (x_t) \Big] dt \)

где \( \log p_t (x_t) \) — score функция \( p_t(x_t). \)

Но откуда мы знаем эту score-функцию для решения нашего диффура? По сути задача обучения диффузионной модели сводится к обучению нейронной сети, аппроксимирующей данную score-функцию. Более подробный вывод вы можете найти в презентации.

Таким образом, нам нужно обучить score-функцию:

\( \frac{d}{dx}\log p_t(x_t) \approx s_{\theta}(x_t, t) = -\mathbb{E}\Big[\frac{x_t — x}{t^2}|x_t\Big] \)

Таким образом. если мы знаем score функцию — мы знаем направление траектории в каждой точке.

Тогда, засэмплировав элемент из шума, мы можем использовать численные солверы обыкновенных дифференциальных уравнений (например, Эйлера) для получения итогового элемента.

Рисунок 4. Обратный процесс SDE [источник]

Решение SDE: \( dx_t = \Big[f(x_t,t)dt — \sigma(t)^2\nabla \log p_t (x_t) \Big] dt + \sigma(t)dw_t \)

Рисунок 5. Обратный процесс ODE [источник]

Решение ODE (probability flow): \( dx_t = \Big[f(x_t,t)dt — \frac{1}{2}\sigma(t)^2\nabla \log p_t (x_t) \Big] dt \)

То есть для сгенерированного диффузионной моделью объекта нам необходимо просто проинтегрировать полученную с помощью score-функции траекторию.

Рисунок 6. Основная идея генерации с помощью ODE при известной score-функции в каждой точке (см. на красную стрелочку)

Начинаем мы обратный процесс с какого-то нормального шума \( X_T \) в момент времени \( t \) и движемся к \( X_0 \) — сэмплу из исходного распределения изображений.

Важно отметить: интегрирование происходит до точки \( t = \epsilon \), где \( \epsilon \) — небольшое положительное число и конечный элемент \( \hat{x}_{\epsilon} \).

Обучение интегратора

А если бы мы сразу учили интегратор \( I_{\phi}(x_{t_1}, t_1, t_2) \), который получал бы сэмпл в точке \( t_1 \) и сразу предсказывал значение в точке \( t_2 \)?

В частности, можно было бы из объекта шума \( x_T \sim \mathcal{N}(0, \mathbf{I}) \) сразу сгенерировать объект исходного распределения \( \hat{x}0 = I{\phi}(x_{T}, T, 0) \).

Рисунок 7. Основная идея работы Consistency моделей [источник]

Задачу обучения такого интегратора и ставят перед собой авторы работы Consistency Models 😊 Давайте разберёмся с ней подробнее.

Consistency models

Для начала рассмотрим несколько определений, чтобы далее быть в контексте происходящего 🙂

Определения

Пусть у нас есть траектория решения ODE \( \{x_t\}{t \in [\epsilon, T]} \) генерации изображения с помощью диффузионной модели. Тогда определим consistency function \( f: (x_t, t) → x{\epsilon} \), которая на вход берёт какой-то зашумленный сэмпл с траектории в момент времени \( t \), непосредственно само время \( t \), и возвращает элемент в начале траектории.

Она должна обладать свойством self-consistency \( f(x_t, t) = f(x_{t’},t’) \) для любых точек траектории в моменты времени \( t, t’ \in [\epsilon, T] \).

Рисунок 8. Self-consistency [источник]

Такая функция обладает двумя важными особенностями:

  • консистентность — функция от объектов \( x_{t_1} \) и \( x_{t_2} \), принадлежащих одной траектории, будет возвращать одинаковые значения \( f(x_{t_1}, t_1) = f(x_{t_2}, t_2) \) для любых \( t_1, t_2 \in [\epsilon, T] \). То есть можно сказать, что производная нашей функции по времени равна нулю.
  • тождественное отображение — \( f(x_{\epsilon}, \epsilon) = x_{\epsilon} \), то есть \( f(\cdot, \epsilon) \). Это условие также можно назвать граничным условием.

Как же построить функцию, удовлетворяющую данным условиям?

Параметризация

Функцию \( f(\cdot, \cdot) \) мы планируем параметризовать через нейронную сеть, которая должна соответствовать перечисленным выше условиям. Авторы статьи предлагают два возможных варианта параметризации:

  1. \( f_{\theta}(x, t) = \begin{cases}x, \text{ }t = \epsilon \\ F_{\theta}(x,t), \text{ } t \in (\epsilon, T] \end{cases} \)
  2. \( f_{\theta}(x, t) = c_{skip}(t)x + c_{out}(t)F_{\theta}(x, t) \)

Где \( c_{skip} \) и \( c_{out} \) — две дифференцируемые функции, такие что

\( c_{skip}(\epsilon)=1 \), \( c_{out}(\epsilon) = 0 \)

и \( c_{skip} = \frac{\sigma_{data}^2}{\sigma^2 + \sigma_{data}^2} \), \( c_{out} = \frac{\sigma \sigma_{data}}{\sqrt{\sigma_{data}^2 + \sigma^2}} \)

\( \sigma_{data} = 0.5$ , $\sigma = t — \epsilon \)

Такая параметризация уже встречалась в других работах, например, в Karras et al., 2022 и Balaji et al., 2022 (там представлено более подробное описание). Автор Consistency models в дальнейших экспериментах используют именно её.

Сэмплирование

Пусть мы обучили consistency модель \( f_{\theta}(\cdot, \cdot) \). Тогда мы берём элемент из шума \( \hat{x}T \sim \mathcal{N}(0, T^2\mathbf{I}) \), прогоняем нашу модель и получаем \( \hat{x}{\epsilon} = f_{\theta}(\hat{x}_T, T) \). Это значит, что Consistency модель генерирует сэмпл за один шаг.

Однако мы также можем использовать Multi-step генератор. Для этого мы дополнительно зашумляем предсказанные результаты и снова делаем предсказание с помощью Consistency модели.

Рисунок 9. Consistency модель как многошаговый генератор [источник]

Алгоритм представлен ниже:

Рисунок 10. Алгоритм многошагового сэмплирования Consistency модели [источник]

Важно отметить: time points \( \{\tau_1, \tau_2, …, \tau_{N-1}\} \) найдены с помощью жадного алгоритма при оптимизации метрики FID.

Теперь возникает вопрос, а как обучить такую модель, чтобы также выполнялось условие консистентности? Рассмотрим два варианта обучения:

  • есть учитель (дистилляция) — Сonsistency Distillation (CD);
  • отдельное обучение модели — Consistency Training (CT).

Consistency Distillation (CD)

Пусть у нас есть предобученная score-модель \( s_{\phi}(x,t) \) и дискретизация отрезка времени \( [\epsilon, T] \) на \( N-1 \) интервал. Тогда можно оценить \( x_{t_n} \) из \( x_{t_{n+1}} \):

\( \hat{x}{t_n}^{\phi} := x{t_{n+1}} + (t_n — t_{n+1}) \Phi(x_{t_{n+1}}, t_{n+1}; \phi) \)

где \( \Phi(\cdot, \cdot; \phi) \) — некоторый ODE solver, например, можно взять Euler Solver.

Для тех, кто хочет вспомнить, что такое Euler solver

Пусть у нас есть диффур \( dy(t) = f(t, y(t))dt \). Зададим первую точку \( y=y(t_0) \) и некоторый шаг h (в идеале h → 0). Тогда \( t_n = t_0 + nh \) → \( t_{n+1} = t_n + h \) и \( y_{n+1} = y_n + hf(t_n, y_n) \)

А в случае диффузионных моделей: \( \frac{dx_t}{dt} = -t s_{\phi}(x_t, t) \) (более подробно см. работу) → Euler Solver

\( \Phi(x_{t_{n+1}}, t_{n+1}; \phi) = s_{\phi}(x_{t_{n+1}}, t_{n+1})t_{n+1} \)

например, в случае Euler Solver \( \Phi(x_{t_{n+1}}, t_{n+1}; \phi) = s_{\phi}(x_{t_{n+1}}, t_{n+1})t_{n+1} \).

Теперь у нас есть пары точек \( (\hat{x}{t_n}, x{t_{n+1}}) \), где \( x_{t_{n+1}} \) получен путём добавления шума к изначальному сэмплу. Значит, мы можем записать следующую функцию потерь:

\( \mathcal{L}{CD}^N(\mathbf{\theta}, \mathbf{\theta}^{-}; \phi) := \mathbb{E}{x \sim p_{data}, n \sim \mathcal{U}[1, N-1]} \Big[\lambda(t_n)d(f_{\theta}(x_{t_{n+1}}, t_{n+1}), f_{\theta^-}(\hat{x}{t_n}^{\phi}, t{n}))\Big] \)

где:

  • \( \lambda(\cdot) \in \mathbb{R}^+ \) — положительный весовой коэффициент;
  • \( \theta^- \) — скользящее среднее предыдущих оценок \( \theta \): \( \theta^{-} \leftarrow stopgrad(\mu \theta^{-} + (1-\mu)\theta) \); \( 0 \leq \mu < 1 \).

В качеcтве \( d \) можно использовать различные функции. Авторы предлагают три варианта:

  • \( d(x, y) = \|x — y\|^2_2 \) — \( l_2 \) distance;
  • \( d(x, y) = \|x — y\|_1 \) — \( l_1 \) distance;
  • LPIPS.

В результатах чуть ниже мы приведём сравнение функций.

Таким образом, обучение с учителем можно представить следующим образом:

Рисунок 11. Основной алгоритм дистилляции Consistency models [источник]

Важно отметить: в данной схеме обучения у нас есть изначальная модель с параметрами \( \theta \), а также модель с обновляемыми параметрами \( \theta^{-} \). В работе авторы называют \( f_{\theta^{-}} \) как «target network»‎, \( f_{\theta} \) — «online network»‎.

Мы схематично отобразили основную идею:

Рисунок 12. Основная идея Consistency distillation

Их можно инициализировать предобученой моделью с известными параметрами \( \phi \). С её помощью мы и делаем шаг солвером (так как она даёт нам знания о score-функции).

Consistency Training (CT)

Также модели могут обучаться и без предобученной диффузионной модели, то есть без предобученной score-модели \( s_{\phi}(x, t) \) \( \approx \nabla \log p_t(x) \). Как же тогда аппроксимировать score-функцию?

Напомним, что score-функция выглядит следующим образом:

\( \nabla \log p_t(x) = -\mathbb{E}\Big[\frac{x_t — x}{t^2}|x_t\Big] \)

Авторы приводят теорему (мы рассмотрим её без доказательства), которая показывает, как можно убрать из обучения предобученного учителя.

Let \( ∆t := max_{n \in [1, N-1]} \{|t_{n+1} — t_n \} \), \( d \) and \( f_{\theta} \) — are both twice continuously differentiable with bounded second derivatives, the weighting function \( \lambda (\cdot) \) is bounded, and \( \mathbb{E}[\| \nabla \log p_{t_n}(x_{t_n})\|^2_2] < \infty \).

Assume further that we use the Euler ODE solver, and the pre-trained score model matches the ground truth, i.e., \( \forall t \in [\epsilon, T] : s_{\phi}(x,t) \equiv \nabla \log p_t(x) \). Then,

\( \mathcal{L}{CD}^{N}(\theta, \theta^{*} ; \phi) = \mathcal{L}{CT}^{N}(\theta, \theta^{}) + o(\Delta t) \) , where the expectation is taken with respect to \( x \sim p_{\text{data}} \) , \( n \sim \mathcal{U}[1, N-1] \), and \( x_{t_{n+1}} \sim \mathcal{N}(x ; t_n^2 \mathbf{I} + t_n \mathbf{I}) \). The consistency training objective, denoted by \( \mathcal{L}_{CT}^{N}(\theta, \theta^{-} \) ), is defined as

\( \mathbb{E}[\lambda(t_n)d(f_{\theta}(x + t_{n+1}z, t_{n+1}) — f_{\theta^{-}}(x + t_n z, t_n)]^2 \) where \( z \sim \mathcal{N}(0, \mathbf{I}) \). Moreover, \( \mathcal{L}_{CT}^{N}(\theta, \theta^{-}) > O(\Delta t) \) if \( \mathcal{L}_{CD}^{N}(\theta, \theta^{-}; \phi) > 0 \).

Выглядит страшновато 👹 Но давайте посмотрим на основную мысль. Авторы вводят понятие consistency train loss:

\( \mathcal{L}{CT}^{N}(\theta, \theta^{-}) = \mathbb{E}[\lambda(t_n)d(f{\theta}(x + t_{n+1}z, t_{n+1}) — f_{\theta^{-}}(x + t_n z, t_n)]^2 \)

Здесь уже нет шага солвером, а есть зашумление сэмплов на соседние шаги по времени. Далее смотрится consistency loss. При хороших функциях (см. условие теоремы) можно показать, что consistency train и consistency distillation лоссы отличаются на o-малое от \( t \).

Следовательно, теперь мы можем не делать шаг солвером, а рассмотреть наше решение без дополнительно предобученной модели! 🙂

Таким образом, обучение можно представить следующим образом:

Рисунок 13. Основной алгоритм consistency train [источник]
Для любопытных

На самом деле похожая идея обучения двух моделей часто используется в задачах Self-Supervised learning. Например, что-то похожее можно найти в работе Mean teachers are better role models, Bootstrap Your Own Latent или Dino.

Однако в зависимости от задачи необходимо корректно выбирать лосс для обучения. И в случае consistency models основной принцип обучения строится на свойстве консистентности модели, которую мы обучаем.

Гиперпараметры

Таким образом, авторы рассмотрели два основных типа обучения Consistency моделей, а также объяснили, как проводить генерацию. При описании они рассказывали про возможность использования различных гиперпараметров, таких как солверы, \( d \)-функции и шаг дискретизации.

Для подбора лучших гиперпараметров они проводят ряд экспериментов на датасете CIFAR-10:

  1. Во-первых, они сравнивают различные \( d \)-функции, а именно \( l_2 \), \( l_1 \) и LPIPS. Как можно заметить, LPIPS даёт минимальное (самое хорошее) значение метрики FID.
  2. Во-вторых, они тестируют различные ODE-солверы для consistency distillation обучения: Euler и Heun.
Рисунок 14. Подбор гиперпараметров: 1) d-функция (а), различные солверы и размер шага дискретизации (b) и (c) для CT, а также аналогичные эксперименты для CT (d) [источник]

Как мы видим на графиках, наилучшие результаты дали:

  • Consistency distillation: LPIPS loss, Heun solver и количество шагов дискретизации \( N = 18 \).
  • Consistency training: LPIPS, Euler is ok, \( N \) — некий трейдоф между скоростью сходимости и качеством.

Результаты в метриках

Авторы провели много экспериментов и сравнений 🙂

Они сопоставили Progressive distillation (PD) и Consistency distillation (CD) на нескольких датасетах с различным количеством шагов сэмплирования и разными дистанс (\( d \)) функциями:

  • CIFAR-10;
  • ImageNet 64×64;
  • Bedroom 256×256;
  • Cat 256×256.
Рисунок 15. Сравнение Progressive distillation (PD) и Consistency distillation (CD) (l2, LPIPS) на различных датасетах с 1, 2, 3 и 4 шагами генерации [источник]

Также авторы сравнили на нескольких датасетах различные генеративные метрики и определили, что Сonsistency модели превосходят non-adversarial генеративные модели на CIFAR-10. Также они дают качественно хороший результат без дистилляции.

Рисунок 16. Сравнение по скорости и метрикам FID, IS для различных генераторов при Progressive distillation (PD) и Consistency distillation (CD) [источник]

Результаты в картинках

Ниже представлено сравнение в картинках:

  • верхний ряд — EDM;
  • средний ряд — CT в 1 шаг;
  • нижний ряд — CT в 2 шага.
Рисунок 17. Сравнение качества генерации на картинках [источник]

Также можно посмотреть на то, как модель справляется с другими задачами, например:

  • superresolution;
  • editing;
  • colorization.
Рисунок 18. Результаты работы для задач superresolution, editing и colorization [источник]

Latent Сonsistency models

Работа Consistency models была направлена в первую очередь на ускорение диффузионных моделей в пиксельном пространстве, что позволяло генерировать изображения лишь небольшого разрешения. Однако большая часть современных диффузионных моделей работают в латентном пространстве. Поэтому логичным продолжением работы стала статья Latent consistency models, авторы которой применили идею CT и CD, но уже к латентным моделям. Давайте разберёмся, какие там были новые идеи 🙂

Генерация

Так же, как и в Consistency моделях, мы будем обучать нашу функцию на свойство консистентности и получать одношаговый генератор из шума в изображение.

Рисунок 19. Self-consistency [источник]

При этом обученный одношаговый генератор можно также использовать как многошаговый:

Рисунок 20. Алгоритм multi-step генерации [источник]

Latent Consistency Distillation (LCD)

А как же обучать такую модель?

На самом деле основной алгоритм дистилляции несильно отличается от алгоритма Consistency моделей за исключением трёх основных моментов:

  1. Мы теперь работаем в латентном пространстве, поэтому всё зашумление и расшумление происходит именно в нём.
  2. Появляется так называемый skipping интервал \( k \). В работе Consistency models для обучения консистентности брались соседние шаги траектории генерации ( \( t_n \) и \( t_{n+1} \) ). Однако в латентном пространстве разница между латентами на этих шагах очень маленькая, и предсказания уже в целом довольно близки, что ведёт к более плохой сходимости и качеству обучения. Для решения проблемы авторы предлагают брать не соседние шаги, а шаги на каком-то расстоянии \( k \) друг от друга ( \( t_n \) и \( t_{n+k} \)).

Авторы проводят ряд экспериментов для вычисления этого оптимального k:

Рисунок 21. Сравнение различных skip k шагов для трёх диффузионных сэмплеров: DDIM, DPM, DPM++ [источник]
  1. Мы работаем не с unconditional диффузионным моделям, а с conditional моделями. Следовательно, мы теперь должны как-то учитывать Classifier-free guidance. Guidance scale для conditional генерации является некоторым trade-off регуляризатором между качеством и разнообразием генерации. При обучении guidance scale случайным образом сэмлпируется из отрезка \( [w_{min}, w_{max}] \).

Для инференса авторы сравнивают влияние guidance scale для генерации по двум основным метрикам — FID и CLIP-Score. Как и предполагается, при больших значениях guidance scale улучшается CLIP, однако ухудшается FID.

Рисунок 22. Влияние значений guidance scale при инференсе модели на метрики CLIP-Score и FID [источник]

Также различие в качестве для guidance scale можно увидеть на изображениях. При совсем низких значениях картинки получаются немного размытыми, а также на них присутствуют хорошо заметные артефакты.

Рисунок 23. Влияние значений guidance scale при инференсе модели на генерацию [источник]

Таким образом, весь алгоритм обучения можно записать следующим образом (дополнительно на рисунке 22 показан алгоритм обучения Consistency models):

Рисунок 24. Основной алгоритм Latent Consistency Distillation [источник]
Рисунок 25. Основной алгоритм Consistency Distillation [источник]

Latent Consistency Fine-Tuning (LCF)

Когда мы обсуждали Consistency models, мы говорили не только про дистилляцию, но и про Consistency Training (CT). В случае латентных моделей авторы останавливаются на Latent Consistency Fine-Tuning (LCF) для некоторого датасета пользователя:

Рисунок 26. Основной алгоритм LCF [источник]

Результаты в метриках

Сравнение проводилось относительно следующих метрик:

Как можно заметить, LCM модели значительно превосходят предложенные бейзлайны по метрикам на генерации за 1, 2 и 4 шага.

Рисунок 27. Сравнение по FID и CLIP-Score метрикам для text-to-image генерации \( 512 \times 512 \) на датасете LAION-Aesthetic-6+ с фиксированным \( w=8 \) [источник]
Рисунок 28. Сравнение по FID и CLIP-Score метрикам для text-to-image генерации \( 768 \times 768 \) на датасете LAION-Aesthetic-6+ с фиксированным \( w=8 \) [источник]

Результаты в картинках

При визуальном сравнении мы видим, что двух шагов модели не хватает для генерации изображений высокого качества.

Рисунок 29. Визуальное сравнение для text-to-image генерации на датасете LAION-Aesthetic-6+ с фиксированным \( w=8 \) [источник]

Также авторы показывают результаты работы предложенного алгоритма LCF для разного количества шагов finetuning:

Рисунок 30. Результаты LCF на двух датасетах: Pokemon Dataset (слева), Simpsons Dataset (справа) [источник]

Заключение

Итак, сегодня мы познакомились с ещё одним способом ускорения диффузионных моделей за счёт сокращения количества шагов генерации при обучении модели на свойство self-consistency.

Такую идею можно применять как для моделей, работающих в пиксельном пространстве, так и для латентных диффузионных моделей.

Этот метод хорошо себя показывает и используется для ускорения новых генераторов, например, SDXL.

Однако сейчас всё большую популярность набирают дистилляционные методы, включающие adversarial training. О них мы подробнее рассказали в нашем предыдущем посте — Diffusion Distillation 🙂

Полезные ссылки

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!