Назад
66

DISTIL-WHISPER: Robust knowledge distillation via large-scale pseudo labeling

66

Введение

В 2022 году вышла модель от OpenAI, которая, по их заявлениям, решила задачу Speech2Text. Хоть у этого заявления и были противники, модель хорошо себя зарекомендовала. Я ей пользовался, когда нужно было сделать s2t «нестандартных» текстов (обзор новостей из мира игр).

И спустя год Hugging Face выкатили более лёгкую версию, которую назвали Distil-Whisper. Авторы говорят, что им удалось ускорить модель в 6 раз, подрезав параметры на 50% и потеряв всего 1% WER на out of distribution (OOD) тестовых данных. В этой статье мы разберёмся, как у них получился такой результат.

Distil-Whisper: основные составляющие

В качестве основных компонентов Distil-Whisper можно выделить следующие:

  1. Грамотная инициализация модели-ученика;
  2. Дистилляция по словам (авторы это называют word-level KD);
  3. Дистилляция не только по словам, но и по наиболее вероятной последовательности, сгенерированной учителем (авторы называют это sequence-level KD и псевдо-лейбеллинг);
  4. Фильтрация предсказаний, полученных при помощи псевдо-лейбеллинга;
  5. Ускорение самой большой модели и только декодера.

Кратко про архитектуру

Whisper — это Seq2seq модель, енкодер и декодер которой являются трансформерами. На вход в енкодер ей подаются мел спектрограммы, а токены предсказываются уже в авторегрессионном режиме в декодере.

Вся вычислительная сложность сосредоточена именно в декодере, который гоняется в “цикле for”. Его и оптимизируют авторы. Отсюда и получается ускорение x6 при выбрасывании 50% параметров.

Рисунок 1. Архитектура сети Whisper

Грамотная инициализация

Авторы инициализируют модель-ученика весами учителя, максимально разнесенными по архитектуре.

В качестве примера рассмотрим такую ситуацию: есть модель из 6-и слоев, мы хотим натренировать модель из 3-х слоев. Для этого берем 0, 3, 5 (первый и последний включены, а по середине опционален). Если из 2-х слоев, то 0 и 5 (максимально разнесенные по архитектуре). Если надо взять всего 1 слой, то берем 0, не задумываясь. Это исследовалось в одной из предыдущих работ авторов. При этом саму конфигурацию слоёв мы не меняем.

KD по словам

Модель учится предсказывать текст токен за токеном. Поэтому вместо использования оригинальной разметки мы применяем предсказания учителя, которые имеют более сглаженную природу.

\( {L}_{\text{WORD-KD}} = — \sum_{j=1}^J \sum_{k=1}^{|\nu|} q(t_j = k | \mathbf{s}, \mathbf{t}_{<j}) \times {\log p(t_j = k | \mathbf{s}, \mathbf{t}_{<j})} \)

где q — предсказания учителя, p — предсказания ученика, \( |\nu| \) — количество токенов в словаре, J — количество токенов в целевом тексте. Жирным выделены векторные величины.

Обозначение \( p(t_j=k|\mathbf{s}, \mathbf{t}_j) \) следует читать как «вероятность того, что \( j \)–й токен равен \( k \) при условии того, что мы сгенерировали предыдущие \( j-1 \) токенов». То же самое и для вероятностей \( q \) учителя.

Sequence-level KD

Предсказывая по словам, мы не улавливаем «временну́ю» составляющую того, как работает учитель. Но её можно уловить, если обучать ученика предсказывать последовательности также, как это делает учитель.

\( {L}_{\text{SEQ-KD}} = — \sum_{\mathbf{t} \in {T}} q(\mathbf{t} | s) \log p(\mathbf{t} | s) \)

При этом суммирование тут идет по ВСЕМ последовательностям токенов \( \mathbf{t} \), поэтому оно отличается от \( {L}_{WORD-KD} \).

Пройтись по всем последовательностям токенов невозможно. По этой причине авторы предлагают считать только ту последовательность у учителя, которая имеет наибольшую вероятность. Но и её они не могут нормально посчитать, поэтому говорят: давайте выберем последовательность, которая сгенерирована бимсёрчем (от англ. beam search, поиск луча).

В своей предыдущей статье авторы объясняют это так: ну бимсерч же работает, ну чё вы)))

Фильтрация предсказаний

Давайте брать для Sequence-Level KD только те последовательности, которые максимально близки к оригинальной по WER (WER — количество ошибок в предложении). Действительно, зачем нам учиться на мусоре? В итоге получаем подход Sequence-Level Interpolation. Вот иллюстрация из их предыдущей статьи, на которой видны отличия всех 3 подходов.

Рисунок 2. Отличие подходов World-Level KD, Sequence-Level KD и Sequence-Level Interpolation

