
    sg9                     .   d dl mZmZmZ d dlZd dlmZ d dlZddlm	Z	 ddl
mZmZ ddlmZ ddlmZmZ ddlmZmZmZmZ d	d
lmZmZmZmZmZmZmZmZm Z   e       rddl!m"Z"  ed      rd dl#m$Z$ dZ% ejL                  e'      Z( G d de      Z) G d de      Z* G d dejV                        Z, G d de      Z-de)dej\                  dej\                  dej\                  deej\                     deej\                  ej\                  f   fdZ/ej`                  fde)dej\                  dej\                  dej\                  deej\                     dejb                  deej\                  df   fdZ2	 d6de)dej\                  dej\                  dej\                  deej\                     d e3deej\                  eej\                     f   fd!Z4de)dej\                  dej\                  dej\                  deej\                     deej\                  df   fd"Z5e2e4e/e5d#Z6 G d$ d%ejV                        Z7 G d& d'e7      Z8 G d( d)e7      Z9 G d* d+ejV                        Z: G d, d-e      Z; G d. d/ee;      Z< G d0 d1e      Z= G d2 d3e      Z> G d4 d5e      Z?y)7    )OptionalTupleUnionN   )ACT2FN)CacheHybridCache)PretrainedConfig)BaseModelOutputWithPastCausalLMOutputWithPast)is_flash_attn_2_availableis_flash_attn_greater_or_equalis_torch_greater_or_equallogging   )	GemmaForCausalLMGemmaForSequenceClassificationGemmaForTokenClassification
GemmaModelGemmaPreTrainedModelGemmaRMSNormGemmaRotaryEmbeddingapply_rotary_pos_emb	repeat_kv)_flash_attention_forwardz2.5)flex_attentionzgoogle/gemma2-7bc                   ^     e Zd ZdZdZdgZ	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d fd	Z xZS )Gemma2Configa  
    This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the Gemma2-7B.
    e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        vocab_size (`int`, *optional*, defaults to 256000):
            Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`Gemma2Model`]
        hidden_size (`int`, *optional*, defaults to 2304):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 9216):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 26):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 8):
            Number of attention heads for each attention layer in the Transformer decoder.
        num_key_value_heads (`int`, *optional*, defaults to 4):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
            `num_attention_heads`.
        head_dim (`int`, *optional*, defaults to 256):
            The attention head dimension.
        hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
            The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
            if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
        max_position_embeddings (`int`, *optional*, defaults to 8192):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        pad_token_id (`int`, *optional*, defaults to 0):
            Padding token id.
        eos_token_id (`int`, *optional*, defaults to 1):
            End of stream token id.
        bos_token_id (`int`, *optional*, defaults to 2):
            Beginning of stream token id.
        tie_word_embeddings (`bool`, *optional*, defaults to `True`):
            Whether to tie weight embeddings
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores
        sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
            size of the sliding window.
        final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
        attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
        cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.

    ```python
    >>> from transformers import Gemma2Model, Gemma2Config
    >>> # Initializing a Gemma2 gemma2-7b style configuration
    >>> configuration = Gemma2Config()
    >>> # Initializing a model from the gemma2-7b style configuration
    >>> model = Gemma2Model(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```gemma2past_key_valuesc                 F   t        |   d||||d| || _        |	| _        || _        || _        || _        || _        || _        || _	        |
| _
        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        y )N)pad_token_idbos_token_ideos_token_idtie_word_embeddings )super__init__
vocab_sizemax_position_embeddingshidden_sizeintermediate_sizenum_hidden_layersnum_attention_headshead_dimnum_key_value_headsinitializer_rangerms_norm_eps	use_cache
rope_thetaattention_biasattention_dropouthidden_activationquery_pre_attn_scalarsliding_windowfinal_logit_softcappingattn_logit_softcappingcache_implementation)selfr)   r+   r,   r-   r.   r0   r/   r7   r*   r1   r2   r3   r"   r$   r#   r%   r4   r5   r6   r8   r9   r:   r;   r<   kwargs	__class__s                             \/var/www/html/venv/lib/python3.12/site-packages/transformers/models/gemma2/modular_gemma2.pyr(   zGemma2Config.__init__   s    8 	 	
%%% 3		

 	
 %'>$&!2!2#6  #6 !2("$,!2!2%:",'>$&<#$8!    )i  i 	  i $              gelu_pytorch_tanhi    g{Gz?gư>Tr      r   Tg     @F        rE   i   g      >@g      I@hybrid)__name__
