Назад
63

InstaFlow

63

Введение

Диффузионные модели на стыке физики и машинного обучения — мощный инструмент для моделирования сложных процессов генерации данных. За последние годы они получили широкое признание благодаря возможности эффективно генерировать реалистичные изображения, видео и даже звуковые сигналы.

Рисунок 1. Пример генерации видео с помощью диффузионных моделей [источник]

Несмотря на свои сильные стороны (качество и разнообразие генерации), диффузионные модели остаются медленными генераторами. Для генерации одного сэмпла часто нужно много запусков одной нейронной сети, что занимает немало времени в сравнении с GANs.

Поэтому сегодня мы рассмотрим метод ускорения диффузионных моделей через сокращение количества шагов генерации, а именно — работу InstaFlow.

Спойлер: будет математика 😟

Но не пугайтесь! В конце мы собрали саммари, чтобы легче ориентироваться во всём происходящем и возвращаться к нужным местам при необходимости 😊

Итак, наш план:

  1. Кратко вспомним, что такое диффузионные модели и как они связаны со стохастическими и обыкновенными дифференциальными уравнениями.
  2. Познакомимся со flow matching’ом.
  3. Разберёмся, что такое ректифицированные потоки, как их можно «выпрямлять».
  4. И обсудим идею авторов работы InstaFLow.

Optimal transport problem 🏎

Рисунок 2. Перенос одного распределения в другое [источник]

Transport Mapping Problem

Задача оптимального транспорта имеет непосредственное отношение к генеративным моделям и переносу стиля. Как перенести одно распределение в другое, ещё и сделать это оптимальным способом? И что значит «оптимальным»? Давайте обсудим эту задачу подробнее 🙂

Есть два эмпирических распределения \( \pi_0 \) и \( \pi_1 \in \mathbb{R}^d \). Необходимо построить такую транспортировку (транспортную карту) \( T \): \( \mathbb{R}^d \rightarrow \mathbb{R}^d \), что \( Z_1 :=T(Z_0) \sim \pi_1 \), где \( Z_0 \sim \pi_0 \).

Можно построить бесконечно много таких транспортных карт \( T \), поэтому обычно пытаются решить задачу оптимального транспорта через поиск \( T \), которое минимизирует транспортную стоимость: \( min_{T} \mathbb{E} \Big[c(T(Z_0) — Z_0)\Big] \), где \( с \) — transport cost.

Тем не менее решение этой задачи довольно сложное, особенно для многомерных задач. Важным вопросом также является выбор транспортной стоимости \( c \).

В случае генеративных моделей \( T \) чаще всего параметризуется нейронной сетью и обучается в стиле алгоритмов GAN / максимизацией правдоподобия (MLE). Но GAN страдают от мод коллапса, а VAE даёт слишком замыленные результаты.

С недавних пор широкую популярность получил неявный способ транспортировки через непрерывный по времени процесс, в частности — диффузионные модели.

Такую известность они получили, в первую очередь, из-за качества и разнообразия генерации. Однако качественные результаты возможны при генерации за большое количество шагов и требуют много времени в сравнении с другими генеративными моделями (например, GANs).

Отсюда мы получаем Generative Learning Trilemma: ни одна генеративная модель не удовлетворяет сразу трём пунктам — разнообразию, качеству и скорости.

Рисунок 3. Generative Learning Trilemma

Поэтому возникает вопрос: а как ускорить генерацию диффузии?

Диффузия 🧪

Рисунок 4. Зашумление изображения с помощью библиотеки diffusers. Изображение сгенерировано через SDXL Turbo

Перед обсуждением ускорения диффузии давайте вспомним, что такое диффузионные модели.

В этом обзоре мы предполагаем, что вы уже знакомы с диффузионными моделями (а если нет — советуем прочитать наш пост по данной теме). Поэтому сейчас сделаем небольшое введение.

Диффузионный процесс состоит из двух частей:

  • Forward process — постепенное добавление шума к распределению объектов до тех пор, пока итоговое распределение не станет просто распределением шума (чаще всего работают с гауссовским шумом, но недавно была представлена работа или работа, показывающая, как можно использовать другие типы шумов);
  • Backward process — сэмплирование объекта из латентного распределения и путём пошагового расшумления получение объекта из исходного распределения.

Прямой процесс (forward process)

Пусть у нас есть сэмпл \( x_0 \sim q(x) \). Тогда прямой процесс диффузии — постепенное добавление небольшого количества шума, чаще всего гауссовского (далее работаем с ним), в течение \( T \) шагов, в результате которых получаем последовательность зашумленных сэмплов \( x_1, x_2 …, x_T \) и в итоге гауссовский шум. Количество добавляемого шума регулируется с помощью variance scheduler \( \{\beta_t \in (0,1)\}^T_t=1 \).

