Адаптируем Visual-Language модель для детекции аномалий
Введение
Задача детекции аномалий (anomaly detection, AD) заключается в поиске необычных примеров, а именно данных, непохожих на их основной массив.
Детекция аномалий на медицинских изображениях — довольно сложная задача:
- сами форматы данных (рентгены, разные виды КТ и МРТ, гистопатологические исследования), патологии и группы органов очень разнообразны;
- размеченных и неразмеченных медицинских снимков в открытом доступе крайне мало (намного меньше, чем обычных фото на камеру).
Аномалии — «странные» данные. Их странность может быть вызвана как проблемами с аппаратом или неудачным положением пациента, так и различиями в строении организма, которые не встречались в обучающей выборке из-за своей редкости. Аномальные снимки стараются отфильтровать до подачи в модели, определяющие заболевания, поскольку результат работы может оказаться непредсказуемым.
Иногда детекцию аномалий применяют для нахождения конкретных, но редких патологий. Например, gossypiboma — хирургическое осложнение, вызванное забытым во время операции куском хлопка. В зависимости от места поражения это состояние выглядит по-разному, что дополнительно усложняет сбор датасета, но везде оно выглядит «странно».
Больше примеров разных паталогий можно найти в Benchmarks for Medical Anomaly Detection.
Обычно для каждой модальности и анатомической области создают свой детектор аномалий. Пример: детектор аномалий на КТ (модальность, modality) грудной клетки (анатомическая область, anatomical region). Но сегодня мы рассмотрим универсальный детектор аномалий, работающий со всеми модальностями и частями тела.
В статье «Adapting Visual-Language Models for Generalizable Anomaly Detection in Medical Images» авторы решили использовать предобученный на обычных фото CLIP, умело адаптировать его под медицинские данные и добавить сегментацию аномалий для наилучшей интерпретируемости.
Пайплайн обучения
Итак, мы берём CLIP, предобученный на обычных фото (natural image domain) и задаче сопоставления текста и картинки. Но нам необходимо его применить для детекции аномалий на медицинских изображениях (medical image domain). Различия есть и в домене (domain gap), и в сути задачи (task gap).
Для решения этой проблемы авторы используют метод multi-level feature adaptation (MVFA). Адаптация проходит на разнообразной supervised-разметке: даже если считать идущие подряд 2D-срезы КТ и МРТ за отдельные семплы, то в датасете около 80’000 изображений. Это внушительное число для адаптации CLIP под задачу с данными из natural image domain. Но нам нужна адаптация под разные и значительно отличающиеся модальности, поэтому, на мой взгляд, данных всё-таки маловато 🙂.
Обратим внимание на часть схемы «Test». Она описывает, какую модель мы хотим получить с помощью multi-level feature adaptation.
Multi-level feature adaptation — метод для определения аномалии на изображении (normal или abnormal на схеме выше) и создания сегментационной маски для аномальных областей.
Модель при этом функционирует в двух режимах:
- zero-shot — предсказание вероятности для классов normal и abnormal;
- few-shot — сравнение целевого изображения с небольшим количеством нормальных примеров для повышения точности предсказаний.
Кроме того, модель должна быть универсальной, то есть способной работать на не использованных в обучении модальностях и анатомических областях (unseen modality and anatomical regions).
Train: multi-level feature adaptation
Рисунок 5. Схема обучения модели (a) и архитектура используемого адаптера (b)
Теперь давайте перейдём к сути. Anomaly detection мы решаем как две параллельные задачи:
- классификация изображений на нормальные / аномальные (anomaly classification, AC, также называют anomaly detection);
- выделение аномальной области сегментационной маской (anomaly segmantation, AS).
На вход CLIP visual encoder’а мы подаём изображение
Давайте вмешаемся в этот процесс: добавим адаптеры к каждой стадии \( S_1-S_3 \), обучим новый feature projector и учтём выход промежуточных стадий при подсчёте лосса.
Для начала разберёмся с адаптерами \( A_1-A_3 \) на примере адаптера \( A_1 \).
Применяя адаптер \( A_1 \) к тензору \( \mathcal{F}_{\text{1}} \), на выходе мы получаем три тензора:
- \( \mathcal{F}_{\text{1}}^{*} \) передаётся в стадию S2 вместо \( \mathcal{F}_{\text{1}} \);
- \( \mathcal{F}_{\text{cls,1}} \) используется для подсчёта классификационного binary cross-entropy loss;
- \( \mathcal{F}_{\text{seg,1}} \) применяется для подсчёта сегментационного focal loss.
Лосс вычисляется на уровне каждой стадии отдельно:
$$\mathcal{L}_l = \lambda_1 \text{Dice}(\text{softmax}(F_{\text{seg},l} F_{\text{text}}^T), S) + \\ \lambda_2 \text{Focal}(\text{softmax}(F_{\text{seg},l} F_{\text{text}}^T), S) + \\ \lambda_3 \text{BCE}(\max_{h \times w}(\text{softmax}(F_{\text{cls},l} F_{\text{text}}^T)), c)$$
На формуле выше лосс на уровне стадии, в статье веса компонентов лосса
\( \lambda_1=\lambda_2=\lambda_3=1 \), \(S\) — таргетная сегментационная маска, \(c\) — класс normal или abnormal.
Итоговый лосс получаем как сумму лоссов по всем стадиям \(\mathcal{L}_{\text{adapt}} = \sum_{l=1}^{4} \mathcal{L}_{l}\).
Пока мы говорили только про адаптеры \( A_l, l \in \{1, 2, 3\} \) . А как мы получаем \( \mathcal{L}_{4} \) для feature projector \( P \)? Давайте снова посмотрим на схему из начала главы:
Чтобы посчитать лосс для S4, нам нужно получить только \( \mathcal{F}_{\text{cls,4}} \) и \( \mathcal{F}_{\text{seg,4}} \). Для этого тензор \( \mathcal{F}_{\text{vis}} \in \mathbb{R}^{G \times d} \) умножаем на матрицы \( W_{\text{cls}} \) и \( W_{\text{seg}} \) соответственно, то есть \( \mathcal{F}_{\text{cls},4} = \mathcal{F}_{\text{vis}}^T W_{\text{cls}} \) . Аналогично и для сегментации.
Этот конструктор мы используем для получения разнообразных текстовых эмбеддингов \( \mathcal{F}_{\text{text}} \).
Таким образом, мы обсудили все изменения в архитектуре и процессе обучения. Как же происходит инференс в zero-shot и few-shot режимах?
Test: multi-level feature adaptation
Инференс состоит из двух веток — zero-shot и few-shot. Обе одновременно используются в предсказании. Zero-shot работает с текстом и изображением, для которого мы ищем аномалии (test image), а few-shot — с test image и примерами нормальных изображений (referenced normal images).
Для начала рассмотрим zero-shot ветку:
- Прогоняем изображение через image encoder с адаптерами, получаем все тензоры \( \mathcal{F}_{\text{cls,l}}, \mathcal{F}_{\text{seg,l}}, \text{ } l \in \{1, 2, 3, 4\} \)
- Прогоняем неизменённый text encoder и получаем эмбеддинг текста \( \mathcal{F}_{\text{text}} \)
- Находим косинусное расстояние между эмбеддингами c 1 и 2 шагов
- Получаем итоговый предсказанный класс \( c_\text{zero} \) и маску аномальной области \( S_\text{zero} \), усредняя предсказания по всем 4 уровням. В формуле ниже \( \max_{G} \) — максимум по всем патчам, \( BI \) — bilinear interpolation до размера входного изображения: $$ c_{\text{zero}} = \frac{1}{4} \sum_{l=1}^{4} \max_G (\text{softmax}(F_{\text{cls},l} F_{\text{text}}^T)) \\ S_{\text{zero}} = \frac{1}{4} \sum_{l=1}^{4} \text{BI}(\text{softmax}(F_{\text{seg},l} F_{\text{text}}^T)). $$
В итоге каждая «голова» участвует в предсказании. Это напоминает deep supervision, который активно применяется в сегментации, но тут мы усредняем маски до подсчёта лосса вместо вычисления взвешенной суммы лосса по разным маскам.
Теперь изучим few-shot ветку:
- Прогоняем изображение через image encoder с адаптерами, получаем все тензоры \(\mathcal{F}_{\text{cls,l}}, \mathcal{F}_{\text{seg,l}}, \text{l} \in \{1, 2, 3, 4\}\)
- Прогоняем примеры нормальных изображений (referenced normal images) через ту же сетку, их эмбеддинги \( \mathcal{F}_{\text{cls,l}}, \mathcal{F}_{\text{seg,l}}, \text{l} \in \{1, 2, 3, 4\} \) сохраняем в список (memory bank) \( \mathcal{G} \), держим в памяти, что все примеры нормальные, но \( \mathcal{F}_{\text{seg,l}} \) для них необязательно пустые;
- Находим косинусное расстояние между эмбеддингами с 1 и 2 шагов;
- Получаем итоговый предсказанный класс \( c_\text{few} \) и маску аномальной области \( S_\text{few} \) через нахождение ближайшего соседа из memory bank и усреднение предсказания по всем 4 уровням. В формуле ниже \( \max_{G} \) — максимум по всем патчам, \( \text{BI} \)— bilinear interpolation до размера входного изображения, \( \text{Dist} \) — косинусное расстояние между эмбеддингами, \( \mathcal{G} \) — memory bank:
$$c_{\text{few}} = \frac{1}{4} \sum_{l=1}^{4} \max_{G} \left( \min_{m \in \mathcal{G}} \text{Dist}(F_{\text{cls},l}, m) \right), \\ S_{\text{few}} = \frac{1}{4} \sum_{l=1}^{4} \text{BI} \left( \min_{m \in \mathcal{G}} \text{Dist}(F_{\text{seg},l}, m) \right).$$
Затем мы объединяем предсказания двух веток, складывая их с весами (в статье использовались \( \beta_1 = \beta_2 = 0.5 \)):
$$ c_{\text{pred}} = \beta_1 c_{\text{zero}} + \beta_2 c_{\text{few}}, \\
S_{\text{pred}} = \beta_1 S_{\text{zero}} + \beta_2 S_{\text{few}}.$$
В итоге получаем предсказание, которое учитывает и текст, и схожесть эмбеддингов на разных уровнях с эмбеддингами референсных нормальных изображений. При этом мы можем обойтись без дополнительных изображений, если прогоним только zero-shot ветку.
Результаты
Подход демонстрирует качество SOTA на бенчмарке BMAD, обходя WinCLIP и APRIL-GAN. Обе работы также используют CLIP, но не multi-level подход. Именно в нём авторы видят причину улучшения метрик как при медицинском домене, так и при out-domain валидации на датасете MVTec.
В статье также представили варианты архитектуры, не вошедшие в финальную версию, и сравнили их следующим образом:
- Single-adapter vs dual-adapter — использовать одну «голову» для предсказания маски и класса или две. Выбрали две.
2. Projectors vs adapters. Первые только предсказывают маску и класс, вторые меняют тензор, приходящий на вход следующему уровню. Выбрали adapters.
Авторы показали на примерах, что усреднение масок с разных уровней улучшает сегментацию:
Заключение
Итак, сегодня мы рассмотрели интересный подход к детекции аномалий на медицинских изображениях, который использует адаптацию предобученного CLIP. Основное отличие от прошлых работ — применение метода multi-level feature adaptation. Он позволяет модели эффективно адаптироваться к новым задачам и доменам за счёт похожего на deep supervision использования параметров с промежуточных слоёв сети. Мы подробно рассмотрели его отдельно для обучения и тестирования.
Ещё один примечательный факт — одновременное использование zero-shot и few-shot веток при предсказании.
В заключение отметим, адаптация предобученных визуально-языковых моделей для медицины упрощает разработку универсальных моделей, которых очень не хватает из-за сложности и дороговизны разметки. Развитие таких подходов ускоряет диагностику, а [ранняя диагностика спасает жизни](https://typeset.io/search?q=How many lives does early diagnosis save?).
Полезные ссылки
📌 За время написания поста выпустили SAM 2, который очень подходит для адаптации к медицинским 3D-данным, потому что уже работает с последовательностями изображений. Ждём новых статей 😎