Назад
45

MobileOne

45

Введение

Долгое время модели семейства MobileNet были хорошим выбором в контексте соотношения скорости работы и качества. Еще они зарекомендовали себя как энкодеры в задачах детекции и сегментации. Подробнее о них можно прочитать в нашем посте 😊

Однако прогресс не стоит на месте: ML-ресерчеры из Apple представили свое семейство быстрых моделей MobileOne с лучшим качеством на ImageNet в сравнении с MobileNet.

Официальный код с имплементацией MobileOne: https://github.com/apple/ml-mobileone.

Термины

Перед знакомством с моделями семейства MobileOne давайте разберем несколько терминов, которые нам понадобятся:

  1. FLOPs (FLoating-point OPerations per Second) — единица измерения вычислительной мощности процессоров в операциях с плавающей точкой. Важный параметр в контексте нейросетевых моделей: он показывает количество вычислений (операций с плавающей точкой) модели для получения результата. Следовательно, чем ниже этот параметр, тем лучше (меньшими усилиями мы получаем тот же результат).
  2. NMP (Number of Model Parameters) — количество параметров модели.
  3. MAC (Memory Access Cost) — временные затраты процессора на доступ к переменной в памяти.
  4. DOP (Degree Of Parallelism) — показатель, который отражает возможность распараллелить вычисления сети (насколько хорошо это можно сделать).
  5. Latency — время работы модели. Обычно считают в миллисекундах / ms.

Пара уточнений

  1. Далее везде инференс моделей будет происходить на процессоре A14 в Apple Iphone 12. Там нет доступа к командной строке или функциональности, позволяющей зарезервировать всю вычислительную структуру только для запуска модели, поэтому авторы написали специальное приложение на Swift для iOS. Для бенчмаркинга приложение запускают много раз (по умолчанию 1000) и усредняют полученные результаты.
  2. Периодически в посте будут появляться бранчи (branches) — дополнительные ответвления линейной архитектуры сети. Классический пример: skip connections. Multi branch — несколько branches.
  3. При описании MobileOne block будут упомянуты свертки depthwise, pointwise и separable. Подробнее о них можно прочитать в этом посте про виды 2D сверток.

Mobile problem

Кроме качества работы модели для многих реальных задач есть еще два ограничения:

  1. Размер модели должен быть небольшим, чтобы ее можно было использовать на мобильных устройствах и несильно мощных CPU-ядрах.
  2. Время работы тоже должно быть недлительным, чтобы конечный пользователь видел быстрый отклик сервиса или приложения, куда встроена модель.

В случае мобильных или real-time сценариев работы эти ограничения выходят на передний план.

Что не так с FLOPs и NMP в контексте latency?

Теперь давайте поговорим о том, что влияет на latency (время работы) модели.

Прогресс развития архитектур для мобильных устройств проходил с постоянным уменьшением FLOPs (единица измерения вычислительной мощности процессоров в операциях с плавающей точкой) и NMP (количество параметров модели) при одновременном улучшении качества их работы.

Но последние исследования показали: FLOPs и NMP не так хорошо коррелируют со скоростью работы модели.

  1. FLOPs не учитывает MAC (временные затраты процессора на доступ к переменной в памяти) и DOP (показатель, который отражает возможность распараллелить вычисления сети).
  2. В качестве иллюстрации этого эффекта для NMP вспомним операцию residual skip connection: она не содержит дополнительных параметров, но увеличивает latency модели, ведь нужно больше времени на доступ к тензорам в памяти.

Как можно количественно измерить корреляцию между FLOPs / NMP и latency?

Для этого авторы взяли самые эффективные по соотношению latency / accuracy CV-модели (они не указали какие именно, скорее всего, это были модели из Рисунка 6 в разделе про MobileOne блок), сконвертировали их в ONNX-формат и посчитали latency. Затем сделали тест ранговой корреляции Спирмена и получили следующие результаты:

Рисунок 1. Коэффициент ранговой корреляции Спирмена для latency / FLOPs и latency / NMP

Из Рисунка 1 мы видим: для мобильных устройств есть небольшая корреляция между FLOPs и latency и слабая — между NMP и latency. Для CPU корреляции практически нет.

Исследование боттлнеков

1. Функции активации

Ресерчеры пошли дальше и решили измерить влияние некоторых популярных функций активации на latency. Для этого они собрали CNN модель с 30 слоями и изменили в ней только функции активации (сделали так, чтобы везде в модели использовалась одна и та же функция активации). У полученных моделей измерялось latency.

