Назад
160

Mixture-of-Head Attention (MoH): новый подход к Multi-Head Attention

160

Введение

Multi-head attention

Multi-head attention — ключевая часть архитектуры transformer. Его основная идея — параллельно задействовать несколько «голов внимания» (attention heads), каждая из которых может фокусироваться на разных аспектах входных данных.

Каждая голова независимо вычисляет self-attention, используя собственные матрицы keys, queries и values. После этого результаты всех голов объединяются и проходят через линейный слой.

Почему Multi-head attention так эффективен? Важную роль играет то, что отдельные головы могут изучать уникальные паттерны:

  • одни головы могут фокусироваться на локальной структуре (например, в статье нашли голову, активную при появлении редких токенов);
  • другие — на глобальном контексте (в той же статье нашли головы, которые более, чем в 90% случаев обращают максимальное внимание на следующий или предыдущий токен, а также головы, реагирующие на связи «существительное-глагол»).
Mixture-of-Experts

Mixture of Experts — это архитектурный подход, который позволяет выбирать, какие «эксперты» (части модели) будут участвовать в обработке каждого конкретного примера. Для каждого входного примера задействуется только часть экспертов, это снижет количество вычислений нагрузку при сохранении хороших метрик.

Архитектура MoE делится на две ключевые составляющие:

  • Эксперты — это отдельные подмодули, например, feed-forward сети, которые обучаются распознавать различные признаки
  • Роутер — выбирает несколько релевантных экспертов к входным данным

В современных LLM MoE используется для того, чтобы разные эксперты специализировались на различных аспектах текста: одни на грамматических конструкциях, другие на понимании контекста, третьи на стиле предложения.

Архитектура Mixture-of-Head Attention

Рисунок 1. Схемы для стандартного Multi-Head Attention (a) и предложенного MoH Attention (b)

Multi-Head Attention обычно записывается через конкатенацию attention каждой отдельной головы, назовём такую запись concatenation form:

\begin{align*} \text{MultiHead}(X, X’) &= \text{Concat}(H^1, H^2, \dots, H^h) W_O, \\ H^i &= \text{Attention}(X W_Q^i, X’ W_K^i, X’ W_V^i), \\ \text{Attention}(Q, K, V) &= \text{Softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right)V \end{align*}

Размерности матриц в уравнении выше:

\begin{align*} Q &= X W_Q^i, & W_Q^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_k}{h}}, \\ K &= X’ W_K^i, & W_K^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_k}{h}}, \\ V &= X’ W_V^i, & W_V^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_v}{h}}, \\ X &\in \mathbb{R}^{n \times d_{\mathrm{in}}}, & H^i &\in \mathbb{R}^{n \times \frac{d_v}{h}}. \end{align*}

\( d_\text{in}: \) input feature dimension,
\( d_k: \) query/key projection dimension,
\( d_v: \) value projection dimension,
\( n: \) num input tokens,
\( h: \) num heads.

Но при этом мы можем записать матрицу \( W_O \) как конкатенацию \( h \) строк:

\( \begin{align*} \begin{bmatrix} W_O^1 \\ W_O^2 \\ \vdots \\ W_O^h \end{bmatrix} = W_O, \quad W_O \in \mathbb{R}^{d_v \times d_\text{out}}, \quad W_O^i \in \mathbb{R}^{\frac{d_v}{h} \times d_\text{out}} \\ \end{align*} \)

И в итоге получить новую форму для записи Multi-Head Attention через сумму, назовём её summation form:

\( \begin{align*} \text{MultiHead}(X, X’) = \text{Concat}(H^1, H^2, \dots, H^h) W_O = \\ = [H^1, H^2, \dots, H^h] W_O = [H^1, H^2, \dots, H^h]\begin{bmatrix} W_O^1 \\ W_O^2 \\ \vdots \\ W_O^h \end{bmatrix} = \\ = H^1W_O^1+H^2W_O^2 + \dots + H^hW_O^h = \sum_{i=1}^h H^i W_O^i \end{align*} \)

Таким образом, мы представили Multi-Head Attention как сумму матриц, а дополнительных ограничений на свойства \( W_O \) не появилось — это просто две разные записи одного и того же выражения.

Теперь давайте перейдём к Mixture-of-Head Attention. Для этого заменим сумму в выражении выше на взвешенную сумму, то есть каждую матрицу будем умножать на скаляр \( g_i \), который равен нулю, только если i-я голова не выбрана в качестве «эксперта»:

\( \begin{align*} \text{MoH}(X, X’) = \sum_{i=1}^h g_iH^i W_O^i \end{align*} \)

Кроме того, разделим головы на два типа — shared и routed. Shared будут использоваться всегда (то есть \( g_i \) ≠ 0 всегда для shared), и мы ожидаем, что они будут отвечать за какие-то «общие знания», которые нужны независимо от конкретной задачи и токена. Routed будут выбираться под каждый токен. То есть при прогоне через модель будут использоваться все shared и только часть routed heads. Для простоты пусть первые \( h_s \) по нумерации голов — это shared, а остальные от \( h_s + 1 \) до \( h \) — routed.

