Назад
69

FastVit, Apple

69

💙 Небольшое обращение к читателям В посте есть тогглы с кодом на случай, если захочется разобраться в деталях. Их можно пропускать — понимание общей картины не ухудшится. Вопрос к аудитории: поделитесь, пожалуйста, удобен ли вам формат со вставками кода и полезны ли они? Мы открыты к идеям по улучшению подачи материала. Вы также можете предложить свои темы для статей в комментариях под нашим постом в телеграме. Так мы сможем учитывать ваши пожелания по материалу, который было бы интересно изучать 😊

Пререквизиты

Что можно прочитать, чтобы лучше и проще понять эту статью:

  1. FastVit — прямое продолжение идей MobileOne, поэтому крайне рекомендуем сначала ознакомиться с этим постом. Важно разобраться со следующими моментами (далее они будут использоваться без дополнительных пояснений):
    1. Термины: FLOPs, NMP, MAC, DOP, latency;
    2. Преимущества и недостатки multi-branch компонент в архитектурных блоках;
    3. Как репараметризация нивелирует недостатки multi-branch.
  2. В статье будет часто упоминаться архитектура трансформера в контексте CV. Есть пост с детальным разбором ViT: там раскрываются основные моменты работы трансформеров в целом, и в CV в частности.
  3. Дополнительно к пункту 2 — разбор работы Attention.

Введение

Архитектуры на базе трансформеров уже достигли SOTA-качества не только в родном для них NLP, но и в Computer Vision. Они равны или даже превосходят по качеству CNN модели в задачах классификации, детекции и сегментации. При этом традиционно такие модели требуют много вычислительных ресурсов как во время обучения, так и во время инференса. Авторы статьи FastViT ставили задачу совместить сильные стороны свёрточных и трансформерных архитектур для решения широкого круга CV задач в real-time и мобильных сценариях работы.

ViT: напоминание

Для начала вспомним, как устроена ViT. Модель состоит из 3-х основных блоков:

  • Linear Projection / Stem переводит картинку в эмбеддинг для трансформера;
  • Transformer Encoder — энкодер трансформера;
  • MLP Head — полносвязный слой классификатора.
Рисунок 1. Архитектура FastViT. Слева: архитектура ViT целиком, справа: схема трансформер энкодера

Рассмотрим основную часть модели ViT — энкодер трансформера. Он состоит из L трансформер-блоков. Каждый из блоков получает на вход и возвращает на выход эмбеддинги патчей, при этом число эмбеддингов и их размерность не меняется при прохождении через все энкодер-блоки сети.

Multi-Head Self-Attention отвечает за учет локальной и глобальной информации вместо сверток — это и есть концептуальное отличие трансформера от CNN архитектур. Делает он это путем перевзвешивания эмбеддингов скорами, которые меняются в зависимости от входных данных. Другими словами, Self-Attention осуществляет обмен информации между пространственными токенами. Поэтому его еще называют token mixer. Это важный момент, который пригодится в дальнейшем.

Также стоит немного поговорить о вычислительной сложности слоя Self-Attention. Она квадратная или O(n^2), где N — количество токенов в последовательности (количество патчей), так как мы должны посчитать скор попарно между всеми токенами. К тому же, не будем забывать, что Query, Key и Value — обучаемые параметры, которые вносят свой вклад в увеличение latency через MAC.

Таким образом, вычислительная сложность Self-Attention — это боттлнек, если мы хотим спроектировать модель, которая сможет работать на мобильных устройствах и несильно мощных CPU-ядрах.

MobileOne block: напоминание

MobileOne блок — основной блок в архитектуре модели MobileOne. Он представляет собой multi-branch из:

  1. трех веток, если kernel size = 3
    1. 1×1 depth wise свертка + слой batch normalization;
    2. 3×3 depth wise свертка + слой batch normalization;
    3. слой batch normalization.
  2. двух веток, если kernel size = 1
    1. 1×1 depth wise свертка + слой batch normalization;
    2. слой batch normalization.

После конкатенации бранчей идет функция активации. В исходной модели MobileOne это была ReLU (авторы проводили целое исследование о зависимости latency от функции активации, где и победила ReLU).

Главной особенностью блока является возможность его репараметризации — преобразования весов бранчей в веса одной свертки путем простой арифметики. Подробнее про это можно прочитать в разборе MobileOne. После репараметризации имеем:

  1. multi-branch из трех веток переходит в 3×3 свертку + batch norm;
  2. multi-branch из двух веток переходит в 1×1 свертку + batch norm.

После репараметризации каждого блока функция активации сохраняется.

В конце напоминания отметим ключевые идеи блока MobileOne, которые затем будут использоваться в FastViT:

  • multi-branch структура дает прирост в качестве;
  • depthwise свертки в связке с pointwise, так как они позволяют экономить в вычислениях и рассматривать пространственную информацию отдельно от информации в каналах;
  • репараметризация multi-branch структуры позволяет сохранить прирост в качестве, не увеличивая latency во время инференса.
Рисунок 2. Архитектура MobileOne stage, состоящая из двух блоков с разной структурой во время обучения и инференса. Слева: блоки с дополнительными бранчами во время обучения. Справа: блоки без них во время инференса. k — гиперпараметр, который отдельно подбирался к каждому варианту MobileOne. Act — функция активации ReLU
Код MobileOne блока
import torch
import torch.nn as nn


