Назад
332

DETR

332

Пререквизиты

Что можно прочитать, чтобы лучше понять эту статью:

  1. Мы будем часто упоминать архитектуру трансформера в контексте CV. Есть пост с детальным разбором ViT: там раскрываются основные моменты работы трансформеров в целом, и в CV в частности.
  2. Дополнительно к пункту 1 — разбор работы Attention.
  3. Для понимания некоторых деталей в работе CNN детекторов, а именно — анкоров, рекомендуем познакомиться с постом про YOLOv2.

Введение

Задача детекции в Computer Vision — одна из самых распространенных, при этом непростых на практике задач: нужно одновременно и искать объекты, и определять их классы.

За относительно небольшое время применения нейросетевых подходов в детекции появилось много подходов и фишек для эффективного решения этой задачи. Так, детекторы часто подразделяются на one-stage и two-stage виды, которые, в свою очередь, делятся на anchor-based (используют анкоры) и anchor-free (не используют анкоры).

Анкоры — предварительно выбранные вручную боксы.

Также есть огромное количество методов для улучшения качества детекции путем “умного” взвешивания в лоссе, улучшения точности пулинга фичей и так далее.

Авторы DETR решили полностью пересмотреть подход к построению архитектуры детектора. При этом они руководствовались следующими идеями:

  1. Упрощение пайплайна детекции за счет удаления вспомогательных техник, таких как NMS или анкоров;
  2. Использование Encoder-Decoder трансформера поверх CNN фичей для учета глобального контекста и формирования на его основе финального предсказания боксов.

“Классические” CNN подходы в детекции

Этот небольшой топик не претендует на охват всех подходов и деталей из обширной области Object Detection. Но мы постараемся поговорить о том, как в целом работает детекция на CNN.

Если говорить тезисно, большинство CNN детекторов делится на два больших класса: two-stage и one-stage.

  • two-stage детекция состоит из двух стадий (как видно из названия):
    • генерация так называемых region proposals (области интересов) как обычными алгоритмами, так и нейросетями (в 2024 встречается только второй способ; у этой подсети в детекторе есть свое название “Region Proposal Network”, или RPN). Основная задача здесь — сужение количества потенциальных областей для поиска объектов.
    • классификация для каждого region proposal и регрессия координат боксов.
  • one-stage детекция из одной стадии:
    • классификация и локализация объектов происходит напрямую, без предварительно сгенерированных region proposals.
Рисунок 1. Сравнение two-stage детектора Faster R-CNN и one-stage детектора RetinaNet (источник)

Как правило, детектор предсказывает детекции с дубликатами. Чтобы эффективно их убирать, используется специальный алгоритм — NMS.

Алгоритм NMS

Изначально имеем набор боксов на изображении. Далее:

  1. Сортируем bbox-ы, опираясь на их confidence scores.
  2. Выбираем bbox c максимальным confidence score, сохраняем его как итоговое предсказание детектора и удаляем из первоначального набора. Таких боксов будет один или несколько (об этом расскажем чуть позже).
  3. Удаляем из набора все боксы, которые имеют пересечение с боксом / боксами итогового предсказания выше заданного порога. Пересечение считается, например, как Intersection over Union.
  4. Повторяем шаги 2 и 3, пока в исходном наборе не останется боксов.
Рисунок X1. Использование NMS на практике. Красный бокс имеет самый большой confidence score — 0.92, поэтому станет итоговым предсказанием. Синие же боксы будут отфильтрованы алгоритмом, так как имеют значительное пересечение с красным
from typing import List


def non_max_suppression(
		boxes: List[List[int]],
		scores: List[float],
		threshold: float
):
    """
    Алгоритм NMS для отбора детекций из набора боксов с соответсвующими им
    confidence scores.
   
    :param boxes: набор bbox-ов формата [xmin, ymin, xmax, ymax]
    :param scores: список confidence scores
    :param threshold: IoU (intersection-over-union) threshold
    :return: список индексов боксов из исходного набора, которые останутся
    после NMS.
    """
    # 1. Сортируем bbox-ы, опираясь на их confidence scores
    order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
    keep = []
    while order:
		    # 2. Выбираем bbox c максимальным confidence score, сохраняем его как
		    # итоговое предсказание детектора и удаляем из первоначального набора.
        i = order.pop(0)
        keep.append(i)
        for j in order:
            # 3. Считаем IoU между итоговым предсказанием и остальными боксами.
            intersection = max(0, min(boxes[i][2], boxes[j][2]) - max(boxes[i][0], boxes[j][0])) * \
								           max(0, min(boxes[i][3], boxes[j][3]) - max(boxes[i][1], boxes[j][1]))
            union = (boxes[i][2] - boxes[i][0]) * (boxes[i][3] - boxes[i][1]) + \
                    (boxes[j][2] - boxes[j][0]) * (boxes[j][3] - boxes[j][1]) - intersection
            iou = intersection / union

            # Удаляем все боксы из набора, которые имеют пересечение
            # выше заданного порога.
            if iou > threshold:
                order.remove(j)
    return keep