__module____qualname____doc__
model_typekeys_to_ignore_at_inferencer(   __classcell__r?   s   @r@   r   r   <   sn    FP J#4"5 - $ ! $#%369 69rA   r   c                       e Zd Zy)Gemma2RMSNormNrJ   rK   rL   r&   rA   r@   rS   rS          rA   rS   c                   $     e Zd Z fdZd Z xZS )	Gemma2MLPc                    t         |           || _        |j                  | _        |j                  | _        t        j                  | j                  | j                  d      | _        t        j                  | j                  | j                  d      | _        t        j                  | j                  | j                  d      | _	        t        |j                     | _        y )NFbias)r'   r(   configr+   r,   nnLinear	gate_projup_proj	down_projr   r7   act_fnr=   r[   r?   s     r@   r(   zGemma2MLP.__init__   s    !--!'!9!94#3#3T5K5KRWXyy!1!143I3IPUV4#9#94;K;KRWXV556rA   c                     | j                  | j                  | j                  |            | j                  |      z        S N)r`   ra   r^   r_   )r=   xs     r@   forwardzGemma2MLP.forward   s0    ~~dkk$..*;<t||ANOOrA   )rJ   rK   rL   r(   rf   rP   rQ   s   @r@   rW   rW      s    7PrA   rW   c                       e Zd Zy)Gemma2RotaryEmbeddingNrT   r&   rA   r@   rh   rh      rU   rA   rh   r[   querykeyvaluemaskreturnc                    t        || j                        }t        || j                        }t        j                  ||j	                  dd            | j
                  z  }| j                  3|| j                  z  }t        j                  |      }|| j                  z  }|#|d d d d d d d |j                  d   f   }	||	z   }t        j                  j                  |dt        j                        j                  |j                        }t        j                  j                  || j                   | j"                        }t        j                  ||      }
|
j	                  dd      j%                         }
|
|fS )Nr   r   )dimdtype)ptrainingrG   )r   num_key_value_groupstorchmatmul	transposescalingr;   tanhshaper\   
functionalsoftmaxfloat32torr   dropoutr6   rt   
contiguous)r[   ri   rj   rk   rl   _kwargs
key_statesvalue_statesattn_weightscausal_maskattn_outputs              r@   eager_attention_forwardr      sM    3 ; ;<JUF$?$?@L<<z';';Aq'ABV^^SL$$0#f&C&CCzz,/#f&C&CC1a$:j&6&6r&:$::;#k1 ==((2U]](SVVW\WbWbcL==((9Q9Q\b\k\k(lL,,|\:K''1-88:K$$rA   target_dtypec                 R   |+|j                   d   }|d d d d d |f   }|d d d d d |f   }|j                  dd      }|j                  dd      }	|j                  dd      }
| j                  r| j                  nd}|j                  }|t
        j                  k(  r3|j                  |      }|	j                  |      }	|