Тогда зашумленное распределение на шаге \( t \) можно получить следующим образом:

\( q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbf{I}) \)

\( q(x_{1:T}|x_0) = \prod_{t=1}^{T}q(x_t|x_{t-1}) \)

Рисунок 5. Зашумление изображения с помощью библиотеки diffusers. Изображение сгенерировано через SDXL Turbo

При этом с помощью трюка репараметризации из получившегося распределения можно сэмплировать:

\( x_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t}\epsilon_{t-1} \)

Обратный процесс (backward process)

Если мы сможем обернуть данный процесс — получим генеративную модель, которая возьмёт сэмпл из гауссовского шума с нулевым средним и единичной дисперсией \( x \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$] \) и восстановит элемент из исходного распределения.

Важная особенность диффузионного процесса: если \( \beta_t \) — достаточно маленькое значение, то \( q(x_{t-1}|x_t) \) тоже будет гауссовским. Однако мы не можем так просто оценить это распределение. Поэтому нашей задачей становится аппроксимация с помощью нейронной модели \( p_{\theta} \) с параметрами \( \theta \):

\( p_{\theta}(x_{0:T}) = p(x_T) \prod_{t=1}^{T}p_{\theta}(x_{t-1}|x_{t}) \)

\( p_{\theta}(x_{t-1}|x_{t}) = \mathcal{N}(x_{t-1}; \mathbf{\mu}{\theta}(x_t, t), \mathbf{\sum}{\theta}(x_t,t)) \)

Рисунок 6. Схема работы диффузионной модели с моделированием обратного процесса [источник]

Диффузия и SDE

Прежде, чем двигаться дальше, необходимо разобраться, как связана диффузия с SDE и ODE.

При переходе от дискретных шагов по времени к непрерывным можно связать процесс диффузионных моделей со стохастическими дифференциальными уравнениями.

Тогда мы работаем с непрерывным диффузионным процессом: \( \{x(t)\}^T_{t=0} \) , \( x(0) \sim p_0 \) и \( x(T) \sim p_T \).

Можно записать прямое и обратное SDE для процесса, характеризующегося функцией \( f(x,t) \) и виннеровским процессом \( w \).

Прямой процесс (forward SDE): данные → шум

\( dx = f(x,t)dt + g(t)dw \)

Рисунок 7. Зашумление данных с помощью непрерывного по времени стохастического процесса [источник]

Обратный процесс (reverse SDE): шум → данные

\( dx = \Big[f(x,t) — g^2(t) \nabla_x \log p_t(x)\Big]dt + g(t)d \hat{w} \)

где \( \nabla \log p_t(x) \) — score функция \( p_t(x) \).

Рисунок 8. Процесс генерации данных из шума путём обращения SDE по времени [источник]
Рисунок 9. Прямое и обратное SDE
Рисунок 10. SDE для прямого и обратного процессов [источник]

Представленному стохастическому дифференциальному уравнению можно сопоставить обыкновенное дифференциальное уравнение. Его решения (сэмплированные в точках t траектории) совпадают с решениями SDE:

\( dx = \Big[f(x,t) — \frac{1}{2}g^2(t) \nabla_x \log p_t(x)\Big]dt \)

Именно за счёт существования соответствующего ODE возникает возможность использовать различного рода сэмплирование для ускорения генерации диффузии, ведь есть большое количество способов решения ODE.

Вывод

Таким образом, если перейти к непрерывным шагам по времени — диффузионный процесс (как прямой, так и обратный) можно представить через SDE. Более того, SDE сопоставимо с ODE, сохраняющим маргинальные распределения на каждом шаге, что позволяет перейти к быстрому сэмплированию.

Для подробного погружения в эту тему советуем посмотреть лекции Дмитрия Ветрова:

Flow matching и ректификация 🏄‍♂️

Основная задача диффузионных моделей — перевод шума в распределение исходных данных. А если поставить её немного по-другому? Пусть мы хотим перегнать одно распределение данных в другое. Это может быть распределение шума в распределение картинок, как в случае диффузионных моделей, но также и распределение картинок в распределение картинок.

Например, мы хотим превратить человека в персонажа мультфильма:

Рисунок 11. Стилизация изображений

Часто при работе с подобными проблемами мы имеем датасет только непарных изображений. То есть набор изображений одного (например, фото людей) и другого (например, персонажи мультфильмов) распределений.

