Назад
482

ViT: Vision transformer

482

ViT — сокращение от Vision Transformer. Это полноценный трансформер, способный работать с изображениями. Авторы ставили себе цель внести как можно меньше изменений в архитектуру. Поэтому в модели практически полностью сохранились особенности из NLP.

Модель состоит из 3 основных блоков:

  • Linear Projection — переводим картинку в вектора для трансформера;
  • Transformer Encoder — энкодер трансформера (только энкодер, декодера нет!);
  • MLP Head — полносвязный слой классификатора.

Давайте рассмотрим все блоки по порядку.

Linear Projection

Трансформер работает с векторами, а у нас картинки размерности W x H x C. Чтобы преобразовать картинку в векторы для трансформера, в ViT используется блок Linear Projection.

Перед Linear Projection нам нужно выполнить такие шаги:

  1. Разрезаем картинку размерности (W x H x С) на патчи размерности (P x P x С);
  2. Патчи размерности (P x P x С) решейпим в вектора размерности (1 x (P*P*С)).

После того как мы получили векторы, передаем их в Linear Projection:

  1. Вектора размерности (1 x (P*P*С)) переводим в вектора размерности (1 x D) с помощью матрицы \( E \) — это обычный линейный слой, только без функции активации;
  2. Складываем вектора размерности (1 x D) с векторами некоторой матрицы \( E_{pos} \)(обсудим ее чуть позже).

На рисунке 1 показано, как именно мы режем картинку на патчи и решейпим в вектор. В итоге мы получим набор векторов \( [x^1_p;x^2_p;…;x^N_p] \), каждый из которых имеет размерность (1 x (P*P*С)).

Рисунок 1. Преобразование входного изображения в векторы патчей

Затем вектора \( [x^1_p;x^2_p;…;x^N_p] \) размерности (1 x (P*P*С)) нужно преобразовать в вектора \( [z^1_0;z^2_0;…;z^N_0] \) размерности (1 x D). На рисунке 2 изображен этот процесс на примере вектора \( x^1_p \):

Рисунок 2. Преобразование вектора патча изображения в его эмбеддинг

Матрицы \( E \) и \( E_{pos} \) являются обучаемыми параметрами модели.

Вы, возможно, заметили, что на рисунке 2 у матрицы \( E_{pos} \) добавилась одна строка. Эта строка соответствует некоторому вектору [cls]-токена. Откуда берется [cls]-токен и что это вообще такое? Авторы ViT старались максимально сохранить архитектуру трансформера. Поэтому этот артефакт из NLP перекочевал в ViT. В модели BERT [cls]-токен используется для классификации текстов. Это служебный дополнительный токен для классификатора, который вставляют в начало предложения. Выходной эмбединг именно этого токена подается в Classification Head для классификации. В процессе обучения модель учится складывать в эмбединг этого токена знания о всем предложении. Ровно то же самое происходит и в ViT (с поправкой на то, что у нас картинка, а не текст). В модели ViT от авторов статьи [cls]-вектор изначально инициализируется нулями и имеет размер (1 x D) и учится также как и вся модель, т.е. это еще один обучаемый вектор-параметр нашей модели.

Осталось разобраться, зачем нам нужна матрица \( E_{pos} \). Дело в том, что в трансформер-энкодере не передается информация о 2D структуре изображения. Таким образом мы теряем inductive bias о том, что у патчей из картинки есть патчи-соседи сверху и снизу, слева и справа. Эти локальные отношения между патчами могут быть очень полезными. Для того, чтобы добавить этот inductive bias, мы каждый эмбеддинг патча и [cls]-вектор будем складывать с определённым вектором из матрицы позиций \( E_{pos} \). Точно так же в NLP кодируются позиции токенов. Это по сути является единственным inductive bias, который мы передаем в модель. В модели ViT от авторов статьи \( E_{pos} \) была инициализирована из нормального распределения со стандартным отклонением 0.02.

Давайте теперь реализуем это в коде. Условимся, что за получение эмбеддингов патчей входного изображения будет отвечать некоторый класс PatchEmbedder. Реализовать этот класс можно двумя способами:

  1. С использованием сверточного слоя;
  2. С использованием операции space-to-depth и линейного слоя (если мы хотим отказаться от сверток в нашем трансформере).
import typing as ty

import torch
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked


patch_typeguard()

BATCH_SIZE = 1
IN_CHANNELS = 3
IMAGE_SIZE = 224

PATCH_SIZE = 16
EMBED_DIM = 768


