
    sgn              	          d Z ddlZddlZddlmZmZmZmZ ddlZddl	m
c mZ ddlZddlm
Z
 ddlmZmZmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZmZmZm Z  ddl!m"Z"  e jF                  e$      Z%dZ&dZ'g dZ(dZ)dZ*d1dejV                  de,de-dejV                  fdZ. G d de
j^                        Z0 G d de
j^                        Z1 G d de
j^                        Z2 G d de
j^                        Z3 G d d e
j^                        Z4 G d! d"e
j^                        Z5 G d# d$e
j^                        Z6 G d% d&e
j^                        Z7 G d' d(e      Z8d)Z9d*Z: ed+e9       G d, d-e8             Z; ed.e9       G d/ d0e8             Z<y)2zPyTorch PVT model.    N)IterableOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputImageClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )	PvtConfigr   zZetatech/pvt-tiny-224)r   2   i   ztabby, tabby catinput	drop_probtrainingreturnc                    |dk(  s|s| S d|z
  }| j                   d   fd| j                  dz
  z  z   }|t        j                  || j                  | j
                        z   }|j                          | j                  |      |z  }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
            r   r   )r   )dtypedevice)shapendimtorchrandr   r    floor_div)r   r   r   	keep_probr!   random_tensoroutputs          W/var/www/html/venv/lib/python3.12/site-packages/transformers/models/pvt/modeling_pvt.py	drop_pathr+   6   s     CxII[[^

Q 77E

5ELL YYMYYy!M1FM    c                   x     e Zd ZdZd	dee   ddf fdZdej                  dej                  fdZ	de
fdZ xZS )
PvtDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   r   c                 0    t         |           || _        y N)super__init__r   )selfr   	__class__s     r*   r2   zPvtDropPath.__init__N   s    "r,   hidden_statesc                 D    t        || j                  | j                        S r0   )r+   r   r   r3   r5   s     r*   forwardzPvtDropPath.forwardR   s    FFr,   c                 8    dj                  | j                        S )Nzp={})formatr   )r3   s    r*   
extra_reprzPvtDropPath.extra_reprU   s    }}T^^,,r,   r0   )__name__
__module____qualname____doc__r   floatr2   r#   Tensorr8   strr;   __classcell__r4   s   @r*   r.   r.   K   sG    b#(5/ #T #GU\\ Gell G-C -r,   r.   c                        e Zd ZdZ	 ddedeeee   f   deeee   f   dedededef fd	Z	d
e
j                  dedede
j                  fdZde
j                  dee
j                  eef   fdZ xZS )PvtPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    config
image_size
patch_sizestridenum_channelshidden_size	cls_tokenc                    t         	|           || _        t        |t        j
                  j                        r|n||f}t        |t        j
                  j                        r|n||f}|d   |d   z  |d   |d   z  z  }|| _        || _        || _	        || _
        t        j                  t        j                  d|r|dz   n||            | _        |r*t        j                  t        j                   dd|            nd | _        t        j$                  ||||      | _        t        j(                  ||j*                        | _        t        j.                  |j0                        | _        y )Nr   r   kernel_sizerJ   eps)p)r1   r2   rG   
