
    sg p              	          d Z ddlZddlmZ ddlmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZmZmZ ddlmZmZmZ dd	lmZmZ dd
lmZmZmZ ddlmZ ddlmZ  ej>                  e       Z!dZ"dZ#g dZ$dZ%dZ&e G d de             Z'd?de	jP                  de)de*de	jP                  fdZ+ G d dejX                        Z- G d dejX                        Z. G d dejX                        Z/ G d d ejX                        Z0 G d! d"ejX                        Z1 G d# d$ejX                        Z2 G d% d&ejX                        Z3 G d' d(ejX                        Z4 G d) d*ejX                        Z5 G d+ d,ejX                        Z6 G d- d.ejX                        Z7 G d/ d0ejX                        Z8 G d1 d2ejX                        Z9 G d3 d4ejX                        Z: G d5 d6e      Z;d7Z<d8Z= ed9e<       G d: d;e;             Z> ed<e<       G d= d>e;             Z?y)@zPyTorch CvT model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forward)$ImageClassifierOutputWithNoAttentionModelOutput)PreTrainedModel find_pruneable_heads_and_indicesprune_linear_layer)logging   )	CvtConfigr   zmicrosoft/cvt-13)r   i     r   ztabby, tabby catc                       e Zd ZU dZdZej                  ed<   dZej                  ed<   dZ	e
eej                  df      ed<   y)BaseModelOutputWithCLSTokena  
    Base class for model's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
            Classification token at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
    Nlast_hidden_statecls_token_value.hidden_states)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r   r        W/var/www/html/venv/lib/python3.12/site-packages/transformers/models/cvt/modeling_cvt.pyr   r   /   sI     ,0u((/)-OU&&-=AM8E%"3"3S"89:Ar%   r   input	drop_probtrainingreturnc                    |dk(  s|s| S d|z
  }| j                   d   fd| j                  dz
  z  z   }|t        j                  || j                  | j
                        z   }|j                          | j                  |      |z  }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
            r   r   )r   )dtypedevice)shapendimr!   randr-   r.   floor_div)r'   r(   r)   	keep_probr/   random_tensoroutputs          r&   	drop_pathr7   E   s     CxII[[^

Q 77E

5ELL YYMYYy!M1FMr%   c                   x     e Zd ZdZd	dee   ddf fdZdej                  dej                  fdZ	de
fdZ xZS )
CvtDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr(   r*   c                 0    t         |           || _        y N)super__init__r(   )selfr(   	__class__s     r&   r=   zCvtDropPath.__init__]   s    "r%   r   c                 D    t        || j                  | j                        S r;   )r7   r(   r)   )r>   r   s     r&   forwardzCvtDropPath.forwarda   s    FFr%   c                 8    dj                  | j                        S )Nzp={})formatr(   )r>   s    r&   
extra_reprzCvtDropPath.extra_reprd   s    }}T^^,,r%   r;   )r   r   r   r    r   floatr=   r!   TensorrA   strrD   __classcell__r?   s   @r&   r9   r9   Z   sG    b#(5/ #T #GU\\ Gell G-C -r%   r9   c                   (     e Zd ZdZ fdZd Z xZS )CvtEmbeddingsz'
    Construct the CvT embeddings.
    c                     t         |           t        |||||      | _        t	        j
                  |      | _        y )N)
