Self-supervised learning
Идея
В этом посте мы затронем хайповую сейчас тему — self-supervised learning
; далее — SSL
. Постараемся разобраться с тем, что это такое и зачем оно нужно.
Давайте для начала вспомним, что в типичном supervised-learning сетапе без ограничения общности у нас есть некоторые сырые данные. Их мы передаем экспертам доменной области на разметку. Затем формируется размеченный набор данных, на котором уже в свою очередь обучается модель. Цель supervised-learning’a — получить модель, способную на приемлемом уровне решать какую-то целевую задачу.
Целью же SSL является получение модели, способной хорошо решать задачу извлечения признаков из входных данных. Сетап SSL-задачи выглядит таким образом: как и в supervised аналоге у нас есть сырые данные, но мы не передаем их на разметку экспертам доменной области, а генерируем каким-то дешевым способом разметку “на лету”, прямо из этих сырых данных. На такой дешево сгенерированной разметке решается некоторая задача. А результат — модель, которая научилась извлекать хорошие признаки.
Схематически эти два сетапа представлены на рисунке 1.
Зачем нам нужна такая модель? Чаще всего ее веса используются в качестве точки начальной инициализации при решении уже целевой задачи (например, вместо весов с ImageNet) (см. рисунок 2). То есть модель просто дотюнивают на размеченном экспертами наборе данных под целевую задачу.
И при обычном supervised, и при SSL-подходе мы размечаем данные при помощи экспертов, затем подгоняем модель и проводим цикл обучения. В чем же тогда их различие
Объяснить это можно так:
- Зачастую случается, что веса модели, предобученной на ImageNet’e, оказываются субоптимальными. Происходить это может из-за сильного отличия домена ImageNet’a от вашего целевого домена (например, вы работаете с рентгенограммами, а в ImageNet’e такого домена нет). Сеть, предобученная на ImageNet’e, извлекает признаки, но эти признаки не будут очень релевантными. В случае же использования SSL сеть учится извлекать признаки, характерные конкретно вашим данным.
- Согласно результатам статей, часто на бенчмарк наборах данных типа ImageNet, PASCAL VOC, MS COCO и других схема с SSL-предобучением позволяет достичь бОльшего значения целевой метрики.
- Еще в статьях говорится о том, что при использовании SSL-схемы часто можно добиться того же показателя целевой метрики даже при обращении к гораздо меньшему объему данных. Получается, если раньше вам надо было 10000 изображений для получения целевой метрики 0.8, то с использованием SSL схемы вам, возможно, хватит и 1000. Кроме того, наряду с меньшим объемом данных вам, весьма вероятно, понадобится и меньшее число эпох для достижения нужного показателя целевой метрики.
- Модель, единожды предобученная SSL-задачей, может быть пере-использована как точка начальной инициализации при решении других задач на таком же домене данных.
Тут можно остановиться и вспомнить о статье, авторы которой показывают способность сверточной сети даже со случайными весами из-за своей архитектуры вытаскивать из изображений много полезной информации без каких-либо данных. Имеется в виду, что структура самой нейронной сети сама по себе уже содержит неплохие представления об окружающем мире.
Итак, мы теперь пониманием, что такое SSL и для чего оно нам нужно. А сейчас давайте поглубже изучим терминологию.
Терминология
Нам понадобится 3 основных термина:
- предварительная задача (pretext task) — сама задача SSL с искусственной разметкой. Именно ее решает модель, чтобы научиться извлекать хорошие признаки из данных;
- псевдо-разметка — та самая дешевая разметка, происходящая без участия человека;
- последующая задача (downstream task) — задача, по которой проверяют качество выученных признаков. Как правило, это простые модели (KNN, LinReg, LogReg и другие), обучающиеся на извлекаемых с помощью SSL-модели признаках. Иногда бывает и так, что модель не фиксируется и дообучается целиком.
Давайте еще раз посмотрим на схему SSL-пайплайна на примере ImageNet для визуализации теории:
На рисунке 3 синий заштрихованный прямоугольник — решение предварительной SSL-задачи, на которой модель учится извлекать хорошие признаки; бордовый заштрихованный прямоугольник — оценка того, насколько хорошие признаки модель научилась извлекать.
Процесс проверки качества извлеченных признаков подробно представлен на схеме, предварительная задача же не так очевидна и требует дополнительного пояснения. Поэтому давайте сейчас поподробнее рассмотрим ее типы.
Виды предварительных задач
SSL появился довольно давно (первое упоминание автор нашел в статье Хинтона аж 1992 года), но основной бум в этом разделе глубокого обучения случился в условных 2020 годах.
До современных SOTA-методов в исследованиях упор делался в основном на сам процесс генерации псевдо-разметки. Для изображений были следующие варианты:
- Задача восстановления информации: ключевая идея в удалении какой-либо информации из изображения и тренировка сети для восстановления этой удаленной информации. Например, можно удалить цвет из изображения и превратить его в монохромное, а затем обучить сеть восстановлению цвета (статья, выступление, код). Авторы показывают, что такая предварительная задача неплохо подходит в том случае, если наша целевая задача является задачей семантической сегментации, поскольку для восстановления цвета сцены необходимо понимать семантику изображения и границы объектов на этом изображении. Еще есть идея маскирования части изображения и обучение сети восстановлению этой части (ранний вариант, современная SOTA).
- Изучение пространственного контекста: основной смысл в обучении сети пониманию относительных позиций и ориентации объектов на сцене. Например, в статье авторы генерировали разметку простым поворотом изображений и в качестве псевдо-разметки использовали класс поворота (0 — поворот на 0 градусов, 1 — поворот на 90 градусов, 2 — на 180 градусов и так далее). А в этой статье авторы предложили предсказывать относительную позицию двух случайных патчей из изображения.
- Группирование похожих изображений вместе: идея обучения сети группировать схожие изображения и, следовательно, извлекать хорошие признаки. Есть целое семейство методов, в которых для этого адаптирован алгоритм K-Means. Например, раз и два.
- Генеративные модели: попытка адаптировать веса предобученных автоэнкодеров и GAN’ов для их использования в процессе решения целевой задачи.
- Multi-view invariance методы: в основе лежит идея о том, что два вектора-признака, полученных из одного изображения путем применения различных трансформаций (чаще всего аугментаций) должны быть похожими (contrastive learning).
Подробнее про некоторые из таких “старых” методов можно прочитать в посте Дьяконова.
В современных SOTA-методах процедура генерации псевдо-разметки в основном сводится к двум вариантам:
- Multi-view invariance: здесь псевдо-разметка формируется по принципу contrastive learning, то есть позитивными примерами являются два по-разному аугментированных варианта одного и того же изображения, а негативными — аугментированные варианты другого изображения. Пример одного из таких методов можно посмотреть на рисунке 4 (а статью — тут).
2. Задача восстановления информации: из изображения удаляются некоторые патчи, а сеть учится восстанавливать эту информацию. Пример представлен на рисунке 5 (статью можно найти тут). В таком случае псевдо-разметка состоит из пар (X, Y), где Y — исходное изображение, а Х — маскированная версия этого изображения, в котором удалена часть информации.
У нас всего в основном две процедуры генерации псевдо-разметки. Чем же тогда различаются современные методы? Архитектурами, функциями потерь и различными хитростями, используемыми при обучении. В статье A Cookbook of Self-Supervised Learning авторы выделяют четыре больших семейства современных методов SSL:
- Методы, основанные на metric learning;
- Методы, основанные на self-distillation;
- Методы, основанные на каноническом корреляционном анализе;
- Методы, основанные на задачах восстановления информации.
Про каждый пункт можно и нужно рассказывать подробно, поэтому мы оставим этот обзор для последующих статей серии 🙂
Заключение
Эта статья обзорная, но нам хотелось, чтобы после ее прочтения вам запомнилось:
- Что такое SSL;
- Терминология;
- Общая схема работы SSL-пайплайна;
- Где SSL может пригодиться;
- Какими бывают предварительные задачи и методы получения псевдо-разметки.