Назад
162

Self-supervised learning

162

Идея

В этом посте мы затронем хайповую сейчас тему — self-supervised learning; далее — SSL. Постараемся разобраться с тем, что это такое и зачем оно нужно.

Давайте для начала вспомним, что в типичном supervised-learning сетапе без ограничения общности у нас есть некоторые сырые данные. Их мы передаем экспертам доменной области на разметку. Затем формируется размеченный набор данных, на котором уже в свою очередь обучается модель. Цель supervised-learning’a — получить модель, способную на приемлемом уровне решать какую-то целевую задачу.

Целью же SSL является получение модели, способной хорошо решать задачу извлечения признаков из входных данных. Сетап SSL-задачи выглядит таким образом: как и в supervised аналоге у нас есть сырые данные, но мы не передаем их на разметку экспертам доменной области, а генерируем каким-то дешевым способом разметку “на лету”, прямо из этих сырых данных. На такой дешево сгенерированной разметке решается некоторая задача. А результат — модель, которая научилась извлекать хорошие признаки.

Схематически эти два сетапа представлены на рисунке 1.

Рисунок 1. Supervised и self-supervised задачи

Зачем нам нужна такая модель? Чаще всего ее веса используются в качестве точки начальной инициализации при решении уже целевой задачи (например, вместо весов с ImageNet) (см. рисунок 2). То есть модель просто дотюнивают на размеченном экспертами наборе данных под целевую задачу.

Рисунок 2. Схема использования SSL-модели для решения целевой задачи

И при обычном supervised, и при SSL-подходе мы размечаем данные при помощи экспертов, затем подгоняем модель и проводим цикл обучения. В чем же тогда их различие

Объяснить это можно так:

  1. Зачастую случается, что веса модели, предобученной на ImageNet’e, оказываются субоптимальными. Происходить это может из-за сильного отличия домена ImageNet’a от вашего целевого домена (например, вы работаете с рентгенограммами, а в ImageNet’e такого домена нет). Сеть, предобученная на ImageNet’e, извлекает признаки, но эти признаки не будут очень релевантными. В случае же использования SSL сеть учится извлекать признаки, характерные конкретно вашим данным.
  2. Согласно результатам статей, часто на бенчмарк наборах данных типа ImageNet, PASCAL VOC, MS COCO и других схема с SSL-предобучением позволяет достичь бОльшего значения целевой метрики.
  3. Еще в статьях говорится о том, что при использовании SSL-схемы часто можно добиться того же показателя целевой метрики даже при обращении к гораздо меньшему объему данных. Получается, если раньше вам надо было 10000 изображений для получения целевой метрики 0.8, то с использованием SSL схемы вам, возможно, хватит и 1000. Кроме того, наряду с меньшим объемом данных вам, весьма вероятно, понадобится и меньшее число эпох для достижения нужного показателя целевой метрики.
  4. Модель, единожды предобученная SSL-задачей, может быть пере-использована как точка начальной инициализации при решении других задач на таком же домене данных.

Тут можно остановиться и вспомнить о статье, авторы которой показывают способность сверточной сети даже со случайными весами из-за своей архитектуры вытаскивать из изображений много полезной информации без каких-либо данных. Имеется в виду, что структура самой нейронной сети сама по себе уже содержит неплохие представления об окружающем мире.

Итак, мы теперь пониманием, что такое SSL и для чего оно нам нужно. А сейчас давайте поглубже изучим терминологию.

Терминология

Нам понадобится 3 основных термина:

  • предварительная задача (pretext task) — сама задача SSL с искусственной разметкой. Именно ее решает модель, чтобы научиться извлекать хорошие признаки из данных;
  • псевдо-разметка — та самая дешевая разметка, происходящая без участия человека;
  • последующая задача (downstream task) — задача, по которой проверяют качество выученных признаков. Как правило, это простые модели (KNN, LinReg, LogReg и другие), обучающиеся на извлекаемых с помощью SSL-модели признаках. Иногда бывает и так, что модель не фиксируется и дообучается целиком.

Давайте еще раз посмотрим на схему SSL-пайплайна на примере ImageNet для визуализации теории:

Рисунок 3. Схема SSL-пайплайна