Чтобы проранжировать routed heads по необходимости её применения для токена, нам нужен Router, который выдаст для каждой головы скор её релевантности. Обычно для этого используется какая-то эвристика или небольшая сеть. Так и в этой работе: при добавлении MoH в предобученную модель применяли \( l_2 \) норму от query каждой головы, при обучении модели с нуля добавляли отдельную сеть. Выбирать будем k голов с самым высоким скором от Router.

Зафиксируем всё, что знаем на этом этапе, про коэффициент \( g_i \) в формуле MoH:

\( g_i = \begin{cases} \text{something}_i, & \text{if } 1 \leq i \leq h_s, \\ \text{something}_i, & \text{if } h_s + 1 \leq i \leq h \text{ and } i \text{ in Top-K}, \\ 0, & \text{otherwise}, \end{cases} \)

В формуле выше записано, что для всех shared и Top-K routed heads у нас есть какой-то коэффициент, неравный нулю, для остальных routed \( g_i = 0 \). Осталось определить, как получить эти ненулевые коэффициенты.

Для эмбеддинга каждого токена \( x_t, \ 1 <= t <= h \) их получают так:

\( g_i = \begin{cases} \alpha_1 \text{Softmax}(W_s x_t)_i, & \text{if } 1 \leq i \leq h_s, \\ \alpha_2 \text{Softmax}(W_r x_t)_i, & \text{if } (W_r x_t)_i \in \text{Top-K} \big(\{(W_r x_t)_i \,|\, h_s + 1 \leq i \leq h\}\big), \\ 0, & \text{otherwise}, \end{cases} \)

Теперь разберёмся, что написано выше 🙂. Для получения веса каждой головы мы заводим две матрицы \( W_s \in R^{h_s \times d_{in}}, W_r \in R^{(h-h_s) \times d_{in}} \) — для shared и routed соответственно, \( d_{in} \) — размер вектора \( x_t \). Тогда \( W_sx_t, W_rx_t \), — результат умножения матриц на входной токен, это вектора размером \( h_s \) и \( (h-h_s) \) соответственно. То есть для каждой головы мы получаем по одному числу на выходе. В коде такие матрицы можно записать через два независимых линейных слоя Linear(\( d_{in} \), \( h_s \), bias=False) и Linear(\( d_{in} \), \( h — h_s \), bias=False). После этого применяем softmax, чтобы сумма чисел для каждого из двух наборов голов равнялась единице.

Теперь мы получили индивидуальный вес для каждой головы \( \text{Softmax}(W_s x_t)_i \) для shared или \( \text{Softmax}(W_r x_t)i \) для routed. Кроме того, авторы ищут баланс между использованием shared и routed heads, то есть учитывают тип головы. Это делается через умножение индивидуального веса головы на скаляры \( \alpha_1, \alpha_2 \). Их мы получаем с помощью дополнительной обучаемой матрицы \( W_h \in R^{2 \times d{in}} \), \( [\alpha_1, \alpha_2] = \text{Softmax}(W_hx_t) \). Этот способ выбора весов назвали two-stage routing, где под первой стадией обозначили получение индивидуального веса, а под второй — балансирование между двумя типами голов через \( \alpha_1, \alpha_2 \).

Применение

Во-первых, можно с нуля обучать модель с MoH attention. Основным отличием в обучении будет необходимость добавить Load Balance Loss, как это делают при обучении MoE. Он нужен для того, чтобы Router более равномерно распределял задачи между экспертами, и не было экспертов, получавших слишком мало задач и потому недоучившихся до хорошего качества.

\( \mathcal{L} = \mathcal{L}_{task} + \beta\mathcal{L}_b, \quad\beta=0.01 \\ \mathcal{L}_b = \sum_{i=h_s+1}^{h} f_i P_i, \\ f_i = \frac{1}{T} \sum_{t=1}^{T} \mathbf{1}(\text{Token } \mathbf{x}_t \text{ selects Head } i), \\ P_i = \frac{1}{T} \sum_{t=1}^{T} \text{Softmax}(W_r \mathbf{x}_t)_i \)

Итоговый лосс \( \mathcal{L} \) представлен на формуле выше. Он является взвешенной суммой лосса для конкретной задачи \( \mathcal{L}_{task} \) (например, кросс-энтропия для классификации) и Load Balance Loss \( \mathcal{L}_b \). Послений считается только по routed heads и является суммой произведений двух величин для кажой routed головы:

  • \( f_i \) — доля токенов, которые выбрали i-ю голову, чем чаще токены выбирают конкретную голову, тем выше \( f_i \);
  • \( P_i \) — средний индивидуальный вес i-й головы.

Получается, что лосс тем больше, чем чаще мы выбираем конкретную голову и чем больше при этом у неё скор. То есть для его минимизации нужно как можно больше задействованных разных голов, при этом распределение индивидуального веса должно быть как можно ближе к равномерному. Это и позволяет обучать все головы, а не часть.

Других важных изменений для обучения с нуля нет, метрики для сравнения с vanila attention представлены в следующем разделе.