isinstancecollectionsabcr   rH   rI   rK   num_patchesr   	Parameterr#   randnposition_embeddingszerosrM   Conv2d
projection	LayerNormlayer_norm_eps
layer_normDropouthidden_dropout_probdropout)
r3   rG   rH   rI   rJ   rK   rL   rM   rW   r4   s
            r*   r2   zPvtPatchEmbeddings.__init__`   s0    	#-j+//:R:R#SZZdfpYq
#-j+//:R:R#SZZdfpYq
!!}
15*Q-:VW=:XY$$(&#%<<KKi;?[+V$
  JSekk!Q&DEX\))L+6Zde,,{8M8MNzzF$>$>?r,   
embeddingsheightwidthr   c                    ||z  }t         j                  j                         s<|| j                  j                  | j                  j                  z  k(  r| j
                  S |j                  d||d      j                  dddd      }t        j                  |||fd      }|j                  dd||z        j                  ddd      }|S )Nr   r   r      bilinear)sizemode)
r#   jit
is_tracingrG   rH   rZ   reshapepermuteFinterpolate)r3   rd   re   rf   rW   interpolated_embeddingss         r*   interpolate_pos_encodingz+PvtPatchEmbeddings.interpolate_pos_encoding|   s    un yy##%+9O9ORVR]R]RhRh9h*h+++''65"=EEaAqQ
"#--
&%Wa"b"9"A"A!RRW"X"`"`abdegh"i&&r,   pixel_valuesc                    |j                   \  }}}}|| j                  k7  rt        d      | j                  |      }|j                   ^ }}}|j	                  d      j                  dd      }| j                  |      }| j                  | j                  j                  |dd      }	t        j                  |	|fd      }| j                  | j                  d d dd f   ||      }
t        j                  | j                  d d d df   |
fd      }
n| j                  | j                  ||      }
| j                  ||
z         }|||fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.ri   r   rh   dim)r!   rK   
ValueErrorr]   flatten	transposer`   rM   expandr#   catrt   rZ   rc   )r3   ru   
batch_sizerK   re   rf   patch_embed_rd   rM   rZ   s              r*   r8   zPvtPatchEmbeddings.forward   sM   2>2D2D/
L&%4,,,w  ool3'--FE!))!,66q!<__[1
>>%--j"bAIIz#:BJ"&"?"?@X@XYZ\]\^Y^@_agin"o"'))T-E-Ea!e-LNa,bhi"j"&"?"?@X@XZ`bg"h\\*/B"BC
65((r,   F)r<   r=   r>   r?   r   r   intr   boolr2   r#   rA   rt   r   r8   rC   rD   s   @r*   rF   rF   Y   s      @@ #x},-@ #x},-	@
 @ @ @ @8	'5<< 	' 	'UX 	']b]i]i 	')ELL )U5<<c;Q5R )r,   rF   c                   `     e Zd Zdedef fdZdej                  dej                  fdZ xZ	S )PvtSelfOutputrG   rL   c                     t         |           t        j                  ||      | _        t        j
                  |j                        | _        y r0   )r1   r2   r   Lineardensera   rb   rc   )r3   rG   rL   r4   s      r*   r2   zPvtSelfOutput.__init__   s6    YY{K8
zz&"<"<=r,   r5   r   c                 J    | j                  |      }| j                  |      }|S r0   )r   rc   r7   s     r*   r8   zPvtSelfOutput.forward   s$    

=1]3r,   )
r<   r=   r>   r   r   r2   r#   rA   r8   rC   rD   s   @r*   r   r      s1    >y >s >
U\\ ell r,   r   c                        e Zd ZdZdedededef fdZdedej                  fd	Z
	 ddej                  d
edededeej                     f
dZ xZS )PvtEfficientSelfAttentionzpEfficient self-attention mechanism with reduction of the sequence [PvT paper](https://arxiv.org/abs/2102.12122).rG   rL   num_attention_headssequences_reduction_ratioc                    t         |           || _        || _        | j                  | j                  z  dk7  r&t	        d| j                   d| j                   d      t        | j                  | j                  z        | _        | j                  | j                  z  | _        t        j                  | j                  | j                  |j                        | _        t        j                  | j                  | j                  |j                        | _        t        j                  | j                  | j                  |j                        | _        t        j                  |j                        | _        || _        |dkD  rEt        j$                  ||||      | _        t        j(                  ||j*                        | _        y y )	Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ())biasr   rO   rQ   )r1   r2   rL   r   ry   r   attention_head_sizeall_head_sizer   r   qkv_biasquerykeyvaluera   attention_probs_dropout_probrc   r   r\   sequence_reductionr^   r_   r`   r3   rG   rL   r   r   r4   s        r*   r2   z"PvtEfficientSelfAttention.__init__   sr    	&#6 d666!;#D$4$4#5 622316 
 $'t'7'7$:R:R'R#S !558P8PPYYt//1C1C&//Z
99T--t/A/AXYYt//1C1C&//Z
zz&"E"EF)B&$q(&(ii[6OXq'D# !ll;F<Q<QRDO	 )r,   r5   r   c                     |j                         d d | j                  | j                  fz   }|j                  |      }|j	                  dddd      S )Nrh   r   ri   r   r   )rk   r   r   viewrp   )r3   r5   	new_shapes      r*   transpose_for_scoresz.PvtEfficientSelfAttention.transpose_for_scores   sT    !&&("-1I1I4KcKc0dd	%**95$$Q1a00r,   re   rf   output_attentionsc                    | j                  | j                  |            }| j                  dkD  r{|j                  \  }}}|j	                  ddd      j                  ||||      }| j                  |      }|j                  ||d      j	                  ddd      }| j                  |      }| j                  | j                  |            }	| j                  | j                  |            }
t        j                  ||	j                  dd            }|t        j                  | j                        z  }t         j"                  j%                  |d      }| j'                  |      }t        j                  ||
      }|j	                  dddd      j)                         }|j+                         d d | j,                  fz   }|j/                  |      }|r||f}|S |f}|S )Nr   r   ri   rh   rw   r   )r   r   r   r!   rp   ro   r   r`   r   r   r#   matmulr{   mathsqrtr   r   
functionalsoftmaxrc   
contiguousrk   r   r   )r3   r5   re   rf   r   query_layerr~   seq_lenrK   	key_layervalue_layerattention_scoresattention_probscontext_layernew_context_layer_shapeoutputss                   r*   r8   z!PvtEfficientSelfAttention.forward   s    //

=0IJ))A-0=0C0C-J)11!Q:BB:|]cejkM 33MBM)11*lBOWWXY[\^_`M OOM:M--dhh}.EF	//