patch_sizenum_channels	embed_dimstridepadding)r<   r=   CvtConvEmbeddingsconvolution_embeddingsr   Dropoutdropout)r>   rM   rN   rO   rP   rQ   dropout_rater?   s          r&   r=   zCvtEmbeddings.__init__m   s:    &7!	Z`jq'
# zz,/r%   c                 J    | j                  |      }| j                  |      }|S r;   )rS   rU   )r>   pixel_valueshidden_states      r&   rA   zCvtEmbeddings.forwardt   s&    22<@||L1r%   r   r   r   r    r=   rA   rH   rI   s   @r&   rK   rK   h   s    0r%   rK   c                   (     e Zd ZdZ fdZd Z xZS )rR   z"
    Image to Conv Embedding.
    c                     t         |           t        |t        j                  j
                        r|n||f}|| _        t        j                  |||||      | _	        t        j                  |      | _        y )N)kernel_sizerP   rQ   )r<   r=   
isinstancecollectionsabcIterablerM   r   Conv2d
projection	LayerNormnormalization)r>   rM   rN   rO   rP   rQ   r?   s         r&   r=   zCvtConvEmbeddings.__init__   sa    #-j+//:R:R#SZZdfpYq
$))L)\blst\\)4r%   c                     | j                  |      }|j                  \  }}}}||z  }|j                  |||      j                  ddd      }| j                  r| j	                  |      }|j                  ddd      j                  ||||      }|S Nr      r   )rc   r/   viewpermutere   )r>   rX   
batch_sizerN   heightwidthhidden_sizes          r&   rA   zCvtConvEmbeddings.forward   s    |42>2D2D/
L&%un#((\;OWWXY[\^_`--l;L#++Aq!499*lTZ\abr%   rZ   rI   s   @r&   rR   rR   z   s    5
r%   rR   c                   $     e Zd Z fdZd Z xZS )CvtSelfAttentionConvProjectionc           	          t         |           t        j                  |||||d|      | _        t        j
                  |      | _        y )NF)r]   rQ   rP   biasgroups)r<   r=   r   rb   convolutionBatchNorm2dre   )r>   rO   r]   rQ   rP   r?   s        r&   r=   z'CvtSelfAttentionConvProjection.__init__   sG    99#
  ^^I6r%   c                 J    | j                  |      }| j                  |      }|S r;   )rt   re   r>   rY   s     r&   rA   z&CvtSelfAttentionConvProjection.forward   s(    ''5)),7r%   r   r   r   r=   rA   rH   rI   s   @r&   rp   rp      s    7r%   rp   c                       e Zd Zd Zy) CvtSelfAttentionLinearProjectionc                 z    |j                   \  }}}}||z  }|j                  |||      j                  ddd      }|S rg   )r/   ri   rj   )r>   rY   rk   rN   rl   rm   rn   s          r&   rA   z(CvtSelfAttentionLinearProjection.forward   sK    2>2D2D/
L&%un#((\;OWWXY[\^_`r%   N)r   r   r   rA   r$   r%   r&   rz   rz      s    r%   rz   c                   &     e Zd Zd fd	Zd Z xZS )CvtSelfAttentionProjectionc                 p    t         |           |dk(  rt        ||||      | _        t	               | _        y )Ndw_bn)r<   r=   rp   convolution_projectionrz   linear_projection)r>   rO   r]   rQ   rP   projection_methodr?   s         r&   r=   z#CvtSelfAttentionProjection.__init__   s7    '*HT_ahjp*qD'!A!Cr%   c                 J    | j                  |      }| j                  |      }|S r;   )r   r   rw   s     r&   rA   z"CvtSelfAttentionProjection.forward   s(    22<@--l;r%   )r   rx   rI   s   @r&   r}   r}      s    Dr%   r}   c                   .     e Zd Z	 d fd	Zd Zd Z xZS )CvtSelfAttentionc                    t         |           |dz  | _        || _        || _        || _        t        |||||dk(  rdn|      | _        t        |||||      | _        t        |||||      | _	        t        j                  |||	      | _        t        j                  |||	      | _        t        j                  |||	      | _        t        j                  |
      | _        y )Ng      avglinear)r   )rr   )r<   r=   scalewith_cls_tokenrO   	num_headsr}   convolution_projection_queryconvolution_projection_keyconvolution_projection_valuer   Linearprojection_queryprojection_keyprojection_valuerT   rU   )r>   r   rO   r]   	padding_q
padding_kvstride_q	stride_kvqkv_projection_methodqkv_biasattention_drop_rater   kwargsr?   s                r&   r=   zCvtSelfAttention.__init__   s     	_
,"",F*?5*HhNc-
) +E{J	Mb+
' -G{J	Mb-
) !#		)YX N ii	98L "		)YX Nzz"56r%   c                     |j                   \  }}}| j                  | j                  z  }|j                  ||| j                  |      j	                  dddd      S )Nr   rh   r   r   )r/   rO   r   ri   rj   )r>   rY   rk   rn   _head_dims         r&   "rearrange_for_multi_head_attentionz3CvtSelfAttention.rearrange_for_multi_head_attention   sV    %1%7%7"
K>>T^^3  [$..(S[[\]_`bcefggr%   c                 `   | j                   rt        j                  |d||z  gd      \  }}|j                  \  }}}|j	                  ddd      j                  ||||      }| j                  |      }| j                  |      }	| j                  |      }
| j                   rKt        j                  |	fd      }	t        j                  ||fd      }t        j                  ||
fd      }
| j                  | j                  z  }| j                  | j                  |	            }	| j                  | j                  |            }| j                  | j                  |
            }