j                  |      }
t        ||	|
||| j                  | j                  | j                  | j                  t        d      r| j                  nd       }|d fS )NrG   r   rH   z2.6.0)r   softmax_scale	is_causalr9   use_top_left_masksoftcap)r{   rx   rt   r6   rr   rv   r~   r   r   ry   r   r9   _flash_attn_uses_top_left_maskr   r;   )r[   ri   rj   rk   rl   r   r   seq_lenquery_statesr   r   dropout_rateinput_dtyper   s                 r@   flash_attention_forwardr      s$    **Q-aHWHn%aHWHn% ??1a(Lq!$J??1a(L/56++CL$$Kemm##|4]]<0
#|4*nn"",, ??1OPW1X--^bK rA   output_attentionsc           	            fd}t        ||||d j                  |      }|sd }	n|\  }}	|j                  dd      j                         }||	fS )Nc                 |    j                   }|t        j                  | |z        z  } | |   d   |   |   z   S | S )Nr   )r;   rv   rz   )scorebhq_idxkv_idxsoft_capr[   rl   s         r@   tanh_softcapz,flex_attention_forward.<locals>.tanh_softcap+  sN    005::eh&677471:e,V444rA   T)	score_mod
enable_gqascale
return_lserG   r   )r   ry   rx   r   )
r[   ri   rj   rk   rl   r   r   r   r   r   s
   `   `     r@   flex_attention_forwardr   "  sg     !nn$K $/!\''1-88:K$$rA   c           	      P   t        || j                        }t        || j                        }|}||d d d d d d d |j                  d   f   }|j                  j                  dk(  r2|0|j                         }|j                         }|j                         }||j                  d   dkD  rdnd}t        j                  j                  j                  ||||| j                  r| j                  nd|| j                        }|j                  dd      j                         }|d fS )	Nro   cudarG   TFrH   )	attn_mask	dropout_pr   r   r   )r   ru   r{   devicetyper   rv   r\   r|   scaled_dot_product_attentionrt   r6   ry   rx   )	r[   ri   rj   rk   rl   r   r   r   r   s	            r@   sdpa_attention_forwardr   D  s    C44
5CeV889EK!!Q?SYYr]?":; ||F"{'>  "nn  " $+A0BI((%%BB.4oo&**3nn C K ''1-88:KrA   )flash_attention_2r   eagersdpac                   ,    e Zd ZdZddedee   f fdZ	 	 	 	 	 	 ddej                  deej                     deej                     dee   d	ed
edeej                     deej                  eej                     eeej                        f   fdZ xZS )Gemma2Attentionz=Multi-headed attention from 'Attention Is All You Need' paperr[   	layer_idxc                 (   t         |           || _        || _        |j                  | _        |j
                  | _        |j                  | _        |j                  | _        |j                  | _	        | j                  | j                  z  | _
        |j                  | _        |j                  | _        d| _        |j                  dz  | _        t!        |dz        s|j"                  nd | _        |j$                  | _        | j
                  | j                  z  dk7  r&t'        d| j
                   d| j                   d      t)        j*                  | j
                  | j                  | j                  z  |j,                        | _        t)        j*                  | j
                  | j                  | j                  z  |j,                        | _        t)        j*                  | j
                  | j                  | j                  z  |j,                        | _        t)        j*                  | j                  | j                  z  | j
                  |j,                        | _        t7        | j                  | j                  | j                  	      | _        y )