Давайте для решения задачи будем выучивать какое-то векторное поле, которое переведёт одно распределение в другое:

Рисунок 12. Перевод одного распределения в другое

Пусть есть два распределения \( \pi_0 \) и \( \pi_1 \), и мы хотим задать транспортное преобразование \( T \) через обыкновенное дифференциальное уравнение (ODE) с drift force \( v \): \( \mathbb{R}^d \times [0,1] \):

\( dZ_t = v(Z_t, t)dt \), \( t \in [0,1] \) с начальными условиями \( Z_0 \sim \pi_0 \).

Необходимо построить \( v \) так, чтобы при заданных начальных условиях во время решения ODE получилось \( Z_1 \sim \pi_1 \).

Как уже было сказано, мы хотим обучить нашу «движущую силу» \( v \), которая будет переводить одно распределение в другое. Чаще всего \( v \) параметризуется нейронной сетью.

\( v \) можно построить при решении минимизации среднеквадратичной ошибки:

\( {L} = min_v \mathbb{E}_{(Z_0, Z_1)} \Big[\int_0^1 \|\frac{d}{dt}Z_t — v(Z_t, t) \|^2dt \Big] \)

где \( Z_t = \phi(X_0, X_1, t) \) — любая дифференцируемая по времени интерполяция, а минимизация происходит по всем парам распределения (например, в случае диффузионных моделей DDIM интерполяция выражается так: \( Z_t = \alpha_t Z_0 + \beta_t Z_1 \), где \( \alpha_t \) и \( \beta_t \) — дифференцируемые по времени и специально выбранные параметры).

Давайте теперь построим случайный процесс, который точно переведёт одно распределение в другое. Авторы Rectified Flow предлагают простую линейную интерполяцию (запоминаем на будущее, когда будем говорить про выпрямление):

\( X_t = (1-t)X_0 + tX_1 \Rightarrow \frac{d}{dt} X_t = X_1 — X_0 \) — константа!

Рисунок 13. Линейная интерполяция сэмплов из двух распределений
Рисунок 14. Анимация подобной интерполяции [источник]

Какими свойствами обладает данный процесс?

  1. Перевод одного распределения в другое;
  2. Необходимость знания о точке \( X_0 \) и \( X_1 \) для сэмплирования траектории (на практике довольно бесполезно, поскольку обычно нужно из точки \( X_0 \) попасть в точку \( X_1 \));
  3. Немарковский процесс: сэмлплы на траектории зависят не только от предыдущих шагов, но и от будущих (об этом мы говорили выше).

Давайте выучим наше «хорошее» векторное поле перевода одного распределения в другое на основе данного интерполяционного процесса:

\( {L} = min_v \mathbb{E}_{(X_0, X_1)} \Big[\int_0^1 \|(X_1 — X_0) — v(X_t, t) \|^2dt \Big] \)

То есть мы:

  1. Берём точки \( X_0, X_1 \) из исходных распределений;
  2. Берём точку \( X_t \);
  3. Делаем так, чтобы вектор скорости в точке \( X_t \) (в момент времени \( t \)) совпал с вектором \( X_1 — X_0 \).
Рисунок 15. Обучение векторного поля в каждой точке траектории

Теоретически искомый вектор скорости выражается следующим образом:

\( v(z, t) = \mathbb{E}[X_1 — X_0|X_t = z] \)

Это значит, что мы берём среднее всех направлений, проходящих через точку \( z \) в момент времени \( t \).

Однако на практике мы не знаем всех траекторий, проходящих через заданную точку, поэтому такую задачу решают обыкновенным градиентным спуском, параметризуя \( v \) нейронной сетью. После обучения \( v \) мы решаем ODE с начальными условиями \( X_0 \sim \pi_0 \), чтобы перевести распределение \( \pi_0 \) в \( \pi_1 \).

Рисунок 16. Алгоритм ректификации [Instaflow]

Особенности Rectified Flow

Совпадение маргинальных распределений

Важная особенность полученных траекторий — маргинальные распределения в точках \( t \) совпадают с маргинальными распределениями указанной выше интерполяции.

Если мы рассмотрим в момент времени \( t \) распределение \( \pi^{interpolation}_t \), полученное при интерполяции \( X_t = [(1-t)X_0 + tX_1] \sim \pi^{interpolation}t \) , и распределение \( \pi^{flow}t \), полученное в результате построения потока с drift force \( v \) за счёт диффура \( dZ{t’} = v(Z{t’}, t’)dt’ \) \( t’ \in [0,t] \) с начальными условиями \( Z_0 \sim \pi_0 \), то распределения \( \pi^{interpolation}_t \) и \( \pi^{flow}_t \) совпадут.