=0IJ !<<Y5H5HR5PQ+dii8P8P.QQ --//0@b/I ,,7_kB%--aAq9DDF"/"4"4"6s";t?Q?Q>S"S%**+BC6G=/2 O\M]r,   r   )r<   r=   r>   r?   r   r   r@   r2   r#   rA   r   r   r   r8   rC   rD   s   @r*   r   r      s    zSS.1SHKShmS:1# 1%,, 1 #(*||* * 	*
  * 
u||	*r,   r   c                        e Zd Zdedededef fdZd Z	 ddej                  ded	ed
e
deej                     f
dZ xZS )PvtAttentionrG   rL   r   r   c                     t         |           t        ||||      | _        t	        ||      | _        t               | _        y )N)rL   r   r   )rL   )r1   r2   r   r3   r   r)   setpruned_headsr   s        r*   r2   zPvtAttention.__init__   sB     	-# 3&?	
	 $FDEr,   c                 >   t        |      dk(  ry t        || j                  j                  | j                  j                  | j
                        \  }}t        | j                  j                  |      | j                  _        t        | j                  j                  |      | j                  _        t        | j                  j                  |      | j                  _	        t        | j                  j                  |d      | j                  _        | j                  j                  t        |      z
  | j                  _        | j                  j                  | j                  j                  z  | j                  _        | j
                  j                  |      | _        y )Nr   r   rw   )lenr   r3   r   r   r   r   r   r   r   r)   r   r   union)r3   headsindexs      r*   prune_headszPvtAttention.prune_heads	  s   u:?749900$))2O2OQUQbQb
u
 -TYY__eD		*499==%@		,TYY__eD		.t{{/@/@%QO )-		(E(EE
(R		%"&))"?"?$))B_B_"_		 --33E:r,   r5   re   rf   r   r   c                 h    | j                  ||||      }| j                  |d         }|f|dd  z   }|S )Nr   r   )r3   r)   )r3   r5   re   rf   r   self_outputsattention_outputr   s           r*   r8   zPvtAttention.forward  sE     yy?PQ;;|A7#%QR(88r,   r   )r<   r=   r>   r   r   r@   r2   r   r#   rA   r   r   r8   rC   rD   s   @r*   r   r      sn    "".1"HK"hm";& _d"\\36?BW[	u||	r,   r   c            
       z     e Zd Z	 	 d	dededee   dee   f fdZdej                  dej                  fdZ	 xZ
S )
PvtFFNrG   in_featureshidden_featuresout_featuresc                 j   t         |           ||n|}t        j                  ||      | _        t        |j                  t              rt        |j                     | _	        n|j                  | _	        t        j                  ||      | _
        t        j                  |j                        | _        y r0   )r1   r2   r   r   dense1rT   
