Flash Attention-2
Сегодня мы продолжим разбор оптимизаций механизма attention и обсудим FlashAttention-2. Эта версия даёт ускорение до 1.5–2× относительно FlashAttention-1 и в несколько раз превосходит стандартную реализацию attention в PyTorch. Разберём, за счёт каких изменений это достигается.
Ограничения FlashAttention-1
Как мы обсуждали в прошлом посте, ключевые улучшения первой версии — тайлинг матрицы внимания и онлайн-расчёт softmax — радикально снизили количество обращений к глобальной памяти GPU.
Однако ряд ограничений сохранялся. Значительная часть вычислений внутри CUDA-ядра была недостаточно распараллелена. Данные переиспользовались через shared memory, однако ядра загружались неравномерно. Несмотря на отсутствие необходимости материализовывать всю матрицу благодаря механизму тайлинга, появилась проблема постоянной синхронизации потоков (подробнее обсудим ниже при разборе архитектуры GPU), которая могла выполняться последовательно внутри блока потоков. Это создавало эффект «бутылочного горлышка».
Вспомним архитектуру GPU
Чтобы разобраться с улучшениями, нужно представить иерархию вычислений на видеокарте:
- Thread (поток) — минимальная единица выполнения. В задачах attention поток обычно участвует в вычислении фрагмента результата (например, блока матрицы), а сами операции выполняются с использованием векторизованных инструкций и tensor cores.
- Warp (варп) — группа из 32 потоков, которые выполняются синхронно. Если один поток в варпе ждёт данные из памяти, весь варп простаивает.
- Thread Block (блок потоков) — группа варпов, которые могут использовать общую быструю память (Shared Memory) и синхронизироваться друг с другом.
Во FlashAttention именно на уровне блоков происходит загрузка тайлов в Shared Memory. В первой версии внутри блока потоки часто ожидали друг друга, что снижало загрузку GPU и приводило к замедлению работы. Как именно вторая версия обошла это ограничение — разберём в следующих пунктах 🙂
Ключевая идея FlashAttention-2
Основным направлением улучшений стало увеличение параллелизма вычислений при сохранении тайлинга вычислений и без изменений самого механизма attention.
Улучшение tiling’а и появление work partitioning
Вторая версия алгоритма переработала схему разбивки работы между потоками. Теперь один блок потоков обрабатывает подматрицы, а не целые строки, как было в первой версии. Вычисления распределяются между несколькими варпами блока, где каждый warp отвечает за свою часть результата.
Такая параллелизация сокращает количество синхронизаций и уменьшает время ожидания результата между потоками. В результате утилизация GPU происходит более полно.
Благодаря новой схеме заполнения FA2 достигает 80-90% утилизации ****потоковых мультипроцессоров вместо 40-60% в FA1 на тех же задачах.
Параллельный softmax
Алгоритм softmax был адаптирован под новую схему параллелизма. Вместо того, чтобы одному потоку обрабатывать всю строку, вычисления разбиваются между несколькими потоками. Частичные максимумы и суммы считаются параллельно, а потом агрегируются на warp-уровне. Улучшение параллельности работы softmax также привело к снижению задержек в работе моделей.
Стоит отметить: авторы не выделяют улучшение softmax в отдельную метрику, так как он выполняется внутри единого CUDA-ядра. Однако они приводят данные о значительном снижении non-matmul FLOPs (операций, не являющихся матричным умножением). Если в первой версии FlashAttention такие вычисления становились «узким местом» из-за неоптимального параллелизма, то в FA-2 их вклад в общее время работы минимизировали.
Снижение накладных расходов
FlashAttention-2 также оптимизирует и другие операции для более эффективной обработки на GPU:
- Оптимизация регистров и борьба с Register Spilling. Регистры — самая быстрая память внутри видеокарты. Когда данных слишком много, они «выплёскиваются» (spilling) в более медленную память, что резко замедляет расчёты. В FA-2 перераспределили хранение данных так, чтобы всё нужное помещалось в регистры, сократив лишние обращения к памяти.
- Layout данных. Авторы изменили порядок расположения данных в памяти так, чтобы GPU мог считывать их максимально широкими «порциями» за один такт.
Также была улучшена поддержка различных размеров head dimension (теперь до 256 вместо 128) и batch size, что делает алгоритм универсальным для современных LLM.
Улучшение backward pass
Авторы переработали обратное распространение ошибки. Идея схожа с tiling в forward pass: вместо того, чтобы хранить полную матрицу весов внимания целиком, операции выполняются блочно. Это позволяет избежать материализации полной матрицы внимания в памяти.
Кроме того, за счёт лучшего распределения работы между потоками (аналогично forward-проходу), backward pass в FlashAttention-2 стал в 2–2.5 раза быстрее, чем в первой версии, практически сравнявшись по эффективности с прямым проходом.
Вывод
За счёт эффективного распараллеливания, уменьшения синхронизаций и улучшения tiling’а вторая версия достигает существенно большей производительности. Часто удаётся получить дополнительное ускорение относительно первой версии в 1.5-2 раза (рисунок). При этом сохраняются точность и линейная по памяти сложность. В итоге FlashAttention-2 продолжает подход совместного проектирования алгоритмов и архитектуры, показывая, что максимальная эффективность достигается только при совместной оптимизации алгоритма и архитектуры железа.