Ниже схематично представлена идея совпадающих распределений в каждый момент времени \( t \).

Рисунок 17. Сохранение маргинальных распределений в промежуточных точках t

Распутанность

Ключевая особенность Rectified Flow — траектории решения ODE не должны пересекаться, поскольку существует единственное решение обыкновенного дифференциального уравнения с заданным начальным условием. То есть нет промежуточной точки \( z \in \mathbb{R}^d \) и времени \( t \in [0, 1) \), при которых две траектории решения ODE проходят через точку \( z \) в данный момент времени. При этом изначально заданная интерполяция допускает возможность пересечения этих траекторий. Таким образом, Rectified Flows «‎распутывают» пересечение, индуцируя новые детерминистические пары \( Z_0 \) и \( Z_1. \)

Рисунок 18. Ректифицированные потоки [источник]

We can view the linear interpolation \( Z_t \) as building roads (or tunnels) to connect \( π_0 \) and \( π_1 \), and the rectified flow as traffics of particles passing through the roads in a myopic, memoryless, non-crossing way, which allows them to ignore the global path information of how \( Z_0 \)\( Z_1 \) are paired, and rebuild a more deterministic pairing of \( (X_0, X_1) \).

Транспортная стоимость

Таким образом, мы получаем непересекающиеся траектории, которые в результате дают нам пары из двух распределений. При этом маргинальные распределения в любой момент времени t у таких траекторий совпадут с маргинальными распределениями интерполянтов.

Авторы отмечают ещё одну важную особенность результатов: для любой выпуклой функции \( c \) транспортная стоимость \( Z_0 → Z_1 \) меньше, чем \( X_0 → X_1 \):

\( \mathbb{E}[c(Z_1 — Z_0)] \leq \mathbb{E}[c(X_1 — X_0)] \)

Итак, мы научились строить векторное поле, которое переведёт одно распределение в другое, и получать при этом непересекающиеся траектории с меньшими транспортными затратами.

Ускорение диффузии 🚀

Ура, мы вплотную подошли к ускорению диффузионки! Но сначала ещё немного поговорим про потоки, а именно — про их «‎выпрямление»‎.

Как уже было сказано, после обучения векторного поля мы решаем ODE.

Обычно для этого используются различные солверы, например, схема Эйлера:

\( Z_{t + \frac{1}{N}} = Z_t + \frac{1}{N} v(Z_t, t)\text{, } \forall t \in \{0, …, N-1\} / N \)

Важно: чем больше N, тем точнее решение.

Если траектория решения прямая — даже при одном шаге можно получить достаточно точное решение.

Рисунок 19. Решение ODE с непрямой и прямой траекториями [источник]

Reflow

Reflow — итеративная процедура «‎выпрямления» траекторий начального и конечного распределений без изменения маргинальных распределений путём ректификации.

Обученное нами векторное поле в результате решения ODE выдаст новую пару \( (Z_0, Z_1) \), индуцированную из \( (X_0, X_1) \).

Итак, у нас есть новые пары, давайте сейчас снова построим ректифицированный поток, минимизируя функцию потерь:

\( {L} = min_v \mathbb{E}_{(Z_0, Z_1)} \Big[\int_0^1 \|(Z_1 — Z_0) — v(Z_t, t) \|^2dt \Big] \)

Получим новую пару \( (Z_0^1, Z_1^1) \), повторим построение, индуцируем новые пары \( (Z_0^2, Z_1^2) \) и так далее. Тогда на \( k \) шаге мы получим \( k \)-ый ректифицированный поток:

\( Z^{k+1} = \text{Rectflow}((Z_0^k, Z_1^k)) \)

Ключевые особенности Reflow:

  • \( v_{k+1} \) переносит \( \pi_0 \) в \( \pi_1 \), если переносит \( v_k \);
  • траектории решения у ODE[\( v_{k+1} \)] более прямые, чем у ODE[\( v_{k} \)];
  • (\( X_0 \), ODE[\( v_{k+1}(X_0) \)]) образуют более удачные пары, чем (\( X_0 \), ODE[\( v_{k}(X_0) \)]), что даёт меньшую стоимость переноса для любых выпуклых функций стоимости \( c \).
Рисунок 20. Процесс ректификации на примере двух распределений [источник]

Авторы утверждают — достаточно всего нескольких (двух-трёх) итераций Reflow для получения достаточно прямых траекторий.