hidden_actrB   r   intermediate_act_fndense2ra   rb   rc   )r3   rG   r   r   r   r4   s        r*   r2   zPvtFFN.__init__&  s     	'3'?|[ii_=f''-'-f.?.?'@D$'-'8'8D$ii>zz&"<"<=r,   r5   r   c                     | j                  |      }| j                  |      }| j                  |      }| j                  |      }| j                  |      }|S r0   )r   r   rc   r   r7   s     r*   r8   zPvtFFN.forward7  sP    M200?]3M2]3r,   )NN)r<   r=   r>   r   r   r   r2   r#   rA   r8   rC   rD   s   @r*   r   r   %  sY    
 *.&*>> > "#	>
 sm>"U\\ ell r,   r   c                   f     e Zd Zdedededededef fdZddej                  d	ed
ede	fdZ
 xZS )PvtLayerrG   rL   r   r+   r   	mlp_ratioc                 v   t         |           t        j                  ||j                        | _        t        ||||      | _        |dkD  rt        |      nt        j                         | _
        t        j                  ||j                        | _        t        ||z        }t        |||      | _        y )NrQ   )rG   rL   r   r   r   )rG   r   r   )r1   r2   r   r^   r_   layer_norm_1r   	attentionr.   Identityr+   layer_norm_2r   r   mlp)	r3   rG   rL   r   r+   r   r   mlp_hidden_sizer4   s	           r*   r2   zPvtLayer.__init__A  s     	LL&:O:OP%# 3&?	
 4=s?Y/LL&:O:OPkI56[Rabr,   r5   re   rf   r   c                    | j                  | j                  |      |||      }|d   }|dd  }| j                  |      }||z   }| j                  | j	                  |            }| j                  |      }||z   }	|	f|z   }|S )N)r5   re   rf   r   r   r   )r   r   r+   r   r   )
r3   r5   re   rf   r   self_attention_outputsr   r   
mlp_outputlayer_outputs
             r*   r8   zPvtLayer.forwardW  s    !%++M:/	 "0 "
 2!4(,>>*:;(=8XXd//>?
^^J/
$z1/G+r,   r   )r<   r=   r>   r   r   r@   r2   r#   rA   r   r8   rC   rD   s   @r*   r   r   @  so    cc c !	c
 c $)c c,U\\ 3 s _c r,   r   c                   x     e Zd Zdef fdZ	 	 	 d	dej                  dee   dee   dee   de	e
ef   f
dZ xZS )

PvtEncoderrG   c                    t         	|           || _        t        j                  d|j
                  t        |j                              j                         }g }t        |j                        D ]  }|j                  t        ||dk(  r|j                  n| j                  j                  d|dz   z  z  |j                  |   |j                  |   |dk(  r|j                   n|j"                  |dz
     |j"                  |   ||j                  dz
  k(                t%        j&                  |      | _        g }d}t        |j                        D ]  }g }|dk7  r||j                  |dz
     z  }t        |j                  |         D ]\  }|j                  t+        ||j"                  |   |j,                  |   |||z      |j.                  |   |j0                  |                ^ |j                  t%        j&                  |              t%        j&                  |      | _        t%        j4                  |j"                  d   |j6                        | _        y )Nr   ri   r   )rG   rH   rI   rJ   rK   rL   rM   )rG   rL   r   r+   r   r   rh   rQ   )r1   r2   rG   r#   linspacedrop_path_ratesumdepthstolistrangenum_encoder_blocksappendrF   rH   patch_sizesstridesrK   hidden_sizesr   
ModuleListpatch_embeddingsr   r   sequence_reduction_ratios
mlp_ratiosblockr^   r_   r`   )
r3   rG   drop_path_decaysrd   iblockscurlayersjr4   s
            r*   r2   zPvtEncoder.__init__o  s*    !>>!V-B-BCDVW^^` 
v001 	A"!45Fv00@V@V[\abefaf[g@h%11!4!>>!,89Q!4!4FDWDWXY\]X]D^ & 3 3A 66#<#<q#@@
	 !#j 9 v001 	1AFAvv}}QU++6==+, 
