Weight Averaging для улучшения качества нейросети
Stohastic Weight Averaging
Давайте для начала разберёмся с методом FGE (Fast Geometric Ensembling). Его суть заключается в использовании циклического learning rate scheduler на последних эпохах и сохранении веса каждый раз, когда learning rate достигает минимума.
Затем мы создаем ансамбль из моделей с сохраненными весами. Все эти модели будут иметь различные веса, но примерно один loss. Такая методика позволяет нам быстро построить ансамбль, метрики у которого лучше, чем у отдельной модели (как показывает статья). У данного подхода есть и недостаток: время предикта у ансамбля кратно больше, чем у отдельной нейросети.
Более быстрым методом являетcя SWA (Stoсhastic Weight Averaging). Он предлагает сделать то же самое, что и FGE, но вместо создания ансамбля — использует модель, в которой веса будут усреднены. Важно отметить: мы должны также, как и в FGE, использовать либо циклический learning rate scheduler, либо любой другой, который не просто двигается в сторону локального минимума, а позволяет найти разные точки около него. Этот метод дает примерно такое же качество, как и FGE, только предсказания происходят кратно быстрее (так как у нас будет не ансамбль из
Как использовать?
Этот алгоритм очень легко добавить к вашему циклу обучения, если вы используете pytorch lightning. Вам нужно лишь передать callback в trainer.
Важные параметры:
swa_lrs
— какой learning rate использовать. Если поставить None, то будет использоваться learning rate оптимизатора;
swa_epoch_start
— с какой эпохи начинать сохранять веса.
swa_callback = StochasticWeightAveraging(swa_lrs=5e-4, swa_epoch_start=1)
trainer = pl.Trainer(callbacks=[swa_callback, ...]
На обычном pytorch использовать SWA тоже не сложно, он поддерживается с версии 1.6:
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for epoch in range(100):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if epoch > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)
Exponential Weight Averaging
Для начала давайте вспомним, что из себя представляет Exponentially Weighted Moving Average:
\( EWMA_t=\alpha*x_t+(1-\alpha)EWMA_{t-1}, \)
где \( \alpha \) — степень сглаживания, а \( x_t \) — значение величины в момент времени \( t \).
Exponential Weight Averaging (EWA) —метод, который использует данную формулу для сглаживания весов при обучении. Он часто дает лучший результат в сравнении с простым обучением нейросети. Также он используется при обучении MobileNet-V3 и EfficientNet.
Как использовать?
Для использования EWA можно обратиться к библиотеке pytorch_ema:
from torch_ema import ExponentialMovingAverage
...
model = ...
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.9)
... # Обучаем модель несколько эпох
model.train()
for _ in range(20):
logits = model(x_train)
loss = F.cross_entropy(logits, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
ema.update()
И к pytorch_lightning:
from torch_ema import ExponentialMovingAverage
...
class SomeModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.criterion = SomeLoss()
self.encoder = encoder()
self.head = nn.Sequential(...)
self.ema = ExponentialMovingAverage(self.encoder.parameters(), decay=0.9)
def on_before_zero_grad(self, *args, **kwargs):
if self.current_epoch > n: # Начинаем обновлять после n эпох
self.ema.update(model.parameters())