InstaFlow
Введение
Диффузионные модели на стыке физики и машинного обучения — мощный инструмент для моделирования сложных процессов генерации данных. За последние годы они получили широкое признание благодаря возможности эффективно генерировать реалистичные изображения, видео и даже звуковые сигналы.
Несмотря на свои сильные стороны (качество и разнообразие генерации), диффузионные модели остаются медленными генераторами. Для генерации одного сэмпла часто нужно много запусков одной нейронной сети, что занимает немало времени в сравнении с GANs.
Поэтому сегодня мы рассмотрим метод ускорения диффузионных моделей через сокращение количества шагов генерации, а именно — работу InstaFlow.
Спойлер: будет математика 😟
Но не пугайтесь! В конце мы собрали саммари, чтобы легче ориентироваться во всём происходящем и возвращаться к нужным местам при необходимости 😊
Итак, наш план:
- Кратко вспомним, что такое диффузионные модели и как они связаны со стохастическими и обыкновенными дифференциальными уравнениями.
- Познакомимся со flow matching’ом.
- Разберёмся, что такое ректифицированные потоки, как их можно «выпрямлять».
- И обсудим идею авторов работы InstaFLow.
Optimal transport problem 🏎
Transport Mapping Problem
Задача оптимального транспорта имеет непосредственное отношение к генеративным моделям и переносу стиля. Как перенести одно распределение в другое, ещё и сделать это оптимальным способом? И что значит «оптимальным»? Давайте обсудим эту задачу подробнее 🙂
Есть два эмпирических распределения
Можно построить бесконечно много таких транспортных карт \( T \), поэтому обычно пытаются решить задачу оптимального транспорта через поиск \( T \), которое минимизирует транспортную стоимость: \( min_{T} \mathbb{E} \Big[c(T(Z_0) — Z_0)\Big] \), где \( с \) — transport cost.
Тем не менее решение этой задачи довольно сложное, особенно для многомерных задач. Важным вопросом также является выбор транспортной стоимости \( c \).
В случае генеративных моделей \( T \) чаще всего параметризуется нейронной сетью и обучается в стиле алгоритмов GAN / максимизацией правдоподобия (MLE). Но GAN страдают от мод коллапса, а VAE даёт слишком замыленные результаты.
С недавних пор широкую популярность получил неявный способ транспортировки через непрерывный по времени процесс, в частности — диффузионные модели.
Такую известность они получили, в первую очередь, из-за качества и разнообразия генерации. Однако качественные результаты возможны при генерации за большое количество шагов и требуют много времени в сравнении с другими генеративными моделями (например, GANs).
Отсюда мы получаем Generative Learning Trilemma: ни одна генеративная модель не удовлетворяет сразу трём пунктам — разнообразию, качеству и скорости.
Поэтому возникает вопрос: а как ускорить генерацию диффузии?
Диффузия 🧪
Перед обсуждением ускорения диффузии давайте вспомним, что такое диффузионные модели.
В этом обзоре мы предполагаем, что вы уже знакомы с диффузионными моделями (а если нет — советуем прочитать наш пост по данной теме). Поэтому сейчас сделаем небольшое введение.
Диффузионный процесс состоит из двух частей:
- Forward process — постепенное добавление шума к распределению объектов до тех пор, пока итоговое распределение не станет просто распределением шума (чаще всего работают с гауссовским шумом, но недавно была представлена работа или работа, показывающая, как можно использовать другие типы шумов);
- Backward process — сэмплирование объекта из латентного распределения и путём пошагового расшумления получение объекта из исходного распределения.
Прямой процесс (forward process)
Пусть у нас есть сэмпл \( x_0 \sim q(x) \). Тогда прямой процесс диффузии — постепенное добавление небольшого количества шума, чаще всего гауссовского (далее работаем с ним), в течение \( T \) шагов, в результате которых получаем последовательность зашумленных сэмплов \( x_1, x_2 …, x_T \) и в итоге гауссовский шум. Количество добавляемого шума регулируется с помощью variance scheduler \( \{\beta_t \in (0,1)\}^T_t=1 \).
Тогда зашумленное распределение на шаге \( t \) можно получить следующим образом:
\( q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbf{I}) \)
\( q(x_{1:T}|x_0) = \prod_{t=1}^{T}q(x_t|x_{t-1}) \)
При этом с помощью трюка репараметризации из получившегося распределения можно сэмплировать:
\( x_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t}\epsilon_{t-1} \)
Обратный процесс (backward process)
Если мы сможем обернуть данный процесс — получим генеративную модель, которая возьмёт сэмпл из гауссовского шума с нулевым средним и единичной дисперсией \( x \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$] \) и восстановит элемент из исходного распределения.
Важная особенность диффузионного процесса: если \( \beta_t \) — достаточно маленькое значение, то \( q(x_{t-1}|x_t) \) тоже будет гауссовским. Однако мы не можем так просто оценить это распределение. Поэтому нашей задачей становится аппроксимация с помощью нейронной модели \( p_{\theta} \) с параметрами \( \theta \):
\( p_{\theta}(x_{0:T}) = p(x_T) \prod_{t=1}^{T}p_{\theta}(x_{t-1}|x_{t}) \)
\( p_{\theta}(x_{t-1}|x_{t}) = \mathcal{N}(x_{t-1}; \mathbf{\mu}{\theta}(x_t, t), \mathbf{\sum}{\theta}(x_t,t)) \)
Диффузия и SDE
Прежде, чем двигаться дальше, необходимо разобраться, как связана диффузия с SDE и ODE.
При переходе от дискретных шагов по времени к непрерывным можно связать процесс диффузионных моделей со стохастическими дифференциальными уравнениями.
Тогда мы работаем с непрерывным диффузионным процессом: \( \{x(t)\}^T_{t=0} \) , \( x(0) \sim p_0 \) и \( x(T) \sim p_T \).
Можно записать прямое и обратное SDE для процесса, характеризующегося функцией \( f(x,t) \) и виннеровским процессом \( w \).
Прямой процесс (forward SDE): данные → шум
\( dx = f(x,t)dt + g(t)dw \)
Обратный процесс (reverse SDE): шум → данные
\( dx = \Big[f(x,t) — g^2(t) \nabla_x \log p_t(x)\Big]dt + g(t)d \hat{w} \)
где \( \nabla \log p_t(x) \) — score функция \( p_t(x) \).
Представленному стохастическому дифференциальному уравнению можно сопоставить обыкновенное дифференциальное уравнение. Его решения (сэмплированные в точках t траектории) совпадают с решениями SDE:
\( dx = \Big[f(x,t) — \frac{1}{2}g^2(t) \nabla_x \log p_t(x)\Big]dt \)
Именно за счёт существования соответствующего ODE возникает возможность использовать различного рода сэмплирование для ускорения генерации диффузии, ведь есть большое количество способов решения ODE.
Вывод
Таким образом, если перейти к непрерывным шагам по времени — диффузионный процесс (как прямой, так и обратный) можно представить через SDE. Более того, SDE сопоставимо с ODE, сохраняющим маргинальные распределения на каждом шаге, что позволяет перейти к быстрому сэмплированию.
Для подробного погружения в эту тему советуем посмотреть лекции Дмитрия Ветрова:
Flow matching и ректификация 🏄♂️
Основная задача диффузионных моделей — перевод шума в распределение исходных данных. А если поставить её немного по-другому? Пусть мы хотим перегнать одно распределение данных в другое. Это может быть распределение шума в распределение картинок, как в случае диффузионных моделей, но также и распределение картинок в распределение картинок.
Например, мы хотим превратить человека в персонажа мультфильма:
Часто при работе с подобными проблемами мы имеем датасет только непарных изображений. То есть набор изображений одного (например, фото людей) и другого (например, персонажи мультфильмов) распределений.
Давайте для решения задачи будем выучивать какое-то векторное поле, которое переведёт одно распределение в другое:
Пусть есть два распределения \( \pi_0 \) и \( \pi_1 \), и мы хотим задать транспортное преобразование \( T \) через обыкновенное дифференциальное уравнение (ODE) с drift force \( v \): \( \mathbb{R}^d \times [0,1] \):
\( dZ_t = v(Z_t, t)dt \), \( t \in [0,1] \) с начальными условиями \( Z_0 \sim \pi_0 \).
Необходимо построить \( v \) так, чтобы при заданных начальных условиях во время решения ODE получилось \( Z_1 \sim \pi_1 \).
Как уже было сказано, мы хотим обучить нашу «движущую силу» \( v \), которая будет переводить одно распределение в другое. Чаще всего \( v \) параметризуется нейронной сетью.
\( v \) можно построить при решении минимизации среднеквадратичной ошибки:
\( {L} = min_v \mathbb{E}_{(Z_0, Z_1)} \Big[\int_0^1 \|\frac{d}{dt}Z_t — v(Z_t, t) \|^2dt \Big] \)
где \( Z_t = \phi(X_0, X_1, t) \) — любая дифференцируемая по времени интерполяция, а минимизация происходит по всем парам распределения (например, в случае диффузионных моделей DDIM интерполяция выражается так: \( Z_t = \alpha_t Z_0 + \beta_t Z_1 \), где \( \alpha_t \) и \( \beta_t \) — дифференцируемые по времени и специально выбранные параметры).
Давайте теперь построим случайный процесс, который точно переведёт одно распределение в другое. Авторы Rectified Flow предлагают простую линейную интерполяцию (запоминаем на будущее, когда будем говорить про выпрямление):
\( X_t = (1-t)X_0 + tX_1 \Rightarrow \frac{d}{dt} X_t = X_1 — X_0 \) — константа!
Какими свойствами обладает данный процесс?
- Перевод одного распределения в другое;
- Необходимость знания о точке \( X_0 \) и \( X_1 \) для сэмплирования траектории (на практике довольно бесполезно, поскольку обычно нужно из точки \( X_0 \) попасть в точку \( X_1 \));
- Немарковский процесс: сэмлплы на траектории зависят не только от предыдущих шагов, но и от будущих (об этом мы говорили выше).
Давайте выучим наше «хорошее» векторное поле перевода одного распределения в другое на основе данного интерполяционного процесса:
\( {L} = min_v \mathbb{E}_{(X_0, X_1)} \Big[\int_0^1 \|(X_1 — X_0) — v(X_t, t) \|^2dt \Big] \)
То есть мы:
- Берём точки \( X_0, X_1 \) из исходных распределений;
- Берём точку \( X_t \);
- Делаем так, чтобы вектор скорости в точке \( X_t \) (в момент времени \( t \)) совпал с вектором \( X_1 — X_0 \).
Теоретически искомый вектор скорости выражается следующим образом:
\( v(z, t) = \mathbb{E}[X_1 — X_0|X_t = z] \)
Это значит, что мы берём среднее всех направлений, проходящих через точку \( z \) в момент времени \( t \).
Однако на практике мы не знаем всех траекторий, проходящих через заданную точку, поэтому такую задачу решают обыкновенным градиентным спуском, параметризуя \( v \) нейронной сетью. После обучения \( v \) мы решаем ODE с начальными условиями \( X_0 \sim \pi_0 \), чтобы перевести распределение \( \pi_0 \) в \( \pi_1 \).
Особенности Rectified Flow
Совпадение маргинальных распределений
Важная особенность полученных траекторий — маргинальные распределения в точках \( t \) совпадают с маргинальными распределениями указанной выше интерполяции.
Если мы рассмотрим в момент времени \( t \) распределение \( \pi^{interpolation}_t \), полученное при интерполяции \( X_t = [(1-t)X_0 + tX_1] \sim \pi^{interpolation}t \) , и распределение \( \pi^{flow}t \), полученное в результате построения потока с drift force \( v \) за счёт диффура \( dZ{t’} = v(Z{t’}, t’)dt’ \) \( t’ \in [0,t] \) с начальными условиями \( Z_0 \sim \pi_0 \), то распределения \( \pi^{interpolation}_t \) и \( \pi^{flow}_t \) совпадут.
Ниже схематично представлена идея совпадающих распределений в каждый момент времени \( t \).
Распутанность
Ключевая особенность Rectified Flow — траектории решения ODE не должны пересекаться, поскольку существует единственное решение обыкновенного дифференциального уравнения с заданным начальным условием. То есть нет промежуточной точки \( z \in \mathbb{R}^d \) и времени \( t \in [0, 1) \), при которых две траектории решения ODE проходят через точку \( z \) в данный момент времени. При этом изначально заданная интерполяция допускает возможность пересечения этих траекторий. Таким образом, Rectified Flows «распутывают» пересечение, индуцируя новые детерминистические пары \( Z_0 \) и \( Z_1. \)
We can view the linear interpolation \( Z_t \) as building roads (or tunnels) to connect \( π_0 \) and \( π_1 \), and the rectified flow as traffics of particles passing through the roads in a myopic, memoryless, non-crossing way, which allows them to ignore the global path information of how \( Z_0 \)\( Z_1 \) are paired, and rebuild a more deterministic pairing of \( (X_0, X_1) \).
Транспортная стоимость
Таким образом, мы получаем непересекающиеся траектории, которые в результате дают нам пары из двух распределений. При этом маргинальные распределения в любой момент времени t у таких траекторий совпадут с маргинальными распределениями интерполянтов.
Авторы отмечают ещё одну важную особенность результатов: для любой выпуклой функции \( c \) транспортная стоимость \( Z_0 → Z_1 \) меньше, чем \( X_0 → X_1 \):
\( \mathbb{E}[c(Z_1 — Z_0)] \leq \mathbb{E}[c(X_1 — X_0)] \)
Итак, мы научились строить векторное поле, которое переведёт одно распределение в другое, и получать при этом непересекающиеся траектории с меньшими транспортными затратами.
Ускорение диффузии 🚀
Ура, мы вплотную подошли к ускорению диффузионки! Но сначала ещё немного поговорим про потоки, а именно — про их «выпрямление».
Как уже было сказано, после обучения векторного поля мы решаем ODE.
Обычно для этого используются различные солверы, например, схема Эйлера:
\( Z_{t + \frac{1}{N}} = Z_t + \frac{1}{N} v(Z_t, t)\text{, } \forall t \in \{0, …, N-1\} / N \)
Важно: чем больше N, тем точнее решение.
Если траектория решения прямая — даже при одном шаге можно получить достаточно точное решение.
Reflow
Reflow — итеративная процедура «выпрямления» траекторий начального и конечного распределений без изменения маргинальных распределений путём ректификации.
Обученное нами векторное поле в результате решения ODE выдаст новую пару \( (Z_0, Z_1) \), индуцированную из \( (X_0, X_1) \).
Итак, у нас есть новые пары, давайте сейчас снова построим ректифицированный поток, минимизируя функцию потерь:
\( {L} = min_v \mathbb{E}_{(Z_0, Z_1)} \Big[\int_0^1 \|(Z_1 — Z_0) — v(Z_t, t) \|^2dt \Big] \)
Получим новую пару \( (Z_0^1, Z_1^1) \), повторим построение, индуцируем новые пары \( (Z_0^2, Z_1^2) \) и так далее. Тогда на \( k \) шаге мы получим \( k \)-ый ректифицированный поток:
\( Z^{k+1} = \text{Rectflow}((Z_0^k, Z_1^k)) \)
Ключевые особенности Reflow:
- \( v_{k+1} \) переносит \( \pi_0 \) в \( \pi_1 \), если переносит \( v_k \);
- траектории решения у ODE[\( v_{k+1} \)] более прямые, чем у ODE[\( v_{k} \)];
- (\( X_0 \), ODE[\( v_{k+1}(X_0) \)]) образуют более удачные пары, чем (\( X_0 \), ODE[\( v_{k}(X_0) \)]), что даёт меньшую стоимость переноса для любых выпуклых функций стоимости \( c \).
Авторы утверждают — достаточно всего нескольких (двух-трёх) итераций Reflow для получения достаточно прямых траекторий.
Instaflow
Наконец, мы с вами переходим к основной теме обзора — статье Instaflow.
Она объединяет две идеи:
- дистилляция;
- ректификация.
Распространённый подход ускорения диффузионных моделей — дистилляция.
Дистилляция — передача знаний «учителя» (предобученной диффузионной модели) «ученику», который может генерировать за меньшее количество шагов.
В идеале вместо генерации за $n$ шагов мы обучаем с помощью дистилляции генерировать за $1$ шаг. Если сделать это напрямую — результаты окажутся неудовлетворительными. Но если «выпрямить» траектории генерации (что и предлагается в работе Instaflow) — после дистилляции в один шаг у сгенерированных изображений будет достаточно хорошее качество.
Выпрямление
Подобное «выпрямление» производится за счёт итеративной ректификации, которая с каждой последующей итерацией даёт всё более прямую траекторию. Поскольку полностью прямая траектория требует бесконечного числа итераций, авторы предлагают сделать приблизительно прямые траектории за пару итераций, а затем провести дистилляцию.
Дистилляция
После относительно прямых траекторий генерации для получения диффузионной модели, способной генерировать в один шаг, необходимо провести дистилляцию и обучить один шаг Эйлера вместо нескольких итераций решения дифференциального уравнения на k шаге ректификации:
\( \tilde{v}k = \text{arg min}v \mathbb{E}{X_0 \sim \pi_0, \Tau \sim D{\Tau}} \Big[\mathbb{D}(\text{ODE}[v_k](X_0, \Tau), X_0 + v(X_0|\Tau))\Big] \)
где \( D(\cdot, \cdot) \) — similarity loss, например, LPIPS.
Весь алгоритм InstaFlow можно представить следующим образом:
Результаты
Авторы использовали в качестве начальной модели предобученную Stable Diffusion. Затем они прошли несколько этапов ректификации и провели эксперименты с последующей дистилляцией.
Ниже будут представлены некоторые результаты из оригинальной статьи:
Результаты в картинках
Стоит отметить: если дистилляция сразу применяется к диффузионной модели, она не даёт хороших результатов. Однако двух итераций ректификации будет достаточно для получения качественной дистилляции в один шаг 🙂.
Мы видим неплохие результаты работы. Стоит отметить: одной ректификации нам не хватает, значительное улучшение вносит именно дистилляция. Однако, применив эти две идеи (причем ректификацию всего пару раз), можно получить отличные изображения за несколько шагов!
Результаты в метриках
Summary
Итак, мы с вами обсудили довольно много теории, порой не очень очевидной. Предлагаем сделать саммари разобранного материала, чтобы точно его закрепить 🙂
Задача — получить диффузионную модель, способную генерировать в 1 шаг. Прямая дистилляция предобученной диффузионной модели в диффузионную модель, генерирующую за 1 шаг, не даёт удовлетворительных результатов:
Но если прежде, чем проводить дистилляцию, добавить «выпрямление» траекторий — результат будет лучше:
За теорией «выпрямления» стоит теория flow matching’а моделей. Она показывает, как можно обучить векторное поле для перевода одного распределения в другое.
Полезные ссылки
- Введение в диффузионные модели
- Лекции Дмитрия Ветрова по диффузионным моделям
- Rectified Flow
- InstaFlow