Назад
86

Адаптируем Visual-Language модель для детекции аномалий

86

Введение

Задача детекции аномалий (anomaly detection, AD) заключается в поиске необычных примеров, а именно данных, непохожих на их основной массив.

Детекция аномалий на медицинских изображениях — довольно сложная задача:

  • сами форматы данных (рентгены, разные виды КТ и МРТ, гистопатологические исследования), патологии и группы органов очень разнообразны;
  • размеченных и неразмеченных медицинских снимков в открытом доступе крайне мало (намного меньше, чем обычных фото на камеру).

Аномалии — «‎странные» данные. Их странность может быть вызвана как проблемами с аппаратом или неудачным положением пациента, так и различиями в строении организма, которые не встречались в обучающей выборке из-за своей редкости. Аномальные снимки стараются отфильтровать до подачи в модели, определяющие заболевания, поскольку результат работы может оказаться непредсказуемым.

Иногда детекцию аномалий применяют для нахождения конкретных, но редких патологий. Например, gossypiboma — хирургическое осложнение, вызванное забытым во время операции куском хлопка. В зависимости от места поражения это состояние выглядит по-разному, что дополнительно усложняет сбор датасета, но везде оно выглядит «‎странно».

Больше примеров разных паталогий можно найти в Benchmarks for Medical Anomaly Detection.

Рисунок 1. Примеры аномалий: образование в мозгу (сверху слева), ложка (сверху справа), сросшийся перелом пальца (снизу слева), gossypiboma (снизу справа)
Рисунок 2. Примеры аномалий сетчатки глаза на Retinal OCT
Рисунок 3. Примеры аномалий на тканях под микроскопом при гистопатологическом исследовании

Обычно для каждой модальности и анатомической области создают свой детектор аномалий. Пример: детектор аномалий на КТ (модальность, modality) грудной клетки (анатомическая область, anatomical region). Но сегодня мы рассмотрим универсальный детектор аномалий, работающий со всеми модальностями и частями тела.

В статье «Adapting Visual-Language Models for Generalizable Anomaly Detection in Medical Images» авторы решили использовать предобученный на обычных фото CLIP, умело адаптировать его под медицинские данные и добавить сегментацию аномалий для наилучшей интерпретируемости.

Пайплайн обучения

Итак, мы берём CLIP, предобученный на обычных фото (natural image domain) и задаче сопоставления текста и картинки. Но нам необходимо его применить для детекции аномалий на медицинских изображениях (medical image domain). Различия есть и в домене (domain gap), и в сути задачи (task gap).

Для решения этой проблемы авторы используют метод multi-level feature adaptation (MVFA). Адаптация проходит на разнообразной supervised-разметке: даже если считать идущие подряд 2D-срезы КТ и МРТ за отдельные семплы, то в датасете около 80’000 изображений. Это внушительное число для адаптации CLIP под задачу с данными из natural image domain. Но нам нужна адаптация под разные и значительно отличающиеся модальности, поэтому, на мой взгляд, данных всё-таки маловато 🙂.

Рисунок 4. Схема пайплайнов обучения и теста

Обратим внимание на часть схемы «‎Test». Она описывает, какую модель мы хотим получить с помощью multi-level feature adaptation.

Multi-level feature adaptation — метод для определения аномалии на изображении (normal или abnormal на схеме выше) и создания сегментационной маски для аномальных областей.

Модель при этом функционирует в двух режимах:

  • zero-shot — предсказание вероятности для классов normal и abnormal;
  • few-shot — сравнение целевого изображения с небольшим количеством нормальных примеров для повышения точности предсказаний.

Кроме того, модель должна быть универсальной, то есть способной работать на не использованных в обучении модальностях и анатомических областях (unseen modality and anatomical regions).

Train: multi-level feature adaptation

Рисунок 5. Схема обучения модели (a) и архитектура используемого адаптера (b)

Теперь давайте перейдём к сути. Anomaly detection мы решаем как две параллельные задачи:

  • классификация изображений на нормальные / аномальные (anomaly classification, AC, также называют anomaly detection);
  • выделение аномальной области сегментационной маской (anomaly segmantation, AS).

На вход CLIP visual encoder’а мы подаём изображение \( x \in \mathbb{R}^{h \times w \times 3} \) . Без добавления адаптеров (\( A_1-A_3, P \)) оно проходит через стадии \( S_1-S_4 \). В результате работы последнего этапа \( S_4 \) мы получаем тензор \( \mathcal{F}_{\text{vis}} \in \mathbb{R}^{G \times d} \), где \( G \) — количество патчей, на которые мы делим изображение, а \( d \) — размерность эмбеддинга соответствующего патча. Промежуточные между стадиями тензоры тоже имеют размерность \( \mathcal{F}_{\text{1,2,3}} \in \mathbb{R}^{G \times d} \). После этого применяем feature projector \( P \), который переводит \( \mathcal{F}_{\text{vis}} \) в вектор с размерностью, необходимой для подсчета лосса \( P(\mathcal{F}_{\text{vis}}) = I \). И только затем мы сравниваем полученный из visual encoder’а эмбеддинг \( I \) с эмбеддингом из CLIP text encoder’а.