На рисунке 3 синий заштрихованный прямоугольник — решение предварительной SSL-задачи, на которой модель учится извлекать хорошие признаки; бордовый заштрихованный прямоугольник — оценка того, насколько хорошие признаки модель научилась извлекать.

Процесс проверки качества извлеченных признаков подробно представлен на схеме, предварительная задача же не так очевидна и требует дополнительного пояснения. Поэтому давайте сейчас поподробнее рассмотрим ее типы.

Виды предварительных задач

SSL появился довольно давно (первое упоминание автор нашел в статье Хинтона аж 1992 года), но основной бум в этом разделе глубокого обучения случился в условных 2020 годах.

До современных SOTA-методов в исследованиях упор делался в основном на сам процесс генерации псевдо-разметки. Для изображений были следующие варианты:

  1. Задача восстановления информации: ключевая идея в удалении какой-либо информации из изображения и тренировка сети для восстановления этой удаленной информации. Например, можно удалить цвет из изображения и превратить его в монохромное, а затем обучить сеть восстановлению цвета (статья, выступление, код). Авторы показывают, что такая предварительная задача неплохо подходит в том случае, если наша целевая задача является задачей семантической сегментации, поскольку для восстановления цвета сцены необходимо понимать семантику изображения и границы объектов на этом изображении. Еще есть идея маскирования части изображения и обучение сети восстановлению этой части (ранний вариант, современная SOTA).
  2. Изучение пространственного контекста: основной смысл в обучении сети пониманию относительных позиций и ориентации объектов на сцене. Например, в статье авторы генерировали разметку простым поворотом изображений и в качестве псевдо-разметки использовали класс поворота (0 — поворот на 0 градусов, 1 — поворот на 90 градусов, 2 — на 180 градусов и так далее). А в этой статье авторы предложили предсказывать относительную позицию двух случайных патчей из изображения.
  3. Группирование похожих изображений вместе: идея обучения сети группировать схожие изображения и, следовательно, извлекать хорошие признаки. Есть целое семейство методов, в которых для этого адаптирован алгоритм K-Means. Например, раз и два.
  4. Генеративные модели: попытка адаптировать веса предобученных автоэнкодеров и GAN’ов для их использования в процессе решения целевой задачи.
  5. Multi-view invariance методы: в основе лежит идея о том, что два вектора-признака, полученных из одного изображения путем применения различных трансформаций (чаще всего аугментаций) должны быть похожими (contrastive learning).

Подробнее про некоторые из таких “старых” методов можно прочитать в посте Дьяконова.

В современных SOTA-методах процедура генерации псевдо-разметки в основном сводится к двум вариантам:

  1. Multi-view invariance: здесь псевдо-разметка формируется по принципу contrastive learning, то есть позитивными примерами являются два по-разному аугментированных варианта одного и того же изображения, а негативными — аугментированные варианты другого изображения. Пример одного из таких методов можно посмотреть на рисунке 4 (а статью — тут).
Рисунок 4. SimCLR

2. Задача восстановления информации: из изображения удаляются некоторые патчи, а сеть учится восстанавливать эту информацию. Пример представлен на рисунке 5 (статью можно найти тут). В таком случае псевдо-разметка состоит из пар (X, Y), где Y — исходное изображение, а Х — маскированная версия этого изображения, в котором удалена часть информации.

Рисунок 5. Masked autoencoders (MAE)

У нас всего в основном две процедуры генерации псевдо-разметки. Чем же тогда различаются современные методы? Архитектурами, функциями потерь и различными хитростями, используемыми при обучении. В статье A Cookbook of Self-Supervised Learning авторы выделяют четыре больших семейства современных методов SSL:

  1. Методы, основанные на metric learning;
  2. Методы, основанные на self-distillation;
  3. Методы, основанные на каноническом корреляционном анализе;
  4. Методы, основанные на задачах восстановления информации.

Про каждый пункт можно и нужно рассказывать подробно, поэтому мы оставим этот обзор для последующих статей серии 🙂

Заключение

Эта статья обзорная, но нам хотелось, чтобы после ее прочтения вам запомнилось:

  1. Что такое SSL;
  2. Терминология;
  3. Общая схема работы SSL-пайплайна;
  4. Где SSL может пригодиться;
  5. Какими бывают предварительные задачи и методы получения псевдо-разметки.

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!