t        j                   d|	|g      | j"                  z  }t        j$                  j&                  j)                  |d      }| j+                  |      }t        j                   d||
g      }|j                  \  }}}}|j	                  dddd      j-                         j                  ||| j                  |z        }|S )	Nr   r   rh   dimzbhlk,bhtk->bhltzbhlt,bhtv->bhlvr   )r   r!   splitr/   rj   ri   r   r   r   catrO   r   r   r   r   r   einsumr   r   
functionalsoftmaxrU   
contiguous)r>   rY   rl   rm   	cls_tokenrk   rn   rN   keyqueryvaluer   attention_scoreattention_probscontextr   s                   r&   rA   zCvtSelfAttention.forward   s   &+kk,FUN@SUV&W#I|0<0B0B-
K#++Aq!499*lTZ\ab--l;11,?11,?IIy%0a8E))Y,!4CIIy%0a8E>>T^^3778M8Me8TU55d6I6I#6NO778M8Me8TU,,'85#,G$**T((--55o25N,,7,,0?E2JK&}}1k1//!Q1-88:??
KY]YgYgjrYrsr%   T)r   r   r   r=   r   rA   rH   rI   s   @r&   r   r      s     '7Rhr%   r   c                   (     e Zd ZdZ fdZd Z xZS )CvtSelfOutputz
    The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    c                     t         |           t        j                  ||      | _        t        j
                  |      | _        y r;   )r<   r=   r   r   denserT   rU   )r>   rO   	drop_rater?   s      r&   r=   zCvtSelfOutput.__init__  s0    YYy)4
zz),r%   c                 J    | j                  |      }| j                  |      }|S r;   r   rU   r>   rY   input_tensors      r&   rA   zCvtSelfOutput.forward  s$    zz,/||L1r%   rZ   rI   s   @r&   r   r     s    
-
r%   r   c                   .     e Zd Z	 d fd	Zd Zd Z xZS )CvtAttentionc                     t         |           t        |||||||||	|
|      | _        t	        ||      | _        t               | _        y r;   )r<   r=   r   	attentionr   r6   setpruned_heads)r>   r   rO   r]   r   r   r   r   r   r   r   r   r   r?   s                r&   r=   zCvtAttention.__init__   sW     	)!
 $Iy9Er%   c                 >   t        |      dk(  ry t        || j                  j                  | j                  j                  | j
                        \  }}t        | j                  j                  |      | j                  _        t        | j                  j                  |      | j                  _        t        | j                  j                  |      | j                  _	        t        | j                  j                  |d      | j                  _        | j                  j                  t        |      z
  | j                  _        | j                  j                  | j                  j                  z  | j                  _        | j
                  j                  |      | _        y )Nr   r   r   )lenr   r   num_attention_headsattention_head_sizer   r   r   r   r   r6   r   all_head_sizeunion)r>   headsindexs      r&   prune_headszCvtAttention.prune_heads@  s   u:?74>>55t~~7Y7Y[_[l[l
u
  2$..2F2FN/0B0BEJ1$..2F2FN.t{{/@/@%QO .2^^-O-ORUV[R\-\*'+~~'I'IDNNLnLn'n$ --33E:r%   c                 P    | j                  |||      }| j                  ||      }|S r;   )r   r6   )r>   rY   rl   rm   self_outputattention_outputs         r&   rA   zCvtAttention.forwardR  s+    nn\65A;;{LAr%   r   )r   r   r   r=   r   rA   rH   rI   s   @r&   r   r     s     "@;$ r%   r   c                   $     e Zd Z fdZd Z xZS )CvtIntermediatec                     t         |           t        j                  |t	        ||z              | _        t        j                         | _        y r;   )r<   r=   r   r   intr   GELU
activation)r>   rO   	mlp_ratior?   s      r&   r=   zCvtIntermediate.__init__Y  s7    YYy#i).C*DE
'')r%   c                 J    | j                  |      }| j                  |      }|S r;   )r   r   rw   s     r&   rA   zCvtIntermediate.forward^  s$    zz,/|4r%   rx   rI   s   @r&   r   r   X  s    $
r%   r   c                   $     e Zd Z fdZd Z xZS )	CvtOutputc                     t         |           t        j                  t	        ||z        |      | _        t        j                  |      | _        y r;   )r<   r=   r   r   r   r   rT   rU   )r>   rO   r   r   r?   s       r&   r=   zCvtOutput.__init__e  s:    YYs9y#899E
zz),r%   c                 T    | j                  |      }| j                  |      }||z   }|S r;   r   r   s      r&   rA   zCvtOutput.forwardj  s.    zz,/||L1#l2r%   rx   rI   s   @r&   r   r   d  s    -
r%   r   c                   ,     e Zd ZdZ	 d fd	Zd Z xZS )CvtLayerzb
    CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
    c                 Z   t         |           t        |||||||||	|
