Segment Anything (SAM)
Задача статьи — сделать zero-shot и few-shot сегментацию простой и доступной. Для начала давайте вспомним базовый вариант задачи интерактивной сегментации:
- Прогоняем картинку через сеть и получаем маску (как при обычной сегментации);
- Пользователь кликами или выделением bounding box указывает модели на еще не выделенные маской части объекта (или выделенные зря);
- Подсказка от пользователя (prompt) кодируется и передается в сеть;
- Сеть прогоняется еще раз уже с промптом от пользователя во входных данных;
- Снова выполняем действия из 2-го пункта до тех пор, пока маска не окажется достаточно точной. На 3-м шаге добавляем к исходному изображению и подсказке маску сегментации с предыдущего цикла.
В статье же в качестве prompt также используется текст. Сама по себе идея подавать текст на вход для zero-shot сегментации появилась раньше, например, в “Image Segmentation Using Text and Image Prompts” уже использовали эмбеддинги CLIP.
Здесь новшество именно в совмещении текстовых эмбеддингов и процесса интерактивной сегментации. Стоит отметить, что в опубликованном коде пока нет возможности использовать текст, но это обещают добавить.
А как именно мы кодируем и передаем в сеть три вида промптов: текст, клики и рамки?
Prompt encoding
В отличие от работы, которую мы рассмотрели в прошлый раз, здесь клики и боксы кодируются не в виде матрицы, а как наборы точек. Для кликов — это координаты и бинарная метка (объект или фон (x, y, is_foreground)
), для боксов — координаты углов (x1, x2, y1, y2)
.
При этом, к каждому набору координат применяется “positional encoding” (см. Рисунок 3). Элементы матрицы B
в формуле сэмплируются из нормального распределения. По сути мы представляем точку распределением точек. Более подробное объяснение можно прочитать в статье.
После positional encoding вектор суммируется с выученным эмбеддингом для каждого типа. Для извлечения эмбеддинга из текста используется энкодер из CLIP. Все эти чудеса подаются напрямую в mask decoder (см. Рисунок 4).
С маской из предыдущего шага все еще проще — прогоняем ее через несколько сверточных слоев и получаем тензор того же размера, что и эмбеддинг изображения. В коде это 256x64x64. Затем эмбеддинги маски и изображения поэлементно складываются. Если это первый прогон — мы данный шаг пропускаем. Идея, кстати, почти без изменений пришла еще из RITM.
Mask Decoder
Декодер делится на три части:
- Модифицированный декодер трансформера. Модификации включают в себя:
- self-attention только для промпта;
- cross-attention в двух направлениях (от изображения к промпту и от промта к изображению для обновления вообще всех эмбеддингов).
- Upscale с помощью transposed convolution.
- MLP для предсказания масок, как в MaskFormer.
В статье предсказывают по три маски для одного изображения (Рисунок 4), каждой из них соответствует свой MLP. Каждый MLP выдаёт N весов, которые используются для взвешенной суммы upscaled feature maps, которых тоже N. Выходы из трех MLP конкатенируются в эмбеддинг размером (batch_size, 3, N). Пространственный эмбединг (upscaled feature maps) имеет размер (batch_size, N, H, W). Затем два этих представления совмещаются с помощью dot product. В результате получаем тензор размера (3, H, W) — это и есть наши маски.
Идея архитектуры идет из уже упомянутого нами MaskFormer’а. Подробнее объясняем магию в выводе маски в колабе.
Еще один MLP предсказывает confidence score для каждой из масок. Это решает следующую проблему: иногда нельзя однозначно понять, какой объект ищется на фото (см. Рисунок 6). Лосс считается по маске с минимальной ошибкой, а confidence scores учатся предсказывать IOU между предсказанной и истинной масками.
В качестве энкодера используются ViT внушительного размера. Самая крупная модель, ViT-Huge, имеет более 600M параметров (при 4М у декодера).
Для сокращения времени запуска image encoder прогоняется только один раз. При взаимодействии с пользователем запускается заново только prompt и mask encoder, mask decoder.
Хорошо, вот мы и разобрались с тем, как передать промпт в сеть и почему ее запуск не занимает миллионы лет. Остался открытый вопрос — как сеть училась?
Training
TL;DR: очень много данных, несколько итераций псевдолейблинга и переобучения моделей.
Процесс обучения модели неразрывно связан с разметкой датасета. Он делится на три этапа:
- Assisted-manual stage. Модель предсказывает маски с помощью кликов, затем маски вручную исправляются аннотаторами. Стараются размечать вообще все: и объекты, и то, что обычно относят к фону. В процессе разметки модель обучается на новых аннотациях. Всего на этой стадии модель переучивают 6 раз. Среднее число масок на изображении удается поднять с 20 до 44, собирают 4.3M масок 120k фото.
- Semi-automatic stage. Основная цель — повышение разнообразия предсказываемых масок для успеха в zero-shot segmentation. Как это происходит? Модель прогоняют на изображении и сохраняют самые уверенные маски. Затем просят аннотаторов добавить все маски, какие еще имеются на изображении. Также обучается детектор на всех объектах с первой стадии на единственный класс “объект”. Результат предсказания детектора используется как промпт. Модель переучивают 5 раз. Среднее число масок на изображении поднимается до 77, собирают 5.9M масок 180k фото (всего 10.2М масок с двух стадий).
- Fully automatic stage. Качество модели поднялось, следовательно, теперь можно попробовать исключить из процесса человека. Но что делать с неоднозначными ситуациями (см. Рисунок 6)? Авторы предложили нам такой алгоритм:
- Разбиваем изображение на сетку 32×32. В центре каждой ячейки ставим по точке, передаем ее как промпт в модель;
- Для каждого промпта предсказываем три маски и отсекаем маски с низким confidence;
- Для оставшихся масок проверяем наличие хотя бы одной схожей маски от другого промпта. Удаляем не имеющие пары маски, применяем NMS для удаления дубликатов;
- Некоторые участки изображения дополнительно прогоняем отдельно для уточнения контуров небольших масок.
Итоговая модель переразметила все 11 миллионов фото, мы получили более 1.1 миллиарда масок. Важная деталь о масках из датасета: в итоговом наборе все маски размечены моделью. При этом, люди разметили всего 10.2M масок (около 1% от общего числа).
Однако авторы спешат убедить нас в качественной разметке. Для этого они дали аннотаторам подправить вручную 50k масок. Далее они сравнили IoU исходной маски и переразмеченной человеком. Для 94% пар IoU > 0.9, для 97% пар IoU > 0.75. Это соответствует согласованности разметки для людей-аннотаторов. Правда, стоит сделать поправку на то, что люди только поправляли маски, уже полученные из сети. Такая метрика не учитывает тех ситуаций, когда нужный объект не был выделен вовсе.
Интересные картинки
Итоги
Итак, что мы имеем:
- Аналог GPT-3 в задачах разметки для computer vision. Тоже может работать с кучей задач, но в некоторых из них проигрывает более узконаправленным инструментам.
- Пока нельзя оценить работу с текстом, без него иногда не отделяет какие-то семантически важные объекты от фона (см. Рисунок 11).
- Лицензия на модель — Apache 2.0, но датасет пока только для некоммерческого использования. Вопрос о лицензии на веса уже задали. Если разрешат коммерческое использование весов, тогда будем ждать плагины ко всем популярным инструментам разметки.
- Слишком медленные предсказания — 0.15 сек на A100 для ViT image encoder, 50ms на CPU для всего остального. Для realtime не подойдет, но сможет сильно помочь как инструмент интерактивной разметки. Особенно это актуально при близости размечаемых объектов по домену к фото в датасете (иначе см. Рисунок 11).
- Для решения проблемы со специфическими данными можно попробовать дообучить легкие компоненты (все, кроме image encoder’а). Пишите, если попробовали, интересно узнать результаты 🙂
Напоследок прикрепляем сайт с примерами, ссылками на demo, датасет и статью.