# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from examples/modular-transformers/modular_test_detr.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_test_detr.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from collections.abc import Callable from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from ... import initialization as init from ...activations import ACT2FN from ...backbone_utils import load_backbone from ...integrations import use_kernel_forward_from_hub from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from .configuration_test_detr import TestDetrConfig @dataclass @auto_docstring( custom_intro=""" Base class for outputs of the TEST_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding losses. """ ) class TestDetrDecoderOutput(BaseModelOutputWithCrossAttentions): r""" cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): Stacked intermediate reference points (reference points of each layer of the decoder). """ intermediate_hidden_states: torch.FloatTensor | None = None intermediate_reference_points: torch.FloatTensor | None = None @dataclass @auto_docstring( custom_intro=""" Base class for outputs of the Deformable DETR encoder-decoder model. """ ) class TestDetrModelOutput(ModelOutput): r""" init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): Initial reference points sent through the Transformer decoder. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): Stacked intermediate hidden states (output of each layer of the decoder). intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): Stacked intermediate reference points (reference points of each layer of the decoder). enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are picked as region proposals in the first stage. Output of bounding box binary classification (i.e. foreground and background). enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): Logits of predicted bounding boxes coordinates in the first stage. """ init_reference_points: torch.FloatTensor | None = None last_hidden_state: torch.FloatTensor | None = None intermediate_hidden_states: torch.FloatTensor | None = None intermediate_reference_points: torch.FloatTensor | None = None decoder_hidden_states: tuple[torch.FloatTensor] | None = None decoder_attentions: tuple[torch.FloatTensor] | None = None cross_attentions: tuple[torch.FloatTensor] | None = None encoder_last_hidden_state: torch.FloatTensor | None = None encoder_hidden_states: tuple[torch.FloatTensor] | None = None encoder_attentions: tuple[torch.FloatTensor] | None = None enc_outputs_class: torch.FloatTensor | None = None enc_outputs_coord_logits: torch.FloatTensor | None = None @use_kernel_forward_from_hub("MultiScaleDeformableAttention") class MultiScaleDeformableAttention(nn.Module): def forward( self, value: Tensor, value_spatial_shapes: Tensor, value_spatial_shapes_list: list[tuple], level_start_index: Tensor, sampling_locations: Tensor, attention_weights: Tensor, im2col_step: int, ): batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes_list): # batch_size, height*width, num_heads, hidden_dim # -> batch_size, height*width, num_heads*hidden_dim # -> batch_size, num_heads*hidden_dim, height*width # -> batch_size*num_heads, hidden_dim, height, width value_l_ = ( value_list[level_id] .flatten(2) .transpose(1, 2) .reshape(batch_size * num_heads, hidden_dim, height, width) ) # batch_size, num_queries, num_heads, num_points, 2 # -> batch_size, num_heads, num_queries, num_points, 2 # -> batch_size*num_heads, num_queries, num_points, 2 sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) # batch_size*num_heads, hidden_dim, num_queries, num_points sampling_value_l_ = nn.functional.grid_sample( value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False, ) sampling_value_list.append(sampling_value_l_) # (batch_size, num_queries, num_heads, num_levels, num_points) # -> (batch_size, num_heads, num_queries, num_levels, num_points) # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) attention_weights = attention_weights.transpose(1, 2).reshape( batch_size * num_heads, 1, num_queries, num_levels * num_points ) output = ( (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) .sum(-1) .view(batch_size, num_heads * hidden_dim, num_queries) ) return output.transpose(1, 2).contiguous() class TestDetrFrozenBatchNorm2d(nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed. Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than torchvision.models.resnet[18,34,50,101] produce nans. """ def __init__(self, n): super().__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def forward(self, x): # move reshapes to the beginning # to make it user-friendly weight = self.weight.reshape(1, -1, 1, 1) bias = self.bias.reshape(1, -1, 1, 1) running_var = self.running_var.reshape(1, -1, 1, 1) running_mean = self.running_mean.reshape(1, -1, 1, 1) epsilon = 1e-5 scale = weight * (running_var + epsilon).rsqrt() bias = bias - running_mean * scale return x * scale + bias def replace_batch_norm(model): r""" Recursively replace all `torch.nn.BatchNorm2d` with `TestDetrFrozenBatchNorm2d`. Args: model (torch.nn.Module): input model """ for name, module in model.named_children(): if isinstance(module, nn.BatchNorm2d): new_module = TestDetrFrozenBatchNorm2d(module.num_features) if module.weight.device != torch.device("meta"): new_module.weight.copy_(module.weight) new_module.bias.copy_(module.bias) new_module.running_mean.copy_(module.running_mean) new_module.running_var.copy_(module.running_var) model._modules[name] = new_module if len(list(module.children())) > 0: replace_batch_norm(module) class TestDetrConvEncoder(nn.Module): """ Convolutional backbone, using either the AutoBackbone API or one from the timm library. nn.BatchNorm2d layers are replaced by TestDetrFrozenBatchNorm2d as defined above. """ def __init__(self, config): super().__init__() self.config = config backbone = load_backbone(config) self.intermediate_channel_sizes = backbone.channels # replace batch norm by frozen batch norm with torch.no_grad(): replace_batch_norm(backbone) # We used to load with timm library directly instead of the AutoBackbone API # so we need to unwrap the `backbone._backbone` module to load weights without mismatch is_timm_model = False if hasattr(backbone, "_backbone"): backbone = backbone._backbone is_timm_model = True self.model = backbone backbone_model_type = config.backbone_config.model_type if "resnet" in backbone_model_type: for name, parameter in self.model.named_parameters(): if is_timm_model: if "layer2" not in name and "layer3" not in name and "layer4" not in name: parameter.requires_grad_(False) else: if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name: parameter.requires_grad_(False) def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): # send pixel_values through the model to get list of feature maps features = self.model(pixel_values) if isinstance(features, dict): features = features.feature_maps out = [] for feature_map in features: # downsample pixel_mask to match shape of corresponding feature_map mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] out.append((feature_map, mask)) return out class TestDetrSinePositionEmbedding(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__( self, num_position_features: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None, ): super().__init__() if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") self.num_position_features = num_position_features self.temperature = temperature self.normalize = normalize self.scale = 2 * math.pi if scale is None else scale @compile_compatible_method_lru_cache(maxsize=1) def forward( self, shape: torch.Size, device: torch.device | str, dtype: torch.dtype, mask: torch.Tensor | None = None, ) -> torch.Tensor: if mask is None: mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) y_embed = mask.cumsum(1, dtype=dtype) x_embed = mask.cumsum(2, dtype=dtype) if self.normalize: eps = 1e-6 y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format # expected by the encoder pos = pos.flatten(2).permute(0, 2, 1) return pos class TestDetrLearnedPositionEmbedding(nn.Module): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, embedding_dim=256): super().__init__() self.row_embeddings = nn.Embedding(50, embedding_dim) self.column_embeddings = nn.Embedding(50, embedding_dim) @compile_compatible_method_lru_cache(maxsize=1) def forward( self, shape: torch.Size, device: torch.device | str, dtype: torch.dtype, mask: torch.Tensor | None = None, ): height, width = shape[-2:] width_values = torch.arange(width, device=device) height_values = torch.arange(height, device=device) x_emb = self.column_embeddings(width_values) y_emb = self.row_embeddings(height_values) pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) pos = pos.permute(2, 0, 1) pos = pos.unsqueeze(0) pos = pos.repeat(shape[0], 1, 1, 1) # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format # expected by the encoder pos = pos.flatten(2).permute(0, 2, 1) return pos def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float | None = None, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class TestDetrSelfAttention(nn.Module): """ Multi-headed self-attention from 'Attention Is All You Need' paper. In TEST_DETR, position embeddings are added to both queries and keys (but not values) in self-attention. """ def __init__( self, config: TestDetrConfig, hidden_size: int, num_attention_heads: int, dropout: float = 0.0, bias: bool = True, ): super().__init__() self.config = config self.head_dim = hidden_size // num_attention_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = dropout self.is_causal = False self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_embeddings: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: """ Position embeddings are added to both queries and keys (but not values). """ input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class TestDetrMultiscaleDeformableAttention(nn.Module): """ Multiscale deformable attention as proposed in Deformable DETR. """ def __init__(self, config: TestDetrConfig, num_heads: int, n_points: int): super().__init__() self.attn = MultiScaleDeformableAttention() if config.d_model % num_heads != 0: raise ValueError( f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" ) dim_per_head = config.d_model // num_heads # check if dim_per_head is power of 2 if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): warnings.warn( "You'd better set embed_dim (d_model) in TestDetrMultiscaleDeformableAttention to make the" " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" " implementation." ) self.im2col_step = 64 self.d_model = config.d_model self.n_levels = config.num_feature_levels self.n_heads = num_heads self.n_points = n_points self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) self.value_proj = nn.Linear(config.d_model, config.d_model) self.output_proj = nn.Linear(config.d_model, config.d_model) self.disable_custom_kernels = config.disable_custom_kernels def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, encoder_hidden_states=None, encoder_attention_mask=None, position_embeddings: torch.Tensor | None = None, reference_points=None, spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: # add position embeddings to the hidden states before projecting to queries and keys if position_embeddings is not None: hidden_states = hidden_states + position_embeddings batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape total_elements = sum(height * width for height, width in spatial_shapes_list) torch_compilable_check( total_elements == sequence_length, "Make sure to align the spatial shapes with the sequence length of the encoder hidden states", ) value = self.value_proj(encoder_hidden_states) if attention_mask is not None: # we invert the attention_mask value = value.masked_fill(~attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 ) attention_weights = self.attention_weights(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels * self.n_points ) attention_weights = F.softmax(attention_weights, -1).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 num_coordinates = reference_points.shape[-1] if num_coordinates == 2: offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] ) elif num_coordinates == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") output = self.attn( value, spatial_shapes, spatial_shapes_list, level_start_index, sampling_locations, attention_weights, self.im2col_step, ) output = self.output_proj(output) return output, attention_weights class TestDetrMLP(nn.Module): def __init__(self, config: TestDetrConfig, hidden_size: int, intermediate_size: int): super().__init__() self.fc1 = nn.Linear(hidden_size, intermediate_size) self.fc2 = nn.Linear(intermediate_size, hidden_size) self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.dropout = config.dropout def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states class TestDetrEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: TestDetrConfig): super().__init__() self.hidden_size = config.d_model self.self_attn = TestDetrMultiscaleDeformableAttention( config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points, ) self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) self.dropout = config.dropout self.mlp = TestDetrMLP(config, self.hidden_size, config.encoder_ffn_dim) self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, spatial_position_embeddings: torch.Tensor | None = None, reference_points=None, spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Input to the layer. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Attention mask. position_embeddings (`torch.FloatTensor`, *optional*): Position embeddings, to be added to `hidden_states`. reference_points (`torch.FloatTensor`, *optional*): Reference points. spatial_shapes (`torch.LongTensor`, *optional*): Spatial shapes of the backbone feature maps. level_start_index (`torch.LongTensor`, *optional*): Level start index. """ residual = hidden_states hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, position_embeddings=spatial_position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) if self.training: if not torch.isfinite(hidden_states).all(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) return hidden_states class TestDetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: TestDetrConfig): super().__init__() self.hidden_size = config.d_model self.self_attn = TestDetrSelfAttention( config=config, hidden_size=self.hidden_size, num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) self.dropout = config.dropout self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) self.encoder_attn = TestDetrMultiscaleDeformableAttention( config, num_heads=config.decoder_attention_heads, n_points=config.decoder_n_points, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size) self.mlp = TestDetrMLP(config, self.hidden_size, config.decoder_ffn_dim) self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states: torch.Tensor, object_queries_position_embeddings: torch.Tensor | None = None, reference_points=None, spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(seq_len, batch, embed_dim)`. position_embeddings (`torch.FloatTensor`, *optional*): Position embeddings that are added to the queries and keys in the self-attention layer. reference_points (`torch.FloatTensor`, *optional*): Reference points. spatial_shapes (`torch.LongTensor`, *optional*): Spatial shapes. level_start_index (`torch.LongTensor`, *optional*): Level start index. encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. """ residual = hidden_states # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=object_queries_position_embeddings, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states # Cross-Attention hidden_states, _ = self.encoder_attn( hidden_states=hidden_states, attention_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, position_embeddings=object_queries_position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) return hidden_states @auto_docstring class TestDetrPreTrainedModel(PreTrainedModel): config: TestDetrConfig base_model_prefix = "model" main_input_name = "pixel_values" input_modalities = ("image",) supports_gradient_checkpointing = True _no_split_modules = [ r"TestDetrConvEncoder", r"TestDetrEncoderLayer", r"TestDetrDecoderLayer", ] _supports_sdpa = True _supports_flash_attn = True _supports_attention_backend = True _supports_flex_attn = True _keys_to_ignore_on_load_unexpected = [ r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked" ] @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, TestDetrLearnedPositionEmbedding): init.uniform_(module.row_embeddings.weight) init.uniform_(module.column_embeddings.weight) elif isinstance(module, TestDetrMultiscaleDeformableAttention): init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads ) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(module.n_heads, 1, 1, 2) .repeat(1, module.n_levels, module.n_points, 1) ) for i in range(module.n_points): grid_init[:, :, i, :] *= i + 1 init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) init.constant_(module.attention_weights.weight, 0.0) init.constant_(module.attention_weights.bias, 0.0) init.xavier_uniform_(module.value_proj.weight) init.constant_(module.value_proj.bias, 0.0) init.xavier_uniform_(module.output_proj.weight) init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, (nn.Linear, nn.Conv2d)): init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: init.zeros_(module.bias) elif isinstance(module, nn.Embedding): init.normal_(module.weight, mean=0.0, std=std) # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): init.zeros_(module.weight[module.padding_idx]) if hasattr(module, "reference_points") and not self.config.two_stage: init.xavier_uniform_(module.reference_points.weight, gain=1.0) init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): init.normal_(module.level_embed) class TestDetrEncoder(TestDetrPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a [`TestDetrEncoderLayer`]. The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers. Args: config: TestDetrConfig """ _can_record_outputs = { "hidden_states": TestDetrEncoderLayer, "attentions": OutputRecorder(TestDetrMultiscaleDeformableAttention, layer_name="self_attn", index=1), } def __init__(self, config: TestDetrConfig): super().__init__(config) self.dropout = config.dropout self.layers = nn.ModuleList([TestDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) # Initialize weights and apply final processing self.post_init() @merge_with_config_defaults @capture_outputs def forward( self, inputs_embeds=None, attention_mask=None, spatial_position_embeddings=None, spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, valid_ratios=None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: - 1 for pixel features that are real (i.e. **not masked**), - 0 for pixel features that are padding (i.e. **masked**). [What are attention masks?](../glossary#attention-mask) spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer. spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): Spatial shapes of each feature map. level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): Starting index of each feature map. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): Ratio of valid area in each feature level. """ hidden_states = inputs_embeds hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) spatial_shapes_tuple = tuple(spatial_shapes_list) reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, **kwargs, ) return BaseModelOutput(last_hidden_state=hidden_states) @staticmethod def get_reference_points(spatial_shapes_list, valid_ratios, device): """ Get reference points for each feature map. Used in decoder. Args: spatial_shapes_list (`list[tuple[int, int]]`): Spatial shapes of each feature map. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): Valid ratios of each feature map. device (`torch.device`): Device on which to create the tensors. Returns: `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` """ reference_points_list = [] for level, (height, width) in enumerate(spatial_shapes_list): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), indexing="ij", ) # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36 ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2) class TestDetrDecoder(TestDetrPreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TestDetrDecoderLayer`]. The decoder updates the query embeddings through multiple self-attention and cross-attention layers. Some tweaks for Deformable DETR: - `position_embeddings`, `reference_points`, `spatial_shapes` and `valid_ratios` are added to the forward pass. - it also returns a stack of intermediate outputs and reference points from all decoding layers. Args: config: TestDetrConfig """ _can_record_outputs = { "hidden_states": TestDetrDecoderLayer, "attentions": OutputRecorder(TestDetrSelfAttention, layer_name="self_attn", index=1), "cross_attentions": OutputRecorder(TestDetrMultiscaleDeformableAttention, layer_name="encoder_attn", index=1), } def __init__(self, config: TestDetrConfig): super().__init__(config) self.dropout = config.dropout self.layers = nn.ModuleList([TestDetrDecoderLayer(config) for _ in range(config.decoder_layers)]) # hack implementation for iterative bounding box refinement and two-stage Deformable DETR self.bbox_embed = None self.class_embed = None # Initialize weights and apply final processing self.post_init() @merge_with_config_defaults @capture_outputs def forward( self, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, object_queries_position_embeddings=None, reference_points=None, spatial_shapes=None, spatial_shapes_list=None, level_start_index=None, valid_ratios=None, **kwargs: Unpack[TransformersKwargs], ): r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): The query embeddings that are passed into the decoder. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected in `[0, 1]`: - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer. reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): Spatial shapes of the feature maps. level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): Indexes for the start of each feature level. In range `[0, sequence_length]`. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): Ratio of valid area in each feature level. """ if inputs_embeds is not None: hidden_states = inputs_embeds # decoder layers intermediate = () intermediate_reference_points = () for idx, decoder_layer in enumerate(self.layers): num_coordinates = reference_points.shape[-1] if num_coordinates == 4: reference_points_input = ( reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] ) elif reference_points.shape[-1] == 2: reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] else: raise ValueError("Reference points' last dimension must be of size 2") hidden_states = decoder_layer( hidden_states, object_queries_position_embeddings, reference_points_input, spatial_shapes, spatial_shapes_list, level_start_index, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask, **kwargs, ) # hack implementation for iterative bounding box refinement if self.bbox_embed is not None: tmp = self.bbox_embed[idx](hidden_states) num_coordinates = reference_points.shape[-1] if num_coordinates == 4: new_reference_points = tmp + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() elif num_coordinates == 2: new_reference_points = tmp new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() else: raise ValueError( f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}" ) reference_points = new_reference_points.detach() intermediate += (hidden_states,) intermediate_reference_points += (reference_points,) # Keep batch_size as first dimension intermediate = torch.stack(intermediate, dim=1) intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) return TestDetrDecoderOutput( last_hidden_state=hidden_states, intermediate_hidden_states=intermediate, intermediate_reference_points=intermediate_reference_points, ) @auto_docstring( custom_intro=""" The bare Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without any specific head on top. """ ) class TestDetrModel(TestDetrPreTrainedModel): def __init__(self, config: TestDetrConfig): super().__init__(config) # Create backbone self.backbone = TestDetrConvEncoder(config) # Create positional encoding if config.position_embedding_type == "sine": self.position_embedding = TestDetrSinePositionEmbedding(config.d_model // 2, normalize=True) elif config.position_embedding_type == "learned": self.position_embedding = TestDetrLearnedPositionEmbedding(config.d_model // 2) else: raise ValueError(f"Not supported {config.position_embedding_type}") # Create input projection layers if config.num_feature_levels > 1: num_backbone_outs = len(self.backbone.intermediate_channel_sizes) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = self.backbone.intermediate_channel_sizes[_] input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, config.d_model, kernel_size=1), nn.GroupNorm(32, config.d_model), ) ) for _ in range(config.num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( nn.Conv2d( in_channels, config.d_model, kernel_size=3, stride=2, padding=1, ), nn.GroupNorm(32, config.d_model), ) ) in_channels = config.d_model self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList( [ nn.Sequential( nn.Conv2d( self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1, ), nn.GroupNorm(32, config.d_model), ) ] ) if not config.two_stage: self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model * 2) self.encoder = TestDetrEncoder(config) self.decoder = TestDetrDecoder(config) self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model)) if config.two_stage: self.enc_output = nn.Linear(config.d_model, config.d_model) self.enc_output_norm = nn.LayerNorm(config.d_model) self.pos_trans = nn.Linear(config.d_model * 2, config.d_model * 2) self.pos_trans_norm = nn.LayerNorm(config.d_model * 2) else: self.reference_points = nn.Linear(config.d_model, 2) self.post_init() def freeze_backbone(self): for name, param in self.backbone.model.named_parameters(): param.requires_grad_(False) def unfreeze_backbone(self): for name, param in self.backbone.model.named_parameters(): param.requires_grad_(True) def get_valid_ratio(self, mask, dtype=torch.float32): """Get the valid ratio of all feature maps.""" _, height, width = mask.shape valid_height = torch.sum(mask[:, :, 0], 1) valid_width = torch.sum(mask[:, 0, :], 1) valid_ratio_height = valid_height.to(dtype) / height valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio def get_proposal_pos_embed(self, proposals): """Get the position embedding of the proposals.""" num_pos_feats = self.config.d_model // 2 temperature = 10000 scale = 2 * math.pi # Compute position embeddings in float32 to avoid overflow with large temperature values in fp16 proposals_dtype = proposals.dtype dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) # batch_size, num_queries, 4 proposals = proposals.sigmoid().to(torch.float32) * scale # batch_size, num_queries, 4, 128 pos = proposals[:, :, :, None] / dim_t # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512 pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) # Convert back to target dtype after all computations are done return pos.to(proposals_dtype) def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): """Generate the encoder output proposals from encoded enc_output. Args: enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder. padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`. spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps. Returns: `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to directly predict a bounding box. (without the need of a decoder) - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse sigmoid. """ batch_size = enc_output.shape[0] proposals = [] _cur = 0 for level, (height, width) in enumerate(spatial_shapes): mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1) valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid( torch.linspace( 0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device, ), torch.linspace( 0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device, ), indexing="ij", ) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale width_height = torch.ones_like(grid) * 0.05 * (2.0**level) proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4) proposals.append(proposal) _cur += height * width output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) # inverse sigmoid output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf")) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) # assign each pixel as an object query object_query = enc_output object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0)) object_query = object_query.masked_fill(~output_proposals_valid, float(0)) object_query = self.enc_output_norm(self.enc_output(object_query)) return object_query, output_proposals @auto_docstring @can_return_tuple def forward( self, pixel_values: torch.FloatTensor, pixel_mask: torch.LongTensor | None = None, decoder_attention_mask: torch.FloatTensor | None = None, encoder_outputs: torch.FloatTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, decoder_inputs_embeds: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor] | TestDetrModelOutput: r""" decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): Not used by default. Can be used to mask object queries. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you can choose to directly pass a flattened representation of an image. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an embedded representation. Examples: ```python >>> from transformers import AutoImageProcessor, TestDetrModel >>> from PIL import Image >>> import requests >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> image_processor = AutoImageProcessor.from_pretrained("SenseTime/deformable-detr") >>> model = TestDetrModel.from_pretrained("SenseTime/deformable-detr") >>> inputs = image_processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state >>> list(last_hidden_states.shape) [1, 300, 256] ```""" batch_size, num_channels, height, width = pixel_values.shape device = pixel_values.device if pixel_mask is None: pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device) # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper) # First, sent pixel_values + pixel_mask through Backbone to obtain the features # which is a list of tuples features = self.backbone(pixel_values, pixel_mask) # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) sources = [] masks = [] position_embeddings_list = [] for level, (source, mask) in enumerate(features): sources.append(self.input_proj[level](source)) masks.append(mask) if mask is None: raise ValueError("No attention mask was provided") # Generate position embeddings for this feature level pos = self.position_embedding(shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask).to( source.dtype ) position_embeddings_list.append(pos) # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage if self.config.num_feature_levels > len(sources): _len_sources = len(sources) for level in range(_len_sources, self.config.num_feature_levels): if level == _len_sources: source = self.input_proj[level](features[-1][0]) else: source = self.input_proj[level](sources[-1]) mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to( torch.bool )[0] pos_l = self.position_embedding( shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask ).to(source.dtype) sources.append(source) masks.append(mask) position_embeddings_list.append(pos_l) # Create queries query_embeds = None if not self.config.two_stage: query_embeds = self.query_position_embeddings.weight # Prepare encoder inputs (by flattening) source_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes_list = [] for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)): batch_size, num_channels, height, width = source.shape spatial_shape = (height, width) spatial_shapes_list.append(spatial_shape) source = source.flatten(2).transpose(1, 2) mask = mask.flatten(1) lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) source_flatten.append(source) mask_flatten.append(mask) source_flatten = torch.cat(source_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder # Also provide spatial_shapes, level_start_index and valid_ratios if encoder_outputs is None: encoder_outputs = self.encoder( inputs_embeds=source_flatten, attention_mask=mask_flatten, spatial_position_embeddings=lvl_pos_embed_flatten, spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs, ) # Fifth, prepare decoder inputs batch_size, _, num_channels = encoder_outputs[0].shape enc_outputs_class = None enc_outputs_coord_logits = None if self.config.two_stage: object_query_embedding, output_proposals = self.gen_encoder_output_proposals( encoder_outputs[0], ~mask_flatten, spatial_shapes_list ) # hack implementation for two-stage Deformable DETR # apply a detection head to each pixel (A.4 in paper) # linear projection for bounding box binary classification (i.e. foreground and background) enc_outputs_class = self.decoder.class_embed[-1](object_query_embedding) # 3-layer FFN to predict bounding boxes coordinates (bbox regression branch) delta_bbox = self.decoder.bbox_embed[-1](object_query_embedding) enc_outputs_coord_logits = delta_bbox + output_proposals # only keep top scoring `config.two_stage_num_proposals` proposals topk = self.config.two_stage_num_proposals topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] topk_coords_logits = torch.gather( enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4), ) topk_coords_logits = topk_coords_logits.detach() reference_points = topk_coords_logits.sigmoid() init_reference_points = reference_points pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits))) query_embed, target = torch.split(pos_trans_out, num_channels, dim=2) else: query_embed, target = torch.split(query_embeds, num_channels, dim=1) query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1) target = target.unsqueeze(0).expand(batch_size, -1, -1) reference_points = self.reference_points(query_embed).sigmoid() init_reference_points = reference_points decoder_outputs = self.decoder( inputs_embeds=target, object_queries_position_embeddings=query_embed, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs, ) return TestDetrModelOutput( init_reference_points=init_reference_points, last_hidden_state=decoder_outputs.last_hidden_state, intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, intermediate_reference_points=decoder_outputs.intermediate_reference_points, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, enc_outputs_class=enc_outputs_class, enc_outputs_coord_logits=enc_outputs_coord_logits, )