Во-вторых, можно дообучать существующую модель с заменой Multi-Head Atttention на Mixture-of-Head Attention. Тогда нужно учесть следующее:

  • выбрали shared attention heads. В статье просто взяли первые 16 голов в каждом слое, но головы и так могут иметь специализацию — есть пространство для экспериментов;
  • в изначальной сети не было Router, а обучение всех сетей с рандомно инициализированным Router могло бы занять много времени и испортить предобучение остальных блоков модели, поэтому для выбора routed heads использовали \( l_2 \) норму от query каждой головы. Интересно было бы посмотреть на эксперименты с обучением Router при замороженных весах или на что-то подобное;
  • кроме выбора головы Router должен присваивать каждой из них скор \( g_i \). При этом важно учитывать, как именно весовые коэффициенты \( g_i \) влияют на выход слоя внимания. Если мы напрямую умножим выходы голов внимания на непрерывные веса, распределение активаций может существенно «поплыть». Чтобы этого избежать, в работе используют идею квантуемых (двоичных) скоров и straight-through estimator (STE):
    • при forward pass используется индикаторная функция, которая даёт 0 или 1 в зависимости от выбора головы;
    • при backward pass — STE, чтобы пропустить градиент через пороговое квантование и избежать его полного обнуления.

Это позволяет не разрушать уже обученные параметры модели и постепенно адаптироваться к новой механике маршрутизации.

Итак, с MoH attention можно и учить модели с нуля, добавив к обучению Router и новый лосс, и файнтюнить существующие модели.

Метрики

Теперь посмотрим на метрики для разных задач: классификации, генерации изображений и текста. Оценили их на Vision Transformers (ViT), Diffusion models with Transformers (DiT), Large Language Models (LLMs). Таблицы ниже отражают результаты:

Рисунок 2. Метрики на задаче классификации изображений: с 75% голов можно получить качество выше, с 50% — немного ниже
Рисунок 3. Метрики на задаче генерации изображений
Рисунок 4. Метрики на задаче генерации изображений

Видно, что для генерации изображений нужно сохранить больше голов, чем в задаче классификации (75% для классификации и 90% для генерации). Возможно, это из-за того что нужна пиксельная точность, жаль что в статье не рассмотрели задачу сегментации.

Рисунок 5. Замер качества обученной с нуля на нескольких бенчмарках для LLM

Иногда 50% голов дают результат лучше, чем 75% и 100%. Как объясняют это авторы TruthfulQA, причина в следующем: «если ложные ответы изучаются из обучающего распределения, то ожидается, что более крупные модели, лучше его изучившие, будут чаще генерировать такие ложные ответы».

Рисунок 6. Замер качества LLaMA3-8B после файнтюнинга с MoH

Можно отметить, что ~75% голов должно хватать для большинства задач. В статье есть все гиперпараметры экспериментов — очень радует, что таких статей становится всё больше 🙂

Полезные компоненты архитектуры

Авторы провели ablative analysis и проверили, какие компоненты являются самыми полезными:

Рисунок 7. Метрики классификации и генерации с shared heads, two-stage routing и без них (сверху) и влияние доли shared heads при фиксированной доле активных routed heads на метрики (снизу)

Мы видим, что добавление shared heads заметно улучшает результаты. Авторы объясняют это тем, что теперь за общие знания отвечают shared heads, значит, routed могут стать более специализированными на конкретном домене. Two-stage routing улучшает результаты — теперь баланс между использованием shared и routed динамически подбирается и может оказаться лучше фиксированного. Two-stage routing нельзя сделать, если все головы routed, потому что нет смысла подбирать параметры \( \alpha_1, \alpha_2 \), так что такого варианта в таблице нет.

При этом не так важна доля shared heads, если она не экстремально большая или низкая. Авторы рекомендуют брать 40% или больше.

Заключение

В Mixture-of-Head Attention мы рассматриваем каждую голову внимания как «эксперта» и используем Router для выбора наиболее релевантных голов под конкретный токен, а Two-Stage Routing создаёт баланс между использованием routed и shared heads. Такой подход:

  • уменьшает вычисления за счёт пропуска частей MHA без существенного падения качества;
  • улучшает специализацию отдельных голов, в том числе благодаря разделению на shared и routed;
  • может быть использован и при обучении с нуля, и при fine-tuning.

Но есть и пара направлений для развития:

  • необходимость прода: в статье нет таблицы со сравнением по скорости;
  • ограничения по размеру и задачам: не исследовали LLM больше 8B, для CV можно было бы взять задачу сегментации и детекции;
  • выбор доли активных голов: в статье тестировали фиксированные значения (50%, 75%, 90%), но динамическое определение как в Dynamic-MoE может быть более полезным в замене attention;
  • мультимодальность: пока нет полноценной проверки на данных, где требуется анализ нескольких типов сигналов — как будто здесь есть большой потенциал для сокращения числа голов за счёт специализации на разных типах.

Таким образом, MoH выглядит как удачное сочетание идей MHA и Mixture-of-Experts, упрощающее использование больших моделей и позволяющее более гибко распределять вычислительные ресурсы. Все параметры обучения открыты и чётко задокументированы, что является значимым вкладом для последующих экспериментов.

Полезные ссылки

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!