
    sg              	          d Z ddlZddlZddlZddlmZ ddlmZm	Z	m
Z
 ddlZddlZddlmZmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& ddl'm(Z(  e"jR                  e*      Z+dZ,dZ-g dZ.dZ/dZ0e G d de             Z1e G d de             Z2e G d de             Z3e G d de             Z4d Z5d Z6dKdej                  de7d e8d!ej                  fd"Z9 G d# d$ejt                        Z; G d% d&ejt                        Z< G d' d(ejt                        Z= G d) d*ejt                        Z> G d+ d,ejt                        Z? G d- d.ejt                        Z@ G d/ d0ejt                        ZA G d1 d2ejt                        ZB G d3 d4ejt                        ZC G d5 d6ejt                        ZD G d7 d8ejt                        ZE G d9 d:ejt                        ZF G d; d<e      ZGd=ZHd>ZI e d?eH       G d@ dAeG             ZJ e dBeH       G dC dDeG             ZK e dEeH       G dF dGeG             ZL e dHeH       G dI dJeGe&             ZMy)Lz!PyTorch Swinv2 Transformer model.    N)	dataclass)OptionalTupleUnion)Tensornn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings	torch_int)BackboneMixin   )Swinv2Configr   z(microsoft/swinv2-tiny-patch4-window8-256)r   @   i   zEgyptian catc                       e Zd ZU dZdZej                  ed<   dZe	e
ej                  df      ed<   dZe	e
ej                  df      ed<   dZe	e
ej                  df      ed<   y)Swinv2EncoderOutputa  
    Swinv2 encoder's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r    torchFloatTensor__annotations__r!   r   r   r"   r#        ]/var/www/html/venv/lib/python3.12/site-packages/transformers/models/swinv2/modeling_swinv2.pyr   r   >   sx    2 ,0u((/=AM8E%"3"3S"89:A:>Ju00#567>FJHU5+<+<c+A%BCJr,   r   c                       e Zd ZU dZdZej                  ed<   dZe	ej                     ed<   dZ
e	eej                  df      ed<   dZe	eej                  df      ed<   dZe	eej                  df      ed<   y)	Swinv2ModelOutputaV  
    Swinv2 model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr    pooler_output.r!   r"   r#   )r$   r%   r&   r'   r    r(   r)   r*   r0   r   r!   r   r"   r#   r+   r,   r-   r/   r/   `   s    6 ,0u((/15M8E--.5=AM8E%"3"3S"89:A:>Ju00#567>FJHU5+<+<c+A%BCJr,   r/   c                      e Zd ZU dZdZeej                     ed<   dZ	ej                  ed<   dZ
eeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   ed	        Zy)
Swinv2MaskedImageModelingOutputa  
    Swinv2 masked image model outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Masked image modeling (MLM) loss.
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Reconstructed pixel values.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlossreconstruction.r!   r"   r#   c                 N    t        j                  dt               | j                  S )Nzlogits attribute is deprecated and will be removed in version 5 of Transformers. Please use the reconstruction attribute to retrieve the final output instead.)warningswarnFutureWarningr4   selfs    r-   logitsz&Swinv2MaskedImageModelingOutput.logits   s%    ]	

 """r,   )r$   r%   r&   r'   r3   r   r(   r)   r*   r4   r!   r   r"   r#   propertyr;   r+   r,   r-   r2   r2      s    6 )-D(5$$
%,(,NE%%,=AM8E%"3"3S"89:A:>Ju00#567>FJHU5+<+<c+A%BCJ# #r,   r2   c                       e Zd ZU dZdZeej                     ed<   dZ	ej                  ed<   dZ
eeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   y)	Swinv2ImageClassifierOutputa  
    Swinv2 outputs for image classification.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr3   r;   .r!   r"   r#   )r$   r%   r&   r'   r3   r   r(   r)   r*   r;   r!   r   r"   r#   r+   r,   r-   r>   r>      s    6 )-D(5$$
%, $FE$=AM8E%"3"3S"89:A:>Ju00#567>FJHU5+<+<c+A%BCJr,   r>   c                     | j                   \  }}}}| j                  |||z  |||z  ||      } | j                  dddddd      j                         j                  d|||      }|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowss          r-   window_partitionrP      s}     /<.A.A+J|!&&Fk);8Lk[gM ##Aq!Q15@@BGGKYdfrsGNr,   c                     | j                   d   }| j                  d||z  ||z  |||      } | j                  dddddd      j                         j                  d|||      } | S )z?
    Merges windows to produce higher resolution features.
    rC   r   r   r   r@   rA   rB   rD   )rO   rJ   rL   rM   rN   s        r-   window_reverserR      sn     ==$Lll2v4e{6JKYdfrsGooaAq!Q/::<AA"feUabGNr,   input	drop_probtrainingreturnc                    |dk(  s|s| S d|z
  }| j                   d   fd| j                  dz
  z  z   }|t        j                  || j                  | j
                        z   }|j                          | j                  |      |z  }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
            r   r   )r   )dtypedevice)rE   ndimr(   randrY   rZ   floor_div)rS   rT   rU   	keep_probrE   random_tensoroutputs          r-   	drop_pathrb      s     CxII[[^

Q 77E

5ELL YYMYYy!M1FMr,   c                   x     e Zd ZdZd	dee   ddf fdZdej                  dej                  fdZ	de
fdZ xZS )
Swinv2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).NrT   rV   c                 0    t         |           || _        y N)super__init__rT   )r:   rT   	__class__s     r-   rh   zSwinv2DropPath.__init__	  s    "r,   r!   c                 D    t        || j                  | j                        S rf   )rb   rT   rU   r:   r!   s     r-   forwardzSwinv2DropPath.forward  s    FFr,   c                 8    dj                  | j                        S )Nzp={})formatrT   r9   s    r-   
extra_reprzSwinv2DropPath.extra_repr  s    }}T^^,,r,   rf   )r$   r%   r&   r'   r   floatrh   r(   r   rl   strro   __classcell__ri   s   @r-   rd   rd     sG    b#(5/ #T #GU\\ Gell G-C -r,   rd   c            
            e Zd ZdZd fd	Zdej                  dededej                  fdZ	 	 dde	ej                     d	e	ej                     d
edeej                     fdZ xZS )Swinv2EmbeddingszW
    Construct the patch and position embeddings. Optionally, also the mask token.
    c                 ~   t         |           t        |      | _        | j                  j                  }| j                  j
                  | _        |r4t        j                  t        j                  dd|j                              nd | _        |j                  r=t        j                  t        j                  d|dz   |j                              | _        nd | _        t        j                  |j                        | _        t        j"                  |j$                        | _        |j(                  | _        || _        y )Nr   )rg   rh   Swinv2PatchEmbeddingspatch_embeddingsnum_patches	grid_size
patch_gridr   	Parameterr(   zeros	embed_dim
mask_tokenuse_absolute_embeddingsposition_embeddings	LayerNormnormDropouthidden_dropout_probdropout
patch_sizeconfig)r:   r   use_mask_tokenry   ri   s       r-   rh   zSwinv2Embeddings.__init__  s     5f =++77//99O]",,u{{1a9I9I'JKcg))')||EKK;QR?TZTdTd4e'fD$'+D$LL!1!12	zz&"<"<= ++r,   
embeddingsrL   rM   rV   c                    |j                   d   dz
  }| j                  j                   d   dz
  }t        j                  j	                         s||k(  r||k(  r| j                  S | j                  ddddf   }| j                  ddddf   }|j                   d   }|| j
                  z  }	|| j
                  z  }
t        |dz        }|j                  d|||      }|j                  dddd      }t        j                  j                  ||	|
fdd	
      }|j                  dddd      j                  dd|      }t        j                  ||fd      S )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   NrC         ?r   r   r@   bicubicF)sizemodealign_cornersdim)rE   r   r(   jit
is_tracingr   r   reshaperG   r   
functionalinterpolaterF   cat)r:   r   rL   rM   ry   num_positionsclass_pos_embedpatch_pos_embedr   
new_height	new_widthsqrt_num_positionss               r-   interpolate_pos_encodingz)Swinv2Embeddings.interpolate_pos_encoding-  s`    !&&q)A-0066q9A= yy##%+*F6UZ?+++221bqb59221ab59r"t.
T__,	&}c'9:)11!5GI[]`a)11!Q1=--33i(	 4 
 *11!Q1=BB1b#Nyy/?;CCr,   pixel_valuesbool_masked_posr   c                    |j                   \  }}}}| j                  |      \  }}	| j                  |      }|j                         \  }
}}|K| j                  j                  |
|d      }|j                  d      j                  |      }|d|z
  z  ||z  z   }| j                  (|r|| j                  |||      z   }n|| j                  z   }| j                  |      }||	fS )NrC         ?)rE   rx   r   r   r   expand	unsqueezetype_asr   r   r   )r:   r   r   r   _rN   rL   rM   r   output_dimensionsrK   seq_lenmask_tokensmasks                 r-   rl   zSwinv2Embeddings.forwardU  s     *6););&<(,(=(=l(K%
%YYz*
!+!2
GQ&//00WbIK",,R088ED#sTz2[45GGJ##/''$*G*G
TZ\a*bb
'$*B*BB
\\*-
,,,r,   )FNF)r$   r%   r&   r'   rh   r(   r   intr   r   r)   
BoolTensorboolr   rl   rr   rs   s   @r-   ru   ru     s    &&D5<< &D &DUX &D]b]i]i &DV 7;).	-u001- "%"2"23- #'	-
 
u||	-r,   ru   c                   v     e Zd ZdZ fdZd Zdeej                     de	ej                  e	e   f   fdZ xZS )rw   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    t         |           |j                  |j                  }}|j                  |j
                  }}t        |t        j                  j                        r|n||f}t        |t        j                  j                        r|n||f}|d   |d   z  |d   |d   z  z  }|| _        || _        || _        || _
        |d   |d   z  |d   |d   z  f| _        t        j                  ||||      | _        y )Nr   r   )kernel_sizestride)rg   rh   
image_sizer   rN   r~   
isinstancecollectionsabcIterablery   rz   r   Conv2d
projection)r:   r   r   r   rN   hidden_sizery   ri   s          r-   rh   zSwinv2PatchEmbeddings.__init__y  s    !'!2!2F4E4EJ
$*$7$79I9Ik#-j+//:R:R#SZZdfpYq
#-j+//:R:R#SZZdfpYq
!!}
15*Q-:VW=:XY$$(&$Q-:a=8*Q-:VW=:XY))L+:^hir,   c                 n   || j                   d   z  dk7  rDd| j                   d   || j                   d   z  z
  f}t        j                  j                  ||      }|| j                   d   z  dk7  rFddd| j                   d   || j                   d   z  z
  f}t        j                  j                  ||      }|S )Nr   r   )r   r   r   pad)r:   r   rL   rM   
pad_valuess        r-   	maybe_padzSwinv2PatchEmbeddings.maybe_pad  s    4??1%%*T__Q/%$//!:L2LLMJ==,,\:FLDOOA&&!+Q4??1#5QRAS8S#STJ==,,\:FLr,   r   rV   c                     |j                   \  }}}}| j                  |||      }| j                  |      }|j                   \  }}}}||f}|j                  d      j	                  dd      }||fS )Nr@   r   )rE   r   r   flatten	transpose)r:   r   r   rN   rL   rM   r   r   s           r-   rl   zSwinv2PatchEmbeddings.forward  s}    )5););&<~~lFEB__\2
(..1fe#UO''*44Q:
,,,r,   )r$   r%   r&   r'   rh   r   r   r(   r)   r   r   r   rl   rr   rs   s   @r-   rw   rw   r  sF    j	-HU->->$? 	-E%,,X]^aXbJbDc 	-r,   rw   c            	            e Zd ZdZej
                  fdee   dedej                  ddf fdZ	d Z
d	ej                  d
eeef   dej                  fdZ xZS )Swinv2PatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    input_resolutionr   
norm_layerrV   Nc                     t         |           || _        || _        t	        j
                  d|z  d|z  d      | _         |d|z        | _        y )NrA   r@   Fbias)rg   rh   r   r   r   Linear	reductionr   )r:   r   r   r   ri   s       r-   rh   zSwinv2PatchMerging.__init__  sI     01s7AG%@q3w'	r,   c                     |dz  dk(  xs |dz  dk(  }|r.ddd|dz  d|dz  f}t         j                  j                  ||      }|S )Nr@   r   r   )r   r   r   )r:   rI   rL   rM   
should_padr   s         r-   r   zSwinv2PatchMerging.maybe_pad  sU    qjAo:519>
Q519a!<JMM--mZHMr,   rI   input_dimensionsc                    |\  }}|j                   \  }}}|j                  ||||      }| j                  |||      }|d d dd ddd dd d f   }|d d dd ddd dd d f   }	|d d dd ddd dd d f   }
|d d dd ddd dd d f   }t        j                  ||	|
|gd      }|j                  |dd|z        }| j                  |      }| j                  |      }|S )Nr   r@   r   rC   rA   )rE   rF   r   r(   r   r   r   )r:   rI   r   rL   rM   rK   r   rN   input_feature_0input_feature_1input_feature_2input_feature_3s               r-   rl   zSwinv2PatchMerging.forward  s   ((5(;(;%
C%**:vulS}feD'14a4Aq(89'14a4Aq(89'14a4Aq(89'14a4Aq(89		?O_Ve"fhjk%**:r1|;KL}5		-0r,   )r$   r%   r&   r'   r   r   r   r   Modulerh   r   r(   r   rl   rr   rs   s   @r-   r   r     sr    
 XZWcWc (s (# (299 (hl (U\\ U3PS8_ Y^YeYe r,   r   c                        e Zd Zddgf fd	Zd Z	 	 	 d
dej                  deej                     deej                     dee	   de
ej                     f
d	Z xZS )Swinv2SelfAttentionr   c           
         t         |           ||z  dk7  rt        d| d| d      || _        t	        ||z        | _        | j                  | j
                  z  | _        t        |t        j                  j                        r|n||f| _        || _        t        j                  t        j                   dt        j"                  |ddf      z              | _        t        j&                  t        j(                  ddd	
      t        j*                  d	      t        j(                  d|d
            | _        t        j.                  | j                  d   dz
   | j                  d   t        j0                        j3                         }t        j.                  | j                  d   dz
   | j                  d   t        j0                        j3                         }t        j4                  t7        ||gd            j9                  ddd      j;                         j=                  d      }|d   dkD  r;|d d d d d d dfxx   |d   dz
  z  cc<   |d d d d d d dfxx   |d   dz
  z  cc<   nS|dkD  rN|d d d d d d dfxx   | j                  d   dz
  z  cc<   |d d d d d d dfxx   | j                  d   dz
  z  cc<   |dz  }t        j>                  |      t        j@                  t        jB                  |      dz         z  tE        j@                  d      z  }|jG                  tI        | j,                  jK                               jL                        }| jO                  d|d       t        j.                  | j                  d         }	t        j.                  | j                  d         }
t        j4                  t7        |	|
gd            }t        jP                  |d      }|d d d d d f   |d d d d d f   z
  }|j9                  ddd      j;                         }|d d d d dfxx   | j                  d   dz
  z  cc<   |d d d d dfxx   | j                  d   dz
  z  cc<   |d d d d dfxx   d| j                  d   z  dz
  z  cc<   |jS                  d      }| jO                  d|d       t        j(                  | j                  | j                  |jT                  
      | _+        t        j(                  | j                  | j                  d
      | _,        t        j(                  | j                  | j                  |jT                  
      | _-        t        j\                  |j^                        | _0        y )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()
   r   r@   i   Tr   )inplaceFrY   ij)indexing   r   relative_coords_table)
persistentrC   relative_position_index)1rg   rh   
ValueErrornum_attention_headsr   attention_head_sizeall_head_sizer   r   r   r   rJ   pretrained_window_sizer   r|   r(   logoneslogit_scale
Sequentialr   ReLUcontinuous_position_bias_mlparangeint64rp   stackr   rG   rH   r   signlog2absmathtonext
parametersrY   register_bufferr   sumqkv_biasquerykeyvaluer   attention_probs_dropout_probr   )r:   r   r   	num_headsrJ   r   relative_coords_hrelative_coords_wr   coords_hcoords_wcoordscoords_flattenrelative_coordsr   ri   s                  r-   rh   zSwinv2SelfAttention.__init__  s   ?a#C5(^_h^iijk  $- #&sY#7 !558P8PP%k;??3K3KLKS^`kRl 	 '=#<<		"uzz9aQRBS7T2T(UV,.MMIIa4("''$*?3PY`eAf-
)
 "LL4+;+;A+>+B)CTEUEUVWEX`e`k`klrrt!LL4+;+;A+>+B)CTEUEUVWEX`e`k`klrrtKK"35F!GRVWXWQ1Z\Yq\	 	 "!$q(!!Q1*-1G1JQ1NN-!!Q1*-1G1JQ1NN-1_!!Q1*-1A1A!1Dq1HH-!!Q1*-1A1A!1Dq1HH-"JJ,-

599EZ;[^a;a0bbeienenopeqq 	 !6 8 8d>_>_>j>j>l9m9s9s t46KX]^ << 0 0 34<< 0 0 34Xx&:TJKvq1(At4~aqj7QQ)11!Q:EEG1a D$4$4Q$7!$;; 1a D$4$4Q$7!$;; 1a A(8(8(;$;a$?? "1"5"5b"968O\abYYt1143E3EFOO\
99T//1C1C%PYYt1143E3EFOO\
zz&"E"EFr,   c                     |j                         d d | j                  | j                  fz   }|j                  |      }|j	                  dddd      S )NrC   r   r@   r   r   )r   r   r   rF   rG   )r:   xnew_x_shapes      r-   transpose_for_scoresz(Swinv2SelfAttention.transpose_for_scores  sL    ffhsmt'?'?AYAY&ZZFF;yyAq!$$r,   r!   attention_mask	head_maskoutput_attentionsrV   c                 z   |j                   \  }}}| j                  |      }| j                  | j                  |            }	| j                  | j	                  |            }
| j                  |      }t
        j                  j                  |d      t
        j                  j                  |	d      j                  dd      z  }t        j                  | j                  t        j                  d            j                         }||z  }| j                  | j                         j#                  d| j$                        }|| j&                  j#                  d         j#                  | j(                  d   | j(                  d   z  | j(                  d   | j(                  d   z  d      }|j+                  ddd      j-                         }d	t        j.                  |      z  }||j1                  d      z   }||j                   d   }|j#                  ||z  || j$                  ||      |j1                  d      j1                  d      z   }||j1                  d      j1                  d      z   }|j#                  d| j$                  ||      }t
        j                  j3                  |d      }| j5                  |      }|||z  }t        j6                  ||
      }|j+                  dddd
      j-                         }|j9                         d d | j:                  fz   }|j#                  |      }|r||f}|S |f}|S )NrC   r   g      Y@)maxr   r   r@      r   )rE   r   r	  r   r   r   r   	normalizer   r(   clampr   r   r   expr   r   rF   r   r   rJ   rG   rH   sigmoidr   softmaxr   matmulr   r   )r:   r!   r
  r  r  rK   r   rN   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresr   relative_position_bias_tablerelative_position_bias
mask_shapeattention_probscontext_layernew_context_layer_shapeoutputss                        r-   rl   zSwinv2SelfAttention.forward  s.    )6(;(;%
C JJ}5--dhh}.EF	//

=0IJ//0AB ==22;B2G"--JaJa2 Kb K

)B
 kk$"2"28LMQQS+k9'+'H'HIcIc'd'i'i(((
$ ">d>Z>Z>_>_`b>c!d!i!iQ$"2"21"55t7G7G7JTM]M]^_M`7`bd"
 "8!?!?1a!H!S!S!U!#emm4J&K!K+.D.N.Nq.QQ%'--a0J/44j(*d6N6NPSUX ((+55a8 9  0.2J2J12M2W2WXY2ZZ/44R9Q9QSVX[\ --//0@b/I ,,7  -	9O_kB%--aAq9DDF"/"4"4"6s";t?Q?Q>S"S%**+BC6G=/2 O\M]r,   NNF)r$   r%   r&   rh   r	  r(   r   r   r)   r   r   rl   rr   rs   s   @r-   r   r     s    TUWXSY ;Gz% 7;15,1;||; !!2!23; E--.	;
 $D>; 
u||	;r,   r   c                   n     e Zd Z fdZdej
                  dej
                  dej
                  fdZ xZS )Swinv2SelfOutputc                     t         |           t        j                  ||      | _        t        j
                  |j                        | _        y rf   )rg   rh   r   r   denser   r   r   r:   r   r   ri   s      r-   rh   zSwinv2SelfOutput.__init__V  s6    YYsC(
zz&"E"EFr,   r!   input_tensorrV   c                 J    | j                  |      }| j                  |      }|S rf   r'  r   )r:   r!   r)  s      r-   rl   zSwinv2SelfOutput.forward[  s$    

=1]3r,   r$   r%   r&   rh   r(   r   rl   rr   rs   s   @r-   r%  r%  U  s2    G
U\\  RWR^R^ r,   r%  c                        e Zd Zd	 fd	Zd Z	 	 	 d
dej                  deej                     deej                     dee	   de
ej                     f
dZ xZS )Swinv2Attentionc           
          t         |           t        ||||t        |t        j
                  j                        r|n||f      | _        t        ||      | _	        t               | _        y )Nr   r   r   rJ   r   )rg   rh   r   r   r   r   r   r:   r%  ra   setpruned_heads)r:   r   r   r   rJ   r   ri   s         r-   rh   zSwinv2Attention.__init__c  sc    '#0+//2J2JK $:(*@A
	 'vs3Er,   c                 >   t        |      dk(  ry t        || j                  j                  | j                  j                  | j
                        \  }}t        | j                  j                  |      | j                  _        t        | j                  j                  |      | j                  _        t        | j                  j                  |      | j                  _	        t        | j                  j                  |d      | j                  _        | j                  j                  t        |      z
  | j                  _        | j                  j                  | j                  j                  z  | j                  _        | j
                  j                  |      | _        y )Nr   r   r   )lenr   r:   r   r   r2  r   r   r   r   ra   r'  r   union)r:   headsindexs      r-   prune_headszSwinv2Attention.prune_headsq  s   u:?749900$))2O2OQUQbQb
u
 -TYY__eD		*499==%@		,TYY__eD		.t{{/@/@%QO )-		(E(EE
(R		%"&))"?"?$))B_B_"_		 --33E:r,   r!   r
  r  r  rV   c                 j    | j                  ||||      }| j                  |d   |      }|f|dd  z   }|S )Nr   r   )r:   ra   )r:   r!   r
  r  r  self_outputsattention_outputr"  s           r-   rl   zSwinv2Attention.forward  sG     yy	K\];;|AF#%QR(88r,   r   r#  )r$   r%   r&   rh   r8  r(   r   r   r)   r   r   rl   rr   rs   s   @r-   r.  r.  b  st    ";* 7;15,1
||
 !!2!23
 E--.	

 $D>
 
u||	
r,   r.  c                   V     e Zd Z fdZdej
                  dej
                  fdZ xZS )Swinv2Intermediatec                    t         |           t        j                  |t	        |j
                  |z              | _        t        |j                  t              rt        |j                     | _        y |j                  | _        y rf   )rg   rh   r   r   r   	mlp_ratior'  r   
hidden_actrq   r   intermediate_act_fnr(  s      r-   rh   zSwinv2Intermediate.__init__  sa    YYsC(8(83(>$?@
f''-'-f.?.?'@D$'-'8'8D$r,   r!   rV   c                 J    | j                  |      }| j                  |      }|S rf   )r'  rB  rk   s     r-   rl   zSwinv2Intermediate.forward  s&    

=100?r,   r,  rs   s   @r-   r>  r>    s#    9U\\ ell r,   r>  c                   V     e Zd Z fdZdej
                  dej
                  fdZ xZS )Swinv2Outputc                     t         |           t        j                  t	        |j
                  |z        |      | _        t        j                  |j                        | _	        y rf   )
rg   rh   r   r   r   r@  r'  r   r   r   r(  s      r-   rh   zSwinv2Output.__init__  sF    YYs6#3#3c#9:C@
zz&"<"<=r,   r!   rV   c                 J    | j                  |      }| j                  |      }|S rf   r+  rk   s     r-   rl   zSwinv2Output.forward  s$    

=1]3r,   r,  rs   s   @r-   rE  rE    s#    >
U\\ ell r,   rE  c                        e Zd Z	 d fd	Zdeeeef   eeef   f   fdZd Zd Z	 	 dde	j                  deeef   dee	j                     d	ee   dee	j                  e	j                  f   f
d
Z xZS )Swinv2Layerc           
      n   t         	|           || _        | j                  |j                  |j                  f||f      \  }}|d   | _        |d   | _        t        |||| j                  t        |t        j                  j                        r|n||f      | _        t        j                  ||j                        | _        |dkD  rt!        |      nt        j"                         | _        t'        ||      | _        t+        ||      | _        t        j                  ||j                        | _        y )Nr   r0  epsrX   )rg   rh   r   _compute_window_shiftrJ   
shift_sizer.  r   r   r   r   	attentionr   r   layer_norm_epslayernorm_beforerd   Identityrb   r>  intermediaterE  ra   layernorm_after)
r:   r   r   r   r   drop_path_raterN  r   rJ   ri   s
            r-   rh   zSwinv2Layer.__init__  s    	 0"&"<"<!3!34z:6N#
Z 'q>$Q-(((0+//2J2JK $:(*@A
 !#Sf6K6K L;IC;O7UWU`U`Ub.vs;"63/!||CV5J5JKr,   rV   c                     t        | j                  |      D cg c]  \  }}||k  r|n| }}}t        | j                  ||      D cg c]  \  }}}||k  rdn| }}}}||fS c c}}w c c}}}w Nr   )zipr   )r:   target_window_sizetarget_shift_sizerwrJ   srN  s           r-   rM  z!Swinv2Layer._compute_window_shift  s~    69$:O:OQc6dedaAFq)ee8;D<Q<QS^`q8rssWQ116aq(s
sJ&& fss   A'A-c           	         | j                   dkD  ryt        j                  d||df|      }t        d| j                         t        | j                   | j                          t        | j                    d       f}t        d| j                         t        | j                   | j                          t        | j                    d       f}d}|D ]  }|D ]  }	||d d ||	d d f<   |dz  }  t        || j                        }
|
j                  d| j                  | j                  z        }
|
j                  d      |
j                  d      z
  }|j                  |dk7  t        d            j                  |dk(  t        d            }|S d }|S )Nr   r   r   rC   r@   g      YrX   )
rN  r(   r}   slicerJ   rP   rF   r   masked_fillrp   )r:   rL   rM   rY   img_maskheight_sliceswidth_slicescountheight_slicewidth_slicemask_windows	attn_masks               r-   get_attn_maskzSwinv2Layer.get_attn_mask  s   ??Q{{Avua#8FHa$***+t'''$//)9:t&-M a$***+t'''$//)9:t&-L
 E - #/ K@EHQk1<=QJE
 ,Hd6F6FGL',,R1A1ADDTDT1TUL$..q1L4J4J14MMI!--i1neFmLXXYbfgYginoristI  Ir,   c                     | j                   || j                   z  z
  | j                   z  }| j                   || j                   z  z
  | j                   z  }ddd|d|f}t        j                  j                  ||      }||fS rW  )rJ   r   r   r   )r:   r!   rL   rM   	pad_right
pad_bottomr   s          r-   r   zSwinv2Layer.maybe_pad  s    %%0@0@(@@DDTDTT	&&$2B2B)BBdFVFVV
Ay!Z8
))-Dj((r,   r!   r   r  r  c                    |\  }}|j                         \  }}}	|}
|j                  ||||	      }| j                  |||      \  }}|j                  \  }}}}| j                  dkD  r1t        j                  || j                   | j                   fd      }n|}t        || j                        }|j                  d| j                  | j                  z  |	      }| j                  |||j                        }||j                  |j                        }| j                  ||||      }|d   }|j                  d| j                  | j                  |	      }t        || j                  ||      }| j                  dkD  r/t        j                  || j                  | j                  fd      }n|}|d   dkD  xs |d   dkD  }|r|d d d |d |d d f   j                         }|j                  |||z  |	      }| j!                  |      }|
| j#                  |      z   }| j%                  |      }| j'                  |      }|| j#                  | j)                  |            z   }|r	||d	   f}|S |f}|S )
Nr   )r   r@   )shiftsdimsrC   r   )r  r   rB   r   )r   rF   r   rE   rN  r(   rollrP   rJ   ri  rY   r   rZ   rO  rR   rH   rQ  rb   rS  ra   rT  )r:   r!   r   r  r  rL   rM   rK   r   channelsshortcutr   
height_pad	width_padshifted_hidden_stateshidden_states_windowsrh  attention_outputsr;  attention_windowsshifted_windows
was_paddedlayer_outputlayer_outputss                           r-   rl   zSwinv2Layer.forward  s    )"/"4"4"6
Ax  &**:vuhO$(NN=&%$P!z&3&9&9#:y!??Q$)JJ}tFVY]YhYhXhEipv$w!$1! !11FHXHX Y 5 : :2t?O?ORVRbRb?bdl m&&z9MDWDW&X	 !%:%A%ABI NN!9iK\ + 
 -Q/,11"d6F6FHXHXZbc():D<L<LjZcd ??Q %

?DOOUYUdUdCelr s /]Q&;*Q-!*;
 1!WfWfufa2G H S S U-22:v~xX--.?@ 4>>-#@@((7{{<0$t~~d6J6J<6X'YY@Q'8';< YeWfr,   )rX   r   r   r   )r$   r%   r&   rh   r   r   rM  ri  r   r(   r   r   r)   r   rl   rr   rs   s   @r-   rI  rI    s    qrL2'eTYZ]_bZbTcejknpsksetTtNu '
8) 26,18||8  S/8 E--.	8
 $D>8 
u||U\\)	*8r,   rI  c                        e Zd Z	 d fd	Z	 	 d	dej
                  deeef   deej                     dee
   deej
                     f
dZ xZS )
Swinv2Stagec	           
      |   t         |           || _        || _        g }	t	        |      D ]?  }
t        ||||||
   |
dz  dk(  rdn|j                  dz  |      }|	j                  |       A t        j                  |	      | _
        |& |||t        j                        | _        d| _        y d | _        d| _        y )Nr@   r   )r   r   r   r   rU  rN  r   )r   r   F)rg   rh   r   r   rangerI  rJ   appendr   
ModuleListblocksr   
downsamplepointing)r:   r   r   r   depthr   rb   r  r   r  iblockri   s               r-   rh   zSwinv2Stage.__init__+  s     	u 
	!A!1#(|!"Q!1&2D2D2I'=E MM% 
	! mmF+ !()9sr||\DO  #DOr,   r!   r   r  r  rV   c                    |\  }}t        | j                        D ]  \  }}|||   nd }	 ||||	|      }
|
d   }  |}| j                  )|dz   dz  |dz   dz  }}||||f}| j                  ||      }n||||f}|||f}|r|
dd  z  }|S )Nr   r   r@   )	enumerater  r  )r:   r!   r   r  r  rL   rM   r  layer_modulelayer_head_maskr|  !hidden_states_before_downsamplingheight_downsampledwidth_downsampledr   stage_outputss                   r-   rl   zSwinv2Stage.forwardG  s     )(5 
	-OA|.7.CilO( !	M *!,M
	- -:)??&5;aZA4EPQ	VWGW 1!'0BDU V OO,MO_`M!' >&(IK\]]12..Mr,   r<  r   )r$   r%   r&   rh   r(   r   r   r   r   r)   r   rl   rr   rs   s   @r-   r~  r~  *  sm    mn@ 26,1 ||   S/  E--.	 
 $D>  
u||	 r,   r~  c                        e Zd Zd fd	Z	 	 	 	 	 ddej
                  deeef   deej                     dee
   dee
   dee
   dee
   d	eeef   fd
Z xZS )Swinv2Encoderc                 :   t         	|           t        |j                        | _        || _        | j
                  j                  |j                  }t        j                  d|j                  t        |j                              D cg c]  }|j                          }}g }t        | j                        D ]  }t        |t        |j                  d|z  z        |d   d|z  z  |d   d|z  z  f|j                  |   |j                   |   |t        |j                  d |       t        |j                  d |dz           || j                  dz
  k  rt"        nd ||         }|j%                  |        t'        j(                  |      | _        d| _        y c c}w )Nr   r@   r   )r   r   r   r  r   rb   r  r   F)rg   rh   r4  depths
num_layersr   pretrained_window_sizesr(   linspacerU  r   itemr  r~  r   r~   r   r   r  r   r  layersgradient_checkpointing)
r:   r   rz   r  r  dprr  i_layerstageri   s
            r-   rh   zSwinv2Encoder.__init__k  st   fmm,;;..:&,&D&D#!&63H3H#fmmJ\!]^Aqvvx^^T__- 	!G((1g:56"+A,1g:">	!QRT[Q[@\!]mmG, **73c&--"9:S}QX[\Q\A]=^_29DOOa<O2O-VZ'>w'G	E MM% 	! mmF+&+## _s   	Fr!   r   r  r  output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrV   c                 V   |rdnd }|rdnd }	|rdnd }
|rE|j                   \  }}} |j                  |g|| }|j                  dddd      }||fz  }|	|fz  }	t        | j                        D ]  \  }}|||   nd }| j
                  r+| j                  r| j                  |j                  |||      }n |||||      }|d   }|d   }|d   }|d   |d   f}|rP|rN|j                   \  }}} |j                  |g|d   |d   f| }|j                  dddd      }||fz  }|	|fz  }	nI|rG|sE|j                   \  }}} |j                  |g|| }|j                  dddd      }||fz  }|	|fz  }	|s
|
|dd  z  }
 |st        d |||
|	fD              S t        |||
|		      S )
Nr+   r   r   r   r@   r  rC   c              3   $   K   | ]  }|| 
 y wrf   r+   ).0vs     r-   	<genexpr>z(Swinv2Encoder.forward.<locals>.<genexpr>  s      = s   )r    r!   r"   r#   )rE   rF   rG   r  r  r  rU   _gradient_checkpointing_func__call__tupler   )r:   r!   r   r  r  r  r  r  all_hidden_statesall_reshaped_hidden_statesall_self_attentionsrK   r   r   reshaped_hidden_stater  r  r  r|  r  r   s                        r-   rl   zSwinv2Encoder.forward  s    #7BD+?RT"$5b4)6)<)<&J;$6M$6$6z$bDT$bVa$b!$9$A$A!Q1$M!-!11&+@*BB&(5 (	9OA|.7.CilO**t}} $ A A ))=:JO! !-!$#%	! *!,M0=a0@- -a 0 1" 57H7LM#(P-N-T-T*
A{ )O(I(N(N)"3A"68I!8L!M)OZ)% )>(E(EaAq(Q%!&G%II!*/D.FF*%.V-:-@-@*
A{(:(:(::(fHX(fZe(f%(=(E(EaAq(Q%!m%55!*/D.FF* #}QR'88#Q(	9T  '):<OQkl   #++*#=	
 	
r,   ))r   r   r   r   )NFFFT)r$   r%   r&   rh   r(   r   r   r   r   r)   r   r   r   rl   rr   rs   s   @r-   r  r  j  s    ,: 26,1/4CH&*L
||L
  S/L
 E--.	L

 $D>L
 'tnL
 3;4.L
 d^L
 
u))	*L
r,   r  c                   ,    e Zd ZdZeZdZdZdZdgZ	d Z
y)Swinv2PreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    swinv2r   Tr~  c                    t        |t        j                  t        j                  f      rm|j                  j
                  j                  d| j                  j                         |j                  %|j                  j
                  j                          yyt        |t        j                        rJ|j                  j
                  j                          |j                  j
                  j                  d       yy)zInitialize the weightsrX   )meanstdNr   )r   r   r   r   weightdatanormal_r   initializer_ranger   zero_r   fill_)r:   modules     r-   _init_weightsz#Swinv2PreTrainedModel._init_weights  s    fryy"))45 MM&&CT[[5R5R&S{{&  &&( '-KK""$MM$$S) .r,   N)r$   r%   r&   r'   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modulesr  r+   r,   r-   r  r    s,    
  L $O&*#&
*r,   r  aI  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`Swinv2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        interpolate_pos_encoding (`bool`, *optional*, default `False`):
            Whether to interpolate the pre-trained position encodings.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z`The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top.c                       e Zd Zd fd	Zd Zd Z ee       ee	e
ede      	 	 	 	 	 	 	 ddeej                     deej                      deej                     d	ee   d
ee   dedee   deee
f   fd              Z xZS )Swinv2Modelc                    t         |   |       || _        t        |j                        | _        t        |j                  d| j
                  dz
  z  z        | _        t        ||      | _
        t        || j                  j                        | _        t        j                  | j                  |j                         | _        |rt        j$                  d      nd | _        | j)                          y )Nr@   r   )r   rK  )rg   rh   r   r4  r  r  r   r~   num_featuresru   r   r  r{   encoderr   r   rP  	layernormAdaptiveAvgPool1dpooler	post_init)r:   r   add_pooling_layerr   ri   s       r-   rh   zSwinv2Model.__init__  s     fmm, 0 0119L3M MN*6.Q$VT__-G-GHd&7&7V=R=RS1Bb**1- 	r,   c                 .    | j                   j                  S rf   r   rx   r9   s    r-   get_input_embeddingsz Swinv2Model.get_input_embeddings%      ///r,   c                     |j                         D ]7  \  }}| j                  j                  |   j                  j	                  |       9 y)z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  layerrO  r8  )r:   heads_to_pruner  r6  s       r-   _prune_headszSwinv2Model._prune_heads(  sE    
 +002 	CLE5LLu%//;;EB	Cr,   vision)
checkpointoutput_typer  modalityexpected_outputr   r   r  r  r  r   r  rV   c                    ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|t	        d      | j                  |t        | j                   j                              }| j                  |||      \  }}	| j                  ||	||||      }
|
d   }| j                  |      }d}| j                  7| j                  |j                  dd            }t        j                  |d      }|s||f|
dd z   }|S t        |||
j                   |
j"                  |
j$                        S )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r   r   )r  r  r  r  r   r   r@   )r    r0   r!   r"   r#   )r   r  r  use_return_dictr   get_head_maskr4  r  r   r  r  r  r   r(   r   r/   r!   r"   r#   )r:   r   r   r  r  r  r   r  embedding_outputr   encoder_outputssequence_outputpooled_outputra   s                 r-   rl   zSwinv2Model.forward0  sp   , 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B]?@@ &&y#dkk6H6H2IJ	-1__/Tl .= .
** ,,/!5# ' 
 *!,..9;;" KK(A(A!Q(GHM!MM-;M%}58KKFM -')77&11#2#I#I
 	
r,   )TFNNNNNFN)r$   r%   r&   rh   r  r  r   SWINV2_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr/   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r(   r)   r   r   r   r   rl   rr   rs   s   @r-   r  r    s    0C ++BC&%$. 596:15,0/3).&*>
u001>
 "%"2"23>
 E--.	>

 $D>>
 'tn>
 #'>
 d^>
 
u''	(>
 D>
r,   r  aY  Swinv2 Model with a decoder on top for masked image modeling, as proposed in
[SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                        e Zd Z fdZ ee       eee      	 	 	 	 	 	 	 dde	e
j                     de	e
j                     de	e
j                     de	e   de	e   ded	e	e   d
eeef   fd              Z xZS )Swinv2ForMaskedImageModelingc                    t         |   |       t        |dd      | _        t	        |j
                  d|j                  dz
  z  z        }t        j                  t        j                  ||j                  dz  |j                  z  d      t        j                  |j                              | _        | j                          y )NFT)r  r   r@   r   )in_channelsout_channelsr   )rg   rh   r  r  r   r~   r  r   r   r   encoder_striderN   PixelShuffledecoderr  )r:   r   r  ri   s      r-   rh   z%Swinv2ForMaskedImageModeling.__init__  s     !&ERVW6++aF4E4E4I.JJK}}II(v7L7La7ORXReRe7est OOF112	
 	r,   r  r  r   r   r  r  r  r   r  rV   c           	         ||n| j                   j                  }| j                  |||||||      }|d   }	|	j                  dd      }	|	j                  \  }
}}t        j                  |dz        x}}|	j                  |
|||      }	| j                  |	      }d}|| j                   j                  | j                   j                  z  }|j                  d||      }|j                  | j                   j                  d      j                  | j                   j                  d      j                  d      j                         }t        j                  j!                  ||d	      }||z  j#                         |j#                         d
z   z  | j                   j$                  z  }|s|f|dd z   }||f|z   S |S t'        |||j(                  |j*                  |j,                        S )aQ  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 256, 256]
        ```N)r   r  r  r  r   r  r   r   r@   r   rC   none)r   gh㈵>)r3   r4   r!   r"   r#   )r   r  r  r   rE   r   floorr   r  r   r   repeat_interleaver   rH   r   r   l1_lossr   rN   r2   r!   r"   r#   )r:   r   r   r  r  r  r   r  r"  r  rK   rN   sequence_lengthrL   rM   reconstructed_pixel_valuesmasked_im_lossr   r   reconstruction_lossra   s                        r-   rl   z$Swinv2ForMaskedImageModeling.forward  s   R &1%<k$++B]B]+++/!5%=#  
 "!*)33Aq94C4I4I1
L/OS$899)11*lFTYZ &*\\/%B"&;;))T[[-C-CCD-55b$EO11$++2H2H!L""4;;#9#91=1	  #%--"7"7F`lr"7"s1D8==?488:PTCTUX\XcXcXpXppN02WQR[@F3A3M^%.YSYY.5!//))#*#A#A
 	
r,   r  )r$   r%   r&   rh   r   r  r   r2   r  r   r(   r)   r   r   r   r   rl   rr   rs   s   @r-   r  r  y  s      ++BC+JYhi 596:15,0/3).&*T
u001T
 "%"2"23T
 E--.	T

 $D>T
 'tnT
 #'T
 d^T
 
u55	6T
 j DT
r,   r  a  
    Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
    of the [CLS] token) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    c                        e Zd Z fdZ ee       eeee	e
      	 	 	 	 	 	 	 ddeej                     deej                     deej                     dee   dee   ded	ee   d
eeef   fd              Z xZS )Swinv2ForImageClassificationc                 >   t         |   |       |j                  | _        t        |      | _        |j                  dkD  r4t        j                  | j                  j                  |j                        nt        j                         | _	        | j                          y rW  )rg   rh   
num_labelsr  r  r   r   r  rR  
classifierr  )r:   r   ri   s     r-   rh   z%Swinv2ForImageClassification.__init__  sx      ++!&) GMFWFWZ[F[BIIdkk..0A0ABacalalan 	
 	r,   )r  r  r  r  r   r  labelsr  r  r   r  rV   c                 .   ||n| j                   j                  }| j                  ||||||      }|d   }	| j                  |	      }
d}|| j                   j                  | j
                  dk(  rd| j                   _        nl| j
                  dkD  rL|j                  t        j                  k(  s|j                  t        j                  k(  rd| j                   _        nd| j                   _        | j                   j                  dk(  rIt               }| j
                  dk(  r& ||
j                         |j                               }n ||
|      }n| j                   j                  dk(  r=t               } ||
j                  d| j
                        |j                  d            }n,| j                   j                  dk(  rt               } ||
|      }|s|
f|dd z   }||f|z   S |S t        ||
|j                   |j"                  |j$                  	      S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r  r  r  r   r  r   
regressionsingle_label_classificationmulti_label_classificationrC   r@   )r3   r;   r!   r"   r#   )r   r  r  r  problem_typer  rY   r(   longr   r   squeezer
   rF   r	   r>   r!   r"   r#   )r:   r   r  r  r  r  r   r  r"  r  r;   r3   loss_fctra   s                 r-   rl   z$Swinv2ForImageClassification.forward  s   . &1%<k$++B]B]++/!5%=#  
  
/{{''/??a'/;DKK,__q(fllejj.HFLL\a\e\eLe/LDKK,/KDKK,{{''<7"9??a'#FNN$4fnn6FGD#FF3D))-JJ+-B @&++b/R))-II,./Y,F)-)9TGf$EvE*!//))#*#A#A
 	
r,   r  )r$   r%   r&   rh   r   r  r   _IMAGE_CLASS_CHECKPOINTr>   r  _IMAGE_CLASS_EXPECTED_OUTPUTr   r(   r)   
LongTensorr   r   r   rl   rr   rs   s   @r-   r  r    s    " ++BC*/$4	 5915-1,0/3).&*@
u001@
 E--.@
 ))*	@

 $D>@
 'tn@
 #'@
 d^@
 
u11	2@
 D@
r,   r  zO
    Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
    c                        e Zd Z fdZd Z ee       eee	      	 	 	 d
de
dee   dee   dee   def
d	              Z xZS )Swinv2Backbonec           	         t         |   |       t         | 	  |       |j                  gt	        t        |j                              D cg c]  }t        |j                  d|z  z         c}z   | _        t        |      | _
        t        || j                  j                        | _        | j                          y c c}w )Nr@   )rg   rh   _init_backboner~   r  r4  r  r   r  ru   r   r  r{   r  r  )r:   r   r  ri   s      r-   rh   zSwinv2Backbone.__init__a  s     v&#--.X]^abhbobo^pXq1rST#f6F6FA6M2N1rr*62$VT__-G-GH 	 2ss   "B>c                 .    | j                   j                  S rf   r  r9   s    r-   r  z#Swinv2Backbone.get_input_embeddingsl  r  r,   r  r   r  r  r  rV   c           	         ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }| j	                  |      \  }}| j                  ||d|dd|      }|r|j                  n|d   }d}	t        | j                  |      D ]  \  }
}|
| j                  v s|	|fz  }	 |s|	f}|r	||d   fz  }|r	||d   fz  }|S t        |	|r|j                  nd|j                        S )	a]  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = AutoBackbone.from_pretrained(
        ...     "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 2048, 7, 7]
        ```NT)r  r  r  r  r  rC   r+   r   r@   )feature_mapsr!   r"   )r   r  r  r  r   r  r#   rX  stage_namesout_featuresr   r!   r"   )r:   r   r  r  r  r  r   r"  r!   r  r  hidden_statera   s                r-   rl   zSwinv2Backbone.forwardo  sD   F &1%<k$++B]B]$8$D $++JjJj 	 2C1N-TXT_T_TqTq-1__\-J**,,/!%59#  
 ;F667SU;#&t'7'7#G 	0E<)))/	0 "_F#71:-' 71:-'M%3G'//T))
 	
r,   )NNN)r$   r%   r&   rh   r  r   r  r   r   r  r   r   r   rl   rr   rs   s   @r-   r  r  Z  s    	0 ++BC>X -1/3&*F
F
 $D>F
 'tn	F

 d^F
 
F
 Y DF
r,   r  )rX   F)Nr'   collections.abcr   r   r6   dataclassesr   typingr   r   r   r(   torch.utils.checkpointr   r   torch.nnr	   r
   r   activationsr   modeling_outputsr   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   r   r   r   r   r   utils.backbone_utilsr   configuration_swinv2r   
get_loggerr$   loggerr  r  r  r  r  r   r/   r2   r>   rP   rR   rp   r   rb   r   rd   ru   rw   r   r   r%  r.  r>  rE  rI  r~  r  r  SWINV2_START_DOCSTRINGr  r  r  r  r  r+   r,   r-   <module>r      s   (    ! ) )    A A ! . - [ [   2 . 
		H	% ! A %  E -  K+ K K@  K  K  KF )#k )# )#X  K+  K  KH	U\\ e T V[VbVb *-RYY -Y-ryy Y-z(-BII (-V3 3l~")) ~D
ryy 
+bii +^  	299 	z")) zz=")) =@f
BII f
T*O *2	  0 f
a
' a

a
H 	 g
#8 g
g
T   V
#8 V
! V
r  	W
*M W
W
r,   