||      | _        t	        ||      | _        t        |||      | _        |dkD  rt        |      nt        j                         | _        t        j                  |      | _        t        j                  |      | _        y )Nr,   )r(   )r<   r=   r   r   r   intermediater   r6   r9   r   Identityr7   rd   layernorm_beforelayernorm_after)r>   r   rO   r]   r   r   r   r   r   r   r   r   r   drop_path_rater   r?   s                  r&   r=   zCvtLayer.__init__v  s    " 	%!
 ,IyA	9i@BPSVBV~>\^\g\g\i "Y 7!||I6r%   c                    | j                  | j                  |      ||      }|}| j                  |      }||z   }| j                  |      }| j	                  |      }| j                  ||      }| j                  |      }|S r;   )r   r   r7   r   r   r6   )r>   rY   rl   rm   self_attention_outputr   layer_outputs          r&   rA   zCvtLayer.forward  s     $!!,/!

 1>>*:; (,6 ++L9((6 {{<>~~l3r%   r   rZ   rI   s   @r&   r   r   q  s    & %7Nr%   r   c                   $     e Zd Z fdZd Z xZS )CvtStagec                 x   t         |           || _        || _        | j                  j                  | j                     rFt        j                  t        j                  dd| j                  j                  d               | _        t        |j                  | j                     |j                  | j                     | j                  dk(  r|j                  n|j                  | j                  dz
     |j                  | j                     |j                  | j                     |j                  | j                           | _        t        j"                  d|j$                  | j                     |j&                  |         D cg c]  }|j)                          }}t        j*                  t-        |j&                  | j                           D cg c]T  }t/        |j0                  | j                     |j                  | j                     |j2                  | j                     |j4                  | j                     |j6                  | j                     |j8                  | j                     |j:                  | j                     |j<                  | j                     |j>                  | j                     |j@                  | j                     |j                  | j                     || j                     |jB                  | j                     |j                  | j                           W c} | _"        y c c}w c c}w )Nr   r   r   )rM   rP   rN   rO   rQ   rV   )r   rO   r]   r   r   r   r   r   r   r   r   r   r   r   )#r<   r=   configstager   r   	Parameterr!   randnrO   rK   patch_sizespatch_striderN   patch_paddingr   	embeddinglinspacer   depthitem
Sequentialranger   r   
kernel_qkvr   r   r   r   r   r   r   r   layers)r>   r   r   xdrop_path_ratesr   r?   s         r&   r=   zCvtStage.__init__  sw   
;;  ,\\%++aDKK<Q<QRT<U*VWDN&))$**5&&tzz204

