Назад
87

Segment Anything (SAM)

87
Рисунок 1. Новая задача, новая модель, новый датасет — всё в одной статье!

Задача статьи — сделать zero-shot и few-shot сегментацию простой и доступной. Для начала давайте вспомним базовый вариант задачи интерактивной сегментации:

  1. Прогоняем картинку через сеть и получаем маску (как при обычной сегментации);
  2. Пользователь кликами или выделением bounding box указывает модели на еще не выделенные маской части объекта (или выделенные зря);
  3. Подсказка от пользователя (prompt) кодируется и передается в сеть;
  4. Сеть прогоняется еще раз уже с промптом от пользователя во входных данных;
  5. Снова выполняем действия из 2-го пункта до тех пор, пока маска не окажется достаточно точной. На 3-м шаге добавляем к исходному изображению и подсказке маску сегментации с предыдущего цикла.

В статье же в качестве prompt также используется текст. Сама по себе идея подавать текст на вход для zero-shot сегментации появилась раньше, например, в  “Image Segmentation Using Text and Image Prompts” уже использовали эмбеддинги CLIP.

Здесь новшество именно в совмещении текстовых эмбеддингов и процесса интерактивной сегментации. Стоит отметить, что в опубликованном коде пока нет возможности использовать текст, но это обещают добавить.

А как именно мы кодируем и передаем в сеть три вида промптов: текст, клики и рамки?

Prompt encoding

Рисунок 2. Архитектура Segmentat Anything Model (SAM)

В отличие от работы, которую мы рассмотрели в прошлый раз, здесь клики и боксы кодируются не в виде матрицы, а как наборы точек. Для кликов — это координаты и бинарная метка (объект или фон (x, y, is_foreground)), для боксов — координаты углов (x1, x2, y1, y2).

При этом, к каждому набору координат применяется “positional encoding” (см. Рисунок 3). Элементы матрицы B в формуле сэмплируются из нормального распределения. По сути мы представляем точку распределением точек. Более подробное объяснение можно прочитать в статье.

Рисунок 3. Positional encoding для координат: B — матрица из случайных величин с нормальным распределением и нулевым матожиданием, v — вектор координат

После positional encoding вектор суммируется с выученным эмбеддингом для каждого типа. Для извлечения эмбеддинга из текста используется энкодер из CLIP. Все эти чудеса подаются напрямую в mask decoder (см. Рисунок 4).

Рисунок 4. Архитектура SAM без ViT image encoder (на инференсе после отклика пользователя запускается именно эта часть)

С маской из предыдущего шага все еще проще — прогоняем ее через несколько сверточных слоев и получаем тензор того же размера, что и эмбеддинг изображения. В коде это 256x64x64. Затем эмбеддинги маски и изображения поэлементно складываются. Если это первый прогон — мы данный шаг пропускаем. Идея, кстати, почти без изменений пришла еще из RITM.

Mask Decoder

Декодер делится на три части:

  1. Модифицированный декодер трансформера. Модификации включают в себя:
    1. self-attention только для промпта;
    2. cross-attention в двух направлениях (от изображения к промпту и от промта к изображению для обновления вообще всех эмбеддингов).
  2. Upscale с помощью transposed convolution.
  3. MLP для предсказания масок, как в MaskFormer.
Рисунок 5. Более подробно про mask decoder. C — количество масок, в статье равно трём, N_tokens — размер эмбеддинга трансформера, N — количество каналов в пространственном эмбединге после upscale, H и W — высота и ширина исходного изображения.

В статье предсказывают по три маски для одного изображения (Рисунок 4), каждой из них соответствует свой MLP. Каждый MLP выдаёт N весов, которые используются для взвешенной суммы upscaled feature maps, которых тоже N. Выходы из трех MLP конкатенируются в эмбеддинг размером (batch_size, 3, N). Пространственный эмбединг (upscaled feature maps) имеет размер (batch_size, N, H, W). Затем два этих представления совмещаются с помощью dot product. В результате получаем тензор размера (3, H, W) — это и есть наши маски.

Идея архитектуры идет из уже упомянутого нами MaskFormer’а. Подробнее объясняем магию в выводе маски в колабе.

Еще один MLP предсказывает confidence score для каждой из масок. Это решает следующую проблему: иногда нельзя однозначно понять, какой объект ищется на фото (см. Рисунок 6). Лосс считается по маске с минимальной ошибкой, а confidence scores учатся предсказывать IOU между предсказанной и истинной масками.

Рисунок 6. Можно предсказать и одежду, и манекен, что создает проблему двусмысленности (ambiguity)

В качестве энкодера используются ViT внушительного размера. Самая крупная модель, ViT-Huge, имеет более 600M параметров (при 4М у декодера).

