Адаптация SAM под 3D медицинские данные
Введение
Segment Anything Model (SAM) — это модель для интерактивной сегментации изображений. Она позволяет быстро и точно выделять на них объекты при помощи подсказок пользователя.
Обычно подсказки представляют собой выделение искомого объекта рамкой или нажатие на места с объектом / без него.
В предыдущих постах мы уже делали обзоры SAM и медицинских данных — советуем с ними тоже познакомиться 🙂:
- Задача интерактивной сегментации и что было до SAM
- Обзор архитектуры, датасета и ограничений SAM
- Введение в медицинские данные
Давайте для начала поймем, какие есть проблемы применения оригинального SAM к трехмерным медицинским данным.
Проблема 1: медицинский домен
В обзоре на SAM мы как раз рассмотрели ограничения модели. Поскольку датасет для ее обучения содержал в основном фото людей и предметов, на других доменах просело качество. В статье «When SAM Meets Medical Images: An Investigation of Segment Anything Model (SAM) on Multi-phase Liver Tumor Segmentation» авторы исследуют использование SAM для упрощения разметки образований в печени. Хороших результатов за пару кликов не получить — их нужно скорее 10-20. Такие проблемы наблюдаются в отношении многих объектов, слабо контрастирующих с окружающим пространством.
Эту проблему решает дообучение на медицинских данных. Во многих работах (1, 2) берут только mask decoder и prompt encoder, не трогая image encoder (ViT), поскольку его дообучение требует очень много ресурсов.
Решаем проблему 1: медицинский домен
Давайте доучим модель на медицинских данных. Mask decoder и prompt encoder можно доучить целиком — они сильно меньше image encoder.
Для обучения image encoder используем технику, идейно похожую на очень популярную LoRA.
Как работает LoRA?
Мы хотим дообучить большую модель с полносвязными линейными слоями, но при этом обучить минимальное число параметров.
Возьмем полносвязный линейный слой без функции активации:
\( h = Wx;\ W \in R^{d \times k} \)
где \( x \) — входной вектор, \( W \) — уже предобученная матрица весов.
Дообучить модель нам нужно без изменения веса W, поэтому новый линейный слой h’ записываем как:
\( h’ = W’x = (W + \Delta W)x = h + \Delta Wx;\ \Delta W \in R^{d \times k} \)
где \( ΔW \) — новая обучаемая матрица весов, \( W \) — уже предобученная и не обучаемая.
При этом новый \( h’ \) отличается от старого \( h \) на \( ΔWx \). Это можно интерпретировать как результат работы еще одного отдельного линейного слоя, который предсказывает разницу между таргетом и \( h \).
Но в чем же выгода? Для сложения \( W \) и \( ΔW \) нам нужно, чтобы эти матрицы имели одинаковый шейп. Это значит, что если мы просто будем учить \( ΔW \) вместо \( W \) — число обучаемых параметров не поменяется.
Тогда давайте представим матрицу \( ΔW \) как произведение двух матриц более низкого ранга. Следовательно, пространство возможных значений матрицы \( ΔW \) уменьшится, как и число обучаемых параметров.
\( h’ = h + \Delta Wx = h + BAx;\ A \in R^{r \times k},\ B \in R^{d \times r},\ r \ll min(d, k) \)
Все операции между \( A \), \( B \) и \( W \) линейные, поэтому перед инференсом их можно объединить в одну матрицу того же размера, что и изначальная \( W \in R^{d \times k} \). Тогда во время самого инференса мы не потеряем ни секунды из-за дополнительных матриц!
Adapter из статьи «Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation» отличается от LoRA только тем, что между двумя обучаемыми матрицами в LoRA ничего нет, а в Adapter добавили ReLU и skip connection. Мы не объединим Up и Down матрицы на инференсе в одну, но в статье такой момент никак не упоминается. Возможно, это решилось в результате экспериментов. Может быть, на это повлияла следующая проблема, о которой мы будем рассказывать. Если вы встречали другие low-rank adaptaters с нелинейностью между матрицами — кидайте в комменты к посту в телегe 🙂
К каждому блоку ViT мы добавим Adapter и будем учить только его. Это поможет извлекать более полезные эмбеддинги из медицинских данных, но не потребует много вычислений.
Проблема 2: трехмерные связи
Поскольку оригинальный SAM учился на 2D данных, первые попытки его использовать тоже работали с 2D. Для этого 3D изображение нарезалось на 2D срезы, каждый из которых отдельно прогонялся через модель, результаты конкатенировались. При этом прогоны модели на соседних слайсах не обменивались информацией.
Ниже показана разница между SAM, SAM-Med2D (SAM, доученный на срезах КТ и МРТ) и SAM-Med3D (не является адаптацией SAM на трехмерные данные, это с нуля обученная на них модель с 3D свертками). В столбце 3D View особенно хорошо видно — даже если модель неплохо справляется с большинством срезов, ее неспособность “заглядывать” в соседние срезы приводит к ошибкам.
Здесь авторы SAM-Med3D:
- потратили много сил на сбор огромного и разнообразного закрытого датасета с КТ и МРТ (в 20 раз больше датасета, на котором учился небезызвестный Totalsegmentator, но меньше датасета оригинального SAM);
- обучили модель на 8 NVIDIA Tesla A100 GPU (8 x 80GB) — меньше, чем у SAM, но все равно достаточно много.
У этого подхода есть два очевидных минуса: долго и дорого. Отсюда возникает вопрос: можно ли как-то заставить SAM учитывать трехмерные связи, но сделать это проще и не сильно хуже?
Решаем проблему 2: трехмерные связи
Для решения проблемы авторы разделили внимание на две ветки:
- Space branch — отвечает за пространственные связи на срезе;
- Depth branch — отвечает за связи между разными срезами.
Space branch получает на вход тензор размером (D × N × L), где D — число срезов в семпле, N — число эмбеддингов, L — длина эмбеддинга. Операции идут по первому измерению (см. реализацию Adapter выше), поэтому Adapter после замороженного multi-head attention учится искать пространственные связи в эмбеддингах с размером (N × L).
Depth branch получает на вход тензор размером (N × D × L), то есть транспонированный вход Space branch. Он проходит через тот же замороженный multi-head attention, но теперь внимание применяется к (D × L) эмбеддингам, чтобы выучить связи между разными срезами. Результаты из Depth branch транспонируются обратно в (D × N × L) и складываются с результатами Space branch.
На рисунке выше мы видим forward метод, включающий обе ветки. Функция rearrange из библиотеки einops позволяет более читаемо применять серии операций над осями. Space_Adapter и Depth_Adapter — объекты класса Adapter.
Результаты
Мы видим, что результаты Med-SA (его мы сейчас и рассматриваем, в какой-то момент закончились незанятные имена из “med” и “SAM” 🙂) лучше, чем у стандартного SAM в режиме интерактивной сегментации. Тот же результат показан и в таблице снизу.
В результате сравнения на нескольких датасетах Med-SA превосходит как оригинальный SAM, так и доученный на медицинских данных MedSAM. Он оказывается лучше популярных сегментационных моделей (например, Swin-UNetr, TransUNet и nnUnet). Но это не совсем честное сравнение, поскольку такие модели не поддерживают интерактивную сегментацию и не получают “подсказок” точками или боксами. Зато наличие обычных сегментационных моделей в таблице позволяет оценить разницу между SAM, MedSAM и Med-SA.
Ссылки
- Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation — рассматриваемая статья. В ней есть не разобранная нами техника HypAdpt, которая улучшает prompt encoder, но не относится к адаптации модели для работы в 3D.
- Customized Segment Anything Model for Medical Image Segmentation — схожая статья, засабмиченная на arxiv с разницей в один день. Работает на 2D слайсах + с LoRA вместо Adapter.
- Reinventing 2D Convolutions for 3D Images — схожая по духу работа, где предобученную сверточную 2D модель переводят в 3D.