Предсказывать боксы “с нуля” — довольно сложная задача. Для ее упрощения были придуманы анкоры — специальные предварительно выбранные вручную боксы. Следовательно, наша задача сводится к тому, чтобы просто “подправить” положение и размеры анкора, “поймав” таким образом нужный нам объект. В этой статье можно подробнее прочитать про выбор анкоров и детекцию, основанную на них.

Стоит отметить: у обозначенных выше техник есть недостатки. NMS требует дополнительных затрат на post-processing, а еще случаются ошибки. Анкоры, в свою очередь, имеют явный prior knowledge — предварительное знание о том, какие боксы встретятся в обучении и тестировании. Это просто перестает работать, когда у боксов появляется отличная от анкоров форма.

Именно эти недостатки попытались обойти авторы DETR.

Архитектура DETR

DETR состоит из 3-х частей:

  • CNN-бэкбон для формирования feature map;
  • Encoder-Decoder трансформер;
  • Feed-Forward Network (FFN) для формирования финального предсказания детекции.

Давайте разберем их все в деталях 🙂

Рисунок 2. Архитектура DETR. В качестве CNN бэкбона — ResNet-50

CNN-бэкбон

CNN-бэкбон (авторы использовали модели семейства ResNet, а в качестве baseline backbone — ResNet-50), как уже было сказано выше, формирует карту признаков исходного изображения. На входе принимается изображение размером \( x \in \mathbb{R}^{3 * H_{0} * W_{0}} \), а на выходе — тензор размером \( 2048 * H * W; H = \frac{H_{0}}{32}, W = \frac{W_{0}}{32} \).

Transformer part 1: Encoder

Далее мы уменьшаем канальную размерность в feature map с помощью 1D свертки (\( 2048 → d, d < 2048 \)) и разворачиваем ее в 1D размерность, сворачивая пространственную размерность \( H * W \). Таким образом, мы получаем вектор размерности \( d×HW \). Конвертация в 1D вектор нужна потому, что энкодер принимает на вход последовательность. Поскольку MHSA (MultiHead Self-Attention) энкодера инвариантен к перестановкам во входной последовательности — необходимо добавить фиксированное позиционное кодирование (spatial positional encoding), которое поможет сети учитывать порядок фич в карте признаков.

Важный момент: как и в ViT, энкодер представляет собой набор \( N \) последовательных блоков, состоящих из Multi-Head Self-Attention и FFN, где размерность входной последовательности совпадает с размером выходной.

Еще один важный момент: в self-attention \( Query == Key \), а позиционное кодирование добавляется только к ним, ведь мы хотим дать информацию от других объектов только content-части. Также, в отличие от классического трансформера и ViT, мы добавляем позиционное кодирование в каждый слой. Это же верно и для декодера!

Код EncoderLayer
import torch
import torch.nn as nn


class TransformerEncoderLayer(nn.Module):
		"""
		Имплементация слоя энкодера трансформера.
    """
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Слои для FFN
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
				
				# Функция для выбора активации из списка: [ReLU, GELU, GLU]
        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        # Forward для случая, когда нормализация идёт после MHSA и MLP.
        # Добавляем позиционное кодирование только к Query и Key.
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        # Forward для случая, когда нормализация идёт до MHSA и MLP.
        src2 = self.norm1(src)
        # Добавляем позиционное кодирование только к Query и Key.
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)
Код Encoder
import torch
import torch.nn as nn


class TransformerEncoder(nn.Module):
		"""
		Имплементация энкодера трансформера.
    """
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src
				
				# Стакаем N раз энкодер слои.
        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)

        if self.norm is not None:
            output = self.norm(output)

        return output