class MobileOneBlock(nn.Module):
		"""
    Имплементация MobileOne блока.
    """
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int = 1,
                 padding: int = 0,
                 dilation: int = 1,
                 groups: int = 1,
                 inference_mode: bool = False,
                 use_se: bool = False,
                 num_conv_branches: int = 1) -> None:
        """
        :param in_channels: Number of channels in the input.
        :param out_channels: Number of channels produced by the block.
        :param kernel_size: Size of the convolution kernel.
        :param stride: Stride size.
        :param padding: Zero-padding size.
        :param dilation: Kernel dilation factor.
        :param groups: Group number.
        :param inference_mode: If True, instantiates model in inference mode.
        :param use_se: Whether to use SE-ReLU activations.
        :param num_conv_branches: Number of linear conv branches.
        """
        super(MobileOneBlock, self).__init__()
        self.inference_mode = inference_mode
        self.groups = groups
        self.stride = stride
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_conv_branches = num_conv_branches

        # Можно использовать ещё squeeze-excitation блок
        if use_se:
            self.se = SEBlock(out_channels)
        else:
            self.se = nn.Identity()
        self.activation = nn.ReLU()
				
				# Во время инференса используем просто одну свёртку, которая заменяет
				# остальные слои.
        if inference_mode:
            self.reparam_conv = nn.Conv2d(in_channels=in_channels,
                                          out_channels=out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation,
                                          groups=groups,
                                          bias=True)
        else:
            # Репараметризуемый skip connection.
            self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) \
                if out_channels == in_channels and stride == 1 else None

            # Репараметризуемая ветка со свёртками.
            rbr_conv = list()
            for _ in range(self.num_conv_branches):
                rbr_conv.append(self._conv_bn(kernel_size=kernel_size,
                                              padding=padding))
            self.rbr_conv = nn.ModuleList(rbr_conv)

            # Репараметризуемая ветка со scale. Обращу внимание, что если изначально
						# kernel size = 1, то этой ветки не будет в блоке.
            self.rbr_scale = None
            if kernel_size > 1:
                self.rbr_scale = self._conv_bn(kernel_size=1,
                                               padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ Apply forward pass. """
        # Если инференс, то просто возвращаем сразу ответ.
        if self.inference_mode:
            return self.activation(self.se(self.reparam_conv(x)))

        # Мульти-бранч train-time forward pass.
        identity_out = 0
        if self.rbr_skip is not None:
            identity_out = self.rbr_skip(x)

        scale_out = 0
        if self.rbr_scale is not None:
            scale_out = self.rbr_scale(x)

        out = scale_out + identity_out
        for ix in range(self.num_conv_branches):
            out += self.rbr_conv[ix](x)

        return self.activation(self.se(out))

    def reparameterize(self):
        """
				Репараметризуем в один свёрточный слой для эффективного инференса.
        """
        if self.inference_mode:
            return
        kernel, bias = self._get_kernel_bias()
        self.reparam_conv = nn.Conv2d(in_channels=self.rbr_conv[0].conv.in_channels,
                                      out_channels=self.rbr_conv[0].conv.out_channels,
                                      kernel_size=self.rbr_conv[0].conv.kernel_size,
                                      stride=self.rbr_conv[0].conv.stride,
                                      padding=self.rbr_conv[0].conv.padding,
                                      dilation=self.rbr_conv[0].conv.dilation,
                                      groups=self.rbr_conv[0].conv.groups,
                                      bias=True)
        self.reparam_conv.weight.data = kernel
        self.reparam_conv.bias.data = bias

        # Удаляем ненужные слои после репараметризации.
        for para in self.parameters():
            para.detach_()
        self.__delattr__('rbr_conv')
        self.__delattr__('rbr_scale')
        if hasattr(self, 'rbr_skip'):
            self.__delattr__('rbr_skip')

        self.inference_mode = True

    def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """ 
				
				По отдельности суммируем kernel, bias с каждого бранча.

        :return: Tuple of (kernel, bias) after fusing branches.
        """
        # get weights and bias of scale branch
        kernel_scale = 0
        bias_scale = 0
        if self.rbr_scale is not None:
            kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
            # Pad scale branch kernel to match conv branch kernel size.
            pad = self.kernel_size // 2
            kernel_scale = torch.nn.functional.pad(kernel_scale,
                                                   [pad, pad, pad, pad])

        # get weights and bias of skip branch
        kernel_identity = 0
        bias_identity = 0
        if self.rbr_skip is not None:
            kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)

        # get weights and bias of conv branches
        kernel_conv = 0
        bias_conv = 0
        for ix in range(self.num_conv_branches):
            _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
            kernel_conv += _kernel
            bias_conv += _bias

        kernel_final = kernel_conv + kernel_scale + kernel_identity
        bias_final = bias_conv + bias_scale + bias_identity
        return kernel_final, bias_final

    def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
        """

				Получение kernel, bias результирующей свёртки из весов исходной свёртки 
				и Batch Norm.

        :param branch:
        :return: Tuple of (kernel, bias) after fusing batchnorm.
        """
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = torch.zeros((self.in_channels,
                                            input_dim,
                                            self.kernel_size,
                                            self.kernel_size),
                                           dtype=branch.weight.dtype,
                                           device=branch.weight.device)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim,
                                 self.kernel_size // 2,
                                 self.kernel_size // 2] = 1
                self.id_tensor = kernel_value
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def _conv_bn(self,
                 kernel_size: int,
                 padding: int) -> nn.Sequential:
        """
	
				Вспомогательный блок свёртка + Batch Norm.

        :param kernel_size: Size of the convolution kernel.
        :param padding: Zero-padding size.
        :return: Conv-BN module.
        """
        mod_list = nn.Sequential()
        mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels,
                                              out_channels=self.out_channels,
                                              kernel_size=kernel_size,
                                              stride=self.stride,
                                              padding=padding,
                                              groups=self.groups,
                                              bias=False))
        mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))
        return mod_list

Hybrid Vision Transformers

Собственно отличие обычного vision transformer от hybrid (гибридный, т.е. совмещающий в себе признаки нескольких объектов) — это совмещение дизайна свёрточных сетей и трансформера. Например:

  1. Использование CNN для разбиения картинки на патчи — это было в самом ViT;
  2. Использование блоков со сверточными слоями перед блоками с self-attention — “CoAtNet: Marrying Convolution and Attention for All Data Sizes”;
  3. Использование MLP как token mixer вместо Self-Attention — “MetaFormer Is Actually What You Need for Vision”.

Последнюю статью мы разберем подробно, так как она предложила новое понимание того, как можно улучшать трансформерные модели, и собственно саму базовую архитектуру для улучшения.

MetaFormer

Авторы статьи “MetaFormer Is Actually What You Need for Vision” решили разобраться с тем, что обеспечивает высокую точность vision трансформеров.

Вместо того, чтобы фокусироваться на экспериментах с механизмом внимания в Self-Attention, они предложили обратиться к самой структуре энкодера трансформера. Сделали они это довольно эффектно — заменили его на обычный Pooling (слой без параметров, который выполняет простейшие операции усреднения или поиск максимума). И такая модель не уступает по точности энкодеру с Self-Attention!

Таким образом, авторы пришли к идее MetaFormer: реализации не конкретной архитектуры, а структуре целого класса архитектур, которые отличаются друг от друга token mixer’ом. Например, если token mixer — это:

  • Self-Attention, то получится ViT, DeiT;
  • MLP-like, то получится ResMLP;
  • Pooling, то получится PoolFormer (его предлагают авторы статьи).
Рисунок 3. (а) схемы Transformer, MLP-like и PoolFormer, а также их обобщение — MetaFormer; (b) сравнение моделей-представителей схем из (a) в разрезе количества параметров, MAC и ImageNet Top-1 Accuracy (%). Из (b) видно: PoolFormer превосходит остальные модели по всем характеристикам

Все, что мы обсудили, не говорит об отсутствии важности token mixer: он по-прежнему занимает своё место в структуре MetaFormer. Важно то, что его выбор не ограничивается вариантами механизма внимания.

PoolFormer: архитектура

Поговорим подробнее о PoolFormer, так как ресерчеры из Apple взяли его как baseline для дальнейших улучшений.

PoolFormer включает 4 стадии, каждая из которых состоит из Patch Embedding блока и набора PoolFormer блоков.

  1. Patch Embedding блок разбивает изображение на патчи. Также именно он отвечает за уменьшение размера изображения в 4, 8, 16, 32 раза в зависимости от стадии;
Код Patch Embedding блока
import torch.nn as nn


class PatchEmbed(nn.Module):
    """
    Patch Embedding that is implemented by a layer of conv. 
    Input: tensor in shape [B, C, H, W]
    Output: tensor in shape [B, C, H/stride, W/stride]
    """
    def __init__(self, patch_size=16, stride=16, padding=0, 
                 in_chans=3, embed_dim=768, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        stride = to_2tuple(stride)
        padding = to_2tuple(padding)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, 
                              stride=stride, padding=padding)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.proj(x)
        x = self.norm(x)
        return x
  1. PoolFormer блок имеет структуру, как у трансформер-энкодера с заменой Self-Attention на Average Pooling. В качестве слоев нормализации используется Modified Layer Normalization (GroupNorm с group number = 1, или просто Instance Normalization). Авторы часто используют слои DropOut и DropPath — все также, как и в оригинальном ViT.
Рисунок 4. Различные виды нормализации. Они отличаются друг от друга выделением частей тензора для сбора статистик и последующей нормализации
Код сравнения Layer Normalization c Modified Layer Normalization (Group Normalization с group number = 1)
import torch
import torch.nn as nn


class LayerNormChannel(nn.Module):
		"""
		Vanilla Layer Normalization нормализует тензоры только по канальной размерности.
		Input: tensor in shape [B, C, H, W].
		"""
		def __init__(self, num_channels, eps=1e-05):
				super().__init__()
				# Размер обучаемых параметров - [num_channels, ].
				self.weight = nn.Parameter(torch.ones(num_channels))
				self.bias = nn.Parameter(torch.zeros(num_channels))
				self.eps = eps

		def forward(self, x):
				u = x.mean(1, keepdim=True) # Считаем средние по 1 размерности, то есть по C.
				s = (x - u).pow(2).mean(1, keepdim=True) # Считаем дисперсии тоже по C.
				x = (x - u) / torch.sqrt(s + self.eps)
				x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
				+ self.bias.unsqueeze(-1).unsqueeze(-1)
				return x


class ModifiedLayerNorm(nn.Module):
		"""
		Modified Layer Normalization нормализует тензоры по канальной и пространственной размерностям.
		Input: tensor in shape [B, C, H, W]
		"""
		def __init__(self, num_channels, eps=1e-05):
				super().__init__()
				# Размер обучаемых параметров также [num_channels, ].
				Normalization.
				self.weight = nn.Parameter(torch.ones(num_channels))
				self.bias = nn.Parameter(torch.zeros(num_channels))
				self.eps = eps

		def forward(self, x):
				u = x.mean([1, 2, 3], keepdim=True) # Считаем средние по 1, 2 и 3 размерности [C, H, W].
				s = (x - u).pow(2).mean([1, 2, 3], keepdim=True) # Считаем дисперсии так же.
				spatial dimensions.
				x = (x - u) / torch.sqrt(s + self.eps)
				x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
				+ self.bias.unsqueeze(-1).unsqueeze(-1)
				return x
Код MLP
import torch
import torch.nn as nn


class Mlp(nn.Module):
    """
	  Имплементация MLP со свёртками 1*1, которые заменяют линейные слои.
		Как всегда dropout важен !
    Input: тензор размерности [B, C, H, W]
    """
    def __init__(self, in_features, hidden_features=None, 
                 out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
Код PoolFormerBlock. В Pooling происходит вычитание входного тензора, так как за этим слоем идет skip connection.
import torch.nn as nn
import torch.nn.GroupNorm as GroupNorm

from modules import Mlp


class Pooling(nn.Module):
    """
    Имплементация Pooling слоя.
    --pool_size: pooling size
    """
    def __init__(self, pool_size=3):
        super().__init__()
        self.pool = nn.AvgPool2d(
            pool_size, stride=1, padding=pool_size//2, count_include_pad=False)

    def forward(self, x):
				# Здесь нужно вычесть входной тензор.
        return self.pool(x) - x


class PoolFormerBlock(nn.Module):
    """
    Имплементация PoolFormer блока.
    --dim: embedding dim
    --pool_size: pooling size
    --mlp_ratio: mlp expansion ratio
    --act_layer: activation
    --norm_layer: normalization
    --drop: dropout rate
    --drop path: Stochastic Depth, 
        refer to https://arxiv.org/abs/1603.09382
    --use_layer_scale, --layer_scale_init_value: LayerScale, 
        refer to https://arxiv.org/abs/2103.17239
    """
    def __init__(self, dim, pool_size=3, mlp_ratio=4., 
                 act_layer=nn.GELU, norm_layer=GroupNorm, 
                 drop=0., drop_path=0., 
                 use_layer_scale=True, layer_scale_init_value=1e-5):

        super().__init__()

        self.norm1 = norm_layer(dim)
        self.token_mixer = Pooling(pool_size=pool_size)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
                       act_layer=act_layer, drop=drop)

        # Следующие две техники полезны для обучения более глубоких вариантов PoolFormer.
				# DropPath нужен для регуляризации.
				# Layer Scale помогает менять масштаб.
        self.drop_path = DropPath(drop_path) if drop_path > 0. \
            else nn.Identity()
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            self.layer_scale_1 = nn.Parameter(
                layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.layer_scale_2 = nn.Parameter(
                layer_scale_init_value * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        if self.use_layer_scale:
            x = x + self.drop_path(
                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
                * self.token_mixer(self.norm1(x)))
            x = x + self.drop_path(
                self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
                * self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.token_mixer(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

Важный момент: в отличие от Self-Attention, Pooling и свертки работают с полями изображения k x k, поэтому чтобы формировать не только локальные признаки, но и глобальные, нам нужно воспользоваться довольно известным приемом из Computer Vision. А именно — будем уменьшать размеры изображения, так как рецептивное поле получающихся пикселей будет приближаться к размерам всей картинки.

Также поэтому выгодно использовать больше каналов ближе к концу архитектуры, на третьей стадии: L/6 + L/6 + L/2 + L/6, где L — общее число блоков.

Рисунок 5. Архитектура PoolFormer. Слева: архитектура целиком, справа: схема PoolFormer блока с Pooling token mixer

FastViT = PoolFormer с улучшениями

Теперь, после разбора корней FastVit можно наконец-то перейти к рассказу о нем самом 😊

В качестве baseline ресерчеры из Apple взяли PoolFormer-S12 — самую небольшую модель из семейства. Затем они стали делать следующие последовательные улучшения, то есть предыдущие улучшения сохраняются и подсчет NMP, FLOPs, latency и ImageNet Top-1 Accuracy % происходит у новых версий:

  1. Подняли входное разрешение изображений с 224 до 256, что увеличило и точность, и latency;
  2. Pooling, все-таки, слишком простой слой для агрегации, поэтому его заменили на блок со свертками;
  3. Факторизовали стандартные свертки: заменили их на связку depthwise + pointwise сверток. Также они добавили overparametrization во время обучения — дополнительные бранчи, которые репараметризуются во время инференса. Это напрямую перешло к FastViT из MobileOne;
  4. Использовали большие свертки (7×7) в блоках ConvFFN и Patch Embeddings. Подробнее о том, что это такое и зачем нужны большие свертки будет рассказано ниже.
Рисунок 6. Таблица с последовательными улучшениями PoolFormer-S12

Далее подробно расскажем об отдельных частях FastViT.

Важно: ниже в блоках используется такая же репараметризация во время инференса, как и в MobileOne.

Stem

Для улучшение качества и скорости работы Stem’a авторы собрали его из 3-х MobileOne блоков, при этом в первом блоке они использовали стандартные свертки, а последние два блока представляют ее факторизацию — разложение на depthwise + pointwise свертки.

Рисунок 7. Архитектура Stem блока. Слева: блок с дополнительными бранчами во время обучения. Справа: блок без них во время инференса. Activation — функция активации GELU

Как такая структура соотносится с тем, что было в MobileOne? По сути связка последних двух блоков в Stem является stage’м в MobileOne при гиперпараметре k = 1. Однако есть одно отличие — в FastViT используется GELU как функция активации вместо ReLU. Авторы путем экспериментов выяснили, что это улучшает итоговое качество работы FastViT.

Код Stem блока
import torch
import torch.nn as nn

from modules import MobileOneBlock


def convolutional_stem(
    in_channels: int, out_channels: int, inference_mode: bool = False
) -> nn.Sequential:
    """
		Имплементация Stem на основе MobileOne блоков.
    Args:
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        inference_mode: Flag to instantiate model in inference mode. Default: ``False``

    Returns:
        nn.Sequential object with stem elements.
    """
    return nn.Sequential(
        MobileOneBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            groups=1,
            inference_mode=inference_mode,
            use_se=False,
            num_conv_branches=1,
        ),
				# Depthwise свёртки.
        MobileOneBlock(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            groups=out_channels,
            inference_mode=inference_mode,
            use_se=False,
            num_conv_branches=1,
        ),
				# Pointwise свёртка.
        MobileOneBlock(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1,
            inference_mode=inference_mode,
            use_se=False,
            num_conv_branches=1,
        ),
    )

RepMixer и ConvFFN

По своей сути RepMixer — это первая из двух частей в общей структуре MetaFormer. В качестве token mixer’а используется MobileOne блок.

Рисунок 8. RepMixer блок. Слева: общая структура MetaFormer. Справа: архитектура RepMixer блока
Код RepMixer блока
import torch
import torch.nn as nn

from modules import MobileOneBlock


class RepMixer(nn.Module):
    """
		Имплементация репараметризуемого token mixer.
    """
    def __init__(
        self,
        dim,
        kernel_size=3,
        use_layer_scale=True,
        layer_scale_init_value=1e-5,
        inference_mode: bool = False,
    ):
        """
        Args:
            dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
            kernel_size: Kernel size for spatial mixing. Default: 3
            use_layer_scale: If True, learnable layer scale is used. Default: ``True``
            layer_scale_init_value: Initial value for layer scale. Default: 1e-5
            inference_mode: If True, instantiates model in inference mode. Default: ``False``
        """
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size
        self.inference_mode = inference_mode
				
				# Во время инференса используем просто одну свёртку, которая заменяет
				# остальные слои.
        if inference_mode:
						# 
            self.reparam_conv = nn.Conv2d(
                in_channels=self.dim,
                out_channels=self.dim,
                kernel_size=self.kernel_size,
                stride=1,
                padding=self.kernel_size // 2,
                groups=self.dim,
                bias=True,
            )
        else:
						# Можно убрать основные свёрточные слои из MobileOne блока,
						# оставив только scale и Batch Norm.
            self.norm = MobileOneBlock(
                dim,
                dim,
                kernel_size,
                padding=kernel_size // 2,
                groups=dim,
                use_act=False,
                use_scale_branch=False,
                num_conv_branches=0,
            )
            self.mixer = MobileOneBlock(
                dim,
                dim,
                kernel_size,
                padding=kernel_size // 2,
                groups=dim,
                use_act=False,
            )
						# Трюк для обучения более глубоких версий: будем учить scale.
            self.use_layer_scale = use_layer_scale
            if use_layer_scale:
                self.layer_scale = nn.Parameter(
                    layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if hasattr(self, "reparam_conv"):
            x = self.reparam_conv(x)
            return x
        else:
            if self.use_layer_scale:
                x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
            else:
                x = x + self.mixer(x) - self.norm(x)
            return x

    def reparameterize(self) -> None:
        """
				Репараметризуем mixer в один свёрточный слой для эффективного инференса.
        """
        if self.inference_mode:
            return

        self.mixer.reparameterize()
        self.norm.reparameterize()
				
				# Получение kernel, bias результирующей свёртки из весов исходной свёртки 
				# и Batch Norm.
        if self.use_layer_scale:
            w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
                self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
            )
            b = torch.squeeze(self.layer_scale) * (
                self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
            )
        else:
            w = (
                self.mixer.id_tensor
                + self.mixer.reparam_conv.weight
                - self.norm.reparam_conv.weight
            )
            b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias

        self.reparam_conv = nn.Conv2d(
            in_channels=self.dim,
            out_channels=self.dim,
            kernel_size=self.kernel_size,
            stride=1,
            padding=self.kernel_size // 2,
            groups=self.dim,
            bias=True,
        )
        self.reparam_conv.weight.data = w
        self.reparam_conv.bias.data = b
				
				# Удаляем ненужные слои после репараметризации.
        for para in self.parameters():
            para.detach_()
        self.__delattr__("mixer")
        self.__delattr__("norm")
        if self.use_layer_scale:
            self.__delattr__("layer_scale")

Вторая же часть — это ConvFFN вместо обычного MLP.

Как мы уже обсуждали выше, свертки могут учитывать только локальную информацию, в отличие от Self-Attention. Вычислительно эффективным способом увеличения receptive field для сверток является использование depthwise сверток с большим kernel size. Путем экспериментов авторы пришли к тому, что свертки 7×7 дают лучший баланс в соотношении FLOPs, latency и NMP итоговой модели.

Код ConvFFN блока
import torch
import torch.nn as nn

from torch.nn.init import trunc_normal_


class ConvFFN(nn.Module):
    """
		Имплементация свёрточного Feed-Forward блока.
		"""
    def __init__(
        self,
        in_channels: int,
        hidden_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        act_layer: nn.Module = nn.GELU,
        drop: float = 0.0,
    ) -> None:
        """

        Args:
            in_channels: Number of input channels.
            hidden_channels: Number of channels after expansion. Default: None
            out_channels: Number of output channels. Default: None
            act_layer: Activation layer. Default: ``GELU``
            drop: Dropout rate. Default: ``0.0``.
        """
        super().__init__()
        out_channels = out_channels or in_channels
        hidden_channels = hidden_channels or in_channels
        self.conv = nn.Sequential()
        self.conv.add_module(
            "conv",
						# Снова Depthwise свёртка.
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=7,
                padding=3,
                groups=in_channels,
                bias=False,
            ),
        )
        self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
        self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Module) -> None:
        if isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
Рисунок 9. Архитектура ConvFFN блока. Activation — функция активации GELU

Stage 1 — 3

Всего FastViT состоит из 4-х стадий — также, как и PoolFormer. 1, 2 и 3 стадии имеют одинаковую структуру, состоящую из пары RepMixer + ConvFFN. Они представляют собой вариацию архитектуры MetaFormer.

Рисунок 10. Архитектура Stage 1, 2, 3 блоков. Как и всегда, слева: train вид, справа: инференс

Patch Embedding

В отличие от стандартного ViT, полного перехода к эмбеддингам в FastViT нет, поэтому патчи здесь — это фрагменты входного изображения, определяемые ядрами сверток. Суть Patch Embedding слоев заключается в уменьшении пространственных размеров тензора и увеличении receptive field: свертки 7×7 выполняют здесь ту же роль, что и в ConvFFN.

BatchNorm используется вместо привычного для трансформеров LayerNorm, так как его можно легко репараметризовать вместе со свертками, в отличие от LayerNorm.

Код Large Kernel Conv
import torch
import torch.nn as nn


class ReparamLargeKernelConv(nn.Module):
    """
		Имплементация репараметризуемого Large Kernel свёрточного блока.
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        groups: int,
        small_kernel: int,
        inference_mode: bool = False,
        activation: nn.Module = nn.GELU(),
    ) -> None:
        """

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output channels.
            kernel_size: Kernel size of the large kernel conv branch.
            stride: Stride size. Default: 1
            groups: Group number. Default: 1
            small_kernel: Kernel size of small kernel conv branch.
            inference_mode: If True, instantiates model in inference mode. Default: ``False``
            activation: Activation module. Default: ``nn.GELU``
        """
        super(ReparamLargeKernelConv, self).__init__()

        self.stride = stride
        self.groups = groups
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.activation = activation

        self.kernel_size = kernel_size
        self.small_kernel = small_kernel
        self.padding = kernel_size // 2
        if inference_mode:
            self.lkb_reparam = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=self.padding,
                dilation=1,
                groups=groups,
                bias=True,
            )
        else:
            self.lkb_origin = self._conv_bn(
                kernel_size=kernel_size, padding=self.padding
            )
            if small_kernel is not None:
                assert (
                    small_kernel <= kernel_size
                ), "The kernel size for re-param cannot be larger than the large kernel!"
                self.small_conv = self._conv_bn(
                    kernel_size=small_kernel, padding=small_kernel // 2
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply forward pass."""
        if hasattr(self, "lkb_reparam"):
            out = self.lkb_reparam(x)
        else:
            out = self.lkb_origin(x)
            if hasattr(self, "small_conv"):
                out += self.small_conv(x)

        self.activation(out)
        return out

    def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
				
				По отдельности суммируем kernel, bias с каждого бранча.

        Returns:
            (kernel, bias) после слияния бранчей.
        """
        eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
        if hasattr(self, "small_conv"):
            small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
            eq_b += small_b
            eq_k += nn.functional.pad(
                small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
            )
        return eq_k, eq_b

    def reparameterize(self) -> None:
        """
        Репараметризуем в один свёрточный слой для эффективного инференса.
        """
        eq_k, eq_b = self.get_kernel_bias()
        self.lkb_reparam = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            dilation=self.lkb_origin.conv.dilation,
            groups=self.groups,
            bias=True,
        )

        self.lkb_reparam.weight.data = eq_k
        self.lkb_reparam.bias.data = eq_b
        self.__delattr__("lkb_origin")
        if hasattr(self, "small_conv"):
            self.__delattr__("small_conv")

    @staticmethod
    def _fuse_bn(
        conv: torch.Tensor, bn: nn.BatchNorm2d
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

				Получение kernel, bias результирующей свёртки из весов исходной свёртки 
				и Batch Norm.

        Args:
            conv: Convolutional kernel weights.
            bn: Batchnorm 2d layer.

        Returns:
            (kernel, bias) после слияния свёртки и Batch Norm.
        """
        kernel = conv.weight
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
        """

				Вспомогательный блок свёртка + Batch Norm.

        Args:
            kernel_size: Size of the convolution kernel.
            padding: Zero-padding size.

        Returns:
            A nn.Sequential Conv-BN module.
        """
        mod_list = nn.Sequential()
        mod_list.add_module(
            "conv",
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=kernel_size,
                stride=self.stride,
                padding=padding,
                groups=self.groups,
                bias=False,
            ),
        )
        mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
        return mod_list
Код Patch Embedding
import torch
import torch.nn as nn


class PatchEmbed(nn.Module):
    """
		Имплементация свёрточного patch embedding слоя.
		"""
    def __init__(
        self,
        patch_size: int,
        stride: int,
        in_channels: int,
        embed_dim: int,
        inference_mode: bool = False,
    ) -> None:
        """

        Args:
            patch_size: Patch size for embedding computation.
            stride: Stride for convolutional embedding layer.
            in_channels: Number of channels of input tensor.
            embed_dim: Number of embedding dimensions.
            inference_mode: Flag to instantiate model in inference mode. Default: ``False``
        """
        super().__init__()
        block = list()
        block.append(
            ReparamLargeKernelConv(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=stride,
                groups=in_channels,
                small_kernel=3,
                inference_mode=inference_mode,
            )
        )
        block.append(
            MobileOneBlock(
                in_channels=embed_dim,
                out_channels=embed_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                groups=1,
                inference_mode=inference_mode,
                use_se=False,
                num_conv_branches=1,
            )
        )
        self.proj = nn.Sequential(*block)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        return x
Рисунок 11. Архитектура Patch Embeddings. Слева: блок с дополнительными бранчами во время обучения. Справа: блок без них во время инференса

Stage 4

Последняя 4-я стадия отличается от первых 3-х наличием условного позиционного кодирования и использования Self-Attention, как token mixer в MetaFormer.

Условное позиционное кодирование (Conditional Positional Encoding, CPE) отличается от фиксированных или обучаемых позиционных кодировок тем, что генерируется динамически и зависит от локальной близости входных токенов. Таким образом, CPE стремится обобщить входные последовательности, которые длиннее тех, что модель когда-либо видела во время обучения. Как и остальные блоки, CPE репараметризуется во время инференса.

Использование Self-Attention на последней стадии обусловлено повышением итогового качества и незначительным повышением latency, поэтому было решено его оставить.

Рисунок 12. Архитектура 4-ой стадии. Слева: блок с дополнительными бранчами во время обучения. Справа: блок без них во время инференса. Обратите внимание, репараметризуется только CPE
Код Conditional Positional Encodings
import torch
import torch.nn as nn


class RepCPE(nn.Module):
    """

		Имлементация репараметризуемого conditional positional encoding слоя.

		"""

    def __init__(
        self,
        in_channels: int,
        embed_dim: int = 768,
        spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
        inference_mode=False,
    ) -> None:
        """
        Args:
            in_channels: Number of input channels.
            embed_dim: Number of embedding dimensions. Default: 768
            spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
            inference_mode: Flag to instantiate block in inference mode. Default: ``False``
        """
        super(RepCPE, self).__init__()
        if isinstance(spatial_shape, int):
            spatial_shape = tuple([spatial_shape] * 2)
        assert isinstance(spatial_shape, Tuple), (
            f'"spatial_shape" must by a sequence or int, '
            f"get {type(spatial_shape)} instead."
        )
        assert len(spatial_shape) == 2, (
            f'Length of "spatial_shape" should be 2, '
            f"got {len(spatial_shape)} instead."
        )

        self.spatial_shape = spatial_shape
        self.embed_dim = embed_dim
        self.in_channels = in_channels
        self.groups = embed_dim

        if inference_mode:
            self.reparam_conv = nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.embed_dim,
                kernel_size=self.spatial_shape,
                stride=1,
                padding=int(self.spatial_shape[0] // 2),
                groups=self.embed_dim,
                bias=True,
            )
        else:
            self.pe = nn.Conv2d(
                in_channels,
                embed_dim,
                spatial_shape,
                1,
                int(spatial_shape[0] // 2),
                bias=True,
                groups=embed_dim,
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if hasattr(self, "reparam_conv"):
            x = self.reparam_conv(x)
            return x
        else:
            x = self.pe(x) + x
            return x

    def reparameterize(self) -> None:
        # Строим эквивалентный Id тензор.
        input_dim = self.in_channels // self.groups
        kernel_value = torch.zeros(
            (
                self.in_channels,
                input_dim,
                self.spatial_shape[0],
                self.spatial_shape[1],
            ),
            dtype=self.pe.weight.dtype,
            device=self.pe.weight.device,
        )
        for i in range(self.in_channels):
            kernel_value[
                i,
                i % input_dim,
                self.spatial_shape[0] // 2,
                self.spatial_shape[1] // 2,
            ] = 1
        id_tensor = kernel_value

        # Репараметризуем Id тнезор и свёртку.
        w_final = id_tensor + self.pe.weight
        b_final = self.pe.bias

        # Строим репараметризованый свёрточный слой.
        self.reparam_conv = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.embed_dim,
            kernel_size=self.spatial_shape,
            stride=1,
            padding=int(self.spatial_shape[0] // 2),
            groups=self.embed_dim,
            bias=True,
        )
        self.reparam_conv.weight.data = w_final
        self.reparam_conv.bias.data = b_final

        for para in self.parameters():
            para.detach_()
        self.__delattr__("pe")
Код Multi-Head Self-Attention
import torch
import torch.nn as nn


class MHSA(nn.Module):
    """

		Имплементация Multi-headed Self Attention блока.

    """

    def __init__(
        self,
        dim: int,
        head_dim: int = 32,
        qkv_bias: bool = False,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        """

        Args:
            dim: Number of embedding dimensions.
            head_dim: Number of hidden dimensions per head. Default: ``32``
            qkv_bias: Use bias or not. Default: ``False``
            attn_drop: Dropout rate for attention tensor.
            proj_drop: Dropout rate for projection tensor.
        """
        super().__init__()
        assert dim % head_dim == 0, "dim should be divisible by head_dim"
        self.head_dim = head_dim
        self.num_heads = dim // head_dim
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.shape
        B, C, H, W = shape
        N = H * W
				# Сжимаем пространственную размерность H x W в одну.
        if len(shape) == 4:
            x = torch.flatten(x, start_dim=2).transpose(-2, -1)  # (B, N, C)
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, self.head_dim)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # трюк, чтобы сделать вычисление q@k.t более стабильным
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
				# Возвращаем тензор к исходной размерности.
        if len(shape) == 4:
            x = x.transpose(-2, -1).reshape(B, C, H, W)

        return x
Код Attention блока
import torch
import torch.nn as nn


class AttentionBlock(nn.Module):
    """
		
		Имлементация MetaFormer блока с MHSA в качестве token mixer.

    """

    def __init__(
        self,
        dim: int,
        mlp_ratio: float = 4.0,
        act_layer: nn.Module = nn.GELU,
        norm_layer: nn.Module = nn.BatchNorm2d,
        drop: float = 0.0,
        drop_path: float = 0.0,
        use_layer_scale: bool = True,
        layer_scale_init_value: float = 1e-5,
    ):
        """

        Args:
            dim: Number of embedding dimensions.
            mlp_ratio: MLP expansion ratio. Default: 4.0
            act_layer: Activation layer. Default: ``nn.GELU``
            norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
            drop: Dropout rate. Default: 0.0
            drop_path: Drop path rate. Default: 0.0
            use_layer_scale: Flag to turn on layer scale. Default: ``True``
            layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
        """

        super().__init__()

        self.norm = norm_layer(dim)
        self.token_mixer = MHSA(dim=dim)

        assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
            mlp_ratio
        )
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.convffn = ConvFFN(
            in_channels=dim,
            hidden_channels=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        # Drop path
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        # Layer Scale
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            self.layer_scale_1 = nn.Parameter(
                layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
            )
            self.layer_scale_2 = nn.Parameter(
                layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
            )

    def forward(self, x):
        if self.use_layer_scale:
            x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
            x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
        else:
            x = x + self.drop_path(self.token_mixer(self.norm(x)))
            x = x + self.drop_path(self.convffn(x))
        return x

FastViT: архитектура

Мы разобрали все части по отдельности и теперь можем собрать их в единое целое!

Рисунок 13. Схема всей архитектуры FastViT с разделением структуры блоков во время обучения и инференса. Стадии 1, 2 и 3 полностью совпадают

Различные варианты FastViT были получены путем:

  • уменьшения MLP expansion ratio с 4 до 3 — префикс “T”;
  • изменения размеров эмбеддингов — у подсемейства с префиксом “S” они меньше, чем у подсемейства с префиксом “M”;
  • добавления Self-Attention как token mixer’а на последней стадии — префикс “A”.
Рисунок 14. Варианты FastViT. Справа сверху указаны сами варианты. Снизу представлены значения NMP (millions) и Giga-FLOPs у всех вариантов

Image classification task

Рисунок 15. Сравнение версий FastViT с другими
популярными моделями классификации на базе сверток
и трансформеров на ImageNet-1k validation set.
Mobile latency — A14, GPU latency — RTX-2080Ti + TensorRT

Видно, что модели семейства FastViT превосходят своих конкурентов в соотношении latency / accuracy. Также, как и в случае с MobileOne, они выигрывают по скорости за счет архитектурных решений, а не за счет меньших FLOPs и NMP.

Отдельно авторы проводили hard knowledge distillation с моделью-учителем RegNetY-16GF, следуя похожему пути в DeiT, но без использования дополнительной классификационной головы для дистилляции. Это позволило сильнее поднять качество всех моделей семейства FastViT.

Downstream tasks

Как и в случае с MobileOne, были проведены эксперименты по применению FastViT для других задач CV, помимо классификации.

Для начала авторы рассмотрели задачу real-time 3D hand mesh estimation: оценка 3D положения кисти в реальном времени (Apple Vision Pro передает привет из 2024 года 👓👀). Современные работы в этой области вводят сложные сеточные слои регрессии поверх CNN backbone для решения данной задачи. При этом GPU хорошо оптимизированы под извлечение фичей CNN-частью и плохо — под сложные сеточные слои регрессии.

Авторы FastViT предположили: если улучшить backbone для извлечения признаков, тем самым получив хорошие эмбеддинги изображений, то можно упростить структуру регрессионных слоев — MANO. Для объективного сравнения авторы выбирали модели-конкуренты, которые обучались только на FreiHand и не использовали другие датасеты с позами для pre-train, train и fine-tune. ImageNet-1k использовался только как pre-train к FastViT.

Рисунок 16. Схема архитектуры FastViT-MA36 + Mano для задачи 3D hand mesh estimation

Немного информации о FreiHand:

  1. 32 уникальных пар рук;
  2. 130 240 изображений в тренировочной выборке, уникальных — 32 560, а остальные получены путем аугментаций;
  3. 3 960 уникальных изображений в валидационной выборке.
Рисунок 17. Сравнение качества работы различных связок «метод + энкодер» по метрикам hand estimation на датасете FreiHAND test. FPS замерялся на видеокарте NVIDIA RTX-2080Ti

Для задачи семантической сегментации авторы встроили различные backbones, в том числе FastViT в semantic FPN как энкодеры. Полученные модели затем обучали и валидировали на ADE20k:

  • 20k изображений в обучающей выборке;
  • 2k изображений в валидационной выборке;
  • всего 150 классов.
Рисунок 18. Сравнение качества работы Semantic FPN с различными энкодерами по метрике mIoU (%) на датасете ADE20k. FLOPs и latency замерялись на изображениях размером 512 x 512

В задаче детекции поверх энкодеров была прикреплена голова Mask-RCNN. Полученные модели затем обучали и валидировали на MS-COCO:

  • 118k изображений в обучающей выборке;
  • 5k изображений в валидационной выборке;
  • всего 80 классов.
Рисунок 19. Сравнение качества работы Mask-RCNN с различными энкодерами по разным метрикам детекции на датасете MS-COCO val2017. FLOPs и latency замерялись на изображениях размером 512 x 512. В таблице присутствуют опечатки — в 3-ем блоке жирным выделены не всегда самые лучшие результаты

Как видно из рисунков выше, FastViT не уступает по качеству другим трансформерам, при этом имеет существенно более низкие значения CPU/GPU latency.

Интересные идеи из статьи

  1. Как таковой token mixer занимает важное место в структуре MetaFormer. Однако его выбор не ограничивается вариантами механизма внимания.
  2. Как и в MobileOne, операция репараметризации позволяет успешно нивелировать недостатки multi-branch архитектур.
  3. Большие depthwise свёртки, например, с kernel size = 7, позволяют эффективно увеличивать receptive field, тем самым заменяя вычислительно дорогие операции Self-Attention в архитектуре трансформера.
  4. Модели из семейства FastViT могут быть использованы как бэкбоны для решения широкого круга CV задач, особенно там, где важно соотношение latency/качество работы модели.

Ссылки

Как можно воспользоваться в своих проектах:

  • timm — интерфейс через фреймворк timm для всего семейства FastViT + дистиллированные модели.
Рисунок 20. Список вариантов FastViT на сайте huggingface.

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!