%$*$7$7$:,2,F,Fq,I"237";282R2RST2U"("3"3A"6	
 MM"--/0!	1$ ]]6*
 ,,v':':2'>FDYDYZr,   ru   r   output_hidden_statesreturn_dictr   c                 2   |rdnd }|rdnd }|j                   d   }t        | j                        }|}	t        t	        | j
                  | j                              D ]|  \  }
\  }} ||	      \  }	}}|D ]&  } ||	|||      }|d   }	|r	||d   fz   }|s!||	fz   }( |
|dz
  k7  sI|	j                  |||d      j                  dddd      j                         }	~ | j                  |	      }	|r||	fz   }|st        d |	||fD              S t        |	||      S )	N r   r   rh   r   ri   c              3   &   K   | ]	  }||  y wr0   r   ).0vs     r*   	<genexpr>z%PvtEncoder.forward.<locals>.<genexpr>  s     mq_`_lms   last_hidden_stater5   
attentions)r!   r   r   	enumeratezipr   ro   rp   r   r`   tupler   )r3   ru   r   r   r   all_hidden_statesall_self_attentionsr~   
num_blocksr5   idxembedding_layerblock_layerre   rf   r   layer_outputss                    r*   r8   zPvtEncoder.forward  si    #7BD$5b4!''*
_
$3<SAVAVX\XbXb=c3d 	v/C//;+:=+I(M65$ M %mVUDU V -a 0$*=qAQ@S*S''(9]<L(L%M j1n$ - 5 5j&%QS T \ \]^`acdfg h s s u	v 6 1]4D Dm]4EGZ$[mmm++*
 	
r,   )FFT)r<   r=   r>   r   r2   r#   FloatTensorr   r   r   r   r   r8   rC   rD   s   @r*   r   r   n  sn    0[y 0[j -2/4&*#
''#
 $D>#
 'tn	#

 d^#
 
uo%	&#
r,   r   c                   x    e Zd ZdZeZdZdZg Zde	e
j                  e
j                  e
j                  f   ddfdZy)PvtPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    pvtru   moduler   Nc                    t        |t        j                        rt        j                  j	                  |j
                  j                  d| j                  j                        |j
                  _        |j                  %|j                  j                  j                          yyt        |t        j                        rJ|j                  j                  j                          |j
                  j                  j                  d       yt        |t              rt        j                  j	                  |j                  j                  d| j                  j                        |j                  _        |j                  Zt        j                  j	                  |j                  j                  d| j                  j                        |j                  _        yyy)zInitialize the weightsr   )meanstdNg      ?)rT   r   r   inittrunc_normal_weightdatarG   initializer_ranger   zero_r^   fill_rF   rZ   rM   )r3   r	  s     r*   _init_weightsz PvtPreTrainedModel._init_weights  sS   fbii( "$!6!6v}}7I7IPSY]YdYdYvYv!6!wFMM{{&  &&( '-KK""$MM$$S) 23.0gg.C.C**//KK11 /D /F&&+
 +(*(=(=$$))55 )> )  % , 4r,   )r<   r=   r>   r?   r   config_classbase_model_prefixmain_input_name_no_split_modulesr   r   r   r\   r^   r  r   r,   r*   r  r    sK    
 L$OE"))RYY*L$M RV r,   r  aG  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`~PvtConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a
  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`PvtImageProcessor.__call__`]
            for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zSThe bare Pvt encoder outputting raw hidden-states without any specific head on top.c                        e Zd Zdef fdZd Z eej                  d             e	e
eede      	 	 	 ddej                  dee   d	ee   d
ee   deeef   f
d              Z xZS )PvtModelrG   c                 r    t         |   |       || _        t        |      | _        | j                          y r0   )r1   r2   rG   r   encoder	post_initr3   rG   r4   s     r*   r2   zPvtModel.__init__  s1      "&) 	r,   c                     |j                         D ]7  \  }}| j                  j                  |   j                  j	                  |       9 y)z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  layerr   r   )r3   heads_to_pruner!  r   s       r*   _prune_headszPvtModel._prune_heads  sE    
 +002 	CLE5LLu%//;;EB	Cr,   %(batch_size, channels, height, width)vision)
