Consistency models
Введение
Диффузионные модели — один из ключевых инструментов в области генеративных моделей. Они применяются в генерации изображений, текстур, 3D-сцен, а также видео, аудио и др.
Главное преимущество диффузионных моделей — их способность генерировать высококачественные изображения, обучаясь восстанавливать данные из зашумлённых версий. Однако, несмотря на свою эффективность, диффузионные модели имеют и недостатки, связанные с высокой вычислительной сложностью. Инференс модели обычно занимает много времени, в частности потому что требует большого количества шагов. Это ограничивает их практическое применение в реальном времени и на устройствах с ограниченными ресурсами.
Поэтому ускорение диффузионных моделей — важная задача, которой активно занимаются многие команды.
Ранее мы познакомились с некоторыми способами ускорения диффузионных моделей:
А сегодня мы изучим ещё один тип ускорения — Consistency models и Latent Consistency models. Итак, давайте начинать 🙂
Интегратор
Диффузия и ODE
Диффузионные модели можно рассматривать с точки зрения стохастических и обыкновенных дифференциальных уравнений (SDE и ODE соответственно) при переходе от дискретного шага по времени к непрерывному.
В таком случаем мы работаем с непрерывным диффузионным процессом: \( \{x(t)\}^T_{t=0} \) , \( x(0) \sim p_0 \) и \( x(T) \sim p_T \).
Мы можем записать прямое и обратное SDE для процесса, характеризующегося функцией \( f(x,t) \) и виннеровским процессом
Прямой процесс (forward SDE): данные → шум
\( dx_t = f(x_t,t)dt + \sigma(t)dw_t \) , \( t\in[0,T], T>0, f(\cdot, \cdot) \) — некоторая drift force, а \( \{w_t\}_{t \in [0,T]} \) — Броуновское движение, взятое с весовым коэффициентом \( \sigma(t) \).
При этом мы считаем, что \( p_0(x) = p_{data}(x) \), а \( p_T \) — нормальное распределение с нулевым средним и единичной дисперсией.
Обратный процесс (reverse SDE): шум → данные
Соответственно, обратное по времени уравнение выглядит следующим образом:
\( dx_t = \Big[f(x_t,t)dt — \sigma(t)^2\nabla \log p_t (x_t) \Big] dt + \sigma(t)dw_t \)
где \( \nabla \log p_t(x) \) — score функция \( p_t(x) \).
Cтохастическому дифференциальному уравнению можно сопоставить обыкновенное дифференциальное уравнение. Его решения (сэмплированные в точках t траектории) совпадают с решениями SDE:
\( dx_t = \Big[f(x_t,t)dt — \frac{1}{2}\sigma(t)^2\nabla \log p_t (x_t) \Big] dt \)
где \( \log p_t (x_t) \) — score функция \( p_t(x_t). \)
Но откуда мы знаем эту score-функцию для решения нашего диффура? По сути задача обучения диффузионной модели сводится к обучению нейронной сети, аппроксимирующей данную score-функцию. Более подробный вывод вы можете найти в презентации.
Таким образом, нам нужно обучить score-функцию:
\( \frac{d}{dx}\log p_t(x_t) \approx s_{\theta}(x_t, t) = -\mathbb{E}\Big[\frac{x_t — x}{t^2}|x_t\Big] \)
Таким образом. если мы знаем score функцию — мы знаем направление траектории в каждой точке.
Тогда, засэмплировав элемент из шума, мы можем использовать численные солверы обыкновенных дифференциальных уравнений (например, Эйлера) для получения итогового элемента.
Решение SDE: \( dx_t = \Big[f(x_t,t)dt — \sigma(t)^2\nabla \log p_t (x_t) \Big] dt + \sigma(t)dw_t \)
Решение ODE (probability flow): \( dx_t = \Big[f(x_t,t)dt — \frac{1}{2}\sigma(t)^2\nabla \log p_t (x_t) \Big] dt \)
То есть для сгенерированного диффузионной моделью объекта нам необходимо просто проинтегрировать полученную с помощью score-функции траекторию.
Начинаем мы обратный процесс с какого-то нормального шума \( X_T \) в момент времени \( t \) и движемся к \( X_0 \) — сэмплу из исходного распределения изображений.
Важно отметить: интегрирование происходит до точки \( t = \epsilon \), где \( \epsilon \) — небольшое положительное число и конечный элемент \( \hat{x}_{\epsilon} \).
Обучение интегратора
А если бы мы сразу учили интегратор \( I_{\phi}(x_{t_1}, t_1, t_2) \), который получал бы сэмпл в точке \( t_1 \) и сразу предсказывал значение в точке \( t_2 \)?
В частности, можно было бы из объекта шума \( x_T \sim \mathcal{N}(0, \mathbf{I}) \) сразу сгенерировать объект исходного распределения \( \hat{x}0 = I{\phi}(x_{T}, T, 0) \).
Задачу обучения такого интегратора и ставят перед собой авторы работы Consistency Models 😊 Давайте разберёмся с ней подробнее.
Consistency models
Для начала рассмотрим несколько определений, чтобы далее быть в контексте происходящего 🙂
Определения
Пусть у нас есть траектория решения ODE \( \{x_t\}{t \in [\epsilon, T]} \) генерации изображения с помощью диффузионной модели. Тогда определим consistency function \( f: (x_t, t) → x{\epsilon} \), которая на вход берёт какой-то зашумленный сэмпл с траектории в момент времени \( t \), непосредственно само время \( t \), и возвращает элемент в начале траектории.
Она должна обладать свойством self-consistency \( f(x_t, t) = f(x_{t’},t’) \) для любых точек траектории в моменты времени \( t, t’ \in [\epsilon, T] \).
Такая функция обладает двумя важными особенностями:
- консистентность — функция от объектов \( x_{t_1} \) и \( x_{t_2} \), принадлежащих одной траектории, будет возвращать одинаковые значения \( f(x_{t_1}, t_1) = f(x_{t_2}, t_2) \) для любых \( t_1, t_2 \in [\epsilon, T] \). То есть можно сказать, что производная нашей функции по времени равна нулю.
- тождественное отображение — \( f(x_{\epsilon}, \epsilon) = x_{\epsilon} \), то есть \( f(\cdot, \epsilon) \). Это условие также можно назвать граничным условием.
Как же построить функцию, удовлетворяющую данным условиям?
Параметризация
Функцию \( f(\cdot, \cdot) \) мы планируем параметризовать через нейронную сеть, которая должна соответствовать перечисленным выше условиям. Авторы статьи предлагают два возможных варианта параметризации:
- \( f_{\theta}(x, t) = \begin{cases}x, \text{ }t = \epsilon \\ F_{\theta}(x,t), \text{ } t \in (\epsilon, T] \end{cases} \)
- \( f_{\theta}(x, t) = c_{skip}(t)x + c_{out}(t)F_{\theta}(x, t) \)
Где \( c_{skip} \) и \( c_{out} \) — две дифференцируемые функции, такие что
\( c_{skip}(\epsilon)=1 \), \( c_{out}(\epsilon) = 0 \)
и \( c_{skip} = \frac{\sigma_{data}^2}{\sigma^2 + \sigma_{data}^2} \), \( c_{out} = \frac{\sigma \sigma_{data}}{\sqrt{\sigma_{data}^2 + \sigma^2}} \)
\( \sigma_{data} = 0.5$ , $\sigma = t — \epsilon \)
Такая параметризация уже встречалась в других работах, например, в Karras et al., 2022 и Balaji et al., 2022 (там представлено более подробное описание). Автор Consistency models в дальнейших экспериментах используют именно её.
Сэмплирование
Пусть мы обучили consistency модель \( f_{\theta}(\cdot, \cdot) \). Тогда мы берём элемент из шума \( \hat{x}T \sim \mathcal{N}(0, T^2\mathbf{I}) \), прогоняем нашу модель и получаем \( \hat{x}{\epsilon} = f_{\theta}(\hat{x}_T, T) \). Это значит, что Consistency модель генерирует сэмпл за один шаг.
Однако мы также можем использовать Multi-step генератор. Для этого мы дополнительно зашумляем предсказанные результаты и снова делаем предсказание с помощью Consistency модели.
Алгоритм представлен ниже:
Важно отметить: time points \( \{\tau_1, \tau_2, …, \tau_{N-1}\} \) найдены с помощью жадного алгоритма при оптимизации метрики FID.
Теперь возникает вопрос, а как обучить такую модель, чтобы также выполнялось условие консистентности? Рассмотрим два варианта обучения:
- есть учитель (дистилляция) — Сonsistency Distillation (CD);
- отдельное обучение модели — Consistency Training (CT).
Consistency Distillation (CD)
Пусть у нас есть предобученная score-модель \( s_{\phi}(x,t) \) и дискретизация отрезка времени \( [\epsilon, T] \) на \( N-1 \) интервал. Тогда можно оценить \( x_{t_n} \) из \( x_{t_{n+1}} \):
\( \hat{x}{t_n}^{\phi} := x{t_{n+1}} + (t_n — t_{n+1}) \Phi(x_{t_{n+1}}, t_{n+1}; \phi) \)
где \( \Phi(\cdot, \cdot; \phi) \) — некоторый ODE solver, например, можно взять Euler Solver.
Для тех, кто хочет вспомнить, что такое Euler solver
Пусть у нас есть диффур \( dy(t) = f(t, y(t))dt \). Зададим первую точку \( y=y(t_0) \) и некоторый шаг h (в идеале h → 0). Тогда \( t_n = t_0 + nh \) → \( t_{n+1} = t_n + h \) и \( y_{n+1} = y_n + hf(t_n, y_n) \)
А в случае диффузионных моделей: \( \frac{dx_t}{dt} = -t s_{\phi}(x_t, t) \) (более подробно см. работу) → Euler Solver
\( \Phi(x_{t_{n+1}}, t_{n+1}; \phi) = s_{\phi}(x_{t_{n+1}}, t_{n+1})t_{n+1} \)
например, в случае Euler Solver \( \Phi(x_{t_{n+1}}, t_{n+1}; \phi) = s_{\phi}(x_{t_{n+1}}, t_{n+1})t_{n+1} \).
Теперь у нас есть пары точек \( (\hat{x}{t_n}, x{t_{n+1}}) \), где \( x_{t_{n+1}} \) получен путём добавления шума к изначальному сэмплу. Значит, мы можем записать следующую функцию потерь:
\( \mathcal{L}{CD}^N(\mathbf{\theta}, \mathbf{\theta}^{-}; \phi) := \mathbb{E}{x \sim p_{data}, n \sim \mathcal{U}[1, N-1]} \Big[\lambda(t_n)d(f_{\theta}(x_{t_{n+1}}, t_{n+1}), f_{\theta^-}(\hat{x}{t_n}^{\phi}, t{n}))\Big] \)
где:
- \( \lambda(\cdot) \in \mathbb{R}^+ \) — положительный весовой коэффициент;
- \( \theta^- \) — скользящее среднее предыдущих оценок \( \theta \): \( \theta^{-} \leftarrow stopgrad(\mu \theta^{-} + (1-\mu)\theta) \); \( 0 \leq \mu < 1 \).
В качеcтве \( d \) можно использовать различные функции. Авторы предлагают три варианта:
- \( d(x, y) = \|x — y\|^2_2 \) — \( l_2 \) distance;
- \( d(x, y) = \|x — y\|_1 \) — \( l_1 \) distance;
- LPIPS.
В результатах чуть ниже мы приведём сравнение функций.
Таким образом, обучение с учителем можно представить следующим образом:
Важно отметить: в данной схеме обучения у нас есть изначальная модель с параметрами \( \theta \), а также модель с обновляемыми параметрами \( \theta^{-} \). В работе авторы называют \( f_{\theta^{-}} \) как «target network», \( f_{\theta} \) — «online network».
Мы схематично отобразили основную идею:
Их можно инициализировать предобученой моделью с известными параметрами \( \phi \). С её помощью мы и делаем шаг солвером (так как она даёт нам знания о score-функции).
Consistency Training (CT)
Также модели могут обучаться и без предобученной диффузионной модели, то есть без предобученной score-модели \( s_{\phi}(x, t) \) \( \approx \nabla \log p_t(x) \). Как же тогда аппроксимировать score-функцию?
Напомним, что score-функция выглядит следующим образом:
\( \nabla \log p_t(x) = -\mathbb{E}\Big[\frac{x_t — x}{t^2}|x_t\Big] \)
Авторы приводят теорему (мы рассмотрим её без доказательства), которая показывает, как можно убрать из обучения предобученного учителя.
Let \( ∆t := max_{n \in [1, N-1]} \{|t_{n+1} — t_n \} \), \( d \) and \( f_{\theta} \) — are both twice continuously differentiable with bounded second derivatives, the weighting function \( \lambda (\cdot) \) is bounded, and \( \mathbb{E}[\| \nabla \log p_{t_n}(x_{t_n})\|^2_2] < \infty \).
Assume further that we use the Euler ODE solver, and the pre-trained score model matches the ground truth, i.e., \( \forall t \in [\epsilon, T] : s_{\phi}(x,t) \equiv \nabla \log p_t(x) \). Then,
\( \mathcal{L}{CD}^{N}(\theta, \theta^{*} ; \phi) = \mathcal{L}{CT}^{N}(\theta, \theta^{}) + o(\Delta t) \) , where the expectation is taken with respect to \( x \sim p_{\text{data}} \) , \( n \sim \mathcal{U}[1, N-1] \), and \( x_{t_{n+1}} \sim \mathcal{N}(x ; t_n^2 \mathbf{I} + t_n \mathbf{I}) \). The consistency training objective, denoted by \( \mathcal{L}_{CT}^{N}(\theta, \theta^{-} \) ), is defined as
\( \mathbb{E}[\lambda(t_n)d(f_{\theta}(x + t_{n+1}z, t_{n+1}) — f_{\theta^{-}}(x + t_n z, t_n)]^2 \) where \( z \sim \mathcal{N}(0, \mathbf{I}) \). Moreover, \( \mathcal{L}_{CT}^{N}(\theta, \theta^{-}) > O(\Delta t) \) if \( \mathcal{L}_{CD}^{N}(\theta, \theta^{-}; \phi) > 0 \).
Выглядит страшновато 👹 Но давайте посмотрим на основную мысль. Авторы вводят понятие consistency train loss:
\( \mathcal{L}{CT}^{N}(\theta, \theta^{-}) = \mathbb{E}[\lambda(t_n)d(f{\theta}(x + t_{n+1}z, t_{n+1}) — f_{\theta^{-}}(x + t_n z, t_n)]^2 \)
Здесь уже нет шага солвером, а есть зашумление сэмплов на соседние шаги по времени. Далее смотрится consistency loss. При хороших функциях (см. условие теоремы) можно показать, что consistency train и consistency distillation лоссы отличаются на o-малое от \( t \).
Следовательно, теперь мы можем не делать шаг солвером, а рассмотреть наше решение без дополнительно предобученной модели! 🙂
Таким образом, обучение можно представить следующим образом:
Для любопытных
На самом деле похожая идея обучения двух моделей часто используется в задачах Self-Supervised learning. Например, что-то похожее можно найти в работе Mean teachers are better role models, Bootstrap Your Own Latent или Dino.
Однако в зависимости от задачи необходимо корректно выбирать лосс для обучения. И в случае consistency models основной принцип обучения строится на свойстве консистентности модели, которую мы обучаем.
Гиперпараметры
Таким образом, авторы рассмотрели два основных типа обучения Consistency моделей, а также объяснили, как проводить генерацию. При описании они рассказывали про возможность использования различных гиперпараметров, таких как солверы, \( d \)-функции и шаг дискретизации.
Для подбора лучших гиперпараметров они проводят ряд экспериментов на датасете CIFAR-10:
- Во-первых, они сравнивают различные \( d \)-функции, а именно \( l_2 \), \( l_1 \) и LPIPS. Как можно заметить, LPIPS даёт минимальное (самое хорошее) значение метрики FID.
- Во-вторых, они тестируют различные ODE-солверы для consistency distillation обучения: Euler и Heun.
Как мы видим на графиках, наилучшие результаты дали:
- Consistency distillation: LPIPS loss, Heun solver и количество шагов дискретизации \( N = 18 \).
- Consistency training: LPIPS, Euler is ok, \( N \) — некий трейдоф между скоростью сходимости и качеством.
Результаты в метриках
Авторы провели много экспериментов и сравнений 🙂
Они сопоставили Progressive distillation (PD) и Consistency distillation (CD) на нескольких датасетах с различным количеством шагов сэмплирования и разными дистанс (\( d \)) функциями:
- CIFAR-10;
- ImageNet 64×64;
- Bedroom 256×256;
- Cat 256×256.
Также авторы сравнили на нескольких датасетах различные генеративные метрики и определили, что Сonsistency модели превосходят non-adversarial генеративные модели на CIFAR-10. Также они дают качественно хороший результат без дистилляции.
Результаты в картинках
Ниже представлено сравнение в картинках:
- верхний ряд — EDM;
- средний ряд — CT в 1 шаг;
- нижний ряд — CT в 2 шага.
Также можно посмотреть на то, как модель справляется с другими задачами, например:
- superresolution;
- editing;
- colorization.
Latent Сonsistency models
Работа Consistency models была направлена в первую очередь на ускорение диффузионных моделей в пиксельном пространстве, что позволяло генерировать изображения лишь небольшого разрешения. Однако большая часть современных диффузионных моделей работают в латентном пространстве. Поэтому логичным продолжением работы стала статья Latent consistency models, авторы которой применили идею CT и CD, но уже к латентным моделям. Давайте разберёмся, какие там были новые идеи 🙂
Генерация
Так же, как и в Consistency моделях, мы будем обучать нашу функцию на свойство консистентности и получать одношаговый генератор из шума в изображение.
При этом обученный одношаговый генератор можно также использовать как многошаговый:
Latent Consistency Distillation (LCD)
А как же обучать такую модель?
На самом деле основной алгоритм дистилляции несильно отличается от алгоритма Consistency моделей за исключением трёх основных моментов:
- Мы теперь работаем в латентном пространстве, поэтому всё зашумление и расшумление происходит именно в нём.
- Появляется так называемый skipping интервал \( k \). В работе Consistency models для обучения консистентности брались соседние шаги траектории генерации ( \( t_n \) и \( t_{n+1} \) ). Однако в латентном пространстве разница между латентами на этих шагах очень маленькая, и предсказания уже в целом довольно близки, что ведёт к более плохой сходимости и качеству обучения. Для решения проблемы авторы предлагают брать не соседние шаги, а шаги на каком-то расстоянии \( k \) друг от друга ( \( t_n \) и \( t_{n+k} \)).
Авторы проводят ряд экспериментов для вычисления этого оптимального k:
- Мы работаем не с unconditional диффузионным моделям, а с conditional моделями. Следовательно, мы теперь должны как-то учитывать Classifier-free guidance. Guidance scale для conditional генерации является некоторым trade-off регуляризатором между качеством и разнообразием генерации. При обучении guidance scale случайным образом сэмлпируется из отрезка \( [w_{min}, w_{max}] \).
Для инференса авторы сравнивают влияние guidance scale для генерации по двум основным метрикам — FID и CLIP-Score. Как и предполагается, при больших значениях guidance scale улучшается CLIP, однако ухудшается FID.
Также различие в качестве для guidance scale можно увидеть на изображениях. При совсем низких значениях картинки получаются немного размытыми, а также на них присутствуют хорошо заметные артефакты.
Таким образом, весь алгоритм обучения можно записать следующим образом (дополнительно на рисунке 22 показан алгоритм обучения Consistency models):
Latent Consistency Fine-Tuning (LCF)
Когда мы обсуждали Consistency models, мы говорили не только про дистилляцию, но и про Consistency Training (CT). В случае латентных моделей авторы останавливаются на Latent Consistency Fine-Tuning (LCF) для некоторого датасета пользователя:
Результаты в метриках
Сравнение проводилось относительно следующих метрик:
Как можно заметить, LCM модели значительно превосходят предложенные бейзлайны по метрикам на генерации за 1, 2 и 4 шага.
Результаты в картинках
При визуальном сравнении мы видим, что двух шагов модели не хватает для генерации изображений высокого качества.
Также авторы показывают результаты работы предложенного алгоритма LCF для разного количества шагов finetuning:
Заключение
Итак, сегодня мы познакомились с ещё одним способом ускорения диффузионных моделей за счёт сокращения количества шагов генерации при обучении модели на свойство self-consistency.
Такую идею можно применять как для моделей, работающих в пиксельном пространстве, так и для латентных диффузионных моделей.
Этот метод хорошо себя показывает и используется для ускорения новых генераторов, например, SDXL.
Однако сейчас всё большую популярность набирают дистилляционные методы, включающие adversarial training. О них мы подробнее рассказали в нашем предыдущем посте — Diffusion Distillation 🙂