Назад
37

Weight Averaging для улучшения качества нейросети

37

Stohastic Weight Averaging

Давайте для начала разберёмся с методом FGE (Fast Geometric Ensembling). Его суть заключается в использовании циклического learning rate scheduler на последних эпохах и сохранении веса каждый раз, когда learning rate достигает минимума.

Затем мы создаем ансамбль из моделей с сохраненными весами. Все эти модели будут иметь различные веса, но примерно один loss. Такая методика позволяет нам быстро построить ансамбль, метрики у которого лучше, чем у отдельной модели (как показывает статья). У данного подхода есть и недостаток: время предикта у ансамбля кратно больше, чем у отдельной нейросети.

Более быстрым методом являетcя SWA (Stoсhastic Weight Averaging). Он предлагает сделать то же самое, что и FGE, но вместо создания ансамбля — использует модель, в которой веса будут усреднены. Важно отметить: мы должны также, как и в FGE, использовать либо циклический learning rate scheduler, либо любой другой, который не просто двигается в сторону локального минимума, а позволяет найти разные точки около него. Этот метод дает примерно такое же качество, как и FGE, только предсказания происходят кратно быстрее (так как у нас будет не ансамбль из \( n \) моделей, а одна модель, соответственно и быстрее в \( n \) раз). На рисунке ниже изображены веса, полученные при помощи FGE(\( w_1,w_2,w_3 \)) и SWA (\( w_{swa} \)).

Как использовать?

Этот алгоритм очень легко добавить к вашему циклу обучения, если вы используете pytorch lightning. Вам нужно лишь передать callback в trainer.

Важные параметры:

swa_lrs — какой learning rate использовать. Если поставить None, то будет использоваться learning rate оптимизатора;

swa_epoch_start — с какой эпохи начинать сохранять веса.

swa_callback = StochasticWeightAveraging(swa_lrs=5e-4, swa_epoch_start=1)
trainer = pl.Trainer(callbacks=[swa_callback, ...]

На обычном pytorch использовать SWA тоже не сложно, он поддерживается с версии 1.6:

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

Exponential Weight Averaging

Для начала давайте вспомним, что из себя представляет Exponentially Weighted Moving Average:

\( EWMA_t=\alpha*x_t+(1-\alpha)EWMA_{t-1}, \)

где \( \alpha \) — степень сглаживания, а \( x_t \) — значение величины в момент времени \( t \).

Exponential Weight Averaging (EWA) —метод, который использует данную формулу для сглаживания весов при обучении. Он часто дает лучший результат в сравнении с простым обучением нейросети. Также он используется при обучении MobileNet-V3 и EfficientNet.

Как использовать?

Для использования EWA можно обратиться к библиотеке pytorch_ema:

from torch_ema import ExponentialMovingAverage
...
model = ...
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.9)

... # Обучаем модель несколько эпох
model.train()
for _ in range(20):
    logits = model(x_train)
    loss = F.cross_entropy(logits, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    ema.update()

И к pytorch_lightning:

from torch_ema import ExponentialMovingAverage
...
class SomeModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.criterion = SomeLoss()
        self.encoder = encoder()
        self.head = nn.Sequential(...)
        self.ema = ExponentialMovingAverage(self.encoder.parameters(), decay=0.9)

def on_before_zero_grad(self, *args, **kwargs):
    if self.current_epoch > n:  # Начинаем обновлять после n эпох
        self.ema.update(model.parameters())

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!