RT-DETR
- Пререквизиты
- Введение
- Недостатки DETR и их решения
- Влияние NMS на скорость и качество YOLO-детекторов
- 1. Количество боксов и confidence threshold
- 2. IoU threshold, confidence threshold vs AP, NMS
- Архитектура RT-DETR
- CNN-бэкбон
- Efficient Hybrid Encoder
- Transformer encoder
- CCFF (CNN-based Cross-scale Feature Fusion)
- Собираем всё вместе
- Decoder
- Deformable Attention
- Transformer Decoder
- Query Selection
- Собираем всё вместе
- Сравнение с другими детекторами
- Заключение
- Полезные ссылки
Пререквизиты
Что можно прочитать, чтобы лучше понять эту статью:
- Архитектура RT-DETR — продолжение идей DETR. Поэтому рекомендуем ознакомиться с нашим постом о нем.
- Часто в посте будут упоминаться модели из семейства YOLO. Вот они (
мои любимые слева направо): YOLOv1, YOLOv2, YOLOv3, YOLOv4, YOLOv5. Для понимания этой статьи достаточно ознакомиться только с последней версией.
Давайте разберём словосочетания на английском, которые тоже нам встретятся (автору так самому будет проще 🙂):
- Confidence threshold — порог вероятности для детекции;
- IoU threshold — порог IoU для NMS;
- Receptive field — рецептивное поле выходного пикселя, то есть контекст, который захватила свёртка.
Введение
В своё время авторы DETR полностью пересмотрели подход к построению архитектуры детектора:
- упростили пайплайн детекции, удалив вспомогательные техники — NMS или анкоры;
- применили Encoder-Decoder трансформер поверх CNN-фичей для учёта глобального контекста и формирования на его основе финального предсказания боксов.
К сожалению, у DETR’а есть недостатки (подробнее о них ниже), которые не позволяют использовать его на практике.
С другой стороны, модели семейства YOLO — наиболее популярные real-time детекторы. Они имеют хороший trade-off между скоростью и качеством. К их недостаткам можно отнести NMS, который требует дополнительных затрат на post-processing и вводит гиперпараметр threshold.
Авторы статьи оценили идею оригинального DETR’а, рассмотрели фишки YOLO-моделей и решили создать модель, которая не уступила бы по качеству и скорости моделям семейства YOLO (X-/ L-версиям) — то есть необходимо было сделать GPU real-time.
Недостатки DETR и их решения
- Сходимость обучения. На практике очень тяжело обучать большие трансформерные архитектуры: нужно много данных, критически важен learning rate и scheduling, чтобы лосс не улетел в NaN и так далее.
- Решение: использовать multi-scale фичи (Deformable-DETR).
- Инициализация decoder queries. В оригинальном DETR’е decoder queries (обучаемые якоря) инициализировались нулевыми векторами, что существенно увеличивало время обучения — они должны были научиться улавливать визуальную информацию об объектах без предварительных знаний.
- Решение: инициализировать их в виде сетки на картах признаков и отбирать top k наиболее вероятных. Подробнее об этом расскажем в блоке о Query Selection.
- Одно разрешение. DETR использует только один scale из бэкбона, который может иметь слишком маленькое разрешение для точной детекции небольших объектов. Введение в архитектуру multi-scale фич решает эту проблему. Да и attention будет спокойно работать с более длинной последовательностью. Но операции self-attention в энкодере и cross-attention в декодере очень чувствительны к размерности фичей, потому что attention имеет квадратичную зависимость от них 😞. Что делать?
- Решение: надо полностью переработать DETR vanilla энкодер 😎, чтобы эффективно посчитать attention внутри каждой фичи и затем агрегировать полученные результаты.
- Вычислительные затраты. DETR — трансформер. Значит, мы имеем большие затраты по времени и памяти на инференсе вследствие квадратичной сложности Attention (мы должны посчитать скор попарно между всеми токенами). А ещё есть Query, Key и Value — обучаемые параметры, которые вносят свой вклад в увеличение latency через MAC. Вычисление MHSA по каждой карте признаков — всё равно очень трудоёмкий процесс. Может, есть способы его обойти?
- Решение: и снова нам надо переработать DETR vanilla энкодер 😎. Интуитивно понятно, что высокоуровневые фичи, содержащие богатую семантическую информацию об объектах, извлекаются из низкоуровневых объектов. Следовательно, для учёта глобального контекста мы можем использовать только самую верхнеуровневую фичу, к которой затем уже добавим информацию с помощью дополнительных блоков от низкоуровневых.
Влияние NMS на скорость и качество YOLO-детекторов
1. Количество боксов и confidence threshold
В рамках первого мини-исследования ресёрчеры посчитали количество боксов, которое остаётся после фильтрации по confidence threshold. Они выяснили, что при увеличении порога количество боксов существенно снижается. Однако важно не забывать о том, что для NMS это существенный плюс — на каждой итерации алгоритма мы должны посчитать IoU одного бокса со всеми оставшимися.
Ключевой момент: если говорить про детекторы семейства YOLO, то anchor-free детекторы превосходят по скорости anchor-base детекторы. Это связано с тем, что первые предсказывают значительно меньше боксов, чем вторые.
2. IoU threshold, confidence threshold vs AP, NMS
Во втором мини-исследовании производились замеры AP (%) и NMS (ms) при фиксированных IoU threshold и confidence threshold. Выяснилось, что время выполнения NMS прямо пропорционально величине IoU threshold и обратно пропорционально confidence threshold. Причины следующие:
- высокий confidence threshold фильтрует больше боксов, чем низкий ⇒ NMS отрабатывает на меньшем количестве боксов;
- высокий IoU threshold убирает меньше боксов при одной итерации NMS, чем низкий.
Ключевой момент: при изменении IoU threshold и confidence threshold незначительно меняется и качество (AP) работы детектора. Поэтому на практике тюнинг этих параметров — отдельная важная задача по настройке детектора (минус в копилку NMS).
Архитектура RT-DETR
RT-DETR можно разделить на 3 части:
- CNN-бэкбон для формирования feature maps;
- Efficient Hybrid Encoder вместо обычного трансформерного энкодера;
\( S_{3}, S_{4}, S_{5} \) — feature maps с бэкбона в порядке убывания разрешения (в порядке возрастания семантики)
- Transformer Decoder с хаками.
CNN-бэкбон
CNN-бэкбон, как уже было сказано выше, формирует карты признаков исходного изображения. Авторы использовали модели семейства ResNet и PResNet, в качестве baseline backbone — ResNet-50.
Ключевой момент: именно карты признаков, а не карту, ведь мы хотим работать с разными разрешениями исходного изображения. Для этого в коде backbone возвращает список из feature map.
Efficient Hybrid Encoder
Давайте закрепим решения по устранению недостатков в DETR:
- Intra-scale (внутреннее) трансформерное преобразование только на \( S_{5} \). Почему только на \( S_{5} \)? Во-первых, это самая высокоуровневая фича, а значит, самая информативная с точки зрения семантики. Во-вторых, мы существенно сокращаем вычисления, так как работаем только с одной картой признаков.
- А вот cross-scale взаимодействие подсчитываем между \( F_{5}, S_{3}, S_{4} \).
В виде формул это можно переписать так:
\( Q = K = V = Flatten(S_{5}) \)
\( F_{5} = Reshape(AIFI(Q, K, V)) \)
\( O = CCFF({S_{3},S_{4}, F_{5}}) \)
Рассмотрим теперь подробнее все компоненты нового энкодера.
Transformer encoder
Здесь всё по классике: сворачиваем пространственную размерность \( H * W \) во входной карте признаков \( S_{5} \), получаем последовательность токенов и добавляем фиксированное позиционное кодирование (spatial positional encoding), которое поможет сети учитывать порядок фич в карте признаков.
Код Transformer Encoder
# transformer
class TransformerEncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False):
"""
Имплементация слоя энкодера трансформера. Всё, как в DETR.
"""
super().__init__()
self.normalize_before = normalize_before
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = get_activation(activation)
@staticmethod
def with_pos_embed(tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
residual = src
if self.normalize_before:
src = self.norm1(src)
# Добавляем позиционное кодирование только к Query и Key.
q = k = self.with_pos_embed(src, pos_embed)
src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
"""
Имплементация энкодера трансформера.
Тут тоже всё, как обычно - собираем Encoder слои для последовательного
forward + нормализация после.
"""
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = norm
def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
output = src
for layer in self.layers:
output = layer(output, src_mask=src_mask, pos_embed=pos_embed)
if self.norm is not None:
output = self.norm(output)
return output
CCFF (CNN-based Cross-scale Feature Fusion)
Итак, мы получили карты признаков \( F_{5}, S_{3}, S_{4} \). Теперь нужно сагрегировать их: обменять информацию между ними так, чтобы получить ещё более сильные фичи для дальнейшей детекции. Часто это делают с помощью FPN-блока или его модификаций. Смысл следующий:
- Bottom-up стадия. Формируем карты признаков в нескольких разрешениях. В нашей задаче за это ответственен CNN-бэкбон.
- Top-down стадия. Чем выше карта, тем лучше семантика, чем ниже — детализация. Агрегируя соседние карты признаков сверху вниз, мы улучшаем детектирование объектов, поскольку обмениваем информацию между картами разных разрешений. Условно карты низкого разрешения хорошо улавливают крупные объекты, высокого разрешения — мелкие. Следовательно, это лучше использования одного разрешения.
Блок CNN-based Cross-scale Feature Fusion — серьёзная модификация top-down стадии FPN-блока. Помимо агрегации соседних карт мы ещё создаём новые фичи, которые тоже затем агрегируем. Получается каскад агрегаций фичей. Он ещё сильнее улучшает детекцию объектов.
За создание новых фичей отвечает отдельный блок — Fusion (или CSPRepLayer). В нём мы:
- конкатим фичи;
- затем разделяем их на два бранча: в первом применяем свёртку 1×1, а во-втором — комбинацию свёртки 1×1 и \( N \) последовательных RepBlock;
Ключевой момент: в чём смысл этих двух бранчей? В первом мы применяем свёртку 1×1 и обмениваем информацию между каналами. Во втором обмениваем информацию и между каналами, и по пространственной размерности \( HW \). Всё это сделано с помощью свёрток 1×1 и 3×3, что сильно сокращает вычисления.
- конкатим результаты двух веток;
- распрямляем полученный тензор.
Также в CSPRepLayer активно используется слияние conv + norm слоёв. Об этом можно почитать подробно здесь в разделе про репараметризацию.
Код ConvNormLayer
class ConvNormLayer(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
"""
Имплементация слоя Conv + Norm.
"""
super().__init__()
self.conv = nn.Conv2d(
ch_in,
ch_out,
kernel_size,
stride,
padding=(kernel_size-1)//2 if padding is None else padding,
bias=bias)
self.norm = nn.BatchNorm2d(ch_out)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
# наша любимая связка conv + norm
return self.act(self.norm(self.conv(x)))
Код RepVggBlock
class RepVggBlock(nn.Module):
def __init__(self, ch_in, ch_out, act='relu'):
"""
Имплементация слоя Rep VGG.
"""
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
"""
Тоже стандартно: свёртка 3x3 с бранчем (свёртка 1x1), то есть добавляем
ещё один scale. Очень эффективно с точки зрения ресурсов.
"""
if hasattr(self, 'conv'):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
return self.act(y)
def convert_to_deploy(self):
if not hasattr(self, 'conv'):
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.data = kernel
self.conv.bias.data = bias
# self.__delattr__('conv1')
# self.__delattr__('conv2')
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return F.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch: ConvNormLayer):
"""
Фьюзим слои: превращаем conv + bn в просто conv
"""
if branch is None:
return 0, 0
kernel = branch.conv.weight
running_mean = branch.norm.running_mean
running_var = branch.norm.running_var
gamma = branch.norm.weight
beta = branch.norm.bias
eps = branch.norm.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
Код CSPRepLayer
# блок для слияния фичей из разных размерностей
class CSPRepLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_blocks=3,
expansion=1.0,
bias=None,
act="silu"):
"""
Имплементация слоя Cross-scale Rep.
"""
super(CSPRepLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(*[
RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
])
if hidden_channels != out_channels:
self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
else:
self.conv3 = nn.Identity()
def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
return self.conv3(x_1 + x_2)
Собираем всё вместе
По сути Efficient Hybrid Encoder — это комбинация Transformer Encoder’а и CCFF-блока.
Код Efficient Hybrid Encoder
class HybridEncoder(nn.Module):
def __init__(self,
in_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
hidden_dim=256,
nhead=8,
dim_feedforward = 1024,
dropout=0.0,
enc_act='gelu',
use_encoder_idx=[2],
num_encoder_layers=1,
pe_temperature=10000,
expansion=1.0,
depth_mult=1.0,
act='silu',
eval_spatial_size=None):
"""
Имплементация Hybrid Encoder.
"""
super().__init__()
self.in_channels = in_channels
self.feat_strides = feat_strides
self.hidden_dim = hidden_dim
self.use_encoder_idx = use_encoder_idx
self.num_encoder_layers = num_encoder_layers
self.pe_temperature = pe_temperature
self.eval_spatial_size = eval_spatial_size
self.out_channels = [hidden_dim for _ in range(len(in_channels))]
self.out_strides = feat_strides
# channel projection
self.input_proj = nn.ModuleList()
for in_channel in in_channels:
self.input_proj.append(
nn.Sequential(
nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(hidden_dim)
)
)
# encoder transformer
encoder_layer = TransformerEncoderLayer(
hidden_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=enc_act)
self.encoder = nn.ModuleList([
TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))
])
# top-down fpn
self.lateral_convs = nn.ModuleList()
self.fpn_blocks = nn.ModuleList()
for _ in range(len(in_channels) - 1, 0, -1):
self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
self.fpn_blocks.append(
CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
)
# bottom-up pan
self.downsample_convs = nn.ModuleList()
self.pan_blocks = nn.ModuleList()
for _ in range(len(in_channels) - 1):
self.downsample_convs.append(
ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act)
)
self.pan_blocks.append(
CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
)
self._reset_parameters()
def _reset_parameters(self):
if self.eval_spatial_size:
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
pos_embed = self.build_2d_sincos_position_embedding(
self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride,
self.hidden_dim, self.pe_temperature)
setattr(self, f'pos_embed{idx}', pos_embed)
# self.register_buffer(f'pos_embed{idx}', pos_embed)
@staticmethod
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
'''
Sin-cos позиционное кодирование в 2D.
'''
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
assert embed_dim % 4 == 0, \
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
def forward(self, feats):
assert len(feats) == len(self.in_channels)
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
# encoder
if self.num_encoder_layers > 0:
for i, enc_ind in enumerate(self.use_encoder_idx):
h, w = proj_feats[enc_ind].shape[2:]
# flatten [B, C, H, W] to [B, HxW, C]
src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1)
if self.training or self.eval_spatial_size is None:
pos_embed = self.build_2d_sincos_position_embedding(
w, h, self.hidden_dim, self.pe_temperature).to(src_flatten.device)
else:
pos_embed = getattr(self, f'pos_embed{enc_ind}', None).to(src_flatten.device)
memory = self.encoder[i](src_flatten, pos_embed=pos_embed)
proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
# print([x.is_contiguous() for x in proj_feats ])
# broadcasting and fusion
inner_outs = [proj_feats[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_high = inner_outs[0]
feat_low = proj_feats[idx - 1]
feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high)
inner_outs[0] = feat_high
upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest')
inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1))
inner_outs.insert(0, inner_out)
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_high = inner_outs[idx + 1]
downsample_feat = self.downsample_convs[idx](feat_low)
out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1))
outs.append(out)
return outs
Decoder
В самой статье мало уделено внимания декодеру — есть подробная информация только про выбор object queries. Поэтому давайте самостоятельно разберём его по частям (они не последовательные, а смысловые):
- MSDeformableAttention
- Transformer Decoder
- Query Selection
Deformable Attention
Deformable Attention есть не что иное, как применение идеи deformable convolution к механизму внимания. Давайте вспомним, что это такое 🙂.
Классические 2D-свёртки — фундаментальная операция для всего CV. Однако они содержат ряд важных ограничений. Одно из самых главных — учёт только локального фиксированного контекста. На практике это может серьёзно сказываться на качестве модели. Например, рецептивные поля для маленьких и больших объектов на изображении должны быть различны, а именно — как-то пропорциональны их размерам, чтобы мы могли лучше захватывать сам объект.
Deformable свёртки (способные менять форму) были предложены для решения этой проблемы. Здесь мы будем для каждой свёртки учить ещё и offsets (отступы) к ней. Таким образом, мы сможем выйти за рамки обычного рецептивного поля и собрать дополнительную информацию из контекста.
Ключевой момент: отступы не перемещают веса свёрток на другие пиксели, а делают это более гибко — могут затрагиваться границы сразу нескольких пикселей. Для подсчёта значения обычно используется интерполяция.
Вернёмся к attention. Вместо подсчёта скоров attention по всем токенам посчитаем их по различным подмножествам около опорной точки. Локации этих опорных точек мы предсказываем (или как-то семплируем) по выходным query. Так как подмножества состоят из фиксированного количества элементов, решается проблема сходимости и вычислительной сложности.
Немного интуиции: обычный Attention как правило глобален (каждый токен из \( K \) несёт хоть какой-то вклад в результат вычисления Attention к конкретному токену из \( Q \)). Для задачи классификации этого достаточно, но для задачи локализации было бы хорошо, если бы токен пикселя объекта смотрел на рядом стоящие токены этого объекта. Deformable как раз позволяет ограничить Attention до некоторого локального рецептивного поля.
Код Deformable Attention
class MSDeformableAttention(nn.Module):
def __init__(self, embed_dim=256, num_heads=8, num_levels=4, num_points=4,):
"""
Multi-Scale Deformable Attention модуль.
"""
super(MSDeformableAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.total_points = num_heads * num_levels * num_points
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2,)
self.attention_weights = nn.Linear(embed_dim, self.total_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.ms_deformable_attn_core = deformable_attention_core_func
self._reset_parameters()
def _reset_parameters(self):
# sampling_offsets
init.constant_(self.sampling_offsets.weight, 0)
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
grid_init = grid_init.reshape(self.num_heads, 1, 1, 2).tile([1, self.num_levels, self.num_points, 1])
scaling = torch.arange(1, self.num_points + 1, dtype=torch.float32).reshape(1, 1, -1, 1)
grid_init *= scaling
self.sampling_offsets.bias.data[...] = grid_init.flatten()
# attention_weights
init.constant_(self.attention_weights.weight, 0)
init.constant_(self.attention_weights.bias, 0)
# proj
init.xavier_uniform_(self.value_proj.weight)
init.constant_(self.value_proj.bias, 0)
init.xavier_uniform_(self.output_proj.weight)
init.constant_(self.output_proj.bias, 0)
def forward(self,
query,
reference_points,
value,
value_spatial_shapes,
value_mask=None):
"""
Args:
query (Tensor): [bs, query_length, C]
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C]
value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, Len_q = query.shape[:2]
Len_v = value.shape[1]
value = self.value_proj(value)
if value_mask is not None:
value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
value *= value_mask
value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)
sampling_offsets = self.sampling_offsets(query).reshape(
bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).reshape(
bs, Len_q, self.num_heads, self.num_levels * self.num_points)
attention_weights = F.softmax(attention_weights, dim=-1).reshape(
bs, Len_q, self.num_heads, self.num_levels, self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.tensor(value_spatial_shapes)
offset_normalizer = offset_normalizer.flip([1]).reshape(
1, 1, 1, self.num_levels, 1, 2)
sampling_locations = reference_points.reshape(
bs, Len_q, 1, self.num_levels, 1, 2
) + sampling_offsets / offset_normalizer
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2] + sampling_offsets /
self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".
format(reference_points.shape[-1]))
output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output
Transformer Decoder
Декодер RT-DETR в целом повторяет структуру декодера оригинального DETR с единственным значимым отличием — вместо обычного Multi-Head Self-Attention используется Deformable Attention в качестве cross-attention [здесь можно почитать о том, что такое cross-attention].
Код DecoderLayer
class TransformerDecoderLayer(nn.Module):
def __init__(self,
d_model=256,
n_head=8,
dim_feedforward=1024,
dropout=0.,
activation="relu",
n_levels=4,
n_points=4,):
"""
Имплементация слоя декодера трансформера с DeformableAttention.
"""
super(TransformerDecoderLayer, self).__init__()
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# cross attention
self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = getattr(F, activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
def forward(self,
tgt,
reference_points,
memory,
memory_spatial_shapes,
memory_level_start_index,
attn_mask=None,
memory_mask=None,
query_pos_embed=None):
# self attention
q = k = self.with_pos_embed(tgt, query_pos_embed)
tgt2, _ = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# cross attention
tgt2 = self.cross_attn(\
self.with_pos_embed(tgt, query_pos_embed),
reference_points,
memory,
memory_spatial_shapes,
memory_mask)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# ffn
tgt2 = self.forward_ffn(tgt)
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
Код Decoder
class TransformerDecoder(nn.Module):
def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
"""
Имплементация декодера трансформера.
"""
super(TransformerDecoder, self).__init__()
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
def forward(self,
tgt,
ref_points_unact,
memory,
memory_spatial_shapes,
memory_level_start_index,
bbox_head,
score_head,
query_pos_head,
attn_mask=None,
memory_mask=None):
output = tgt
dec_out_bboxes = []
dec_out_logits = []
ref_points_detach = F.sigmoid(ref_points_unact)
for i, layer in enumerate(self.layers):
ref_points_input = ref_points_detach.unsqueeze(2)
query_pos_embed = query_pos_head(ref_points_detach)
output = layer(output, ref_points_input, memory,
memory_spatial_shapes, memory_level_start_index,
attn_mask, memory_mask, query_pos_embed)
inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
if self.training:
dec_out_logits.append(score_head[i](output))
if i == 0:
dec_out_bboxes.append(inter_ref_bbox)
else:
dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
elif i == self.eval_idx:
dec_out_logits.append(score_head[i](output))
dec_out_bboxes.append(inter_ref_bbox)
break
ref_points = inter_ref_bbox
ref_points_detach = inter_ref_bbox.detach(
) if self.training else inter_ref_bbox
return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
Query Selection
Среди недостатков DETR мы уже упоминали проблему с обучением, вызванную инициализацией object queries и последующим их обучением.
В RT-DETR она решается следующим образом: после энкодера мы имеем набор из токенов (при стандартных параметрах их 8400 штук, каждый размерностью 256). Далее у нас есть две головы: классификационная (определение класса) и локализационная (предсказание боксов). Передаём в них наш набор токенов, не забываем добавить анкоры к выходу локализационной головы. Далее выбираем top k токенов с самым большим скором (в плане вероятности для классов) и соответствующие им якори. Фильтруем таким образом якори, следовательно, экономим очень много вычислительных ресурсов.
Полученные боксы и ground truth боксы подаются в лосс. Его подсчёт проходит в две стадии:
- Делаем матчинг между предсказанными и ground truth боксами с помощью венгерского алгоритма (подробнее о нём можно почитать в посте про DETR).
- Подсчитываем корректность в каждой паре (сходимость боксов и классов).
Собираем всё вместе
Код MLP
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'):
"""
Имплементация слоя MLP.
"""
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
Код RT-DETR decoder (init)
class RTDETRTransformer(nn.Module):
__share__ = ['num_classes']
def __init__(self,
num_classes=80,
hidden_dim=256,
num_queries=300,
position_embed_type='sine',
feat_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
num_levels=3,
num_decoder_points=4,
nhead=8,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.,
activation="relu",
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0,
learnt_init_query=False,
eval_spatial_size=None,
eval_idx=-1,
eps=1e-2,
aux_loss=True):
"""
Имплементация RT-DETR декодера.
"""
super(RTDETRTransformer, self).__init__()
assert position_embed_type in ['sine', 'learned'], \
f'ValueError: position_embed_type not supported {position_embed_type}!'
assert len(feat_channels) <= num_levels
assert len(feat_strides) == len(feat_channels)
for _ in range(num_levels - len(feat_strides)):
feat_strides.append(feat_strides[-1] * 2)
self.hidden_dim = hidden_dim
self.nhead = nhead
self.feat_strides = feat_strides
self.num_levels = num_levels
self.num_classes = num_classes
self.num_queries = num_queries
self.eps = eps
self.num_decoder_layers = num_decoder_layers
self.eval_spatial_size = eval_spatial_size
self.aux_loss = aux_loss
# backbone feature projection
self._build_input_proj_layer(feat_channels)
# Transformer module
decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, num_decoder_points)
self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
self.num_denoising = num_denoising
self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale
# denoising part
if num_denoising > 0:
# self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim, padding_idx=num_classes-1) # TODO for load paddle weights
self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
# decoder embedding
self.learnt_init_query = learnt_init_query
if learnt_init_query:
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
# encoder head
self.enc_output = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim,)
)
self.enc_score_head = nn.Linear(hidden_dim, num_classes)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
# decoder head
self.dec_score_head = nn.ModuleList([
nn.Linear(hidden_dim, num_classes)
for _ in range(num_decoder_layers)
])
self.dec_bbox_head = nn.ModuleList([
MLP(hidden_dim, hidden_dim, 4, num_layers=3)
for _ in range(num_decoder_layers)
])
# init encoder output anchors and valid_mask
if self.eval_spatial_size:
self.anchors, self.valid_mask = self._generate_anchors()
self._reset_parameters()
def _reset_parameters(self):
"""
Инициализация весов.
"""
bias = bias_init_with_prob(0.01)
init.constant_(self.enc_score_head.bias, bias)
init.constant_(self.enc_bbox_head.layers[-1].weight, 0)
init.constant_(self.enc_bbox_head.layers[-1].bias, 0)
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
init.constant_(cls_.bias, bias)
init.constant_(reg_.layers[-1].weight, 0)
init.constant_(reg_.layers[-1].bias, 0)
# linear_init_(self.enc_output[0])
init.xavier_uniform_(self.enc_output[0].weight)
if self.learnt_init_query:
init.xavier_uniform_(self.tgt_embed.weight)
init.xavier_uniform_(self.query_pos_head.layers[0].weight)
init.xavier_uniform_(self.query_pos_head.layers[1].weight)
RT-DETR decoder (подготовка выхода с энкодера и анкоров)
def _build_input_proj_layer(self, feat_channels):
"""
Формируем блоки для проекции фичей.
"""
self.input_proj = nn.ModuleList()
for in_channels in feat_channels:
self.input_proj.append(
nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)),
('norm', nn.BatchNorm2d(self.hidden_dim,))])
)
)
in_channels = feat_channels[-1]
for _ in range(self.num_levels - len(feat_channels)):
self.input_proj.append(
nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)),
('norm', nn.BatchNorm2d(self.hidden_dim))])
)
)
in_channels = self.hidden_dim
def _get_encoder_input(self, feats):
"""
На вход подаются фичи после энкодера, которые мы проецируем в
размерность hidden dim.
"""
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
if self.num_levels > len(proj_feats):
len_srcs = len(proj_feats)
for i in range(len_srcs, self.num_levels):
if i == len_srcs:
proj_feats.append(self.input_proj[i](feats[-1]))
else:
proj_feats.append(self.input_proj[i](proj_feats[-1]))
# распрямляем фичи
feat_flatten = []
spatial_shapes = []
level_start_index = [0, ]
for i, feat in enumerate(proj_feats):
_, _, h, w = feat.shape
# [b, c, h, w] -> [b, h*w, c]
feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
# [num_levels, 2]
spatial_shapes.append([h, w])
# [l], start index of each level
level_start_index.append(h * w + level_start_index[-1])
# [b, l, c]
feat_flatten = torch.concat(feat_flatten, 1)
level_start_index.pop()
return (feat_flatten, spatial_shapes, level_start_index)
def _generate_anchors(self,
spatial_shapes=None,
grid_size=0.05,
dtype=torch.float32,
device='cpu'):
"""
Генерим анкоры. По сути получим набор из h * w якорей, которые будут
сеткой покрывать всю фиче мапу (так для каждой фичи мапы).
"""
if spatial_shapes is None:
spatial_shapes = [[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)]
for s in self.feat_strides
]
anchors = []
for lvl, (h, w) in enumerate(spatial_shapes):
# делаем сетку из x_top и y_top
grid_y, grid_x = torch.meshgrid(\
torch.arange(end=h, dtype=dtype), \
torch.arange(end=w, dtype=dtype), indexing='ij')
grid_xy = torch.stack([grid_x, grid_y], -1)
valid_WH = torch.tensor([w, h]).to(dtype)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
# теперь задаём высоту и ширину для каждого анкора
wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4))
anchors = torch.concat(anchors, 1).to(device)
# накладываем маску, чтобы не выйти за границы
valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
anchors = torch.log(anchors / (1 - anchors))
# anchors = torch.where(valid_mask, anchors, float('inf'))
# anchors[valid_mask] = torch.inf # valid_mask [1, 8400, 1]
anchors = torch.where(valid_mask, anchors, torch.inf)
return anchors, valid_mask
def _get_decoder_input(self,
memory,
spatial_shapes,
denoising_class=None,
denoising_bbox_unact=None):
"""
Передаём в классификационную и локализационную гоовы наш набор токенов,
не забываем добавить анкоры к выходу локализационной головы.
Далее выбираем top k токенов с самым большим скором (в плане вероятности
для классов) и соответствующие им якори.
"""
bs, _, _ = memory.shape
# prepare input for decoder
if self.training or self.eval_spatial_size is None:
anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device)
else:
anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device)
# memory = torch.where(valid_mask, memory, 0)
memory = valid_mask.to(memory.dtype) * memory # TODO fix type error for onnx export
output_memory = self.enc_output(memory)
enc_outputs_class = self.enc_score_head(output_memory)
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1]))
enc_topk_bboxes = F.sigmoid(reference_points_unact)
if denoising_bbox_unact is not None:
reference_points_unact = torch.concat(
[denoising_bbox_unact, reference_points_unact], 1)
enc_topk_logits = enc_outputs_class.gather(dim=1, \
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]))
# extract region features
if self.learnt_init_query:
target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1])
else:
target = output_memory.gather(dim=1, \
index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
target = target.detach()
if denoising_class is not None:
target = torch.concat([denoising_class, target], 1)
return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
Код RT-DETR decoder (forward)
def forward(self, feats, targets=None):
# input projection and embedding
(memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
# prepare denoising training
if self.training and self.num_denoising > 0:
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
get_contrastive_denoising_training_group(targets, \
self.num_classes,
self.num_queries,
self.denoising_class_embed,
num_denoising=self.num_denoising,
label_noise_ratio=self.label_noise_ratio,
box_noise_scale=self.box_noise_scale, )
else:
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
# decoder
out_bboxes, out_logits = self.decoder(
target,
init_ref_points_unact,
memory,
spatial_shapes,
level_start_index,
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask)
if self.training and dn_meta is not None:
dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
if self.training and self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
if self.training and dn_meta is not None:
out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
out['dn_meta'] = dn_meta
return out
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class, outputs_coord)]
Код всего RT-DETR
class RTDETR(nn.Module):
__inject__ = ['backbone', 'encoder', 'decoder', ]
def __init__(self, backbone: nn.Module, encoder, decoder, multi_scale=None):
super().__init__()
self.backbone = backbone
self.decoder = decoder
self.encoder = encoder
self.multi_scale = multi_scale
def forward(self, x, targets=None):
if self.multi_scale and self.training:
sz = np.random.choice(self.multi_scale)
x = F.interpolate(x, size=[sz, sz])
x = self.backbone(x)
x = self.encoder(x)
x = self.decoder(x, targets)
return x
def deploy(self, ):
self.eval()
for m in self.modules():
if hasattr(m, 'convert_to_deploy'):
m.convert_to_deploy()
return self
Сравнение с другими детекторами
Хотя RT-DETR превосходит по скорости и качеству SOTA real-time и другие end-to-end детекторы аналогичного размера, он всё ещё содержит основное ограничение архитектур по типу DETR — хуже детектирует маленькие и средние объекты в отличие от real-time детекторов 🤔.
Заключение
- DETR — хорошая база, компоненты которой можно улучшать. Количество статей об этом исчисляется десятками. RT-DETR — комбинация таких различных идей.
- Иногда можно упрощать вычисления по картам признаков, например, как это было с \( S_{5} \).
- На практике у RT-DETR всё ещё есть “плохие” черты архитектур на трансформерах, на которые нужно обращать внимание. Например:
- Важен внимательный подбор гиперпараметров;
- Возможен уход лосса в None в середине обучения.
Полезные ссылки
- RT-DETR — официальный репозиторий