DISTIL-WHISPER: Robust knowledge distillation via large-scale pseudo labeling
- Введение
- Distil-Whisper: основные составляющие
- Кратко про архитектуру
- Грамотная инициализация
- KD по словам
- Sequence-level KD
- Фильтрация предсказаний
- Ускорение самой большой модели и только декодера
- Обработка длинных последовательностей
- Speculative Decoding
- Результаты
- Ablation studies
- Фильтрация по WER
- Обучающая выборка
- Размеры модели
- Выводы
Введение
В 2022 году вышла модель от OpenAI, которая, по их заявлениям, решила задачу Speech2Text. Хоть у этого заявления и были противники, модель хорошо себя зарекомендовала. Я ей пользовался, когда нужно было сделать s2t «нестандартных» текстов (обзор новостей из мира игр).
И спустя год Hugging Face выкатили более лёгкую версию, которую назвали Distil-Whisper. Авторы говорят, что им удалось ускорить модель в 6 раз, подрезав параметры на 50% и потеряв всего 1% WER на out of distribution (OOD) тестовых данных. В этой статье мы разберёмся, как у них получился такой результат.
Distil-Whisper: основные составляющие
В качестве основных компонентов Distil-Whisper можно выделить следующие:
- Грамотная инициализация модели-ученика;
- Дистилляция по словам (авторы это называют word-level KD);
- Дистилляция не только по словам, но и по наиболее вероятной последовательности, сгенерированной учителем (авторы называют это sequence-level KD и псевдо-лейбеллинг);
- Фильтрация предсказаний, полученных при помощи псевдо-лейбеллинга;
- Ускорение самой большой модели и только декодера.
Кратко про архитектуру
Whisper — это Seq2seq модель, енкодер и декодер которой являются трансформерами. На вход в енкодер ей подаются мел спектрограммы, а токены предсказываются уже в авторегрессионном режиме в декодере.
Вся вычислительная сложность сосредоточена именно в декодере, который гоняется в “цикле for”. Его и оптимизируют авторы. Отсюда и получается ускорение x6 при выбрасывании 50% параметров.
Грамотная инициализация
Авторы инициализируют модель-ученика весами учителя, максимально разнесенными по архитектуре.
В качестве примера рассмотрим такую ситуацию: есть модель из 6-и слоев, мы хотим натренировать модель из 3-х слоев. Для этого берем 0, 3, 5 (первый и последний включены, а по середине опционален). Если из 2-х слоев, то 0 и 5 (максимально разнесенные по архитектуре). Если надо взять всего 1 слой, то берем 0, не задумываясь. Это исследовалось в одной из предыдущих работ авторов. При этом саму конфигурацию слоёв мы не меняем.
KD по словам
Модель учится предсказывать текст токен за токеном. Поэтому вместо использования оригинальной разметки мы применяем предсказания учителя, которые имеют более сглаженную природу.
где q — предсказания учителя, p — предсказания ученика, \( |\nu| \) — количество токенов в словаре, J — количество токенов в целевом тексте. Жирным выделены векторные величины.
Обозначение \( p(t_j=k|\mathbf{s}, \mathbf{t}_j) \) следует читать как «вероятность того, что \( j \)–й токен равен \( k \) при условии того, что мы сгенерировали предыдущие \( j-1 \) токенов». То же самое и для вероятностей \( q \) учителя.
Sequence-level KD
Предсказывая по словам, мы не улавливаем «временну́ю» составляющую того, как работает учитель. Но её можно уловить, если обучать ученика предсказывать последовательности также, как это делает учитель.
\( {L}_{\text{SEQ-KD}} = — \sum_{\mathbf{t} \in {T}} q(\mathbf{t} | s) \log p(\mathbf{t} | s) \)
При этом суммирование тут идет по ВСЕМ последовательностям токенов \( \mathbf{t} \), поэтому оно отличается от \( {L}_{WORD-KD} \).
Пройтись по всем последовательностям токенов невозможно. По этой причине авторы предлагают считать только ту последовательность у учителя, которая имеет наибольшую вероятность. Но и её они не могут нормально посчитать, поэтому говорят: давайте выберем последовательность, которая сгенерирована бимсёрчем (от англ. beam search, поиск луча).
В своей предыдущей статье авторы объясняют это так: ну бимсерч же работает, ну чё вы)))
Фильтрация предсказаний
Давайте брать для Sequence-Level KD только те последовательности, которые максимально близки к оригинальной по WER (WER — количество ошибок в предложении). Действительно, зачем нам учиться на мусоре? В итоге получаем подход Sequence-Level Interpolation. Вот иллюстрация из их предыдущей статьи, на которой видны отличия всех 3 подходов.
При World-Level KD минимизируется кросс-энтропия между распределениями учителя и ученика (на рисунке выше обозначено жёлтым) для каждого слова в последовательности. Помимо этого, минимизируется кросс-энтропия между распределением ученика истинными метками (обозначено как Ground Truth, чёрным цветом). При Sequence-Level KD сеть ученика обучается на beamsearch-выходе учителя с наибольшим скором. При Sequence-Level Interpolation ученик обучается на beamsearch-выходе учителя, который ближе всего к GT.
Вообще, автор Alexander Rush поступил интересно. Он много работал с обычными Seq2Seq моделями в плане языков, а потом применил свои наработки в этой статье. Забавно, что в NLP подход по фильтрации они мотивируют тем, что данных становится больше, считать долго, а тут на 20к часах учились и ничего, пойдёт 🙂
Ускорение самой большой модели и только декодера
На этом пункте мы остановимся подробнее. Каждый, кто прочитал по ускорению моделей хотя бы пару статей, знает: если взять VGG-19, то можно получить ускорение 90% на любом алгоритме (квантование, прунинг, дистилляция) с потерей 0% точности.
На мой взгляд, здесь похожая ситуация. Авторы взяли большую модель, у которой 20 голов и размер ембеддингов 1280. А ещё енкодер авторы не ускоряют, что тоже сыграет свою роль (мы в этом убедимся, когда будем смотреть ablation studies). В итоге архитектура Distil-Whisper выглядит так:
Обработка длинных последовательностей
В самой Whisper была сложная схема с обработкой длинных последовательностей: там обрабатывались 30-секундные последовательности, в которых сама Whisper предсказывала отступ до следующего фрагмента. Тут авторы решили просто предсказывать 30-секундные фреймы с перекрытием, а потом мёржить текст.
Speculative Decoding
Интересная техника, которую используют авторы. Суть в том, чтобы большая модель корректировала предсказания маленькой модели, если она их предсказала «не так». Ниже показано, как это выглядит. Зелёный цвет — это правильно предсказанные токены, красный — неправильно, а синий — корректировки. Количество запусков процедуры равняется количеству строк.
Если подробнее: у вас есть две модели seq2seq — большая и маленькая, \( M_{p}, M_{q} \). Пусть у вас есть некое стартовое предложение (prefix). Берём маленькую модельку, предсказываем \( \gamma \) токенов авторегрессионным способом с вероятносятми \( q_{1}, q_{2}, q_{3}… q_{\gamma} \). Обозначаем их как \( x_{1}, x_{2}, x_{3}… x_{\gamma} \). Затем берём большую модель и предсказываем в ПАРАЛЛЕЛЬ (по сути только один символ) для каждой из последовательностей \( [prefix,x_{1}]; [prefix,x_{1:2}], [prefix,x_{1:3}], [prefix,x_{1:\gamma-1}] \) с вероятностями \( p_1, p_2 … p_\gamma \). Далее на основе процедуры ниже находим первый символ, на котором ваша маленькая моделька «ошиблась» (если простым языком, то это первый, у которого \( q_{i} > p_{i} \)). И затем корректируем предсказания данного символа при помощи вероятностного алгоритма ниже:
В чём же профит? В том, что большая модель работает не в авторегрессионном режиме, а значит, предсказание можно загнать в батч, и все будет работать быстрее.
Результаты
Ниже представлена таблица, где авторы показали результаты работы:
Они также проверили, что speculative decoding добрасывает в скорости x2 по отношению к large-v2. Табличка по ускорениям представлена ниже.
Ablation studies
Фильтрация по WER
Тут авторы наглядно продемонстрировали принцип GIGO (Garbage in, garbage out). Если ничего не фильтровать — результаты получаются так себе. При этом, если фильтровать слишком много — видимо, dark knowledge не будет хватать. Порога для WER, равного 40, оказалось достаточно (что, кстати, немного удивительно, ведь WER=40 это много).
Обучающая выборка
Окей, а сколько нам надо данных? Авторы тоже это проверили.
Выяснили следующее: при In Domain распознавании на выборке до 1700 качество падает значительно, а вот для Out of Domain после 3к часов WER практически не падает. Видимо, не хватает, всё-таки, обобщающей способности.
Размеры модели
А здесь подтверждение того, что непрунящийся енкодер решает многое. Если подрезать его в 2 раза — итоговый WER сразу подскакивает на 3%.
Выводы
В целом, статья классная. Авторы пожали достаточно сильно свою сетку, не потеряв в качестве, и провели хороший анализ результатов. Лично я для себя вынес технику speculative decoding в качестве нестандартного способа ускорения. Особенно понравилось, что все вопросы, которые у меня возникли во время чтения, авторы проверили при анализе результатов.