checkpointoutput_typer  modalityexpected_outputru   r   r   r   r   c                 ,   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }| j	                  ||||      }|d   }|s	|f|dd  z   S t        ||j                  |j                        S )Nru   r   r   r   r   r   r   )rG   r   r   use_return_dictr  r   r5   r   )r3   ru   r   r   r   encoder_outputssequence_outputs          r*   r8   zPvtModel.forward  s     2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B],,%/!5#	 ' 
 *!,#%(;;;-)77&11
 	
r,   )NNN)r<   r=   r>   r   r2   r#  r   PVT_INPUTS_DOCSTRINGr:   r   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr#   r  r   r   r   r   r8   rC   rD   s   @r*   r  r    s    
y C ++?+F+FGn+op&#$. -1/3&*
''
 $D>
 'tn	

 d^
 
uo%	&
 q
r,   r  z
    Pvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                        e Zd Zdeddf fdZ eej                  d             ee	e
ee      	 	 	 	 ddeej                     deej                     d	ee   d
ee   dee   deee
f   fd              Z xZS )PvtForImageClassificationrG   r   Nc                 0   t         |   |       |j                  | _        t        |      | _        |j                  dkD  r-t        j                  |j                  d   |j                        nt        j                         | _	        | j                          y )Nr   rh   )r1   r2   
num_labelsr  r  r   r   r   r   
classifierr  r  s     r*   r2   z"PvtForImageClassification.__init__L  sy      ++F# FLEVEVYZEZBIIf))"-v/@/@A`b`k`k`m 	
 	r,   r$  )r&  r'  r  r)  ru   labelsr   r   r   c                 (   ||n| j                   j                  }| j                  ||||      }|d   }| j                  |dddddf         }d}	|| j                   j                  | j
                  dk(  rd| j                   _        nl| j
                  dkD  rL|j                  t        j                  k(  s|j                  t        j                  k(  rd| j                   _        nd| j                   _        | j                   j                  dk(  rIt               }
| j
                  dk(  r& |
|j                         |j                               }	n |
||      }	n| j                   j                  dk(  r=t               }
 |
|j                  d| j
                        |j                  d            }	n,| j                   j                  dk(  rt               }
 |
||      }	|s|f|dd z   }|	|	f|z   S |S t        |	||j                   |j"                  	      S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr+  r   r   
regressionsingle_label_classificationmulti_label_classificationrh   )losslogitsr5   r   )rG   r,  r  r7  problem_typer6  r   r#   longr   r
   squeezer	   r   r   r   r5   r   )r3   ru   r8  r   r   r   r   r.  r>  r=  loss_fctr)   s               r*   r8   z!PvtForImageClassification.forwardZ  s   * &1%<k$++B]B]((%/!5#	  
 "!*Aq!9:{{''/??a'/;DKK,__q(fllejj.HFLL\a\e\eLe/LDKK,/KDKK,{{''<7"9??a'#FNN$4fnn6FGD#FF3D))-JJ+-B @&++b/R))-II,./Y,F)-)9TGf$EvE$!//))	
 	
r,   )NNNN)r<   r=   r>   r   r2   r   r/  r:   r   _IMAGE_CLASS_CHECKPOINTr   r1  _IMAGE_CLASS_EXPECTED_OUTPUTr   r#   rA   r   r   r   r8   rC   rD   s   @r*   r4  r4  D  s    y T  ++?+F+FGn+op*)$4	 *.,0/3&*;
u||,;
 &;
 $D>	;

 'tn;
 d^;
 
u++	,;
 q;
r,   r4  )r   F)=r?   rU   r   typingr   r   r   r   r#   torch.nn.functionalr   r   rq   torch.utils.checkpointtorch.nnr   r	   r
   activationsr   modeling_outputsr   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   r   r   configuration_pvtr   
get_loggerr<   loggerr1  r0  r2  rC  rD  rA   r@   r   r+   Moduler.   rF   r   r   r   r   r   r   r  PVT_START_DOCSTRINGr/  r  r4  r   r,   r*   <module>rS     s  "    3 3      A A ! F - Q  ) 
		H	%- % 1 1 U\\ e T V[VbVb *-")) -A) A)H	BII 	O		 Od'299 'TRYY 6+ryy +\V
 V
r! !H	    Y7
! 7
	7
t  Q
 2 Q
Q
r,   