ImageTensor = TensorType[
    "batch_size", "channels", "height", "width", float
]
UnfoldedImageTensor = TensorType[
    "batch_size", "num_patches", "num_patch_pixels", float
]
EmbeddedPatches = TensorType[
    "batch_size", "num_patches", "embedding_size", float
]


class LinearPatchEmbedder(torch.nn.Module):
    def __init__(
        self,
        patch_size: int,
        in_channels: int,
        embed_dim: int = 768,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self._embedder = torch.nn.Linear(
            in_features=in_channels * self.patch_size ** 2,
            out_features=self.embed_dim,
        )

    @typechecked
    def _space_to_depth(self, tensor: ImageTensor) -> UnfoldedImageTensor:
        bs, channels, image_height, image_width = tensor.size()
        num_patches = (image_height * image_width) // (self.patch_size ** 2)
        num_pixels_in_patch = channels * self.patch_size ** 2

        # В einops нотации это выглядит следующим образом:
		    # (bs c (h p1) (w p2)) -> (bs (h w) (c p1 p2)).
        tensor = tensor.view(
            bs,
            channels,
            image_height//self.patch_size,
            self.patch_size,
            image_width//self.patch_size,
            self.patch_size,
        )
        tensor = tensor.permute(0, 2, 4, 1, 3, 5).contiguous()
        return tensor.view(bs, num_patches, num_pixels_in_patch)

    @typechecked
    def forward(self, tensor: ImageTensor) -> EmbeddedPatches:
        unfolded_tensor = self._space_to_depth(tensor)
        return self._embedder(unfolded_tensor)


class Conv2dPatchEmbedder(torch.nn.Module):
    def __init__(
        self,
        patch_size: int,
        in_channels: int,
        embed_dim: int = 64,
	) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self._embedder = torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )
    
    @typechecked
    def forward(self, tensor: ImageTensor) -> EmbeddedPatches:
        bs = tensor.size(0)
        return self._embedder(tensor).view(bs, self.embed_dim, -1).permute(
					(0, 2, 1),
				)


image_batch = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
assert tuple(image_batch.shape) == (1, 3, 224, 224)

