Sharpness-Aware Minimization: как улучшить обобщающую способность моделей
Давайте для начала посмотрим на минимумы лосс функции:
Минимумы находятся на одном уровне, но левый лежит в “резкой” яме.
Давайте представим, что мы обучили нейросеть. Хорошая генерализующая способность будет достигнута при таких весах
Мы знаем, что \( L_{train}(w) \) и \( L_{test}(w) \) будут отличаться. Но мы ожидаем, что локальные минимумы этих функций будут находиться рядом, а функции станут похожими.
Картинка ниже изображает вышесказанное. Если мы на нее посмотрим, то мы поймем, что в случае с “резкой” ямой \( L_{train} \) и \( L_{test} \) будут отличаться гораздо сильнее, чем в случае с “нерезкой” ямой, а следовательно, и обобщающая способность будет ниже.
Авторы статьи Sharpness-Aware Minimization предлагают не просто искать минимум функции, а искать его еще и в “нерезкой” яме. Давайте разбираться, что это все значит.
Введем термин “неровность” (в оригинале — sharpness). Интуитивно кажется, что функцию можно считать идеально “ровной” в точке, при которой в ее окрестности все значения функции равны. Соответственно, чем сильнее различие значений функции в окрестности, тем “неровнее” эта функция. Давайте формализуем:
\( S_L(w) = \underset{(||\epsilon||_2\le p)}{\max}L(w+\epsilon) -L(w) \)
То есть мы находим разницу между максимальным значением функции в окрестности (окружности с радиусом \( p \)) и значением функции в данной точке.
Теперь мы хотим минимизировать не только значение лосс функции, но и неровность. Если мы их сложим, то получим новый лосс, который и будем минимизировать:
\( L_{sam}=S_L(w)+L(w)=\underset{(||\epsilon||_2\le p)}{\max}L(w+\epsilon) \)
Можно доказать, что для \( p>0 \) с большой долей вероятности выполняется следующее неравенство (достаточно объемное его доказательство можно найти в приложении к оригинальной статье):
\( L_D(w)\le \underset{(||\epsilon||_2\le p)}{\max}L_S(w+\epsilon) + \alpha ||w||_2^2 \)
где \( L_S \) — лосс на тренировочной выборке , а \( L_D \) — лосс на генеральной совокупности.
Из этой формулы видно, что минимизация \( L_{sam} \) ведет к уменьшению лосс функции на генеральной совокупности (конечно, если мы не забудем про регуляризацию).
Основная проблема в том, что \( w+\epsilon \) — веса с наибольшим лоссом в окрестности, а они нам неизвестны.
Очевидно, что \( \epsilon \) — функция от весов. Соответственно, нам нужно решить оптимизационную проблему:
\( \epsilon(w)=\underset{\epsilon}\argmax(L(w+\epsilon)) \)
Давайте попробуем аппроксимировать \( L(w+\epsilon) \) при помощи ряда Тейлора.
Мы будем раскладывать нашу функцию в окрестностях точки \( w \) и использовать только первые два члена разложения, так как добавление еще одного сделает вычисления гораздо сложнее, ведь нам нужно будет знать третью производную.
Напомним, как выглядит разложение в ряд Тейлора:
\( f(x)=\underset{n=0}\sum \frac{f^{(n)}(x-a)^n}{n!} \)
где \( f^{(n)} \) — \( n \)-ая производная; \( a \) — точка, в которой мы ищем разложение. Соответственно, для первых двух точек формула определяется следующим образом:
\( f(x)=f(a)+f'(a)(x-a) \)
Итак, раскладываем \( L(w+\epsilon) \) и получаем:
\( L(w+\epsilon) =L(w)+L'(w)(w+\epsilon -w ) = L(w)+L'(w)*\epsilon \)
Тогда
\( \epsilon(w)=\underset{\epsilon}\argmax(L(w)+L'(w)*\epsilon)=\underset{\epsilon}\argmax(L'(w)*\epsilon) \)
Переводим в векторную форму:
\( \epsilon(w)=\underset{\epsilon}\argmax(\epsilon^T ∇L(w)) \)
Затем мы вводим обозначение \( g=∇L(w) \).
Следовательно, решением оптимизационной проблемы будет:
\( \epsilon(w)=psign(g)\frac{|g|^{q-1}}{(|g|_q^q)^\frac{1}{k}} \)
\( \frac{1}{q}+\frac{1}{k}=1 \)
Подробнее про него можно прочитать тут. Авторы предлагают использовать \( q=k=2 \), при котором получается лучший результат.
Итак, решив эту проблему, мы можем посчитать градиент нашей новой лосс функции:
\( ∇L_{sam}=∇L(w+\epsilon (w))=\frac {d(w+\epsilon(w))}{dw}∇L(w)|_{w+\epsilon(w)} \) \( ∇L_{sam} =∇L(w)|{w+\epsilon(w)} +∇L(w) \frac {d(\epsilon(w))}{dw}|{w+\epsilon(w)} \)
Современные фреймворки умеют считать выражения такого рода, но, к сожалению, это очень дорогостоящие операции, поэтому авторы предлагают убирать второй член и использовать следующий лосс:
\( ∇L_{sam} =∇L(w)|_{w+\epsilon(w)} \)
Получается, алгоритм SAM будет выглядеть таким образом:
- Считаем градиент исходной лосс функции.
- С его помощью считаем \( \epsilon(w) \).
- Считаем \( ∇L_{sam}(w)+ 2\lambda\sum w \).
- Обновляем веса.
На картинке ниже приведены минимумы, найденные при обучении Resnet при помощи SGD (слева) и SAM (справа).
Большое преимущество подхода — возможность использовать его с любым из алгоритмов градиентного спуска, а это дает улучшение качества предсказаний, как нам показывают авторы статьи. Более того, они добавляют также, что данный метод делает модель крайне устойчивой к шумным данным.
Если вы используете pytorch, то можете попробовать данный метод с помощью этого репозитория. А для pytorch-lightning есть реализация SAM через callback.