Transformer part 2: Decoder

Decoder в DETR устроен так же, как и в оригинальной работе 2017 года. Однако есть небольшое отличие: DETR использует параллельный декодинг вместо авторегрессионного.

  • авторегрессионный декодинг: целевое предложение генерируется последовательно токен за токеном, отправляя частичный результат в качестве входных данных для следующей итерации авторегрессии, вплоть до длины \( m \) целевого предложения.
  • параллельный декодинг: этот метод изменяет только алгоритм декодирования и может использоваться поверх любой модели авторегрессии без изменений. Алгоритмы параллельного декодирования обрабатывают все предложение или блок из \( b \) токенов параллельно: исходные токены (PAD токены) постепенно уточняются с помощью \( k \) шагов, пока не будет достигнуто условие остановки. Важно отметить: \( k \leq m \) гарантирует качество и общее ускорение декодирования.
Рисунок 3. Сравнение авторегрессионного (слева) и параллельного (справа) декодинга на примере задачи машинного перевода. Оранжевый блок — алгоритм декодирования

Важный момент: зачем в DETR нужен декодер? Дело в том, что мы предсказываем последовательность, которая, вообще говоря, состоит из множества токенов. Обратная ситуация у нас в ViT (там только энкодер), где нам просто нужно предсказать один токен — класс изображения.

В декодере также применяется позиционное кодирование, причем целых два:

  1. Spatial positional encoding — такое же кодирование, как и в энкодере. Как мы уже говорили выше, авторы хотели полностью отказаться от prior knowledge, который несет в себе классические анкоры. По сути их роль на себя и взяло позиционное кодирование. При этом у него нет явного геометрического смысла, и он полностью учится с нуля на данных в обучении. Таким образом, мы дали дополнительную информацию детектору о взаимном расположении токенов и не внесли prior knowledge, поскольку positional encoding — обучаемый параметр.
Рисунок 4. Усредненные центры предсказанных DETR-ом боксов на COCO val датасете (источник: доклад авторов DETR)
  1. Object queries постепенно формируются в ходе декодинга и отвечают за сбор визуальной информации о текущем объекте интереса с помощью скрытого состояния энкодера. Перед первым слоем инициализируются как нулевые векторы.

Важный момент: максимальное количество детекций равно числу object queries.

Стоит отметить: в каждом слое декодера используются два вида механизма внимания:

  1. Self-attention служит обмену информацией между object queries. Как и в энкодере, для V они не добавляются.
  2. Cross-attention. В этой части object queries смотрят на результат работы энкодера и поглощают визуальную информацию. В качестве \( Q \) в данном случае выступает сумма object queries с самой собой после MHSA, а вот \( K \) и \( V \) здесь другие — это выход энкодера с positional embedding и без него соответственно. Таким образом, каждый object query производит некий SoftPooling релевантных визуальных фичей из тех или иных частей изображения. В какой-то степени этот модуль заменяет традиционный RoIPooling, только объекты могут считывать информацию со всего изображения, а не только из ограниченной области.
Код DecoderLayer
import torch
import torch.nn as nn


class TransformerDecoderLayer(nn.Module):
		"""
		Имплементация слоя декодера трансформера.
    """
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Слои для FFN
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
				
				# Функция для выбора активации из списка: [ReLU, GELU, GLU]
        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        # Forward для случая, когда нормализация идёт после MHSA и MLP.
        # Добавляем позиционное кодирование только к Query и Key.
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        # Обратите внимание на то, что подаётся как Q, K, V в attention.
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_pre(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        # Forward для случая, когда нормализация идёт до MHSA и MLP.
        tgt2 = self.norm1(tgt)
        # Добавляем позиционное кодирование только к Query и Key.
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        # Обратите внимание на то, что подаётся как Q, K, V в attention.
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                    tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, tgt_mask, memory_mask,
                                 tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
Код Decoder
import torch
import torch.nn as nn


class TransformerDecoder(nn.Module):
		"""
		Имплементация декодера трансформера.
    """
    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        output = tgt

        intermediate = []
				
				# Стакаем M раз декодер слои и записываем их в список intermediate.
        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)

Трансформер целиком

Рисунок 5. Архитектура Encoder-Decoder Transformer в DETR
Код всего Transformer
import torch
import torch.nn as nn