При World-Level KD минимизируется кросс-энтропия между распределениями учителя и ученика (на рисунке выше обозначено жёлтым) для каждого слова в последовательности. Помимо этого, минимизируется кросс-энтропия между распределением ученика истинными метками (обозначено как Ground Truth, чёрным цветом). При Sequence-Level KD сеть ученика обучается на beamsearch-выходе учителя с наибольшим скором. При Sequence-Level Interpolation ученик обучается на beamsearch-выходе учителя, который ближе всего к GT.

Вообще, автор Alexander Rush поступил интересно. Он много работал с обычными Seq2Seq моделями в плане языков, а потом применил свои наработки в этой статье. Забавно, что в NLP подход по фильтрации они мотивируют тем, что данных становится больше, считать долго, а тут на 20к часах учились и ничего, пойдёт 🙂

Ускорение самой большой модели и только декодера

На этом пункте мы остановимся подробнее. Каждый, кто прочитал по ускорению моделей хотя бы пару статей, знает: если взять VGG-19, то можно получить ускорение 90% на любом алгоритме (квантование, прунинг, дистилляция) с потерей 0% точности.

На мой взгляд, здесь похожая ситуация. Авторы взяли большую модель, у которой 20 голов и размер ембеддингов 1280. А ещё енкодер авторы не ускоряют, что тоже сыграет свою роль (мы в этом убедимся, когда будем смотреть ablation studies). В итоге архитектура Distil-Whisper выглядит так:

Рисунок 3. Арихтектура Distil-Whisper

Обработка длинных последовательностей

В самой Whisper была сложная схема с обработкой длинных последовательностей: там обрабатывались 30-секундные последовательности, в которых сама Whisper предсказывала отступ до следующего фрагмента. Тут авторы решили просто предсказывать 30-секундные фреймы с перекрытием, а потом мёржить текст.

Speculative Decoding

Интересная техника, которую используют авторы. Суть в том, чтобы большая модель корректировала предсказания маленькой модели, если она их предсказала «не так». Ниже показано, как это выглядит. Зелёный цвет — это правильно предсказанные токены, красный — неправильно, а синий — корректировки. Количество запусков процедуры равняется количеству строк.

Рисунок 4. Алгоритм фильтрации

Если подробнее: у вас есть две модели seq2seq — большая и маленькая, \( M_{p}, M_{q} \). Пусть у вас есть некое стартовое предложение (prefix). Берём маленькую модельку, предсказываем \( \gamma \) токенов авторегрессионным способом с вероятносятми \( q_{1}, q_{2}, q_{3}… q_{\gamma} \). Обозначаем их как \( x_{1}, x_{2}, x_{3}… x_{\gamma} \). Затем берём большую модель и предсказываем в ПАРАЛЛЕЛЬ (по сути только один символ) для каждой из последовательностей \( [prefix,x_{1}]; [prefix,x_{1:2}], [prefix,x_{1:3}], [prefix,x_{1:\gamma-1}] \) с вероятностями \( p_1, p_2 … p_\gamma \). Далее на основе процедуры ниже находим первый символ, на котором ваша маленькая моделька «ошиблась» (если простым языком, то это первый, у которого \( q_{i} > p_{i} \)). И затем корректируем предсказания данного символа при помощи вероятностного алгоритма ниже:

Рисунок 5. Алгоритмическое описание метода Speculative Decoding

В чём же профит? В том, что большая модель работает не в авторегрессионном режиме, а значит, предсказание можно загнать в батч, и все будет работать быстрее.

Результаты

Ниже представлена таблица, где авторы показали результаты работы:

Рисунок 6. Сравнение Distil-Whisper и обычного Whisper

Они также проверили, что speculative decoding добрасывает в скорости x2 по отношению к large-v2. Табличка по ускорениям представлена ниже.

Рисунок 7. Влияние Speculative Decoding

Ablation studies

Фильтрация по WER

Тут авторы наглядно продемонстрировали принцип GIGO (Garbage in, garbage out). Если ничего не фильтровать — результаты получаются так себе. При этом, если фильтровать слишком много — видимо, dark knowledge не будет хватать. Порога для WER, равного 40, оказалось достаточно (что, кстати, немного удивительно, ведь WER=40 это много).

Рисунок 8. Тестирование важности выбора порога по WER для фильтрации данных

Обучающая выборка

Окей, а сколько нам надо данных? Авторы тоже это проверили.

Рисунок 9. Влияние размера датасета на качество модели

Выяснили следующее: при In Domain распознавании на выборке до 1700 качество падает значительно, а вот для Out of Domain после 3к часов WER практически не падает. Видимо, не хватает, всё-таки, обобщающей способности.

Размеры модели

А здесь подтверждение того, что непрунящийся енкодер решает многое. Если подрезать его в 2 раза — итоговый WER сразу подскакивает на 3%.

Рисунок 10. Влияние размера енкодера на итоговый WER

Выводы

В целом, статья классная. Авторы пожали достаточно сильно свою сетку, не потеряв в качестве, и провели хороший анализ результатов. Лично я для себя вынес технику speculative decoding в качестве нестандартного способа ускорения. Особенно понравилось, что все вопросы, которые у меня возникли во время чтения, авторы проверили при анализе результатов.

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!