Рисунок 2. Сравнение latency 30 layer CNN модели на мобильных устройствах в зависимости от используемых функций активации

По результатам экспериментов авторы решили использовать в MobileOne только функцию активации ReLU.

2. Архитектурные блоки

Теперь давайте детально разберем негативное влияние MAC и DOP на latency.

  1. MAC, multi-branch блок: при сложении бранчей (арифметическое сложение / конкатенация) мы должны обратиться к памяти, следовательно, MAC — боттлнек для latency. Его можно обойти путем уменьшения числа бранчей.
  2. DOP, squeeze-excitation block: такие операции глобального объединения, как average pooling в squeeze-excitation блоке, увеличивают latency модели. Необходимо синхронизировать параллельные вычисления для перехода на следующий этап — это и ограничивает DOP.

Для показа негативного влияния MAC и DOP авторы использовали skip connections и squeeze-excitation блок соответственно в полученной выше baseline модели с ReLU в качестве функции активации. Результаты представлены в таблице ниже:

Рисунок 3. Влияние squeeze-excitation блока и skip connections на latency baseline модели

Согласно результатам, авторы решили отказаться от дополнительных бранчей с помощью операции репараметризации (о ней подробно ниже) во время инференса (но не при обучении, так как бранчи дают прирост по качеству), а также использовать squeeze-excitation блок только в самой большой модели MobileOne-S4 для улучшения точности.

Ключевая идея

Таким образом, задача авторов — уменьшить latency и одновременно увеличить качество работы модели напрямую, а не только через FLOPs и NMP.

Для этого они предложили следующий трюк: использовать дорогие по MAC операции skip connections только во время обучения модели, а на инференсе делать репараметризацию, то есть заменять multi branch на одну свертку (что делает архитектуру сети линейной).

Таблица ниже иллюстрирует идею авторов на примере сравнения MobileOne-S1 с MobileNetV2 — x1.0.

Model NameNMPlatency, msImageNet 1k accuracy
MobileOne-S14.8 M0.8975.9
MobileNetV2 — x1.03.4 M0.9872.0

Репараметризация

Процедуру репараметризации авторы взяли из статьи RepVGG: Making VGG-style ConvNets Great Again. Ее суть — преобразовать веса бранчей в веса одной свертки путем простой арифметики.

Рисунок 4. Структурная репараметризация блока RepVGG. Перед сложением есть слой Batch Normalization у каждого бранча

Допустим, у нас есть три бранча: 3×3 Conv, 1×1 Conv и Identity со слоем Batch Normalization в конце каждого бранча. Свертку 1×1 Conv западдим нулями до размера 3×3. Identity + Batch Normalization рассмотрим сначала как упрощенную свертку 1×1 (где kernel — это просто единичная матрица), а после паддинга нулями — как свертку 3×3. Сложив отдельно kernels и biases полученных сверток, получим kernel и bias итоговой свертки.

Пусть:

  • \( M^{(1)}, M^{(2)} \) — вход и выход блока RepVGG
  • \( W \) — ядро свертки
  • \( bn \) — Batch Normalization слой
  • \( \mu, \sigma, \gamma, \beta \) — аккумулированные матожидание, стандартное отклонение, scaling фактор и bias блоков
  • \( (3), (1), (0) \) — индексы указывают принадлежность объекта к 3×3 conv, 1×1 conv и identity соответственно
  • \( C_{1}, C_{2} \) — количество входных и выходных каналов

Математически это можно записать следующим образом:

\( M^{(2)} = bn(M^{(1)} ∗ W^{(3)}, µ^{(3)}, σ^{(3)},γ^{(3)},β^{(3)}) \\ + bn(M^{(1)} ∗ W^{(1)}, µ^{(1)},σ^{(1)}, γ^{(1)},β^{(1)}) \\ + bn(M^{(1)}, µ^{(0)},σ^{(0)}, γ^{(0)}, β^{(0)}). \)

Слой Batch Normalization можно переписать в следующем виде, \( 1 \leq i \leq C_{2} \):

\( bn(M, µ,σ, γ, β)_{:,i,:,:} = (M_{:,i,:,:} − µ_{i})\frac{γ_{i}}{σ_{i}} + β_{i}. \)

Пусть \( (W’, b’) \) — kernel и bias, полученные из набора \( (W, µ,σ, γ,β) \). Тогда, \( 1 \leq i \leq C_{2} \):