linear_pe = LinearPatchEmbedder(
    patch_size=PATCH_SIZE, in_channels=IN_CHANNELS, embed_dim=EMBED_DIM,
)
assert (
    tuple(linear_pe(image_batch).shape)
    == (BATCH_SIZE, (IMAGE_SIZE ** 2) // (PATCH_SIZE ** 2), EMBED_DIM)
)
conv_pe = Conv2dPatchEmbedder(
    patch_size=PATCH_SIZE, in_channels=IN_CHANNELS, embed_dim=EMBED_DIM,
)
assert (
    tuple(conv_pe(image_batch).shape)
    == (BATCH_SIZE, (IMAGE_SIZE ** 2) // (PATCH_SIZE ** 2), EMBED_DIM)
)

❗️Нюанс❗️: в оригинальной версии авторов класс реализован с помощью сверточного слоя.

Отлично! Осталось добавить матрицу позиционных эмбеддингов и [cls]-вектор. Давайте реализуем полный класс LinearProjection.

class LinearProjection(torch.nn.Module):
    def __init__(
        self,
        image_size: int,
        patch_size: int,
        in_channels: int,
        embed_dim: int = 768,
    ) -> None:
        super().__init__()
        self.patch_embedder = Conv2dPatchEmbedder(
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim,
        )
        num_patches = (image_size ** 2) // (patch_size ** 2)
        self.pos_embeddings = torch.nn.Parameter(
            torch.normal(
                mean=0,
                std=0.02,

                # Не забываем +1, так как у нас есть еще
                # [cls]-токен!
                size=(1, num_patches + 1, embed_dim),
            ),
        )
        self.cls_token = torch.nn.Parameter(
            torch.zeros(1, 1, embed_dim),
        )

    @typechecked
    def forward(self, tensor: ImageTensor) -> EmbeddedPatches:
        bs = tensor.size(0)
        cls_tokens = self.cls_token.repeat((bs, 1, 1))
        embedded_patches = self.patch_embedder(tensor)
        embedded_patches = torch.cat((cls_tokens, embedded_patches), dim=1)
        return embedded_patches + self.pos_embeddings


lp = LinearProjection(IMAGE_SIZE, PATCH_SIZE, IN_CHANNELS, EMBED_DIM)
assert (
    tuple(lp(image_batch).shape) 
    == (BATCH_SIZE, (IMAGE_SIZE ** 2) // (PATCH_SIZE ** 2) + 1, EMBED_DIM)
)

Transformer Encoder

Рассмотрим основную часть модели ViT — энкодер трансформера. Он состоит из L трансформер-блоков. Каждый из блоков получает на вход и возвращает на выход эмбеддинги патчей размерности (N + 1) x D, где N — это количество патчей, а +1, потому что не забываем про [cls]-вектор, то есть число эмбеддингов и их размерность не меняется при прохождении через все энкодер-блоки сети.

Рисунок 3. Схема энкодер-блока трансформера

На рисунке 3 видно, что в трансформер-блоке используются:

  • Layer Normalisation;
  • Multi-Head Self-Attention;
  • Multi-Layer Perception.

Рассмотрим каждый слой по-отдельности.

❗️Нюанс❗️: помните, авторы ViT ставили себе цель внести как можно меньше изменений в стандартный трансформер? Если вы читали статью Attention is all you need, которая описывает стандартный трансформер, и на которую ссылаются авторы ViT, вы можете заметить, что здесь все слои перепутаны! Вот картинка для сравнения:

Рисунок 4. Сравнение энкодер-блоков трансформера

Мы задали вопрос в репозитории авторов ViT, в чем дело. Они ответили, что pre-LN — подход (слева) работает лучше, чем post-LN (справа). И фактически pre-LN стал стандартом.

Layer Normalisation

LayerNorm использовался в RNN-моделях для решения NLP-задач. Он хорошо подходит для рекуррентных архитектур, уменьшает время обучения и увеличивает качество предсказаний RNN. Дело в том, что в NLP, как правило, длина последовательности изменчива. Из-за этого сложно применять BatchNorm, так как одна из осей все время изменяется.

В классическом трансформер-блоке используется LayerNorm, а не BatchNorm. А так как авторы ViT хотели внести как можно меньше изменений в трансформер энкодер, они оставили LayerNorm.

LayerNorm очень похож на BatchNorm. Например, на рисунке 5 изображена схема работы BatchNorm для feature-map размерности (N+1) x D x BatchSize:

Рисунок 5. Схема работы BatchNorm

А на рисунке 6 — LayerNorm для той же feature-map:

Рисунок 6. Схема работы LayerNorm

В виде формулы эти слои можно было записать так:

\( \frac{x-E(x)}{\sqrt(Var(x)+\varepsilon)}*\gamma+\beta=y \)

Как видно, формула одна и та же, но нормализация происходит по разным осям. Более подробно об этом можно прочитать здесь.

Multi-Head Self-Attention

Self-Attention мы уже описали в этой статье. Здесь он точно такой же, но с одним отличием: перед применением Softmax мы делим произведение \( QK^T \) на \( \sqrt(M) \) — размерность эмбеддингов внутри Attention. Теперь рассмотрим, что будет, если использовать несколько Self-Attention слоев.

Зачем нам понадобилось несколько Attention-слоев параллельно? Во-первых, мы помогаем модели обращать внимание на несколько элементов в рамках одного слоя. Во-вторых, мы даем Attention-слою возможность проецировать эмбеддинги в разные скрытые пространства. Таким образом, больше голов дает модели больше вариативности. Можно сравнить этот механизм с увеличением количества фильтров в сверточном слое. При большем количестве фильтров мы извлечем больше разных паттернов.

На рисунке 7 показано, как преобразуются тензоры в слое Multi-Head Self-Attention.

Рисунок 7. Схема работы Multi-Head Self-Attention

В виде формулы Multi-Head Self-Attention можно записать так:

\( \begin{cases}\text{SelfAttention}_0(z_l)=r^0_l \\ \text{SelfAttention}_1(z_l)=r^1_l \\ … \\ \text{SelfAttention}_{Heads}(z_l)=r^{Heads}_l \end{cases} \\ r_l = \text{Concat}(r^0_l;r^1_l;…;r^{Heads}_l) \\ \text{MultiHeadSelfAttention}(z_l)=r_lW^{output}_l=z’_l \)

Давайте теперь реализуем Multi-head self-attention в коде. Для простоты упражнения будем реализовывать данный блок в виде композиции двух блоков:

  1. ScaledDotProductSelfAttention — будет отвечать за подсчет self-attention по входной последовательности эмбеддингов;
  2. MultiheadSelfAttention — будет отвечать за реализацию N голов self-attention’a.

❗️Нюанс❗️: авторы оригинальной статьи по ViT щедро приправляли self-attention Dropout слоями, и мы тоже последуем их примеру:

WeightedEmbeddings = TensorType[
    "batch_size", "num_patches", "qkv_dim", float,
]

N_HEADS = 12
QKV_DIM = 64


class ScaledDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        embed_dim: int = 768,
        qkv_dim: int = 64,
        dropout_rate: float = 0.1
    ) -> None:
        super().__init__()

        # Это как раз наши обучаемые матрицы W_q, W_k, W_v.
        self.wq = torch.nn.Linear(
            in_features=embed_dim,
            out_features=qkv_dim,
            bias=False,
        )
        self.wk = torch.nn.Linear(
            in_features=embed_dim,
            out_features=qkv_dim,
            bias=False,
        )
        self.wv = torch.nn.Linear(
            in_features=embed_dim,
            out_features=qkv_dim,
            bias=False,
        )

        self.scale = qkv_dim ** -0.5
        self.dropout = torch.nn.Dropout(dropout_rate)

    @typechecked
    def forward(self, tensor: EmbeddedPatches) -> WeightedEmbeddings:
        q = self.wq(tensor)
        k = self.wk(tensor)
        v = self.wv(tensor)
        return self.dropout(((q @ k.mT) * self.scale).softmax(dim=-1)) @ v


class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(
        self,
        n_heads: int = 12,
        embed_dim: int = 768,
        qkv_dim: int = 64,
        dropout_rate: float = 0.1
    ) -> None:
        super().__init__()

				# Несколько голов self-attention'a.
        self.sdpas = torch.nn.ModuleList(
            [ScaledDotProductAttention(embed_dim, qkv_dim) for _ in range(n_heads)]
        )

				# Вернем эмбеддинги патчей назад в их исходную размерность.
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(
                in_features=n_heads * qkv_dim,
                out_features=embed_dim,
            ),
            torch.nn.Dropout(dropout_rate),
        )

    @typechecked
    def forward(self, tensor: EmbeddedPatches) -> EmbeddedPatches:
        qkvs = torch.cat([sdpa(tensor) for sdpa in self.sdpas], dim=-1)
        return self.projection(qkvs)


mhsa = MultiHeadSelfAttention(
    n_heads=N_HEADS,
    embed_dim=EMBED_DIM,
    qkv_dim=QKV_DIM,
    dropout_rate=0.1,
)
assert (
    tuple(mhsa(lp(image_batch)).shape) 
    == (BATCH_SIZE, (IMAGE_SIZE ** 2) // (PATCH_SIZE ** 2) + 1, EMBED_DIM)
)

❗️Нюанс❗️: заметим, что ScaledDotProductAttention реализован с помощью аж 3-х линейных слоев. Можно ли реализовать его с помощью одного слоя, сохраняя при этом каждую из соответствующих матриц \( W^{q}, W^{k}, W^{v} \)? Да, можно! Можно даже оптимизировать подсчет Multi-Head Self-Attention, сильно снизив количество матричных умножений. Посмотреть можно тут.

Multi-Layer Perceptron

MLP содержит два полносвязных слоя с функцией активации GELU. И это, собственно, все в этом блоке :). Код:

MLP_HIDDEN_SIZE = 3072


class MLP(torch.nn.Module):
    def __init__(
        self,
        embed_dim: int = 768,
        mlp_hidden_size: int = 3072,
        dropout_rate: float = 0.1,
    ) -> None:
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(
                in_features=embed_dim,
                out_features=mlp_hidden_size,
            ),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(
                in_features=mlp_hidden_size,
                out_features=embed_dim,
            ),
            torch.nn.Dropout(dropout_rate),
        )
    
    @typechecked
    def forward(self, tensor: EmbeddedPatches) -> EmbeddedPatches:
        return self.mlp(tensor)