Instaflow

Наконец, мы с вами переходим к основной теме обзора — статье Instaflow.

Она объединяет две идеи:

  • дистилляция;
  • ректификация.

Распространённый подход ускорения диффузионных моделей — дистилляция.

Дистилляция — передача знаний «учителя» (предобученной диффузионной модели) «ученику», который может генерировать за меньшее количество шагов.

В идеале вместо генерации за $n$ шагов мы обучаем с помощью дистилляции генерировать за $1$ шаг. Если сделать это напрямую — результаты окажутся неудовлетворительными. Но если «выпрямить» траектории генерации (что и предлагается в работе Instaflow) — после дистилляции в один шаг у сгенерированных изображений будет достаточно хорошее качество.

Рисунок 21. Основная идея ускорения модели Instaflow через последовательную ректификацию и дистилляцию [источник]

Выпрямление

Подобное «выпрямление» производится за счёт итеративной ректификации, которая с каждой последующей итерацией даёт всё более прямую траекторию. Поскольку полностью прямая траектория требует бесконечного числа итераций, авторы предлагают сделать приблизительно прямые траектории за пару итераций, а затем провести дистилляцию.

Рисунок 22. Решение ODE с непрямой и прямой траекториями [источник]

Дистилляция

После относительно прямых траекторий генерации для получения диффузионной модели, способной генерировать в один шаг, необходимо провести дистилляцию и обучить один шаг Эйлера вместо нескольких итераций решения дифференциального уравнения на k шаге ректификации:

\( \tilde{v}k = \text{arg min}v \mathbb{E}{X_0 \sim \pi_0, \Tau \sim D{\Tau}} \Big[\mathbb{D}(\text{ODE}[v_k](X_0, \Tau), X_0 + v(X_0|\Tau))\Big] \)

где \( D(\cdot, \cdot) \) — similarity loss, например, LPIPS.

Весь алгоритм InstaFlow можно представить следующим образом:

Рисунок 23. Основной алгоритм ускорения диффузионной модели в работе Instaflow [источник]

Результаты

Авторы использовали в качестве начальной модели предобученную Stable Diffusion. Затем они прошли несколько этапов ректификации и провели эксперименты с последующей дистилляцией.

Ниже будут представлены некоторые результаты из оригинальной статьи:

Рисунок 24. Сравнение результатов и скорости генерации Stable Diffusion, её дистилляции, ректификации и последующей дистилляции [источник]

Результаты в картинках

Стоит отметить: если дистилляция сразу применяется к диффузионной модели, она не даёт хороших результатов. Однако двух итераций ректификации будет достаточно для получения качественной дистилляции в один шаг 🙂.

Рисунок 25. Качество генерации диффузионной модели за 25 шагов (первый ряд), генерации в 1 шаг без дистилляции (второй ряд) и с дистилляцией (третий ряд). Столбцы: количество шагов ректификации [источник]
Рисунок 26. Качество генерации при уменьшении количества шагов с ректификацией (второй ряд) и без неё (первый ряд) [источник]

Мы видим неплохие результаты работы. Стоит отметить: одной ректификации нам не хватает, значительное улучшение вносит именно дистилляция. Однако, применив эти две идеи (причем ректификацию всего пару раз), можно получить отличные изображения за несколько шагов!

Результаты в метриках

Рисунок 27. Различные генеративные модели и их результаты в метриках. Сравнение числа параметров, скорости генерации и метрики FID [источник]
Рисунок 28. Сравнение метрик FID и CLIP в зависимости от количества шагов у модели SD1.5 и модели SD1.5 после двух операций Reflow [источник]

Summary

Итак, мы с вами обсудили довольно много теории, порой не очень очевидной. Предлагаем сделать саммари разобранного материала, чтобы точно его закрепить 🙂

Задача — получить диффузионную модель, способную генерировать в 1 шаг. Прямая дистилляция предобученной диффузионной модели в диффузионную модель, генерирующую за 1 шаг, не даёт удовлетворительных результатов:

Рисунок 29. Результат дистилляции диффузионной модели в модель, генерирующую за 1 шаг [источник]

Но если прежде, чем проводить дистилляцию, добавить «выпрямление» траекторий — результат будет лучше:

Рисунок 30. Результат дистилляции диффузионной модели в модель, генерирующую за 1 шаг; reflow «выпрямляет» траекторию генерации [источник]

За теорией «выпрямления» стоит теория flow matching’а моделей. Она показывает, как можно обучить векторное поле для перевода одного распределения в другое.

Полезные ссылки

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!