Mixture-of-Head Attention (MoH): новый подход к Multi-Head Attention
Введение
Multi-head attention
Multi-head attention — ключевая часть архитектуры transformer. Его основная идея — параллельно задействовать несколько «голов внимания» (attention heads), каждая из которых может фокусироваться на разных аспектах входных данных.
Каждая голова независимо вычисляет self-attention, используя собственные матрицы keys, queries и values. После этого результаты всех голов объединяются и проходят через линейный слой.
Почему Multi-head attention так эффективен? Важную роль играет то, что отдельные головы могут изучать уникальные паттерны:
- одни головы могут фокусироваться на локальной структуре (например, в статье нашли голову, активную при появлении редких токенов);
- другие — на глобальном контексте (в той же статье нашли головы, которые более, чем в 90% случаев обращают максимальное внимание на следующий или предыдущий токен, а также головы, реагирующие на связи «существительное-глагол»).
Mixture-of-Experts
Mixture of Experts — это архитектурный подход, который позволяет выбирать, какие «эксперты» (части модели) будут участвовать в обработке каждого конкретного примера. Для каждого входного примера задействуется только часть экспертов, это снижет количество вычислений нагрузку при сохранении хороших метрик.
Архитектура MoE делится на две ключевые составляющие:
- Эксперты — это отдельные подмодули, например, feed-forward сети, которые обучаются распознавать различные признаки
- Роутер — выбирает несколько релевантных экспертов к входным данным
В современных LLM MoE используется для того, чтобы разные эксперты специализировались на различных аспектах текста: одни на грамматических конструкциях, другие на понимании контекста, третьи на стиле предложения.
Архитектура Mixture-of-Head Attention
Multi-Head Attention обычно записывается через конкатенацию attention каждой отдельной головы, назовём такую запись concatenation form:
Размерности матриц в уравнении выше:
\begin{align*} Q &= X W_Q^i, & W_Q^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_k}{h}}, \\ K &= X’ W_K^i, & W_K^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_k}{h}}, \\ V &= X’ W_V^i, & W_V^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_v}{h}}, \\ X &\in \mathbb{R}^{n \times d_{\mathrm{in}}}, & H^i &\in \mathbb{R}^{n \times \frac{d_v}{h}}. \end{align*}
\( d_\text{in}: \) input feature dimension,
\( d_k: \) query/key projection dimension,
\( d_v: \) value projection dimension,
\( n: \) num input tokens,
\( h: \) num heads.
Но при этом мы можем записать матрицу \( W_O \) как конкатенацию \( h \) строк:
\( \begin{align*} \begin{bmatrix} W_O^1 \\ W_O^2 \\ \vdots \\ W_O^h \end{bmatrix} = W_O, \quad W_O \in \mathbb{R}^{d_v \times d_\text{out}}, \quad W_O^i \in \mathbb{R}^{\frac{d_v}{h} \times d_\text{out}} \\ \end{align*} \)
И в итоге получить новую форму для записи Multi-Head Attention через сумму, назовём её summation form:
\( \begin{align*} \text{MultiHead}(X, X’) = \text{Concat}(H^1, H^2, \dots, H^h) W_O = \\ = [H^1, H^2, \dots, H^h] W_O = [H^1, H^2, \dots, H^h]\begin{bmatrix} W_O^1 \\ W_O^2 \\ \vdots \\ W_O^h \end{bmatrix} = \\ = H^1W_O^1+H^2W_O^2 + \dots + H^hW_O^h = \sum_{i=1}^h H^i W_O^i \end{align*} \)
Таким образом, мы представили Multi-Head Attention как сумму матриц, а дополнительных ограничений на свойства \( W_O \) не появилось — это просто две разные записи одного и того же выражения.
Теперь давайте перейдём к Mixture-of-Head Attention. Для этого заменим сумму в выражении выше на взвешенную сумму, то есть каждую матрицу будем умножать на скаляр \( g_i \), который равен нулю, только если i-я голова не выбрана в качестве «эксперта»:
\( \begin{align*} \text{MoH}(X, X’) = \sum_{i=1}^h g_iH^i W_O^i \end{align*} \)
Кроме того, разделим головы на два типа — shared и routed. Shared будут использоваться всегда (то есть \( g_i \) ≠ 0 всегда для shared), и мы ожидаем, что они будут отвечать за какие-то «общие знания», которые нужны независимо от конкретной задачи и токена. Routed будут выбираться под каждый токен. То есть при прогоне через модель будут использоваться все shared и только часть routed heads. Для простоты пусть первые \( h_s \) по нумерации голов — это shared, а остальные от \( h_s + 1 \) до \( h \) — routed.
Чтобы проранжировать routed heads по необходимости её применения для токена, нам нужен Router, который выдаст для каждой головы скор её релевантности. Обычно для этого используется какая-то эвристика или небольшая сеть. Так и в этой работе: при добавлении MoH в предобученную модель применяли \( l_2 \) норму от query каждой головы, при обучении модели с нуля добавляли отдельную сеть. Выбирать будем k голов с самым высоким скором от Router.
Зафиксируем всё, что знаем на этом этапе, про коэффициент \( g_i \) в формуле MoH:
\( g_i = \begin{cases} \text{something}_i, & \text{if } 1 \leq i \leq h_s, \\ \text{something}_i, & \text{if } h_s + 1 \leq i \leq h \text{ and } i \text{ in Top-K}, \\ 0, & \text{otherwise}, \end{cases} \)
В формуле выше записано, что для всех shared и Top-K routed heads у нас есть какой-то коэффициент, неравный нулю, для остальных routed \( g_i = 0 \). Осталось определить, как получить эти ненулевые коэффициенты.
Для эмбеддинга каждого токена \( x_t, \ 1 <= t <= h \) их получают так:
\( g_i = \begin{cases} \alpha_1 \text{Softmax}(W_s x_t)_i, & \text{if } 1 \leq i \leq h_s, \\ \alpha_2 \text{Softmax}(W_r x_t)_i, & \text{if } (W_r x_t)_i \in \text{Top-K} \big(\{(W_r x_t)_i \,|\, h_s + 1 \leq i \leq h\}\big), \\ 0, & \text{otherwise}, \end{cases} \)
Теперь разберёмся, что написано выше 🙂. Для получения веса каждой головы мы заводим две матрицы \( W_s \in R^{h_s \times d_{in}}, W_r \in R^{(h-h_s) \times d_{in}} \) — для shared и routed соответственно, \( d_{in} \) — размер вектора \( x_t \). Тогда \( W_sx_t, W_rx_t \), — результат умножения матриц на входной токен, это вектора размером \( h_s \) и \( (h-h_s) \) соответственно. То есть для каждой головы мы получаем по одному числу на выходе. В коде такие матрицы можно записать через два независимых линейных слоя Linear(\( d_{in} \), \( h_s \), bias=False) и Linear(\( d_{in} \), \( h — h_s \), bias=False). После этого применяем softmax, чтобы сумма чисел для каждого из двух наборов голов равнялась единице.
Теперь мы получили индивидуальный вес для каждой головы \( \text{Softmax}(W_s x_t)_i \) для shared или \( \text{Softmax}(W_r x_t)i \) для routed. Кроме того, авторы ищут баланс между использованием shared и routed heads, то есть учитывают тип головы. Это делается через умножение индивидуального веса головы на скаляры \( \alpha_1, \alpha_2 \). Их мы получаем с помощью дополнительной обучаемой матрицы \( W_h \in R^{2 \times d{in}} \), \( [\alpha_1, \alpha_2] = \text{Softmax}(W_hx_t) \). Этот способ выбора весов назвали two-stage routing, где под первой стадией обозначили получение индивидуального веса, а под второй — балансирование между двумя типами голов через \( \alpha_1, \alpha_2 \).
Применение
Во-первых, можно с нуля обучать модель с MoH attention. Основным отличием в обучении будет необходимость добавить Load Balance Loss, как это делают при обучении MoE. Он нужен для того, чтобы Router более равномерно распределял задачи между экспертами, и не было экспертов, получавших слишком мало задач и потому недоучившихся до хорошего качества.
\( \mathcal{L} = \mathcal{L}_{task} + \beta\mathcal{L}_b, \quad\beta=0.01 \\ \mathcal{L}_b = \sum_{i=h_s+1}^{h} f_i P_i, \\ f_i = \frac{1}{T} \sum_{t=1}^{T} \mathbf{1}(\text{Token } \mathbf{x}_t \text{ selects Head } i), \\ P_i = \frac{1}{T} \sum_{t=1}^{T} \text{Softmax}(W_r \mathbf{x}_t)_i \)
Итоговый лосс \( \mathcal{L} \) представлен на формуле выше. Он является взвешенной суммой лосса для конкретной задачи \( \mathcal{L}_{task} \) (например, кросс-энтропия для классификации) и Load Balance Loss \( \mathcal{L}_b \). Послений считается только по routed heads и является суммой произведений двух величин для кажой routed головы:
- \( f_i \) — доля токенов, которые выбрали i-ю голову, чем чаще токены выбирают конкретную голову, тем выше \( f_i \);
- \( P_i \) — средний индивидуальный вес i-й головы.
Получается, что лосс тем больше, чем чаще мы выбираем конкретную голову и чем больше при этом у неё скор. То есть для его минимизации нужно как можно больше задействованных разных голов, при этом распределение индивидуального веса должно быть как можно ближе к равномерному. Это и позволяет обучать все головы, а не часть.
Других важных изменений для обучения с нуля нет, метрики для сравнения с vanila attention представлены в следующем разделе.
Во-вторых, можно дообучать существующую модель с заменой Multi-Head Atttention на Mixture-of-Head Attention. Тогда нужно учесть следующее:
- выбрали shared attention heads. В статье просто взяли первые 16 голов в каждом слое, но головы и так могут иметь специализацию — есть пространство для экспериментов;
- в изначальной сети не было Router, а обучение всех сетей с рандомно инициализированным Router могло бы занять много времени и испортить предобучение остальных блоков модели, поэтому для выбора routed heads использовали \( l_2 \) норму от query каждой головы. Интересно было бы посмотреть на эксперименты с обучением Router при замороженных весах или на что-то подобное;
- кроме выбора головы Router должен присваивать каждой из них скор \( g_i \). При этом важно учитывать, как именно весовые коэффициенты \( g_i \) влияют на выход слоя внимания. Если мы напрямую умножим выходы голов внимания на непрерывные веса, распределение активаций может существенно «поплыть». Чтобы этого избежать, в работе используют идею квантуемых (двоичных) скоров и straight-through estimator (STE):
- при forward pass используется индикаторная функция, которая даёт 0 или 1 в зависимости от выбора головы;
- при backward pass — STE, чтобы пропустить градиент через пороговое квантование и избежать его полного обнуления.
Это позволяет не разрушать уже обученные параметры модели и постепенно адаптироваться к новой механике маршрутизации.
Итак, с MoH attention можно и учить модели с нуля, добавив к обучению Router и новый лосс, и файнтюнить существующие модели.
Метрики
Теперь посмотрим на метрики для разных задач: классификации, генерации изображений и текста. Оценили их на Vision Transformers (ViT), Diffusion models with Transformers (DiT), Large Language Models (LLMs). Таблицы ниже отражают результаты:
Видно, что для генерации изображений нужно сохранить больше голов, чем в задаче классификации (75% для классификации и 90% для генерации). Возможно, это из-за того что нужна пиксельная точность, жаль что в статье не рассмотрели задачу сегментации.
Иногда 50% голов дают результат лучше, чем 75% и 100%. Как объясняют это авторы TruthfulQA, причина в следующем: «если ложные ответы изучаются из обучающего распределения, то ожидается, что более крупные модели, лучше его изучившие, будут чаще генерировать такие ложные ответы».
Можно отметить, что ~75% голов должно хватать для большинства задач. В статье есть все гиперпараметры экспериментов — очень радует, что таких статей становится всё больше 🙂
Полезные компоненты архитектуры
Авторы провели ablative analysis и проверили, какие компоненты являются самыми полезными:
Мы видим, что добавление shared heads заметно улучшает результаты. Авторы объясняют это тем, что теперь за общие знания отвечают shared heads, значит, routed могут стать более специализированными на конкретном домене. Two-stage routing улучшает результаты — теперь баланс между использованием shared и routed динамически подбирается и может оказаться лучше фиксированного. Two-stage routing нельзя сделать, если все головы routed, потому что нет смысла подбирать параметры \( \alpha_1, \alpha_2 \), так что такого варианта в таблице нет.
При этом не так важна доля shared heads, если она не экстремально большая или низкая. Авторы рекомендуют брать 40% или больше.
Заключение
В Mixture-of-Head Attention мы рассматриваем каждую голову внимания как «эксперта» и используем Router для выбора наиболее релевантных голов под конкретный токен, а Two-Stage Routing создаёт баланс между использованием routed и shared heads. Такой подход:
- уменьшает вычисления за счёт пропуска частей MHA без существенного падения качества;
- улучшает специализацию отдельных голов, в том числе благодаря разделению на shared и routed;
- может быть использован и при обучении с нуля, и при fine-tuning.
Но есть и пара направлений для развития:
- необходимость прода: в статье нет таблицы со сравнением по скорости;
- ограничения по размеру и задачам: не исследовали LLM больше 8B, для CV можно было бы взять задачу сегментации и детекции;
- выбор доли активных голов: в статье тестировали фиксированные значения (50%, 75%, 90%), но динамическое определение как в Dynamic-MoE может быть более полезным в замене attention;
- мультимодальность: пока нет полноценной проверки на данных, где требуется анализ нескольких типов сигналов — как будто здесь есть большой потенциал для сокращения числа голов за счёт специализации на разных типах.
Таким образом, MoH выглядит как удачное сочетание идей MHA и Mixture-of-Experts, упрощающее использование больших моделей и позволяющее более гибко распределять вычислительные ресурсы. Все параметры обучения открыты и чётко задокументированы, что является значимым вкладом для последующих экспериментов.
Полезные ссылки
- MoH: Multi-Head Attention as Mixture-of-Head Attention — статья
- Mixture of Attention Heads: Selecting Attention Heads Per To… — похожая статья 2022 года, но без two-stage routing, исследования применения к предобученным моделям и только на задаче машинного перевода