mlp = MLP(EMBED_DIM, MLP_HIDDEN_SIZE)
assert (
    tuple(mlp(mhsa(lp(image_batch))).shape)
    == (BATCH_SIZE, (IMAGE_SIZE ** 2) // (PATCH_SIZE ** 2) + 1, EMBED_DIM)
)

Отлично! У нас есть теперь все составляющие блока энкодера. Давайте его тоже реализуем.

N_HEADS = 12


class EncoderBlock(torch.nn.Module):
    def __init__(
        self,
        n_heads: int = 12,
        qkv_dim: int = 64,
        embed_dim: int = 768,
        mlp_hidden_size: int = 3072,
        attention_dropout_rate: float = 0.1,
        mlp_dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.ln1 = torch.nn.LayerNorm(normalized_shape=embed_dim)
        self.mhsa = MultiHeadSelfAttention(
            n_heads, embed_dim, qkv_dim, attention_dropout_rate
        )
        self.ln2 = torch.nn.LayerNorm(normalized_shape=embed_dim)
        self.mlp = MLP(embed_dim, mlp_hidden_size, mlp_dropout_rate)

    @typechecked
    def forward(self, tensor: EmbeddedPatches) -> EmbeddedPatches:
        tensor += self.mhsa(self.ln1(tensor))
        tensor += self.mlp(self.ln2(tensor))
        return tensor


encoder_block = EncoderBlock(N_HEADS, QKV_DIM, EMBED_DIM, MLP_HIDDEN_SIZE)
assert (
    tuple(encoder_block(lp(image_batch)).shape)
    == (BATCH_SIZE, (IMAGE_SIZE ** 2) // (PATCH_SIZE ** 2) + 1, EMBED_DIM)
)

Осталось теперь аккуратно все собрать в один большой трансформер. Давайте же сделаем это:

N_LAYERS = 12
N_CLASSES = 1_000

Logits = TensorType["batch_size", "n_classes", float]


class ViT(torch.nn.Module):
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        embed_dim: int = 768,
        qkv_dim: int = 64,
        mlp_hidden_size: int = 3072,
        n_layers: int = 12,
        n_heads: int = 12,
        n_classes: int = 1_000,
    ):
        super().__init__()
        self.encoder = torch.nn.Sequential(

						# Превращаем исходную картинку в последовательность
						# эмбеддингов патчей.
            LinearProjection(image_size, patch_size, in_channels, embed_dim),

						# Это наш трансформер.
            *[EncoderBlock(n_heads, qkv_dim, embed_dim, mlp_hidden_size) for _ in range(n_layers)]
        )

				# Классификационная голова.
        self.classifier = torch.nn.Sequential(
            torch.nn.LayerNorm(embed_dim),
            torch.nn.Linear(embed_dim, n_classes)
        )

    @typechecked
    def forward(self, tensor: ImageTensor) -> Logits:
        features = self.encoder(tensor)

				# Помним, что классфикация проводится только по
				# эмбеддингу [cls]-токена.
        return self.classifier(features[:, 0, :])


vit = ViT(
    IMAGE_SIZE,
    PATCH_SIZE,
    IN_CHANNELS,
    EMBED_DIM,
    QKV_DIM,
    MLP_HIDDEN_SIZE,
    N_LAYERS,
    N_HEADS,
    N_CLASSES,
)
assert tuple(vit(image_batch).shape) == (BATCH_SIZE, N_CLASSES)

Ура, мы собрали полноценный трансформер! Схематически наш ViT изображен на рисунке ниже.

Рисунок 8. Схема архитектуры ViT

Если создать модель с параметрами, хранящимися в наших константах EMBED_DIM, PATCH_SIZE, IMAGE_SIZE, ..., то получим модель ViT-Base из исходной статьи. Если мы хотим получить модели ViT-Large или ViT-Huge, то достаточно просто поменять аргументы на соответствующие следующей таблице:

Конечно же, в реальных проектах лучше взять предобученную модель из timm:

import timm

# Выдаст список поддерживаемых ViT-based
# архитектур в timm.
timm.list_models("*vit*")

# Создадим базовую ViT модель.
awesome_vit = timm.create_model("vit_base_patch16_224")
awesome_vit

Как обучали

Авторы ViT обнаружили, что сильная регуляризация является ключевым фактором при предобучении моделей на ImageNet. Dropout применяется после каждого полносвязного слоя и после сложения с векторами из матрицы \( E_{pos} \). Для Multi-Head Self-Attention dropout не используется. Гибридные модели (когда вместо Linear Projection используется ResNet) обучаются с теми же гиперпараметрами, что и их ViT-аналоги. Наконец, все обучение проводится с разрешением 224 х 224.

Они обучали все модели, включая ResNet, используя оптимизатор Adam с β1 = 0.9, β2 = 0.999, размер батча 4096. Применяется L2 регуляризация weight decay = 0.1. Используется линейный learning rate warmup и decay. Для файнтюна берут оптимизатор SGD with momentum, размер батча 512.

Вообще в статье про ViT достаточно подробно описан сетап предобучения и файнтюна. Поэтому если вам нужны детали, рекомендуем вам прочесть статью и приложения к ней в конце.

Замечание о количестве патчей N. Обычно ViT предобучают на больших датасетах, а затем делают файнтюн на сравнительно небольших наборах данных. Часто бывает полезно файнтюнить на большем разрешении по сравнению с разрешением при предобучении.

При подаче картинок с более высоким разрешением мы сохраняем размер патчей P одинаковым, что приводит к увеличению “длины последовательности” или количества патчей. Трансформер может работать с произвольной длиной последовательности, однако предобученная матрица \( E_{pos} \) потеряет смысл. Поэтому мы выполняем 2D-интерполяцию строк предобученной матрицы \( E_{pos} \) для того, чтобы привести ее в соответствие с новой длиной последовательности, и потом дотюниваем вместе с остальной моделью.

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!