class Transformer(nn.Module):
		"""
		Имплементация трансформера.
    """
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()
				
				# Инициализируем энкодер.
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
				
				# Инициализируем декодер.
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, query_embed, pos_embed):
        # Разворачиваем входной тензор размера NxCxHxW в тензор размера HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        mask = mask.flatten(1)

        tgt = torch.zeros_like(query_embed)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

Затем выходные эмбеддинги декодера поступают на вход FFN для предсказания итоговых значений боксов и классов.

FFN

FFN состоит из двух независимых частей:

  1. MLP (MultiLayer Perceptron) блок — последовательность из 3-х линейных слоев с функцией активации ReLU. Предсказывает нормализованные значения центра бокса, его высоты и ширины.
  2. Линейный слой, предсказывающий класс бокса с помощью функции softmax.

Поскольку в конце мы предсказываем \( N \) (фиксированное число; обычно гораздо больше, чем количество искомых объектов на изображении) bbox-ов, нужно добавить специальный класс “no object” — ∅. Он играет такую же роль, что и класс “background” в обычных CNN детекторах.

Код MLP
import torch
import torch.nn as nn


class MLP(nn.Module):
    """
		Имплементация MLP.
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        # output dim = 4, так как bbox = [x, y, w, h]
        # num_layers = 3
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        # если input_dim = hidden_dim = 512, то размерности 3-х Linear слоёв - 
	      # (512, 512), (512, 512), (512, 4)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

Собираем DETR

Рисунок 6. Сравнение качества работы DETR (и различных энкодеров) c моделями Faster RCNN в задаче Object Detection по разным метрикам детекции на датасете COCO val2017. DETR показывает сопоставимые результаты с различными версиями Faster RCNN в задаче Object Detection
Код DETR
import torch
import torch.nn as nn


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
	  """
	  Вспомогательная функция.
	  """
    if tensor_list[0].ndim == 3:
        if torchvision._is_tracing():
            # оптимизация для ONNX
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)


class DETR(nn.Module):
    """
		Имплементация DETR.
    """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
        """
        Инициализация модели.
        Parameters:
            backbone: CNN бэкбон.
            transformer: Encoder-Decoder transformer.
            num_queries: количество object queries. это максимальное число
	            детекций DETR на одно изображение. Для COCO авторы
	            рекомендуют брать число 100.
            
            aux_loss: True если используется вспомогательный лосс для декодинга.
        """
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        # num_classes + 1, потому что мы не забываем про no_objects
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.backbone = backbone
        self.aux_loss = aux_loss

    def forward(self, samples: NestedTensor):
        """
        На вход forward принимает вложенный тензор, который состоит из:
        The forward expects a NestedTensor, which consists of:
             - samples.tensor: батч картинок размера [batch_size x 3 x H x W].
             - samples.mask: бинарная маска размера [batch_size x H x W],
             содержащая 1 на padded пикселях.
						
			Forward возвращает dict со следующими элементами:
             - "pred_logits": классификационные логиты, включая no-object,
             для всех queries.
              Размер = [batch_size x num_queries x (num_classes + 1)]
              
             - "pred_boxes": нормализованные координаты боксов 
             значения от 0 до 1) для всех queries размера
              (center_x, center_y, height, width).
              
             - "aux_outputs": если aux_loss == True, возвращает их значения.
        """
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        src, mask = features[-1].decompose()
        assert mask is not None
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
				
				# классификационная голова
        outputs_class = self.class_embed(hs)
        # регрессионная голова
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

Разбор лосса

Мы смогли прогнать наш обучающий тензор через DETR и получить набор боксов и соответствующих классов. Осталось только как-то соотнести их с ground truth боксами и метками. На первый взгляд может показаться, что это несложно. К сожалению, не все так просто: порядок предсказаний не совпадает с порядком ground truth.

Как же их тогда можно сматчить? Использовать IoU для поиска ближайших боксов? Но такой матчинг точно не будет всегда оптимальным. И здесь к нам на помощь приходит комбинаторная оптимизация — она обеспечит нахождение лучшего one-to-one матчинга, который, в свою очередь, даст минимально возможный суммарный лосс. Для этого используется “Венгерский алгоритм” (Hungarian algorithm), работающий за полиномиальное \( O(n^{3}) \) время. Ниже о нем будет рассказано подробнее. Также стоит отметить, что в функции linear_sum_assignment из scipy используется не сам венгерский алгоритм, а его более быстрая модификация.

Вернемся к лоссу. Он, как и в обычных детекторах, складывается из суммы лоссов классификации и локализации. В данном случае используются кросс-энтропия, L1 и Generalized IoU. Количество предсказаний (равное количеству object queries) почти всегда будет больше, чем количество реальных ground truth объектов на картинке, поэтому “лишние” предсказания отправляются в класс “no object” и не передаются в итоговой набор боксов.

Венгерский алгоритм (Hungarian algorithm)

Любителям алгоритмов предлагаю немного отвлечься от трансформеров и более детально познакомиться с Венгерским алгоритмом 🙂

Начнем с общей формулировки “задачи о назначениях”, которую он решает:

Имеется некоторое число работ и некоторое число исполнителей. Любой исполнитель может быть назначен на выполнение любой (но только одной) работы, но с неодинаковыми затратами. Нужно распределить работы так, чтобы выполнить работы с минимальными затратами.

Такая постановка задачи нередко встречается в CV. Например, сопоставление боксов объектов на разных кадрах в Object Tracking есть не что иное, как задача о назначениях.

Пример “Назначение сотрудников на работы”
Рисунок Y1. Исходная матрица стоимости

Предположим, есть следующая задача:

У нас есть 3 сотрудника, которые должны доехать до 3-х разных клиентов. Мы знаем стоимость поездки каждого сотрудника до каждого клиента в вечно зеленых [долларах]. Как оптимально (с наименьшими затратами) распределить сотрудников по клиентам или, что то же самое, распределить поездки (jobs) по сотрудникам?

Важный момент: все последующие операции Венгерского алгоритма не меняют исходное назначение!

В самом начале нам нужно посчитать матрицу стоимости. Именно с ней и будет работать алгоритм.

После того, как матрица найдена, Венгерский алгоритм предлагает нахождение решения за 5 шагов:

  • Вычитаем самое маленькое значение в строке из всех ее элементов для каждой строки (самое маленькое значение в каждой строке при этом станет равно 0).
Рисунок Y2. Демонстрация шага 1
  • Вычитаем самое маленькое значение в столбце из всех его элементов для каждого столбца (самое маленькое значение в каждом столбце при этом станет равно 0).
Рисунок Y3. Демонстрация шага 2
  • Зачеркиваем все нули минимальным количеством линий. Если количество линий получилось меньше n, где n — количество сотрудников / работ, то переходим к шагу 4. В противном случае — переходим к шагу 5.
Рисунок Y4. Демонстрация шага 3
  • Находим наименьший элемент, не охваченный ни одной линией, и вычитаем его из всей матрицы. Если элемент был охвачен какой-либо линией дважды — добавляем в то место, где он дважды перечеркнут. Затем возвращаемся к шагу 3.
Рисунок Y5. Демонстрация шага 4
  • Назначаем работы сотрудникам, начиная со строки только с одним нулем. Каждый раз, когда мы сопоставляем одну работу с сотрудником, мы пересекаем его строку и столбец, чтобы сделать его недоступным.
Рисунок Y6. Финальное состояние матрицы
Рисунок Y7. Итоговое назначение работ по сотрудникам

Таким образом, в нашей задаче минимальное количество долларов, которое нам потребуется — это (всего лишь 😉) 26.

Если же вернуться к проблеме матчинга для лосса в DETR, то нам нужно придумать, как посчитать матрицу стоимости. Все очень просто: она состоит из взвешенной суммы (по дефолту все веса = 1) трех перечисленных выше лоссов.

Downstream tasks

Для задачи panoptic (мы еще хотим разделять объекты одного и того же класса) сегментации авторы добавили сегментационную голову к DETR, так же, как Faster R-CNN был расширен до Mask R-CNN. Давайте разберем это более подробно.

  • Для начала получаем боксы, как и раньше;
  • Далее мы получаем attention map из MHSA;
  • Мы хотим маски, поэтому накидываем сегментационную голову, которая будет предсказывать бинарные маски. При этом мы объединяем feature map (с бэкбона) c масками боксов, полученных на предыдущем этапе. Размеры итогового тензора будут зависеть от количества боксов;
  • Для определения финального предсказания мы формируем FPN-like архитектуру из полученных выше feature map.

Важный момент: Feature Pyramid Net (FPN), или пирамида признаков — свёрточная нейронная сеть, построенная в виде пирамиды и служащая для объединения достоинств карт признаков нижних и верхних уровней сети; первые имеют высокое разрешение, но низкую семантическую, обобщающую способность; вторые наоборот.

Рисунок 7. Строение FPN. По сути это UNet-like сеть, которая имеет выходы с каждого слоя на стадии top-down (часть справа сверху, где стрелочки идут сверху-вниз). Свёртки 1×1 нужны для изменения канальной размерности, а 2x up — upsample свёртки для увеличения пространственной размерности (чтобы в итоге мы складывали объекты одной и той же размерности)

Итоговая модель может быть получена двумя путями: обучаем детекционную часть и сегментационную голову вместе или обучаем их последовательно (сначала обучаем детектор, замораживаем его веса и обучаем только голову). На практике два этих подхода показали похожие результаты, поэтому авторы выбрали второй способ как менее затратный по времени.

Рисунок 8. Строение сегментационной головы в DETRsegm. Бинарные маски генерируются параллельно для каждого bounding bbox

Для обучения был выбран расширенный вариант датасета COCO:

  • 83K изображений в обучающей выборке;
  • 41K изображений в валидационной выборке;
  • 80 things классов;
  • 53 дополнительных stuff классов.
Рисунок 9. Сравнение качества работы DETR (и различных энкодеров) c SOTA моделями в задаче Panoptic Segmentation по разным метрикам сегментации на датасете COCO val2017. Segmentation DETR показывает сопоставимые результаты с другими SOTA моделями в задаче Panoptic Segmentation. Для корректного сравнения модели PanopticFPN++ и UPSnet были переобучены с аугментациями DETR. PQ — panoptic quality, SQ — segmentation quality (средний IoU score), RQ — recognition quality (F1 score на масках), пометки “th” и “st” обозначают метрики в разрезе разных групп классов
Код MHAttentionMap, отвечающего только за подсчет attention softmax (без умножения на V)
import torch
import torch.nn as nn


class MHAttentionMap(nn.Module):
    """
		Имплементация MHAttentionMap.
    """
    def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
				
				# Инициализация весов.
        nn.init.zeros_(self.k_linear.bias)
        nn.init.zeros_(self.q_linear.bias)
        nn.init.xavier_uniform_(self.k_linear.weight)
        nn.init.xavier_uniform_(self.q_linear.weight)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k, mask: Optional[Tensor] = None):
        q = self.q_linear(q)
        k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
        qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)

        if mask is not None:
            weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
        weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size())
        weights = self.dropout(weights)
        return weights
Код MaskHeadSmallConv — FPN-like CNN головы
import torch
import torch.nn as nn


class MaskHeadSmallConv(nn.Module):
		"""
		Имплементация MaskHeadSmallConv. Upsampling делается с помощью FPN подхода.
    """
    def __init__(self, dim, fpn_dims, context_dim):
        super().__init__()
				
				# В качестве нормализации активно используем GroupNorm.
				# Идея в том, чтобы построить feature map в разных масштабах: 
				# от большего к меньшему, которые затем мы будем складывать друг
				# с другом.
        inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
        self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn1 = torch.nn.GroupNorm(8, dim)
        self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
        self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
        self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
        self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
        self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
        self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
        self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
        self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
        self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)

        self.dim = dim

        self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
        self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
        self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
				
				# Инициализация весов.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
		    """
		    fpns - 3 feature map разного размера с бэкбона, которые ниже
		    мы будем склеивать с feature maps, полученными свёртками в mask head.
		    """
		    
		    # Мы объединяем x (feature map с backbone) и масками боксов,
		    # полученных на этапе MHAttentionMap. Размеры итогового тензора
		    # будут зависеть от количества боксов.
        x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)

        x = self.lay1(x)
        x = self.gn1(x)
        x = F.relu(x)
        x = self.lay2(x)
        x = self.gn2(x)
        x = F.relu(x)
				
				# Строим пирамиду из признаков.
        cur_fpn = self.adapter1(fpns[0])
        if cur_fpn.size(0) != x.size(0):
		        # expand - resize feature map cur_fpn до размеров feature map x.
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay3(x)
        x = self.gn3(x)
        x = F.relu(x)

        cur_fpn = self.adapter2(fpns[1])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay4(x)
        x = self.gn4(x)
        x = F.relu(x)

        cur_fpn = self.adapter3(fpns[2])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay5(x)
        x = self.gn5(x)
        x = F.relu(x)

        x = self.out_lay(x)
        return x
Код Segmentation DETR
import torch
import torch.nn as nn


class DETRsegm(nn.Module):
		"""
		Имплементация сегментационного DETR.
    """
    def __init__(self, detr, freeze_detr=False):
        super().__init__()
        self.detr = detr
				
				# Как было сказано выше, мы можем просто заморозить веса детектора.
				# Это не повлияет на итоговое качество, но существенно упростит 
				# обучение.
        if freeze_detr:
            for p in self.parameters():
                p.requires_grad_(False)

        hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
        self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0)
        self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)

    def forward(self, samples: NestedTensor):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.detr.backbone(samples)

        bs = features[-1].tensors.shape[0]

        src, mask = features[-1].decompose()
        assert mask is not None
        src_proj = self.detr.input_proj(src)
        hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])

        outputs_class = self.detr.class_embed(hs)
        outputs_coord = self.detr.bbox_embed(hs).sigmoid()
        out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
        if self.detr.aux_loss:
            out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord)

        # MHSA формирует attention maps для боксов.
        bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
				
				# А mask head, которая и учит маски объектов.
        seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
        outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])

        out["pred_masks"] = outputs_seg_masks
        return out


def _expand(tensor, length: int):
    return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)

Заключение

В заключение нашей статьи мы бы хотели поговорить о достоинствах и недостатках Detection Transformer’а.

Достоинства DETR

  • Взаимодействие object queries через self-attention декодера вместе с использованием matching loss теоретически приводят к отсутствию дубликатов предсказаний. Однако на практике дубликаты предсказаний все-таки встречаются, поэтому накинуть NMS стоит.
  • Как и в ViT, self-attention отлично справляется с задачей учета глобального контекста и моделированием отношений между далекими друг от друга токенами (патчами из изображения), превосходя обычные CNN детекторы.
  • Слой MultiHeadAttention выполняет похожий трюк в CV, что и в NLP: каждая голова обучается независимо от других и берет на себя (или вместе с другой частью голов) какую-либо подзадачу. Например, в NLP такие задачи могут брать на себя головы:
    • позиционная: как токены расположены друг относительно друга; что идет до / после;
    • синтаксическая: отслеживание некоторых основных синтаксических отношений в предложении (подлежащее-глагол, глагол-объект);
    • частотность токенов: отслеживание наименее частых токенов.
    В случае детекции можно выделить подзадачи локализации и классификации:
    • для предсказания координат нужны границы объекта;
    • для классификации — фокусировка на семантически важных частях.

Недостатки DETR

  • Плохое качество на маленьких объектах. DETR использует только один scale из бэкбона, который имеет слишком маленькое разрешение для точной детекции небольших объектов. Почему при этом нельзя добавить FPN и использовать разрешение повыше / всю пирамиду фичей? Ответ простой и грустный: операция self-attention в энкодере и cross-attention в декодере очень чувствительны к размерности фичей, потому что attention имеет квадратичную зависимость от них 😞.
  • Проблемы обучения. Для достижения адекватных метрик DETR-у, как и многим трансформерам, нужно на порядок больше эпох, чем аналогичным классическим детекторам. Ну и в целом на практике очень тяжело обучать большие трансформерные архитектуры: нужно много данных, критически важен learning rate и scheduling, чтобы лосс не улетел в NaN и так далее.
  • Проблемы инференса. Проблемы здесь из-за того, что DETR — трансформер. Это означает большие затраты по времени и памяти на инференсе вследствие квадратичной сложности Attention (мы должны посчитать скор попарно между всеми токенами). К тому же, не будем забывать, что Query, Key и Value — обучаемые параметры, которые вносят свой вклад в увеличение latency через MAC.

Источник

  • detr — официальный репозиторий DETR

Телеграм-канал

DeepSchool

Короткие посты по теории ML/DL, полезные
библиотеки и фреймворки, вопросы с собеседований
и советы, которые помогут в работе

Открыть Телеграм

Увидели ошибку?

Напишите нам в Telegram!