Назад
60

Sharpness-Aware Minimization: как улучшить обобщающую способность моделей

60

Давайте для начала посмотрим на минимумы лосс функции:

Минимумы находятся на одном уровне, но левый лежит в “резкой” яме.

Давайте представим, что мы обучили нейросеть. Хорошая генерализующая способность будет достигнута при таких весах \( w \), когда на тренировочной выборке ошибка \( L_{train} \) не сильно отличается от ошибки на тестовой \( L_{test} \).

Мы знаем, что \( L_{train}(w) \) и \( L_{test}(w) \) будут отличаться. Но мы ожидаем, что локальные минимумы этих функций будут находиться рядом, а функции станут похожими.

Картинка ниже изображает вышесказанное. Если мы на нее посмотрим, то мы поймем, что в случае с “резкой” ямой \( L_{train} \) и \( L_{test} \) будут отличаться гораздо сильнее, чем в случае с “нерезкой” ямой, а следовательно, и обобщающая способность будет ниже.

Авторы статьи Sharpness-Aware Minimization предлагают не просто искать минимум функции, а искать его еще и в “нерезкой” яме. Давайте разбираться, что это все значит.

Введем термин “неровность” (в оригинале — sharpness). Интуитивно кажется, что функцию можно считать идеально “ровной” в точке, при которой в ее окрестности все значения функции равны. Соответственно, чем сильнее различие значений функции в окрестности, тем “неровнее” эта функция. Давайте формализуем:

\( S_L(w) = \underset{(||\epsilon||_2\le p)}{\max}L(w+\epsilon) -L(w) \)

То есть мы находим разницу между максимальным значением функции в окрестности (окружности с радиусом \( p \)) и значением функции в данной точке.

Теперь мы хотим минимизировать не только значение лосс функции, но и неровность. Если мы их сложим, то получим новый лосс, который и будем минимизировать:

\( L_{sam}=S_L(w)+L(w)=\underset{(||\epsilon||_2\le p)}{\max}L(w+\epsilon) \)

Можно доказать, что для \( p>0 \) с большой долей вероятности выполняется следующее неравенство (достаточно объемное его доказательство можно найти в приложении к оригинальной статье):

\( L_D(w)\le \underset{(||\epsilon||_2\le p)}{\max}L_S(w+\epsilon) + \alpha ||w||_2^2 \)

где \( L_S \) — лосс на тренировочной выборке , а \( L_D \) — лосс на генеральной совокупности.

Из этой формулы видно, что минимизация \( L_{sam} \) ведет к уменьшению лосс функции на генеральной совокупности (конечно, если мы не забудем про регуляризацию).

Основная проблема в том, что \( w+\epsilon \) — веса с наибольшим лоссом в окрестности, а они нам неизвестны.

Очевидно, что \( \epsilon \) — функция от весов. Соответственно, нам нужно решить оптимизационную проблему:

\( \epsilon(w)=\underset{\epsilon}\argmax(L(w+\epsilon)) \)

Давайте попробуем аппроксимировать \( L(w+\epsilon) \) при помощи ряда Тейлора.

Мы будем раскладывать нашу функцию в окрестностях точки \( w \) и использовать только первые два члена разложения, так как добавление еще одного сделает вычисления гораздо сложнее, ведь нам нужно будет знать третью производную.

Напомним, как выглядит разложение в ряд Тейлора:

\( f(x)=\underset{n=0}\sum \frac{f^{(n)}(x-a)^n}{n!} \)

где \( f^{(n)} \) — \( n \)-ая производная; \( a \) — точка, в которой мы ищем разложение. Соответственно, для первых двух точек формула определяется следующим образом:

\( f(x)=f(a)+f'(a)(x-a) \)

Итак, раскладываем \( L(w+\epsilon) \) и получаем:

\( L(w+\epsilon) =L(w)+L'(w)(w+\epsilon -w ) = L(w)+L'(w)*\epsilon \)

Тогда

\( \epsilon(w)=\underset{\epsilon}\argmax(L(w)+L'(w)*\epsilon)=\underset{\epsilon}\argmax(L'(w)*\epsilon) \)

Переводим в векторную форму:

\( \epsilon(w)=\underset{\epsilon}\argmax(\epsilon^T ∇L(w)) \)

Затем мы вводим обозначение \( g=∇L(w) \).

Следовательно, решением оптимизационной проблемы будет:

\( \epsilon(w)=psign(g)\frac{|g|^{q-1}}{(|g|_q^q)^\frac{1}{k}} \)

\( \frac{1}{q}+\frac{1}{k}=1 \)

Подробнее про него можно прочитать тут. Авторы предлагают использовать \( q=k=2 \), при котором получается лучший результат.

Итак, решив эту проблему, мы можем посчитать градиент нашей новой лосс функции:

\( ∇L_{sam}=∇L(w+\epsilon (w))=\frac {d(w+\epsilon(w))}{dw}∇L(w)|_{w+\epsilon(w)} \) \( ∇L_{sam} =∇L(w)|{w+\epsilon(w)} +∇L(w) \frac {d(\epsilon(w))}{dw}|{w+\epsilon(w)} \)

Современные фреймворки умеют считать выражения такого рода, но, к сожалению, это очень дорогостоящие операции, поэтому авторы предлагают убирать второй член и использовать следующий лосс:

\( ∇L_{sam} =∇L(w)|_{w+\epsilon(w)} \)

Получается, алгоритм SAM будет выглядеть таким образом:

  1. Считаем градиент исходной лосс функции.
  2. С его помощью считаем \( \epsilon(w) \).
  3. Считаем \( ∇L_{sam}(w)+ 2\lambda\sum w \).
  4. Обновляем веса.

На картинке ниже приведены минимумы, найденные при обучении Resnet при помощи SGD (слева) и SAM (справа).

Большое преимущество подхода — возможность использовать его с любым из алгоритмов градиентного спуска, а это дает улучшение качества предсказаний, как нам показывают авторы статьи. Более того, они добавляют также, что данный метод делает модель крайне устойчивой к шумным данным.

Если вы используете pytorch, то можете попробовать данный метод с помощью этого репозитория. А для pytorch-lightning есть реализация SAM через callback.

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!