Для сокращения времени запуска image encoder прогоняется только один раз. При взаимодействии с пользователем запускается заново только prompt и mask encoder, mask decoder.

Хорошо, вот мы и разобрались с тем, как передать промпт в сеть и почему ее запуск не занимает миллионы лет. Остался открытый вопрос — как сеть училась?

Training

TL;DR: очень много данных, несколько итераций псевдолейблинга и переобучения моделей.

Процесс обучения модели неразрывно связан с разметкой датасета. Он делится на три этапа:

  1. Assisted-manual stage. Модель предсказывает маски с помощью кликов, затем маски вручную исправляются аннотаторами. Стараются размечать вообще все: и объекты, и то, что обычно относят к фону. В процессе разметки модель обучается на новых аннотациях. Всего на этой стадии модель переучивают 6 раз. Среднее число масок на изображении удается поднять с 20 до 44, собирают 4.3M масок 120k фото.
  2. Semi-automatic stage. Основная цель — повышение разнообразия предсказываемых масок для успеха в zero-shot segmentation. Как это происходит? Модель прогоняют на изображении и сохраняют самые уверенные маски. Затем просят аннотаторов добавить все маски, какие еще имеются на изображении. Также обучается детектор на всех объектах с первой стадии на единственный класс “объект”. Результат предсказания детектора используется как промпт. Модель переучивают 5 раз. Среднее число масок на изображении поднимается до 77, собирают 5.9M масок 180k фото (всего 10.2М масок с двух стадий).
  3. Fully automatic stage. Качество модели поднялось, следовательно, теперь можно попробовать исключить из процесса человека. Но что делать с неоднозначными ситуациями (см. Рисунок 6)? Авторы предложили нам такой алгоритм:
    1. Разбиваем изображение на сетку 32×32. В центре каждой ячейки ставим по точке, передаем ее как промпт в модель;
    2. Для каждого промпта предсказываем три маски и отсекаем маски с низким confidence;
    3. Для оставшихся масок проверяем наличие хотя бы одной схожей маски от другого промпта. Удаляем не имеющие пары маски, применяем NMS для удаления дубликатов;
    4. Некоторые участки изображения дополнительно прогоняем отдельно для уточнения контуров небольших масок.

Итоговая модель переразметила все 11 миллионов фото, мы получили более 1.1 миллиарда масок. Важная деталь о масках из датасета: в итоговом наборе все маски размечены моделью. При этом, люди разметили всего 10.2M масок (около 1% от общего числа).

Однако авторы спешат убедить нас в качественной разметке. Для этого они дали аннотаторам подправить вручную 50k масок. Далее они сравнили IoU исходной маски и переразмеченной человеком. Для 94% пар IoU > 0.9, для 97% пар IoU > 0.75. Это соответствует согласованности разметки для людей-аннотаторов. Правда, стоит сделать поправку на то, что люди только поправляли маски, уже полученные из сети. Такая метрика не учитывает тех ситуаций, когда нужный объект не был выделен вовсе.

Интересные картинки

Рисунок 7. Еще интересная деталь — очень много фото из России. Толока?
Рисунок 8. Примеры фото с более чем 300 масок (всего на рисунке более 4800 масок)
Рисунок 9. Высокое разрешение играет важную роль: без него получается не так детально 🙂
Рисунок 10. Модель не видела ground truth во время обучения
Рисунок 11. Но в специфических задачах нужен fine-tuning. Слева — попытка выделить текст на популярном меме в режиме интерактивной сегментации, по центру — только автоматическая аннотация, справа — выделение с помощью bounding box

Итоги

Итак, что мы имеем:

  1. Аналог GPT-3 в задачах разметки для computer vision. Тоже может работать с кучей задач, но в некоторых из них проигрывает более узконаправленным инструментам.
  2. Пока нельзя оценить работу с текстом, без него иногда не отделяет какие-то семантически важные объекты от фона (см. Рисунок 11).
  3. Лицензия на модель — Apache 2.0, но датасет пока только для некоммерческого использования. Вопрос о лицензии на веса уже задали. Если разрешат коммерческое использование весов, тогда будем ждать плагины ко всем популярным инструментам разметки.
  4. Слишком медленные предсказания — 0.15 сек на A100 для ViT image encoder, 50ms на CPU для всего остального. Для realtime не подойдет, но сможет сильно помочь как инструмент интерактивной разметки. Особенно это актуально при близости размечаемых объектов по домену к фото в датасете (иначе см. Рисунок 11).
  5. Для решения проблемы со специфическими данными можно попробовать дообучить легкие компоненты (все, кроме image encoder’а). Пишите, если попробовали, интересно узнать результаты 🙂

Напоследок прикрепляем сайт с примерами, ссылками на demo, датасет и статью.

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!