NTg      r   r   z?hidden_size must be divisible by num_heads (got `hidden_size`: z and `num_heads`: z).rY   )r*   base)r'   r(   r[   r   r6   r+   r.   	num_headsr/   r0   ru   r*   r4   r   r8   ry   boolr9   r;   
ValueErrorr\   r]   r5   q_projk_projv_projo_projrh   
rotary_embr=   r[   r   r?   s      r@   r(   zGemma2Attention.__init__v  s   "!'!9!9!--33#)#=#= $(NNd6N6N$N!'-'E'E$ ++33T9;?	A;Nf33TX&,&C&C#dnn,1QRVRbRbQc$T^^$4B8 
 ii 0 0$..4==2PW]WlWlmii 0 0$2J2JT]]2Zagavavwii 0 0$2J2JT]]2Zagavavwii >@P@PW]WlWlm/MM$($@$@
rA   hidden_statesattention_maskposition_idspast_key_valuer   r3   cache_positionrm   c                    |j                         \  }}	}
| j                  |      }| j                  |      }| j                  |      }|j	                  ||	| j
                  | j                        j                  dd      }|j	                  ||	| j                  | j                        j                  dd      }|j	                  ||	| j                  | j                        j                  dd      }| j                  ||      \  }}t        ||||      \  }}|2||| j                  |d}|j                  ||| j                  |      \  }}|r0| j                  j                  dv rt         j#                  d       d}n| j                  j                  }t%        |   | |||||      \  }}|j'                  ||	d      j)                         }| j+                  |      }|sd }|||fS )	NrG   r   )sincosr9   r   )r   r   zMSetting `attention_type` to `flex_attention` because `output_attentions=True`r   )r   rp   )sizer   r   r   viewr   r/   rx   r0   r   r   r9   updater   r[   _attn_implementationloggerwarning_onceGEMMA2_ATTENTION_FUNCTIONreshaper   r   )r=   r   r   r   r   r   r3   r   bszq_len_r   r   r   r   r   cache_kwargsattention_typer   r   s                       r@   rf   zGemma2Attention.forward  s    &**,UA{{=1[[/
{{=1#((eT^^T]]S]]^_abc__S%1I1I4==Yccdeghi
#((eT5M5Mt}}]gghiklm??<>S#7jRUWZ#[ j% "&"5"5"0	L (6'<'<ZW[WeWegs't$J!A!AEb!b op-N![[==N$=n$M,
L.\m%
!\ "))#ub9DDFkk+. LL.88rA   rd   NNNFFN)rJ   rK   rL   rM   r   r   intr(   rv   Tensor
LongTensorr   r   r   rf   rP   rQ   s   @r@   r   r   s  s    G
| 
 
H 2637*."'5919||19 !.19 u//0	19
 !19  19 19 !!1!1219 
u||Xell3XeELL>Q5RR	S19rA   r   c                   0     e Zd Zddedee   f fdZ xZS )Gemma2FlashAttention2r[   r   c                 r    t         |   ||       d| j                  _        t        j                  d       y )Nr   The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GemmaAttention` class! It will be removed in v4.48r'   r(   r[   r   r   r   r   s      r@   r(   zGemma2FlashAttention2.__init__  s2    ++>(S	
rA   rd   rJ   rK   rL   r   r   r   r(   rP   rQ   s   @r@   r   r         
| 
 
 
rA   r   c                   0     e Zd Zddedee   f fdZ xZS )Gemma2SdpaAttentionr[   r   c                 r    t         |   ||       d| j                  _        t        j                  d       y )Nr   r   r   r   s      r@   r(   zGemma2SdpaAttention.__init__  s2    ++1(S	
rA   rd   r   rQ   s   @r@   r   r     r   rA   r   c                   (    e Zd Zdedef fdZ	 	 	 	 	 	 ddej                  deej                     deej                     dee
   dee   d	ee   d
eej                     deej                  eeej                  ej                  f      f   fdZ xZS )Gemma2DecoderLayerr[   r   c                    t         |           |j                  | _        || _        t	        |dz         | _        t        ||      | _        t        |      | _	        t        |j                  |j                        | _        t        |j                  |j                        | _        t        |j                  |j                        | _        t        |j                  |j                        | _        |j                   | _        y )Nr   )r[   r   )eps)r'   r(   r+   r[   r   
is_slidingr   	self_attnrW   mlprS   r2   input_layernormpost_attention_layernormpre_feedforward_layernormpost_feedforward_layernormr9   r   s      r@   r(   zGemma2DecoderLayer.__init__  s    !--"9q=11()LV$,V-?-?VEXEXY(5f6H6HfNaNa(b%)6v7I7IvObOb)c&*78J8JPVPcPc*d'$33rA   r   r   r   r   r   r3   r   rm   c           	         | j                   r|| j                  j                  dk(  r||d d | j                   d f   }nt	        j
                  |j                        j                  }t	        j                  t	        j                  |t        j                        | j                         }	t	        j                  |	||      }|j                  d   dk  r|d d d d d d | j                   d f   }|}
| j                  |      }| j                  |||||||      \  }}}| j                  |      }|
|z   }|}
| j!                  |      }| j#                  |      }| j%                  |      }|
|z   }|f}|r||fz  }|r||fz  }|S )Nr   rr   )diagonalrp   rG   )r   r   r   r   r   r3   r   )r   r[   r   r9   rv   finforr   mintril	ones_liker   wherer{   r   r   r   r   r   r   )r=   r   r   r   r   r   r3   r   	min_dtypesliding_window_maskresidualself_attn_weightspresent_key_valueoutputss                 r@   rf   zGemma2DecoderLayer.forward  s    ??~9{{//3FF!-%3A8K8K7K7M4M%NN!KK(;(;<@@	&+jjOON%**EQUQdQdPd'# "'-@)^!\!''+q0%3Aq!d>Q>Q=Q=S4S%TN ,,]; ?Cnn')%)/) ?M ?
;(*; 55mD =0 66}E/77F =0 ")++G)++GrA   r   )rJ   rK   rL   r   r   r(   rv   r   r   r   r   r   r   FloatTensorrf   rP   rQ   s   @r@   r   r     s    4| 4 4" 2637*.,1$)597||7 !.7 u//0	7
 !7 $D>7 D>7 !!1!127 
u  (51B1BEDUDU1U+V"WW	X7rA   r   c                   4     e Zd ZdZeddef fd       Z xZS )Gemma2PreTrainedModelFhard_check_onlyc                 Z    t         |   ||      }|s|j                  dk(  rd|_        |S )z
        Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
        SDPA reduces the model performance on Gemma2 because of the logits softcapping.
        )r   r   r   )r'   _check_and_enable_sdpar   )clsr[   r   r?   s      r@   r   z,Gemma2PreTrainedModel._check_and_enable_sdpa+  s8     //X 6#>#>&#H*1F'rA   F)rJ   rK   rL   _supports_quantized_cacheclassmethodr   r   rP   rQ   s   @r@   r   r   (  s"     %T  rA   r   c                       e Zd Zdef fdZ	 	 	 	 	 	 	 	 	 	 ddej                  deej                     deej                     dee	   deej                     dee   d	ee   d
ee   dee   deej                     deeef   fdZ ej                          dej                  dej                  dej                  de	d	ef
d       Z xZS )Gemma2Modelr[   c           	          t         |   |       t        j                  t	        |j
                        D cg c]  }t        ||       c}      | _        | j                          y c c}w rd   )	r'   r(   r\   
ModuleListranger-   r   layers	post_initr   s      r@   r(   zGemma2Model.__init__;  sS     mmDI&JbJbDcdy	2d
 	 es   A'	input_idsr   r   r    inputs_embedsr3   r   output_hidden_statesreturn_dictr   rm   c                 6   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|	|	n| j                   j                  }	|d u |d uz  rt        d      | j                  r%| j                  r|rt        j                  d       d}|| j                  |      }|rL|J| j                  s>|j                  \  }}}t        | j                   ||| j                  |j                        }|
F||j                         nd}t!        j"                  |||j                  d   z   |j                        }
||
j%                  d      }| j'                  |||
||      }|}t!        j(                  | j                   j*                  dz  |j                  	      }||z  }|rd
nd }|rd
nd }| j,                  d | j                   j.                   D ]e  }|r||fz  }| j                  r/| j                  r#| j1                  |j2                  |||||||
      }n ||||||||
      }|d   }|s]||d   fz  }g | j5                  |      }|r||fz  }|r|nd }|	st7        d ||||fD              S t9        ||||      S )Nz:You must specify exactly one of input_ids or inputs_embedszX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.F)
batch_sizemax_cache_lenr   rr   r   rG   )r   g      ?r   r&   )r   r   r   r   r3   r   c              3   &   K   | ]	  }||  y wrd   r&   ).0vs     r@   	<genexpr>z&Gemma2Model.forward.<locals>.<genexpr>  s     tqfgfsts   )last_hidden_stater    r   
attentions)r[   r   r  r3   use_return_dictr   gradient_checkpointingrt   r   r   embed_tokensr{   r	   r   rr   get_seq_lengthrv   arange	unsqueeze_update_causal_masktensorr+   r  r-   _gradient_checkpointing_func__call__normtupler   )r=   r
  r   r   r    r  r3   r   r  r  r   r  r   r   past_seen_tokensr   r   
normalizerall_hidden_statesall_self_attnsdecoder_layerlayer_outputs
next_caches                          r@   rf   zGemma2Model.forwardB  s    2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	%0%<k$++B]B]-t";<YZZ&&4==Yj I  --i8M0%2%8%8"J)%%{{#))O !CRC^==?de"\\ "2]5H5H5K"KTaThThN )33A6L..M>?L]

 &
 \\$++"9"93">mFYFYZ
%
2 #7BD0d![[)H4;;+H+HI 	6M#!m%55!**t}} $ A A!**! #%"	! !.!#.!-#2&7'#1! *!,M =#3"55;	6> 		-0-!11(1_t
t]J@QSa$bttt&+&+%	
 	
rA   input_tensorc           
      V   | j                   j                  dk(  r|S |j                  |j                  }}|j                  d   }t        |t              r|j                         }	n ||j                  d   n|j                  d   }	| j                  |||	||||j                  d         }
|
S )Nr   rG   rp   r   sequence_lengthtarget_lengthrr   r   r   r  )	r[   r   rr   r   r{   
isinstancer	   get_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_position)r=   r   r*  r   r    r   rr   r   r-  r.  r   s              r@   r  zGemma2Model._update_causal_mask  s     ;;++/BB!!$**L,?,?v&,,Q/o{3+??AM8F8RN004XdXjXjklXmM PP+')#))!, Q 
 rA   )
NNNNNNNNNN)rJ   rK   rL   r   r(   rv   r   r   r   r	   r   r   r   r   r   rf   no_gradr  rP   rQ   s   @r@   r  r  :  sU   |  '+15371559$(,0/3&*59q
##q
 !.q
 u//0	q

 "+.q
   1 12q
 D>q
 $D>q
 'tnq
 d^q
 !!1!12q
 
u--	.q
f U]]_   ll  	 
 %      rA   r  c                   H    e Zd Z fdZ	 	 	 	 	 	 	 	 	 	 	 	 ddej
                  deej                     deej
                     dee   deej                     deej
                     dee
   d	ee
   d
ee
   dee
   deej
                     dedeeef   fdZ	 	 	 	 	 	 	 ddZ xZS )Gemma2ForCausalLMc                 d    t         |   |       t        |      | _        | j	                          y rd   r'   r(   r  modelr	  rb   s     r@   r(   zGemma2ForCausalLM.__init__  &      (
rA   r
  r   r   r    r  labelsr3   r   r  r  r   num_logits_to_keeprm   c                 D   | j                   rF| j                  j                  dk7  r-t        j	                  d| j                  j                   d       ||n| j                  j
                  }|	|	n| j                  j                  }	|
|
n| j                  j                  }
| j                  ||||||||	|
|
      }|d   }| j                  |dd| dddf         }| j                  j                  G|| j                  j                  z  }t        j                  |      }|| j                  j                  z  }d}| | j                  ||| j                  fi |}|
s|f|dd z   }||f|z   S |S t        |||j                   |j"                  |j$                        S )	an  
        ```python
        >>> from transformers import AutoTokenizer, GemmaForCausalLM

        >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")

        >>> prompt = "What is your favorite condiment?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "What is your favorite condiment?"
        ```r   zhIt is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `zp`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.N)
r
  r   r   r    r  r3   r   r  r  r   r   rG   )losslogitsr    r   r  )rt   r[   r   r   r   r   r  r  r7  lm_headr:   rv   rz   loss_functionr)   r   r    r   r  )r=   r
  r   r   r    r  r9  r3   r   r  r  r   r:  loss_kwargsr   r   r=  r<  outputs                      r@   rf   zGemma2ForCausalLM.forward  s   @ ==T[[==H#{{??@  Aqr 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B]**)%+'/!5#)  
  
mA0B/B/CQ,FGH;;..:dkkAAAFZZ'FdkkAAAF%4%%ffdooUUDY,F'+'7D7V#CVC%#33!//))
 	
rA   c	           	         |D||d d |j                   d    d f   }n(|j                   d   |j                   d   k7  r	|d d |f   }|t|r|j                         j                  d      dz
  }|j                  |dk(  d       |r9|d d |j                   d    d f   }|j	                  t
        j                        }||d   dk(  r|d d}
n#|j	                  t
        j                        d d}
t        |t              r|j                  dk(  r| j                  j                  dk(  s|
d	   #|
d	   j                   \  }}}|
d	   j                  }n!|
d
   j                   \  }}|
d
   j                  }| j                  j                  |||j                         | j                   j"                  j$                  |||      }|||
d<   |
j'                  |||||d       |
S )Nr   rG   rp   )memory_format)r  r
  )r
  r  r   r   r  r
  r,  r:  )r   r   r    r3   r   )r{   longcumsummasked_fill_clonerv   contiguous_formatr/  r	   ndimr[   r   r   r7  r1  r0  r>  weightrr   r   )r=   r
  r    r   r  r   r   r3   r:  r>   model_inputsr  r-  r   r   s                  r@   prepare_inputs_for_generationz/Gemma2ForCausalLM.prepare_inputs_for_generation/  s   " &(%a.*>*>q*A)A)C&CD	#~';';A'>>%a&78	%,*>)..077;a?L%%n&91=+A	0B/B/D,DE  ,11@W@W1X $):a)?-:NL *3uG^G^)_rvwL 4##q(KK448KKO,81=o1N1T1T.
OQ%o6==.:;.G.M.M+
O%k299!ZZ]] /-AACll))//-% ^ N )1CL-. ,"0#2&"0	
 rA   )NNNNNNNNNNNr   )NNNNNTN)rJ   rK   rL   r(   rv   r   r   r   r	   r   r   r   r   r   r   rf   rL  rP   rQ   s   @r@   r4  r4    s:    '+15371559-1$(,0/3&*59"#N
##N
 !.N
 u//0	N

 "+.N
   1 12N
 ))*N
 D>N
 $D>N
 'tnN
 d^N
 !!1!12N
  N
 
u,,	-N
f LrA   r4  c                        e Zd Z fdZ xZS )Gemma2ForSequenceClassificationc                 d    t         |   |       t        |      | _        | j	                          y rd   r6  rb   s     r@   r(   z(Gemma2ForSequenceClassification.__init__  r8  rA   rJ   rK   rL   r(   rP   rQ   s   @r@   rN  rN  ~       rA   rN  c                        e Zd Z fdZ xZS )Gemma2ForTokenClassificationc                 d    t         |   |       t        |      | _        | j	                          y rd   r6  rb   s     r@   r(   z%Gemma2ForTokenClassification.__init__  r8  rA   rP  rQ   s   @r@   rS  rS    rQ  rA   rS  r   )@typingr   r   r   rv   torch.nnr\   torch.utils.checkpointactivationsr   cache_utilsr   r	   configuration_utilsr
   modeling_outputsr   r   utilsr   r   r   r   gemma.modeling_gemmar   r   r   r   r   r   r   r   r   modeling_flash_attention_utilsr   !torch.nn.attention.flex_attentionr   _CHECKPOINT_FOR_DOC
get_loggerrJ   r   r   rS   ModulerW   rh   r   r   float16rr   r   r   r   r   r   r   r   r   r   r   r  r4  rN  rS  r&   rA   r@   <module>rd     s    * )    ! - 3 
 
 
 JU#@ ) 			H	%B9# B9J	L 	P		 P	0 	%%<<% 
% <<	%
 5<<
 % 5<<%&%F !&**<<* 
* <<	*
 5<<
 * ++* 5<<*f $%%<<% 
% <<	%
 5<<
 % % 5<<%,,//0%D$$<<$ 
$ <<	$
 5<<
 $ 5<<$P 1,$"	 U9bii U9p
O 

/ 
F FR0 $\*3 \~b( bJ&D #> rA   