Давайте вмешаемся в этот процесс: добавим адаптеры к каждой стадии \( S_1-S_3 \), обучим новый feature projector и учтём выход промежуточных стадий при подсчёте лосса.

Для начала разберёмся с адаптерами \( A_1-A_3 \) на примере адаптера \( A_1 \).

Рисунок 6. Реализация адаптера из официального репозитория. Вектор \( y \) передаётся в следующую стадию, вектор \( x \) используется для подсчёта лосса текущей стадии
Рисунок 7. Архитектура адаптера \( A_l, l \in \{1, 2, 3\} \) . Применяем к \( S_l \) два разных слоя ClipAdapter’а и получаем два различных тензора одного размера \( \mathcal{F}{\text{cls,l}}, \mathcal{F}{\text{seg,l}} \)

Применяя адаптер \( A_1 \) к тензору \( \mathcal{F}_{\text{1}} \), на выходе мы получаем три тензора:

  • \( \mathcal{F}_{\text{1}}^{*} \) передаётся в стадию S2 вместо \( \mathcal{F}_{\text{1}} \);
  • \( \mathcal{F}_{\text{cls,1}} \) используется для подсчёта классификационного binary cross-entropy loss;
  • \( \mathcal{F}_{\text{seg,1}} \) применяется для подсчёта сегментационного focal loss.

Лосс вычисляется на уровне каждой стадии отдельно:

$$\mathcal{L}_l = \lambda_1 \text{Dice}(\text{softmax}(F_{\text{seg},l} F_{\text{text}}^T), S) + \\ \lambda_2 \text{Focal}(\text{softmax}(F_{\text{seg},l} F_{\text{text}}^T), S) + \\ \lambda_3 \text{BCE}(\max_{h \times w}(\text{softmax}(F_{\text{cls},l} F_{\text{text}}^T)), c)$$

На формуле выше лосс на уровне стадии, в статье веса компонентов лосса

\( \lambda_1=\lambda_2=\lambda_3=1 \), \(S\) — таргетная сегментационная маска, \(c\) — класс normal или abnormal.

Итоговый лосс получаем как сумму лоссов по всем стадиям \(\mathcal{L}_{\text{adapt}} = \sum_{l=1}^{4} \mathcal{L}_{l}\).

Пока мы говорили только про адаптеры \( A_l, l \in \{1, 2, 3\} \) . А как мы получаем \( \mathcal{L}_{4} \) для feature projector \( P \)? Давайте снова посмотрим на схему из начала главы:

Рисунок 8. Схема обучения модели MVFA

Чтобы посчитать лосс для S4, нам нужно получить только \( \mathcal{F}_{\text{cls,4}} \) и \( \mathcal{F}_{\text{seg,4}} \). Для этого тензор \( \mathcal{F}_{\text{vis}} \in \mathbb{R}^{G \times d} \) умножаем на матрицы \( W_{\text{cls}} \) и \( W_{\text{seg}} \) соответственно, то есть \( \mathcal{F}_{\text{cls},4} = \mathcal{F}_{\text{vis}}^T W_{\text{cls}} \) . Аналогично и для сегментации.

Рисунок 9. Конструктор текстового промпта, где [o] — название органа или части тела, для которых мы ищем аномалии

Этот конструктор мы используем для получения разнообразных текстовых эмбеддингов \( \mathcal{F}_{\text{text}} \).

Таким образом, мы обсудили все изменения в архитектуре и процессе обучения. Как же происходит инференс в zero-shot и few-shot режимах?

Test: multi-level feature adaptation

Инференс состоит из двух веток — zero-shot и few-shot. Обе одновременно используются в предсказании. Zero-shot работает с текстом и изображением, для которого мы ищем аномалии (test image), а few-shot — с test image и примерами нормальных изображений (referenced normal images).

Рисунок 10. Схема валидации модели MVFA

Для начала рассмотрим zero-shot ветку:

  1. Прогоняем изображение через image encoder с адаптерами, получаем все тензоры \( \mathcal{F}_{\text{cls,l}}, \mathcal{F}_{\text{seg,l}}, \text{ } l \in \{1, 2, 3, 4\} \)
  2. Прогоняем неизменённый text encoder и получаем эмбеддинг текста \( \mathcal{F}_{\text{text}} \)
  3. Находим косинусное расстояние между эмбеддингами c 1 и 2 шагов
  4. Получаем итоговый предсказанный класс \( c_\text{zero} \) и маску аномальной области \( S_\text{zero} \), усредняя предсказания по всем 4 уровням. В формуле ниже \( \max_{G} \) — максимум по всем патчам, \( BI \) — bilinear interpolation до размера входного изображения: $$ c_{\text{zero}} = \frac{1}{4} \sum_{l=1}^{4} \max_G (\text{softmax}(F_{\text{cls},l} F_{\text{text}}^T)) \\ S_{\text{zero}} = \frac{1}{4} \sum_{l=1}^{4} \text{BI}(\text{softmax}(F_{\text{seg},l} F_{\text{text}}^T)). $$