\( W’_{i,:,:,:} = \frac{γ_{i}}{σ_{i}}W_{i,:,:,:}, \qquad b’_{i} = − \frac{µ_{i}γ_{i}}{σ_{i}} + β_{i}. \\ bn(M ∗ W, µ,σ, γ, β)_{:,i,:,:} = (M ∗ W’)_{:,i,:,:} + b’_{i}. \)

Рисунок 5. Математическая схема репараметризации. Обозначения отличаются от обозначений в RepVGG, но общий принцип тот же: собираем по отдельности kernels и biases в одну свертку

MobileOne block

Теперь разберем MobileOne блок. За его основу авторы взяли MobileNetV1 блок: сначала 3×3 depthwise свертка, затем 1×1 pointwise свертка. Но они дополнили его skip connections: использовали слои с маленьким количеством параметров — 1×1 Conv и Batch Norm.

При инференсе у модели уже нет никаких бранчей благодаря репараметризации. По сути в это время блок превращается в обычный MobileNetV1 блок.

Рисунок 6. Архитектура MobileOne блока имеет разную структуру во время обучения и инференса. Слева: блок с дополнительными бранчами во время обучения. Справа: блок без них во время инференса. k — гиперпараметр, который отдельно подбирался к каждому варианту MobileOne

Вводится гиперпараметр k, который принимает значения от 1 до 5 и отвечает за количество блоков с depthwise и pointwise свертками внутри MobileOne блока. Отметим, что при репараметризации мы просто суммируем по отдельности kernels и biases с каждого блока. Влияние гиперпараметра k на точность модели показано на Рисунке 7:

Рисунок 7. Сравнение качества работы версий S0 и S1 MobileOne в зависимости от значения гиперпараметра k

Как видно из Рисунка 8, модели семейства MobileOne превосходят остальные в соотношении latency / accuracy. Выигрывают они при этом не во FLOPs и NMP, а за счет архитектурных решений.

Рисунок 8. Сравнение версий MobileOne с другими популярными моделями классификации на базе сверток и трансформеров на ImageNet-1k validation set. Mobile latency — A14, GPU latency — RTX-2080Ti + TensorRT, CPU latency — Intel Xeon Gold 5118

Делаем семейство моделей

MobileOne наследует идеи по масштабированию в ширину от семейств MobileNet и EfficientNet: с помощью мультипликативного параметра \( \alpha \) мы можем сузить (\( \alpha \in [0, 1] \)) / расширить (\( \alpha > 1 \)) размер сети, уменьшив / увеличив число каналов в каждом слое.

Для MobileOne авторы решили увеличить число каналов, чтобы увеличить accuracy модели. Они использовали специальные наборы масштабов ширины — width multipliers. MobileOne не имеет многоразветвлённой архитектуры при инференсе, что снижает MAC. Следовательно, ее архитектуру можно масштабировать более агрессивно, чем архитектуры конкурентов. Всего авторы предложили 5 различных width multipliers.

Downstream tasks

Для задачи детекции авторы использовали MobileOne как feature extractor backbone в одностадийном детекторе SSD. При этом они заменили стандартные свертки в голове SSD на separable, получив таким образом модификацию SSD — SSDLite. Итоговый детектор обучался и валидировался на датасете MS COCO.

Для задачи сегментации MobileOne backbone встроили в сегментационную сеть DeepLabV3. Полученную модель обучали и валидировали по отдельности на датасетах Pascal VOC и ADE 20k, но с одинаковыми гиперпараметрами и аугментациями.

Как видно из Рисунка 9, MobileOne уверенно опережает другие популярные мобильные архитектуры.

Рисунок 9. Количественное сравнение версий MobileOne c другими моделями: в задаче детекции (a) и в задаче сегментации (b)

Выводы

  1. FLOPs и NMP влияют заметно слабее на latency модели, в отличие от дополнительных бранчей (например, skip connections) и операций глобального объединения в архитектуре, что обусловлено MAC и DOP.
  2. Тем не менее дополнительные бранчи дают прирост в качестве работы модели, поэтому имеет смысл использовать их в обучении и убирать с помощью репараметризации во время инференса.
  3. Такой подход не только позволяет сохранить latency без падения качества, но и хорошо генерализуется для задач детекции и сегментации.

Ссылки

Как можно воспользоваться в своих проектах:

  • mmclassification — интерфейс через фреймворк mmclassification, здесь также воспроизведено обучение на ImageNet;
  • timm — интерфейс через фреймворк timm.

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!