a,,VEUEUVZV`V`cdVdEe&&tzz2((4))$**5
 .3^^Av?T?TUYU_U_?`bhbnbnotbu-vw1668wwmm$ v||DJJ78#" ! $..tzz:$..tzz: & 1 1$** =$..tzz:%00<$..tzz:#__TZZ8*0*F*Ftzz*R#__TZZ8(.(B(B4::(N$..tzz:#24::#>$..tzz:#)#3#3DJJ#?
 xs   L2EL7c                 Z   d }| j                  |      }|j                  \  }}}}|j                  ||||z        j                  ddd      }| j                  j
                  | j                     r6| j
                  j                  |dd      }t        j                  ||fd      }| j                  D ]  } ||||      }|} | j                  j
                  | j                     rt        j                  |d||z  gd      \  }}|j                  ddd      j                  ||||      }||fS )Nr   rh   r   r   r   )r   r/   ri   rj   r   r   r   expandr!   r   r   r   )	r>   rY   r   rk   rN   rl   rm   layerlayer_outputss	            r&   rA   zCvtStage.forward  s'   	~~l32>2D2D/
L&%#((\6E>RZZ[\^_abc;;  ,--j"bAI 99i%>AFL[[ 	)E!,>M(L	) ;;  ,&+kk,FUN@SUV&W#I|#++Aq!499*lTZ\abY&&r%   rx   rI   s   @r&   r   r     s    &
P'r%   r   c                   &     e Zd Z fdZddZ xZS )
CvtEncoderc                     t         |           || _        t        j                  g       | _        t        t        |j                              D ]'  }| j
                  j                  t        ||             ) y r;   )r<   r=   r   r   
ModuleListstagesr   r   r   appendr   )r>   r   	stage_idxr?   s      r&   r=   zCvtEncoder.__init__  s[    mmB's6<<01 	<IKKx	:;	<r%   c                     |rdnd }|}d }t        | j                        D ]  \  }} ||      \  }}|s||fz   } |st        d |||fD              S t        |||      S )Nr$   c              3   &   K   | ]	  }||  y wr;   r$   ).0vs     r&   	<genexpr>z%CvtEncoder.forward.<locals>.<genexpr>  s     bqTUTabs   r   r   r   )	enumerater   tupler   )	r>   rX   output_hidden_statesreturn_dictall_hidden_statesrY   r   r   stage_modules	            r&   rA   zCvtEncoder.forward  s    "6BD#	!*4;;!7 	HA&2<&@#L)#$5$G!	H
 b\9>O$Pbbb**%+
 	
r%   )FTrx   rI   s   @r&   r   r     s    <
r%   r   c                   (    e Zd ZdZeZdZdZdgZd Z	y)CvtPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    cvtrX   r   c                    t        |t        j                  t        j                  f      rt        j                  j                  |j                  j                  d| j                  j                        |j                  _        |j                  %|j                  j                  j                          yyt        |t        j                        rJ|j                  j                  j                          |j                  j                  j                  d       yt        |t              r| j                  j                  |j                      rrt        j                  j                  t#        j$                  dd| j                  j&                  d         d| j                  j                        |j                  _        yyy)zInitialize the weightsr,   )meanstdNg      ?r   r   )r^   r   r   rb   inittrunc_normal_weightdatar   initializer_rangerr   zero_rd   fill_r   r   r   r!   zerosrO   )r>   modules     r&   _init_weightsz CvtPreTrainedModel._init_weights  s0   fryy"))45!#!6!6v}}7I7IPSY]YdYdYvYv!6!wFMM{{&  &&( '-KK""$MM$$S)){{$$V\\2(*(=(=KK1dkk&;&;B&?@sPTP[P[PmPm )> )  % 3 *r%   N)
r   r   r   r    r   config_classbase_model_prefixmain_input_name_no_split_modulesr  r$   r%   r&   r  r    s&    
 L$O#r%   r  aE  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aE  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
            for details.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
z]The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.c                        e Zd Zd
 fd	Zd Z ee       eee	e
de      	 	 	 ddeej                     dee   dee   deee	f   fd	              Z xZS )CvtModelc                 r    t         |   |       || _        t        |      | _        | j                          y r;   )r<   r=   r   r   encoder	post_init)r>   r   add_pooling_layerr?   s      r&   r=   zCvtModel.__init__D  s-     !&)r%   c                     |j                         D ]7  \  }}| j                  j                  |   j                  j	                  |       9 y)z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr%  r   r   r   )r>   heads_to_pruner   r   s       r&   _prune_headszCvtModel._prune_headsJ  sE    
 +002 	CLE5LLu%//;;EB	Cr%   vision)
checkpointoutput_typer  modalityexpected_outputrX   r
  r  r*   c                    ||n| j                   j                  }||n| j                   j                  }|t        d      | j	                  |||      }|d   }|s	|f|dd  z   S t        ||j                  |j                        S )Nz You have to specify pixel_valuesr
  r  r   r   r  )r   r
  use_return_dict
ValueErrorr%  r   r   r   )r>   rX   r
  r  encoder_outputssequence_outputs         r&   rA   zCvtModel.forwardR  s     %9$D $++JjJj 	 &1%<k$++B]B]?@@,,!5# ' 

 *!,#%(;;;*-+;;)77
 	
r%   r   )NNN)r   r   r   r=   r+  r   CVT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r!   rF   boolr   r   rA   rH   rI   s   @r&   r#  r#  ?  s    
C ++?@&/$. 04/3&*	
u||,
 'tn
 d^	

 
u11	2
 A
r%   r#  z
    Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                        e Zd Z fdZ ee       eeee	e
      	 	 	 	 d	deej                     deej                     dee   dee   deeef   f
d              Z xZS )
CvtForImageClassificationc                    t         |   |       |j                  | _        t        |d      | _        t        j                  |j                  d         | _        |j                  dkD  r-t        j                  |j                  d   |j                        nt        j                         | _        | j                          y )NF)r'  r   r   )r<   r=   
num_labelsr#  r  r   rd   rO   	layernormr   r   
classifierr&  )r>   r   r?   s     r&   r=   z"CvtForImageClassification.__init__  s      ++Fe<f&6&6r&:; CIBSBSVWBWBIIf&&r*F,=,=>]_]h]h]j 	
 	r%   )r-  r.  r  r0  rX   labelsr
  r  r*   c                 b   ||n| j                   j                  }| j                  |||      }|d   }|d   }| j                   j                  d   r| j	                  |      }nI|j
                  \  }}	}
}|j                  ||	|
|z        j                  ddd      }| j	                  |      }|j                  d      }| j                  |      }d}|| j                   j                  | j                   j                  dk(  rd| j                   _
        nv| j                   j                  dkD  rL|j                  t        j                  k(  s|j                  t        j                  k(  rd	| j                   _
        nd
| j                   _
        | j                   j                  dk(  rSt!               }| j                   j                  dk(  r& ||j#                         |j#                               }n |||      }n| j                   j                  d	k(  rGt%               } ||j                  d| j                   j                        |j                  d            }n,| j                   j                  d
k(  rt'               } |||      }|s|f|dd z   }||f|z   S |S t)        |||j*                        S )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr2  r   r   r   rh   r   
regressionsingle_label_classificationmulti_label_classification)losslogitsr   )r   r3  r  r   r@  r/   ri   rj   r  rA  problem_typer?  r-   r!   longr   r
   squeezer	   r   r   r   )r>   rX   rB  r
  r  outputsr6  r   rk   rN   rl   rm   sequence_output_meanrH  rG  loss_fctr6   s                    r&   rA   z!CvtForImageClassification.forward  s`   ( &1%<k$++B]B]((!5#  
 "!*AJ	;;  $"nnY7O6E6K6K3Jfe-22:|VV[^\ddefhiklmO"nn_=O.333:!56{{''/;;))Q./;DKK,[[++a/V\\UZZ5OSYS_S_chclclSl/LDKK,/KDKK,{{''<7"9;;))Q.#FNN$4fnn6FGD#FF3D))-JJ+-B0F0F GUWY))-II,./Y,F)-)9TGf$EvE3f\c\q\qrrr%   )NNNN)r   r   r   r=   r   r7  r   _IMAGE_CLASS_CHECKPOINTr   r9  _IMAGE_CLASS_EXPECTED_OUTPUTr   r!   rF   r;  r   r   rA   rH   rI   s   @r&   r=  r=  y  s     ++?@*8$4	 04)-/3&*<su||,<s &<s 'tn	<s
 d^<s 
u::	;<s A<sr%   r=  )r,   F)@r    collections.abcr_   dataclassesr   typingr   r   r   r!   torch.utils.checkpointr   torch.nnr   r	   r
   
file_utilsr   r   r   modeling_outputsr   r   modeling_utilsr   r   r   utilsr   configuration_cvtr   
get_loggerr   loggerr9  r8  r:  rO  rP  r   rF   rE   r;  r7   Moduler9   rK   rR   rp   rz   r}   r   r   r   r   r   r   r   r   r  CVT_START_DOCSTRINGr7  r#  r=  r$   r%   r&   <module>r_     s     ! ) )    A A q q Q c c  ( 
		H	%  ) )  - 1  B+ B B*U\\ e T V[VbVb *-")) -BII $		 2RYY (ryy 
 
Nryy NbBII "6 299 6 r	bii 	
		 
?ryy ?D:'ryy :'z
 
8 6	 
  c3
! 3
	3
l  Rs 2 RsRsr%   