В итоге каждая «‎голова» участвует в предсказании. Это напоминает deep supervision, который активно применяется в сегментации, но тут мы усредняем маски до подсчёта лосса вместо вычисления взвешенной суммы лосса по разным маскам.

Рисунок 11. Схема валидации модели MVFA

Теперь изучим few-shot ветку:

  1. Прогоняем изображение через image encoder с адаптерами, получаем все тензоры \(\mathcal{F}_{\text{cls,l}}, \mathcal{F}_{\text{seg,l}}, \text{l} \in \{1, 2, 3, 4\}\)
  2. Прогоняем примеры нормальных изображений (referenced normal images) через ту же сетку, их эмбеддинги \( \mathcal{F}_{\text{cls,l}}, \mathcal{F}_{\text{seg,l}}, \text{l} \in \{1, 2, 3, 4\} \) сохраняем в список (memory bank) \( \mathcal{G} \), держим в памяти, что все примеры нормальные, но \( \mathcal{F}_{\text{seg,l}} \) для них необязательно пустые;
  3. Находим косинусное расстояние между эмбеддингами с 1 и 2 шагов;
  4. Получаем итоговый предсказанный класс \( c_\text{few} \) и маску аномальной области \( S_\text{few} \) через нахождение ближайшего соседа из memory bank и усреднение предсказания по всем 4 уровням. В формуле ниже \( \max_{G} \) — максимум по всем патчам, \( \text{BI} \)— bilinear interpolation до размера входного изображения, \( \text{Dist} \) — косинусное расстояние между эмбеддингами, \( \mathcal{G} \) — memory bank:

$$c_{\text{few}} = \frac{1}{4} \sum_{l=1}^{4} \max_{G} \left( \min_{m \in \mathcal{G}} \text{Dist}(F_{\text{cls},l}, m) \right), \\ S_{\text{few}} = \frac{1}{4} \sum_{l=1}^{4} \text{BI} \left( \min_{m \in \mathcal{G}} \text{Dist}(F_{\text{seg},l}, m) \right).$$

Затем мы объединяем предсказания двух веток, складывая их с весами (в статье использовались \( \beta_1 = \beta_2 = 0.5 \)):

$$ c_{\text{pred}} = \beta_1 c_{\text{zero}} + \beta_2 c_{\text{few}}, \\
S_{\text{pred}} = \beta_1 S_{\text{zero}} + \beta_2 S_{\text{few}}.$$

В итоге получаем предсказание, которое учитывает и текст, и схожесть эмбеддингов на разных уровнях с эмбеддингами референсных нормальных изображений. При этом мы можем обойтись без дополнительных изображений, если прогоним только zero-shot ветку.

Результаты

Подход демонстрирует качество SOTA на бенчмарке BMAD, обходя WinCLIP и APRIL-GAN. Обе работы также используют CLIP, но не multi-level подход. Именно в нём авторы видят причину улучшения метрик как при медицинском домене, так и при out-domain валидации на датасете MVTec.

Рисунок 12. Сравнение разных zero-shot подходов на BMAD и few-shot на MVTec

В статье также представили варианты архитектуры, не вошедшие в финальную версию, и сравнили их следующим образом:

  1. Single-adapter vs dual-adapter — использовать одну «‎голову» для предсказания маски и класса или две. Выбрали две.
Рисунок 13. Сравнение архитектуры (слева) и результатов для разных модальностей (справа)

2. Projectors vs adapters. Первые только предсказывают маску и класс, вторые меняют тензор, приходящий на вход следующему уровню. Выбрали adapters.

Рисунок 14. Сравнение архитектуры (сверху) и результатов для разных модальностей (снизу)

Авторы показали на примерах, что усреднение масок с разных уровней улучшает сегментацию:

Рисунок 15. Усреднение предсказаний по уровням, улучшающее сегментацию

Заключение

Итак, сегодня мы рассмотрели интересный подход к детекции аномалий на медицинских изображениях, который использует адаптацию предобученного CLIP. Основное отличие от прошлых работ — применение метода multi-level feature adaptation. Он позволяет модели эффективно адаптироваться к новым задачам и доменам за счёт похожего на deep supervision использования параметров с промежуточных слоёв сети. Мы подробно рассмотрели его отдельно для обучения и тестирования.

Ещё один примечательный факт — одновременное использование zero-shot и few-shot веток при предсказании.

В заключение отметим, адаптация предобученных визуально-языковых моделей для медицины упрощает разработку универсальных моделей, которых очень не хватает из-за сложности и дороговизны разметки. Развитие таких подходов ускоряет диагностику, а [ранняя диагностика спасает жизни](https://typeset.io/search?q=How many lives does early diagnosis save?).

Полезные ссылки

📌 За время написания поста выпустили SAM 2, который очень подходит для адаптации к медицинским 3D-данным, потому что уже работает с последовательностями изображений. Ждём новых статей 😎

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!