
    sg                    n   d Z ddlZddlZddlmZ ddlmZmZ ddlm	Z
 ddlZddlmZ ddlmZmZmZ ddlmZmZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZ ddlm Z  ddl!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' ddl(m)Z)m*Z*m+Z+m,Z,m-Z- ddl.m/Z/m0Z0m1Z1m2Z2 ddl3m4Z4  e1jj                  e6      Z7dZ8dZ9ejt                  Z:ejv                  fdejx                  fdZ=dZ>dZ?dZ@dZA G d de
j                        ZC G d de
j                        ZD G d d e
j                        ZE G d! d"e
j                        ZF G d# d$e
j                        ZG G d% d&e
j                        ZH G d' d(e
j                        ZI G d) d*e
j                        ZJ G d+ d,e*      ZK e/d-e>       G d. d/eK             ZL e+eLe8e&e9        G d0 d1e
j                        ZM e/d2e>       G d3 d4eK             ZNd5ZO e-eNe?eOz           e,eNe%e96        G d7 d8e
j                        ZP e/d9e>       G d: d;eK             ZQd<ZR e-eQe?eRz           e,eQe'e96       y)=zFlax whisper model.    N)partial)OptionalTuple)
FrozenDictfreezeunfreeze)combine_masksmake_causal_mask)partitioning)dot_product_attention_weights)flatten_dictunflatten_dict)lax)PRNGKey   )#FlaxWhisperTimeStampLogitsProcessor)FlaxBaseModelOutput-FlaxBaseModelOutputWithPastAndCrossAttentions%FlaxCausalLMOutputWithCrossAttentionsFlaxSeq2SeqLMOutputFlaxSeq2SeqModelOutputFlaxSequenceClassifierOutput)ACT2FNFlaxPreTrainedModelappend_call_sample_docstring append_replace_return_docstringsoverwrite_call_docstring)add_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )WhisperConfigzopenai/whisper-tinyr#   returnc                    |\  }}|dz  dk7  rt        d| d      t        j                  d      |dz  dz
  z  }t        j                  | t        j
                  |dz        z        }t        j
                  |      j                  dd      |j                  dd      z  }t        j                  t        j                  |      t        j                  |      gd      j                  |      S )	z*Returns sinusoids for positional embedding   r   zVNumber of channels has to be divisible by 2 for sinusoidal positional embeddings, got z
 channels.i'  r"   axis)
ValueErrormathlogjnpexparangereshapeconcatenatesincosastype)keyshapedtypelengthchannelslog_timescale_incrementinv_timescalesscaled_times           d/var/www/html/venv/lib/python3.12/site-packages/transformers/models/whisper/modeling_flax_whisper.pysinusoidal_embedding_initr>   >   s    FH!|qdemdnnxy
 	
 #hhuoQ1BCWW55

8q=8QQRN**V$,,R3n6L6LQPR6SSK??CGGK0#''+2FGaPWWX]^^    a  
    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.) This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
    Finally, this model supports inherent JAX features such as:
    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
            inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
            and [`~FlaxPreTrainedModel.to_bf16`].
a  
    Args:
        input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
            Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
            loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
            the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
            [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
            tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
        attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
            is not used. By default the silence in the input log mel spectrogram are ignored.
        decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
            [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
            [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
            the starting token for `decoder_input_ids` generation.
        decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
            in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
            use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
            spectrogram are ignored.
        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
            range `[0, config.max_position_embeddings - 1]`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a  
    Args:
        input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
            Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
            loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
            the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
            [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
            tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
        attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
            is not used. By default the silence in the input log mel spectrogram are ignored.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a
  
    Args:
        decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
            Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
            [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
            [What are decoder input IDs?](../glossary#decoder-input-ids)
        encoder_outputs (`tuple(tuple(numpy.ndarray)`):
            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
        encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
           Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
            but it is not used. By default the silence in the input log mel spectrogram are ignored.
        decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
            in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
            range `[0, config.max_position_embeddings - 1]`.
        past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
            Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
            auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                      e Zd ZU eed<   eed<   eed<   dZeed<   dZe	ed<   dZ
e	ed	<   ej                  Zej                  ed
<   ddZ	 	 	 	 ddej                  deej                     deej                     de	de	deej                     fdZdej                  fdZdej                  fdZej,                  deej                  ej                  ej                  f   fd       Zy)FlaxWhisperAttentionconfig	embed_dim	num_heads        dropoutFcausalTbiasr7   r$   Nc                    | j                   | j                  z  | _        | j                  | j                  z  | j                   k7  r&t        d| j                    d| j                   d      t	        t
        j                  | j                   | j                  t        j
                  j                  j                  | j                  j                              } || j                        | _         |d      | _         || j                        | _         || j                        | _        | j$                  r>t'        t)        j*                  d| j                  j,                  fd	      d	      | _        y y )
Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r7   kernel_init)use_biasFr"   boolr7   )rC   rD   head_dimr*   r   nnDenser7   jaxinitializersnormalrB   init_stdrH   q_projk_projv_projout_projrG   r
   r-   onesmax_target_positionscausal_mask)selfdenses     r=   setupzFlaxWhisperAttention.setup   s   $..8==4>>)T^^;MdnnM]$T^^$4B8 
 HHNN**++224;;3G3GH	
 TYY/U+TYY/tyy1;;/!T[[==>fMU[ D r?   hidden_stateskey_value_statesattention_mask
init_cachedeterministicc                    |d u}|j                   d   }| j                  |      }|r#| j                  |      }	| j                  |      }
n"| j                  |      }	| j                  |      }
| j	                  |      }| j	                  |	      }	| j	                  |
      }
| j
                  r|j                   d   |	j                   d   }}| j                  dd      r[| j                  d   d   }| j                  d   d   j                   d   }t        j                  | j                  dd|dfdd||f      }n| j                  d d d d d |d |f   }t        j                  ||f|j                   dd  z         }|N| j
                  rBt        j                  t        j                  |d      j                         }t        ||      }n(| j
                  r}n|t        j                  |d      }| j
                  r,| j                  dd      s|r| j                  |	|
||      \  }	}
}|t        j                   |dkD  t        j"                  |j                   d      j%                  | j&                        t        j"                  |j                   t        j(                  | j&                        j*                        j%                  | j&                              }nd }d }|s | j,                  dkD  r| j/                  d	      }t1        ||	||| j,                  d
|| j&                  d 	      }t        j2                  d||
      }| j5                  |      }| j7                  |      }||fS )Nr   r"   cache
cached_keycache_index)r(   rE   rF   T)rH   dropout_rngdropout_ratebroadcast_dropoutrd   r7   	precisionz...hqk,...khd->...qhd)r6   rV   rW   rX   _split_headsrG   has_variable	variablesr   dynamic_slicer\   r-   broadcast_toexpand_dimsr	   _concatenate_to_cacheselectfullr4   r7   finfominrF   make_rngr   einsum_merge_headsrY   )r]   r`   ra   rb   rc   rd   is_cross_attention
batch_sizequery_states
key_statesvalue_statesquery_length
key_length
mask_shiftmax_decoder_lengthr\   attention_biasrk   attn_weightsattn_outputs                       r=   __call__zFlaxWhisperAttention.__call__   s(    .T9"((+
{{=1%56J;;'78L]3J;;}5L((6&&z2
((6;;'3'9'9!'<j>N>Nq>Q*L  ,7!^^G4]C
%)^^G%<\%J%P%PQR%S"!//$$:q)<);< #..q!]l]KZK/OP**;HYHYZ[Z\H]8]^K %$++ --coonS[.\^i^o^opN*>;GN[[(N' __^(KN
 ;;D--g|D
7;7Q7QL,84Jn
 % ZZ"--s3::4::F--syy/D/H/HIPPQUQ[Q[\N "N!3--	2K4#"'**

 jj!8,U''4mmK0L((r?   c                 p    |j                  |j                  d d | j                  | j                  fz         S Nr&   )r0   r6   rD   rO   r]   hidden_states     r=   ro   z!FlaxWhisperAttention._split_heads:  s2    ##L$6$6r$:dnndmm=\$\]]r?   c                 Z    |j                  |j                  d d | j                  fz         S r   )r0   r6   rC   r   s     r=   r|   z!FlaxWhisperAttention._merge_heads=  s,    ##L$6$6r$:dnn=N$NOOr?   c                 (   | j                  dd      }| j                  ddt        j                  |j                  |j
                        }| j                  ddt        j                  |j                  |j
                        }| j                  ddd       }|r|j                  j                  ^ }	}
}}|j                  }dt        |	      z  |ddfz   }t        j                  |j                  ||      }t        j                  |j                  ||      }||_        ||_        |j                  d   }|j                  |z   |_        t        j                  t        j                  |
      ||z   k  t        |	      d||
fz         }t        ||      }|||fS )	Nrf   rg   cached_valuerh   c                  L    t        j                  dt         j                        S )Nr   rN   )r-   arrayint32 r?   r=   <lambda>z<FlaxWhisperAttention._concatenate_to_cache.<locals>.<lambda>F  s    CIIaWZW`W`Da r?   )r   r   r"   )rp   variabler-   zerosr6   r7   valuelenr   dynamic_update_slicers   r/   tupler	   )r]   r5   r   queryrb   is_initializedrg   r   rh   
batch_dims
max_lengthrD   depth_per_head	cur_indexindicesnum_updated_cache_vectorspad_masks                    r=   ru   z*FlaxWhisperAttention._concatenate_to_cache@  st    **7LA]]7L#))SYYPSPYPYZ
}}WnciiV[VaVabmmG]<abAKAQAQAWAW>ZY#))IS_,	1a/@@G**:+;+;S'JC,,\-?-?PE"J!&L(-A% + 1 14M MK ''

:&5N)NNj!Q(A:$NNH +8^DNE>))r?   r$   N)NNFT)__name__
__module____qualname__r#   __annotations__intrF   floatrG   rM   rH   r-   float32r7   r_   ndarrayr   r   r   ro   r|   rP   compactru   r   r?   r=   rA   rA      s   NNGUFDD${{E399"8 3704 "V){{V) #3;;/V) !-	V)
 V) V) 
s{{	V)p^CKK ^PCKK P ZZ*%PSP[P[]`]h]hjmjujuPuJv * *r?   rA   c                       e Zd ZU eed<   ej                  Zej                  ed<   ddZ	 	 ddej                  dej                  de
d	e
deej                     f
d
Zy)FlaxWhisperEncoderLayerrB   r7   r$   Nc                 L   | j                   j                  | _        t        | j                   | j                  | j                   j                  | j                   j
                  | j                        | _        t        j                  | j                  d      | _
        t        j                  | j                   j                        | _        t        | j                   j                     | _        t        j                  | j                   j"                        | _        t        j&                  | j                   j(                  | j                  t*        j                  j,                  j/                  | j                   j0                              | _        t        j&                  | j                  | j                  t*        j                  j,                  j/                  | j                   j0                              | _        t        j                  | j                  d      | _        y )NrB   rC   rD   rF   r7   h㈵>r7   epsilonraterJ   )rB   d_modelrC   rA   encoder_attention_headsattention_dropoutr7   	self_attnrP   	LayerNormself_attn_layer_normDropoutrF   dropout_layerr   activation_functionactivation_fnactivation_dropoutactivation_dropout_layerrQ   encoder_ffn_dimrR   rS   rT   rU   fc1fc2final_layer_normr]   s    r=   r_   zFlaxWhisperEncoderLayer.setupd  sV   ,,-;;nnkk99KK11**
 %'LLtzz5$Q!ZZT[[-@-@A#DKK$C$CD(*

8V8V(W%88KK''**++224;;3G3GH

 88NN$**#&&:M:M:T:TUYU`U`UiUi:j
 !#4::u Mr?   r`   rb   output_attentionsrd   c                 |   |}| j                  |      }| j                  ||      \  }}| j                  ||      }||z   }|}| j                  |      }| j	                  | j                  |            }| j                  ||      }| j                  |      }| j                  ||      }||z   }|f}|r||fz  }|S )N)r`   rb   rd   )r   r   r   r   r   r   r   r   )r]   r`   rb   r   rd   residualr   outputss           r=   r   z FlaxWhisperEncoderLayer.__call__{  s     !11-@&*nn=aon&p#|**=*V =0 --m<**488M+BC55mS`5a/**=*V =0 "&Gr?   r   )TT)r   r   r   r#   r   r-   r   r7   r_   r   rM   r   r   r   r?   r=   r   r   `  sn    {{E399"N6 #'"{{   	
  
s{{	r?   r   c            	           e Zd ZU eed<   ej                  Zej                  ed<   dZe	ed<   d Z
	 	 	 	 dde	de	de	d	e	fd
Zy)!FlaxWhisperEncoderLayerCollectionrB   r7   Fgradient_checkpointingc           	         | j                   rjt        t        d      }t        | j                  j
                        D cg c]*  } || j                  t        |      | j                        , c}| _        n[t        | j                  j
                        D cg c]-  }t        | j                  t        |      | j                        / c}| _        | j                  j                  | _
        y c c}w c c}w )N)r&   r   static_argnumsnamer7   )r   rematr   rangerB   encoder_layersstrr7   layersencoder_layerdrop	layerdrop)r]   !FlaxWhisperEncoderCheckpointLayeris      r=   r_   z'FlaxWhisperEncoderLayerCollection.setup  s    &&056M^d0e- t{{99: 2$++CFRVR\R\]DK t{{99: (#a&

SDK 66
   /C.2C3rd   r   output_hidden_statesreturn_dictc                 6   |rdnd }|rdnd }| j                   D ]P  }	|r||fz   }t        j                  dd      }
|s|
| j                  k  rd}n |	||||      }|d   }|sH||d   fz   }R |r||fz  }|||f}|st	        d |D              S t        |||      S )Nr   r   r"   )NNc              3   &   K   | ]	  }||  y wNr   .0vs     r=   	<genexpr>z=FlaxWhisperEncoderLayerCollection.__call__.<locals>.<genexpr>       =qq}=   last_hidden_stater`   
attentions)r   randomuniformr   r   r   )r]   r`   rb   rd   r   r   r   all_attentionsall_hidden_statesencoder_layerdropout_probabilitylayer_outputsr   s                r=   r   z*FlaxWhisperEncoderLayerCollection.__call__  s      1d"6BD![[ 	FM#$58H$H!"(..A"6 &9DNN&J , -!"%!	! *!,M !/=3C2E!E!	F$  -!11 "3^D=G==="+;LYg
 	
r?   N)TFFTr   r   r   r#   r   r-   r   r7   r   rM   r_   r   r   r?   r=   r   r     se    {{E399"#(D(7$ #"'%* (
 	(

  (
 #(
 (
r?   r   c                       e Zd ZU eed<   ej                  Zej                  ed<   ddZ	 	 	 	 	 ddej                  dej                  de
ej                     d	e
ej                     d
edededeej                     fdZy)FlaxWhisperDecoderLayerrB   r7   r$   Nc                 P   | j                   j                  | _        t        | j                   | j                  | j                   j                  | j                   j
                  d| j                        | _        t        j                  | j                   j                        | _        t        | j                   j                     | _        t        j                  | j                   j                        | _        t        j"                  | j                  d      | _        t        | j                   | j                  | j                   j                  | j                   j
                  | j                        | _        t        j"                  | j                  d      | _        t        j*                  | j                   j,                  | j                  t.        j                  j0                  j3                  | j                   j4                              | _        t        j*                  | j                  | j                  t.        j                  j0                  j3                  | j                   j4                              | _        t        j"                  | j                  d      | _        y )NT)rB   rC   rD   rF   rG   r7   r   r   r   r   rJ   )rB   r   rC   rA   decoder_attention_headsr   r7   r   rP   r   rF   r   r   r   r   r   r   r   r   encoder_attnencoder_attn_layer_normrQ   decoder_ffn_dimrR   rS   rT   rU   r   r   r   r   s    r=   r_   zFlaxWhisperDecoderLayer.setup  s   ,,-;;nnkk99KK11**
  ZZT[[-@-@A#DKK$C$CD(*

8V8V(W%$&LLtzz5$Q!0;;nnkk99KK11**
 (*||$**e'T$88KK''**++224;;3G3GH

 88NN$**#&&:M:M:T:TUYU`U`UiUi:j
 !#4::u Mr?   r`   rb   encoder_hidden_statesencoder_attention_maskrc   r   rd   c                    |}| j                  |      }| j                  |||      \  }}	| j                  ||      }||z   }d }
|B|}| j                  |      }| j	                  |||      \  }}
| j                  ||      }||z   }|}| j                  |      }| j                  | j                  |            }| j                  ||      }| j                  |      }| j                  ||      }||z   }|f}|r||	|
fz  }|S )N)r`   rb   rc   r   )r`   ra   rb   )
r   r   r   r   r   r   r   r   r   r   )r]   r`   rb   r   r   rc   r   rd   r   self_attn_weightscross_attn_weightsr   s               r=   r   z FlaxWhisperDecoderLayer.__call__  sZ    !11-@ ,0>>'S] ,: ,
(( **=*V =0 " ,$H 88GM040A0A+!65 1B 1-M-
 !..}M.ZM$}4M !--m<**488M+BC55mS`5a/**=*V =0 ")+=>>Gr?   r   )NNFTT)r   r   r   r#   r   r-   r   r7   r_   r   r   rM   r   r   r   r?   r=   r   r     s    {{E399"NJ 8<8< "&"0{{0 0  (4	0
 !) 50 0  0 0 
s{{	0r?   r   c                       e Zd ZU eed<   ej                  Zej                  ed<   dZe	ed<   d Z
	 	 	 	 	 	 	 ddeej                     deej                     d	e	d
e	de	de	de	fdZy)!FlaxWhisperDecoderLayerCollectionrB   r7   Fr   c           	         | j                   rjt        t        d      }t        | j                  j
                        D cg c]*  } || j                  t        |      | j                        , c}| _        n[t        | j                  j
                        D cg c]-  }t        | j                  t        |      | j                        / c}| _        | j                  j                  | _
        y c c}w c c}w )N)         r   r   )r   r   r   r   rB   decoder_layersr   r7   r   decoder_layerdropr   )r]   !FlaxWhisperDecoderCheckpointLayerr   s      r=   r_   z'FlaxWhisperDecoderLayerCollection.setup4  s    &&056M^g0h- t{{99: 2$++CFRVR\R\]DK t{{99: (#a&

SDK 66
r   Nr   r   rd   rc   r   r   r   c
           
      h   |rdnd }
|rdnd }|r|dnd }| j                   D ]_  }|r|
|fz  }
t        j                  dd      }|s|| j                  k  rd}n ||||||||      }|d   }|sK||d   fz  }|W||d   fz  }a |r|
|fz  }
||
||g}|	st	        d |D              S t        ||
||      S )Nr   r   r"   NNNr&   c              3   &   K   | ]	  }||  y wr   r   r   s     r=   r   z=FlaxWhisperDecoderLayerCollection.__call__.<locals>.<genexpr>s  r   r   r   r`   r   cross_attentions)r   r   r   r   r   r   )r]   r`   rb   r   r   rd   rc   r   r   r   r   all_self_attnsall_cross_attentionsdecoder_layerr   r   r   s                    r=   r   z*FlaxWhisperDecoderLayerCollection.__call__B  s    #7BD0d&7<Q<]rdh![[ 	@M#!m%55!"(..A"6 &9DNN&J 2 -!")*%!! *!,M =#3"55(4(]1-=,??(/	@4  -!11 "3^EYZ=G===<++%1	
 	
r?   )NNTFFFT)r   r   r   r#   r   r-   r   r7   r   rM   r_   r   r   r   r   r?   r=   r  r  /  s    {{E399"#(D(7$ 8<8<" "'%* 8
  (4	8

 !) 58
 8
 8
  8
 #8
 8
r?   r  c                       e Zd ZU eed<   ej                  Zej                  ed<   dZe	ed<   ddZ
	 	 	 	 ddej                  d	e	d
e	de	de	deej                     fdZy)FlaxWhisperEncoderrB   r7   Fr   r$   Nc           	         t        j                  | j                  j                  ddt        j                   j
                  j                  | j                  j                        | j                        | _	        t        j                  | j                  j                  dddt        j                   j
                  j                  | j                  j                        | j                        | _
        t        j                  | j                  j                        | _        t        | j                  | j                  | j                        | _        t        j"                  | j                  j$                  | j                  j                  | j                  t&              | _        t        j*                  | j                  d	
      | _        y )N)r   r"   )kernel_sizepaddingrK   r7   r&   )r  stridesr  rK   r7   r   r7   r   )r7   embedding_initr   r   )rP   ConvrB   r   rR   rS   rT   rU   r7   conv1conv2r   rF   r   r   r   r   Embedmax_source_positionsr>   embed_positionsr   
layer_normr   s    r=   r_   zFlaxWhisperEncoder.setup  s.   WWKK++224;;3G3GH**

 WWKK++224;;3G3GH**

  ZZT[[-@-@A7KK**#'#>#>
  "xxKK,,KK**4	 
 ,,TZZGr?   input_featuresr   r   r   rd   c           	         |j                   dd  | j                  j                  | j                  j                  dz  fk7  rMt	        d|j                   dd   d| j                  j                   d| j                  j                  dz   d      |j                  ddd      }t        j                  j                  | j                  |      d	      }t        j                  j                  | j                  |      d	      }| j                  t        j                  | j                  j                              }t        j                  j                  |      }||z   }| j!                  ||
      }| j#                  |d ||||      }|d   }	| j%                  |	      }	d }|r|d   }|d d |	fz   }|s#|	|f|r|dd  n|dd  z   }t'        d |D              S t)        |	||j*                        S )Nr"   r&   zqinput_features.shape[1:], must be equal to (self.config.num_mel_bins, self.config.max_source_positions * 2) (got z, but should be (z, z))r   F)approximater   )rb   rd   r   r   r   r'   c              3   &   K   | ]	  }||  y wr   r   r   s     r=   r   z.FlaxWhisperEncoder.__call__.<locals>.<genexpr>  r   r   r   )r6   rB   num_mel_binsr  r*   	transposerR   rP   gelur  r  r   r-   r/   r   stop_gradientr   r   r!  r   r   r   )
r]   r"  r   r   r   rd   r`   r   r   last_hidden_statess
             r=   r   zFlaxWhisperEncoder.__call__  s    #(@(@$++BbBbefBf'gg??M?S?STUTV?W>X Y[[--.b1Q1QTU1U0VVXZ  (11!Q:DJJ~$>ERDJJ}$=5Q..szz$++:Z:Z/[\''//@%7**=*V++'/!5#  
 %QZ!__-?@ #AJM)#2.2D1FFM)=9L`WQR[fmnonpfqrG=G==="0'))
 	
r?   r   FFTT)r   r   r   r#   r   r-   r   r7   r   rM   r_   r   r   r   r   r?   r=   r  r  }  s    {{E399"#(D( HJ #(%* "4
4
  4
 #	4

 4
 4
 
s{{	4
r?   r  c                      e Zd ZU eed<   ej                  Zej                  ed<   dZe	ed<   ddZ
	 	 	 	 	 	 ddej                  d	ej                  d
ej                  deej                     de	de	de	de	de	deej                     fdZy)FlaxWhisperDecoderrB   r7   Fr   r$   Nc                 L   t        j                  | j                  j                  | j                  j                  | j
                        | _        t        j                  | j                  j                  | j                  j                  | j
                        | _        t        | j                  | j
                  | j                        | _        t        j                  | j                  j                        | _        t        j                  | j
                  d      | _        y )NrN   r  r   r   r   )rP   r  rB   
vocab_sizer   r7   embed_tokensr[   r   r  r   r   r   rF   r   r   r!  r   s    r=   r_   zFlaxWhisperDecoder.setup  s    HHT[[%;%;T[[=P=PX\XbXbc!xx(H(H$++J]J]eieoeop7KKtzz$B]B]
  ZZT[[-@-@A,,TZZFr?   	input_idsrb   position_idsr   rc   r   r   r   rd   c
           
         | j                  |      }
| j                  |      }|
|z   }| j                  ||	      }| j                  ||||	||||      }|d   }| j	                  |      }d }|r|d   }|d d |fz   }|s#||f|r|dd  n|dd  z   }t        d |D              S t        |||j                  |j                        S )	Nr   )rb   r   rd   rc   r   r   r   r   r"   r'   r&   c              3   &   K   | ]	  }||  y wr   r   r   s     r=   r   z.FlaxWhisperDecoder.__call__.<locals>.<genexpr>  r   r   r  )	r0  r   r   r   r!  r   r   r   r  )r]   r1  rb   r2  r   rc   r   r   r   rd   input_embedsposition_embedsr`   r   r*  s                  r=   r   zFlaxWhisperDecoder.__call__  s    ((3..|<$6**=*V++)"7'!/!5#  	
 %QZ!__-?@ #AJM)#2.2D1FFM)=9L`WQR[fmnonpfqrG=G===<0'))$55	
 	
r?   r   )NFFFTT)r   r   r   r#   r   r-   r   r7   r   rM   r_   r   r   r   r   r   r?   r=   r-  r-    s    {{E399"#(D(
G" 8< "'%* "/
;;/
 /
 kk	/

  (4/
 /
  /
 #/
 /
 /
 
s{{	/
r?   r-  c                       e Zd ZU eed<   ej                  Zej                  ed<   dZe	ed<   ddZ
	 	 	 	 ddej                  dej                  d	ej                  d
ej                  de	de	de	de	fdZd Zd Zy)FlaxWhisperModulerB   r7   Fr   Nc                     t        | j                  | j                  | j                        | _        t        | j                  | j                  | j                        | _        y )Nr  )r  rB   r7   r   encoderr-  decoderr   s    r=   r_   zFlaxWhisperModule.setup#  sF    )KKtzz$B]B]
 *KKtzz$B]B]
r?   r"  decoder_input_idsdecoder_attention_maskdecoder_position_idsr   r   r   rd   c	           
      $   | j                  |||||      }	| j                  ||||	d   ||||      }
|s|
|	z   S t        |
j                  |
j                  |
j
                  |
j                  |	j                  |	j                  |	j
                        S )N)r   r   r   rd   r   )r1  rb   r2  r   r   r   r   rd   )r   decoder_hidden_statesdecoder_attentionsr  encoder_last_hidden_stater   encoder_attentions)r:  r;  r   r   r`   r   r  )r]   r"  r<  r=  r>  r   r   r   rd   encoder_outputsdecoder_outputss              r=   r   zFlaxWhisperModule.__call__+  s     ,,/!5#' ' 
 ,,'1-"1!"4/!5#' ' 	
 "_44%-??"1"?"?.99,==&5&G&G"1"?"?.99
 	
r?   c                     | j                   S r   )r:  r   s    r=   _get_encoder_modulez%FlaxWhisperModule._get_encoder_moduleV      ||r?   c                     | j                   S r   )r;  r   s    r=   _get_decoder_modulez%FlaxWhisperModule._get_decoder_moduleY  rH  r?   r   r+  )r   r   r   r#   r   r-   r   r7   r   rM   r_   r   r   rG  rJ  r   r?   r=   r8  r8    s    {{E399"#(D(
 #(%* ")
)
 ;;)
 !$	)

 "kk)
  )
 #)
 )
 )
Vr?   r8  c                   ~    e Zd ZU eZdZeed<   dZdZ	e
j                  ed<   ddej                  ddfd	ed
ee   dedej                   dedef fdZd Zd&dej*                  j,                  d
ededefdZd Z ee       eee      	 	 	 	 	 	 	 d'dej<                  deej<                     dee   dee   dee   dede defd              Z! ee"       ee#e      	 	 	 	 	 	 	 	 	 	 d(deej<                     deej<                     d eej<                     d!e dee   dee   dee   dede defd"              Z$ e%e&      	 	 	 	 	 	 	 	 	 	 d(dej<                  d#ej<                  deej<                     deej<                     d$eej<                     d eej<                     dee   dee   dee   dede defd%       Z' xZ(S ))FlaxWhisperPreTrainedModelmodelbase_model_prefixr"  Nmodule_classr   TFrB   input_shapeseedr7   _do_initr   c                      | j                   d|||d|}|d|j                  d|j                  z  f}t        	|   ||||||       y )NrB   r7   r   r"   r&   )rP  rQ  r7   rR  r   )rO  r&  r  super__init__)
r]   rB   rP  rQ  r7   rR  r   kwargsmodule	__class__s
            r=   rV  z#FlaxWhisperPreTrainedModel.__init__c  sd     #""w&Vlwpvwf111v7R7R3RSK[tSXcklr?   c                 ^    | j                  | j                  | j                  d      | _        y )NTrT  )rO  rB   r7   _moduler   s    r=   enable_gradient_checkpointingz8FlaxWhisperPreTrainedModel.enable_gradient_checkpointingr  s*    ((;;**#' ) 
r?   rngparamsr$   c                    t        j                  |d      }|j                  d   j                  | j                  j
                        }t        j                  |d   dfd      }t        j                  |      }|j                  \  }}t        j                  t        j                  |      d d d f   ||f      }	t        j                  j                  |      \  }
}|
|d}| j                  j                  |||||	      d	   }|dt        t!        |            }t        t!        |            }| j"                  D ]
  }||   ||<    t               | _        t%        t'        |            S |S )
Nf4rN   .r'   r   r"   i4r^  rF   )r"  r<  r=  r>  r^  )r-   r   atsetrB   eos_token_id	ones_liker6   rs   r/   rR   r   splitrX  initr   r   _missing_keysr   r   )r]   r]  rP  r^  r"  r<  r=  r~   sequence_lengthr>  
params_rngrk   rngsrandom_paramsmissing_keys                  r=   init_weightsz'FlaxWhisperPreTrainedModel.init_weightsy  s_   ;d;'**9599$++:R:RSII{1~q&9F!$/@!A&7&=&=#
O"//

?0KDRSG0TWacrVst"%**"2"23"7
K$=(()/#9!5 ) 
  (-)@AM!(6"23F#11 A&3K&@{#A!$D.011  r?   c           	         t        j                  ||fd      }t        j                  |      }t        j                  t        j                  t        j
                  |      j                  d         |j                        }d }| j                  j                  t        j                  j                  d      ||||d   d|      }t        |d         S )	a+  
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
                cross-attention of the decoder.
        rb  rN   r'   c                 8    | j                         } ||||fi |S r   rJ  rX  r<  r=  r>  rW  decoder_modules         r=   _decoder_forwardz?FlaxWhisperPreTrainedModel.init_cache.<locals>._decoder_forward  s0    #779N!!&$ 	 r?   r   T)r<  r=  r>  r   rc   methodrf   )r-   rZ   rg  rs   r/   
atleast_2dr6   rX  ri  rR   r   r   r   )	r]   r~   r   rD  r<  r=  r>  rv  init_variabless	            r=   rc   z%FlaxWhisperPreTrainedModel.init_cache  s      HHj*%=TJ!$/@!A"//JJs~~&78>>rBCEVE\E\ 
	 ))JJq!/#9!5"1!"4# * 
 w/00r?   output_typeconfig_classrb   r   r   r   trainrk   c	           
      H   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }i }
|||
d<   d }| j                  j                  d|xs | j                  it        j                  |d      |||| |
|      S )a  
        Returns:

        Example:

        ```python
        >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
        >>> from datasets import load_dataset

        >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
        >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
        >>> input_features = inputs.input_features
        >>> encoder_outputs = model.encode(input_features=input_features)
        ```rF   c                 4    | j                         } ||fi |S r   )rG  )rX  r"  rW  encode_modules       r=   _encoder_forwardz;FlaxWhisperPreTrainedModel.encode.<locals>._encoder_forward  s     "668M :6::r?   r^  r`  rN   )r"  r   r   r   rd   rm  rw  	rB   r   r   r   rX  applyr^  r-   r   )r]   r"  rb   r   r   r   r}  r^  rk   rW  rm  r  s               r=   encodez!FlaxWhisperPreTrainedModel.encode  s    < 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++BYBY ")DO	; {{  v,-99^4@/!5##)# ! 	
 		
r?   r   r=  r>  past_key_valuesc                 V   ||n| j                   j                  }||n| j                   j                  }|	|	n| j                   j                  }	|d   }|j                  \  }}|Y|t        d      ||j                  d      |z  dz
  }n2t        j                  t        j                  |      dddf   ||f      }|t        j                  ||f      }i }|||d<   d|xs | j                  i}|r	||d<   dg}nd	}d
 }| j                  j                  |t        j                  |d      t        j                  |d      t        j                  |d      ||||	|
 |||      }||	r|\  }}t        |d         |d<   |S |"|	s |\  }}|dd t        |d         fz   |dd z   }|S )a  
        Returns:

        Example:

        ```python
        >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
        >>> from datasets import load_dataset
        >>> import jax.numpy as jnp

        >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
        >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features

        >>> encoder_outputs = model.encode(input_features=input_features)
        >>> decoder_start_token_id = model.config.decoder_start_token_id

        >>> decoder_input_ids = jnp.ones((input_features.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> last_decoder_hidden_states = outputs.last_hidden_state
        ```Nr   KMake sure to provide `decoder_position_ids` when passing `past_key_values`.r'   r"   rF   r^  rf   Fc                 :    | j                         } |d|||d|S )Nr1  rb   r2  r   rs  rt  s         r=   rv  z;FlaxWhisperPreTrainedModel.decode.<locals>._decoder_forwardK  s5    #779N! +51 	 r?   rb  rN   r<  r=  r>  r   r   r   r   rd   rm  mutablerw  r  )rB   r   r   r   r6   r*   cumsumr-   rs   r/   rZ   r^  rX  r  r   r   )r]   r<  rD  r   r=  r>  r  r   r   r   r}  r^  rk   r   r~   rk  rm  inputsr  rv  r   pasts                         r=   decodez!FlaxWhisperPreTrainedModel.decode  s   R 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++BYBY / 2&7&=&=#
O'* !noo%1(>(E(Eb(ILb(bfg'g$'*'7'7JJ/a8::W($ ")%(XXz?.K%L" ")DOF1dkk2
 -F7OiGG	 ++##!ii(9F#&99-C4#P!$+?t!L"7/!5##)# $ 
  &;#MGT)1$w-)@G%&N(#MGTbqkXd7m%<$>>LGr?   r<  r2  c                    ||n| j                   j                  }||n| j                   j                  }|	|	n| j                   j                  }	|[||j	                  d      |z  dz
  }nA|j
                  \  }}t        j                  t        j                  |      d d d f   ||f      }|t        j                  |      }|d|ini }| j                  j                  d|xs | j                  it        j                  |d      t        j                  |d      t        j                  |d      t        j                  |d      |||	|
 |
      S )	Nr'   r"   rF   r^  r`  rN   rb  )	r"  r<  r=  r>  r   r   r   rd   rm  )rB   r   r   r   r  r6   r-   rs   r/   rg  rX  r  r^  r   )r]   r"  r<  rb   r=  r2  r>  r   r   r   r}  r^  rk   r~   rk  rm  s                   r=   r   z#FlaxWhisperPreTrainedModel.__call__n  s^     2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++BYBY  '%1(>(E(Eb(ILb(bfg'g$.?.E.E+
O'*'7'7JJ/a8::W($ ")%(]]3D%E" ,7+B	;'{{  v,-99^4@!ii(9F#&99-C4#P!$+?t!L/!5##) ! 
 	
r?   r   NNNNFNN
NNNNNNNFNN))r   r   r   r#   r|  rN  r   r   main_input_namerO  rP   Moduler-   r   r   r   r7   rM   rV  r\  rR   r   r   r   rp  rc   r   WHISPER_ENCODE_INPUTS_DOCSTRINGr!   r   r   r   dictr  WHISPER_DECODE_INPUTS_DOCSTRINGr   r  r   WHISPER_INPUTS_DOCSTRINGr   __classcell__rY  s   @r=   rL  rL  ]  s9    L$s$&O"L"))"
 #';;',mm 3Zm 	m
 yym m !%m
!

 2 2 ! !PZ !fp !B'1R 9:+>][ 15,0/3&*#4
4
 !-4
 $D>	4

 'tn4
 d^4
 4
 4
 4
 \ ;4
l 9:+Xgtu
 9=8<6: $,0/3&*#o !) 5	o
 !) 5o 's{{3o o $D>o 'tno d^o o o o v ;ob ++CD
 158<.26:,0/3&*#/
/
 ;;/
 !-	/

 !) 5/
 s{{+/
 's{{3/
 $D>/
 'tn/
 d^/
 /
 /
 /
 E/
r?   rL  zaThe bare Whisper Model transformer outputting raw hidden-states without any specific head on top.c                   R    e Zd ZU eed<   ej                  Zej                  ed<   eZ	y)FlaxWhisperModelrB   r7   N)
r   r   r   r#   r   r-   r   r7   r8  rO  r   r?   r=   r  r    s!    
 {{E399"$Lr?   r  c                       e Zd ZU eed<   ej                  Zej                  ed<   dZe	ed<   ddZ
d Zd Z	 	 	 	 	 	 	 	 dd	ej                  d
ej                  dej                  dej                  de	de	de	de	fdZy))FlaxWhisperForConditionalGenerationModulerB   r7   Fr   Nc                 T   t        | j                  | j                  | j                        | _        t        j                  | j                  j                  d| j                  t        j
                  j                  j                  | j                  j                              | _        y )NrT  F)rL   r7   rK   )r8  rB   r7   r   rM  rP   rQ   r/  rR   rS   rT   rU   lm_headr   s    r=   r_   z/FlaxWhisperForConditionalGenerationModule.setup  sn    &;;djjIdId

 xxKK""**++224;;3G3GH	
r?   c                 .    | j                   j                  S r   )rM  r:  r   s    r=   rG  z=FlaxWhisperForConditionalGenerationModule._get_encoder_module      zz!!!r?   c                 .    | j                   j                  S r   )rM  r;  r   s    r=   rJ  z=FlaxWhisperForConditionalGenerationModule._get_decoder_module  r  r?   r=  r>  r2  rb   r   r   r   rd   c           
         | j                  |||||||	|
      }|d   }| j                  j                  r[| j                   j                  j                  j
                  d   d   }| j                  j                  dd|j                  ii|      }n| j                  |      }|	s|f|dd  z   }|S t        ||j                  |j                  |j                  |j                  |j                  |j                        S )N)r"  r<  r=  r>  r   r   r   rd   r   r^  	embeddingkernelr"   )logitsr@  rA  r  rB  r   rC  )rM  rB   tie_word_embeddingsr;  r0  rq   r  r  Tr   r@  rA  r  rB  r   rC  )r]   r"  r<  r=  r>  r2  rb   r   r   r   rd   r   r`   shared_embedding	lm_logitsoutputs                   r=   r   z2FlaxWhisperForConditionalGenerationModule.__call__  s    **)/#9!5/!5#'  	
  
;;**#zz11>>HHRS^_**HxAQASAS6T+UWdeI]3I\GABK/FM"")"?"?&99$55&-&G&G")"?"?&99
 	
r?   r   )NNNNFFTT)r   r   r   r#   r   r-   r   r7   r   rM   r_   rG  rJ  r   r   r   r?   r=   r  r    s    {{E399"#(D(	
"" /3,0$(&*"'%* ",
 !$	,

 "kk,
 kk,
 ,
  ,
 #,
 ,
 ,
r?   r  z0The Whisper Model with a language modeling head.c                       e Zd ZU eZej                  Zej                  ed<    e	e
       eee      	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     dedee   dee   d	ee   d
ededefd              Z	 	 	 	 	 	 d fd	Z	 	 	 ddeej,                     deej,                     fdZd Z xZS )#FlaxWhisperForConditionalGenerationr7   rz  r   r=  r>  r  r   r   r   r}  r^  rk   c                     ||n j                   j                  }||n j                   j                  }|	|	n j                   j                  }	|d   }|j                  \  }}|Y|t        d      ||j                  d      |z  dz
  }n2t        j                  t        j                  |      dddf   ||f      }|t        j                  ||fd      }i }|||d<   d	|xs  j                  i}|r	||d
<   d
g}nd} fd} j                  j                  |t        j                  |d      t        j                  |d      t        j                  |d      ||||	|
 |||      }||\  }}n|\  \  }}}|	r.t        ||j                   |j"                  |j$                        }n	|f|dd z   }||	rt'        d
         |d<   |S ||	s|dd t'        d
         fz   |dd z   }|S )a  
        Returns:

        Example:

        ```python
        >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
        >>> from datasets import load_dataset

        >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
        >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
        >>> input_features = inputs.input_features
        >>> encoder_outputs = model.encode(input_features=input_features)
        >>> decoder_start_token_id = model.config.decoder_start_token_id

        >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> last_decoder_hidden_states = outputs.last_hidden_state
        ```Nr   r  r'   r"   rb  rN   rF   r^  rf   Fc                 X   | j                         } |d|||d|}|d   }
j                  j                  r^| j                  j                  j
                  j                  d   d   }| j                  j                  dd|j                  ii|      }	|	|fS | j                  |      }	|	|fS )Nr  r   r^  r  r  r   )
rJ  rB   r  rM  r;  r0  rq   r  r  r  )rX  r<  r=  r>  rW  ru  r   r`   r  r  r]   s             r=   rv  zDFlaxWhisperForConditionalGeneration.decode.<locals>._decoder_forwardE  s    #779N$ +51 	G $AJM{{..#)<<#7#7#D#D#N#Nx#XYd#e "NN00(XGWGYGY<Z1[]jk	 g%% #NN=9	g%%r?   r  )r  r`   r   r  r  )rB   r   r   r   r6   r*   r  r-   rs   r/   rZ   r^  rX  r  r   r   r`   r   r  r   )r]   r<  rD  r   r=  r>  r  r   r   r   r}  r^  rk   r   r~   rk  rm  r  r  rv  r   r  rE  r  s   `                       r=   r  z*FlaxWhisperForConditionalGeneration.decode  s[   N 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++BYBY / 2&7&=&=#
O'* !noo%1(>(E(Eb(ILb(bfg'g$'*'7'7JJ/a8::W($ ")%(XXz?.KSW%X" ")DOF1dkk2
 -F7OiGG	&$ ++##!ii(9F#&99-C4#P!$+?t!L"7/!5##)# $ 
 ")0&I18.(Y$; -;;*55!0!A!A	G !l_QR%88G &;)1$w-)@G%&N(bqkXd7m%<$>>LGr?   c                 D   || j                   }|||_        |||_        |||_        |||_        |d|v rt        |d         }	nd}	g }
t        |d      r|j                  rt        |d      r+|
j                  d|j                  |j                     f       n|
j                  d       t        |d      r+|
j                  d|j                  |j                     f       n |
j                  d|j                  d   f       t        |d	      r|j                  s|rt        || j                  |	      g}nC|
rA|
d
   d   |j                  k7  r,|
r|
d
   d   dz   nd}|
j                  ||j                  f       t        |
      dkD  r|
|_        t        | <  ||fd|i|S )Nr<  r"   is_multilinguallanguage)r"   Ntaskr&   
transcribereturn_timestampsr'   r   logits_processor)generation_configr  r  r  r  r   hasattrappend
lang_to_id
task_to_idr   rB   no_timestamps_token_idforced_decoder_idsrU  generate)r]   r"  r  r  r  r  r  r  rW  decoder_input_lengthr  idxrY  s               r=   r  z,FlaxWhisperForConditionalGeneration.generate~  s    $ $ 6 6(2C/%)"&0?-)1&"5"?#&v.A'B#C #$ $&78=N=^=^(*5"))1.?.J.JK\KeKe.f*gh")))4(&1"))1.?.J.JK\KaKa.b*cd"))1.?.J.J<.X*YZ %':;@Q@c@c34Et{{Thi  "&8&<Q&?CTCkCk&k7I(,Q/!3q"))30A0X0X*YZ!"Q&3E0w
 .
 	
 	
r?   rb   c                 L   |j                   \  }}| j                  |||      }	t        j                  ||fd      }
|,|j	                  d      dz
  }t        j                  |
|d      }
n4t        j                  t        j                  |d      d d d f   ||f      }|	|||
|dS )Nrb  rN   r'   r"   )r   r   )r  rD  r   r=  r>  )	r6   rc   r-   rZ   r  r   r   rs   r/   )r]   r<  r   rb   r=  rD  rW  r~   
seq_lengthr  extended_attention_maskr2  s               r=   prepare_inputs_for_generationzAFlaxWhisperForConditionalGeneration.prepare_inputs_for_generation  s     "3!8!8
J//*j/R #&((J
+C4"P!-188<q@L&)&>&>?VXnpv&w#++CJJz,NtUVw,WZdfpYqrL  /.&4&=$0
 	
r?   c                 L    |j                   |d<   |d   d d dd f   dz   |d<   |S )Nr  r>  r'   r"   )r  )r]   model_outputsmodel_kwargss      r=   update_inputs_for_generationz@FlaxWhisperForConditionalGeneration.update_inputs_for_generation  s?    *7*G*G&'/;<R/STUWYWZTZ/[^_/_+,r?   r  )NNNNNNr  )r   r   r   r  rO  r-   r   r7   r   r   r  r!   r   r#   r   r   r  rM   r   r  r  rR   Arrayr  r  r  r  s   @r=   r  r    sQ   <L{{E399"9:+P_lm
 9=8<6: $,0/3&*#B !) 5	B
 !) 5B 's{{3B B $D>B 'tnB d^B B B B n ;BN ?
J /36:
 !+	

 !) 3
>r?   r  al  
    Returns:

    Transcription example:

    ```python
    >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
    >>> from datasets import load_dataset

    >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
    >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
    >>> input_features = inputs.input_features
    >>> generated_ids = model.generate(input_ids=input_features)
    >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    >>> transcription
    ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
    ```
rz  c                   ~    e Zd ZU eed<   ej                  Zej                  ed<   dZe	ed<   d
dZ
	 	 	 	 dde	de	fd	Zy)'FlaxWhisperForAudioClassificationModulerB   r7   Fr   Nc                 
   t        | j                  | j                  | j                        | _        d| j                  _        | j                  j                  dz   }| j                  j                  rt        j                  d|z  |      | _
        t        j                  | j                  j                  | j                        | _        t        j                  | j                  j                  | j                        | _        y )NrT  Fr"   rN   )r  rB   r7   r   r:  is_encoder_decodernum_hidden_layersuse_weighted_layer_sumr-   repeatlayer_weightsrP   rQ   classifier_proj_size	projector
num_labels
classifier)r]   
num_layerss     r=   r_   z-FlaxWhisperForAudioClassificationModule.setup  s    );;djjIdId
 */&[[22Q6
;;--!$A
NJ!GD$++"B"B$**U((4;;#9#9Lr?   r   r   c                    ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|| j	                  ||||      }| j                   j
                  rst        j                  |d      }t        j                  j                  | j                  d      }t        j                  |t        j                  |g d      z  d      }n|d   }| j                  |      }t        j                  |d      }| j!                  |      }	|s	|	f|dd  z   S t#        |	|j$                  |j&                        S )N)r   r   r   r"   r(   r'   )r'   r"   r"   r   )r  r`   r   )rB   r   r   use_return_dictr:  r  r-   stackrR   rP   softmaxr  sumr0   r  meanr  r   r`   r   )
r]   r"  rD  r   r   r   r`   norm_weightspooled_outputr  s
             r=   r   z0FlaxWhisperForAudioClassificationModule.__call__  sB    2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B]""ll"3%9'	 + O ;;--IIoA>M66>>$*<*<2>FLGGMCKKj4Y$Y`abM+A.M}5Q7/9qr222+)77&11
 	
r?   r   )NNTTr   r   r?   r=   r  r    sR    {{E399"#(D(	M %) )

 #)
 )
r?   r  z;The Whisper Model with an audio classification head on top.c                      e Zd ZU eZej                  Zej                  ed<   dde	j                  j                  dededefdZ ee      	 	 	 	 	 	 	 ddej"                  d	eej"                     d
ee   dee   dee   dededefd       Zy)!FlaxWhisperForAudioClassificationr7   Nr]  rP  r^  r$   c                    t        j                  |d      }|j                  d   j                  | j                  j
                        }t        j                  j                  |      \  }}||d}| j                  j                  ||      d   }|dt        t        |            }t        t        |            }| j                  D ]
  }	||	   ||	<    t               | _        t        t        |            S |S )Nr`  rN   ra  rc  )r"  r^  )r-   r   rd  re  rB   rf  rR   r   rh  rX  ri  r   r   rj  r   r   )
r]   r]  rP  r^  r"  rl  rk   rm  rn  ro  s
             r=   rp  z.FlaxWhisperForAudioClassification.init_weightsB  s    ;d;'**9599$++:R:RS"%**"2"23"7
K$=(() ) 
 
 (-)@AM!(6"23F#11 A&3K&@{#A!$D.011  r?   r"  rb   r   r   r   r}  rk   c	                 <   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }i }
|||
d<   | j                  j                  d|xs | j                  it        j                  |d      ||||
      S )NrF   r^  r`  rN   )r"  r   r   r   rm  r  )r]   r"  rb   r   r   r   r}  r^  rk   rW  rm  s              r=   r   z*FlaxWhisperForAudioClassification.__call__Y  s     2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++BYBY ")DO{{  v,-99^4@/!5# ! 
 	
r?   r   r  )r   r   r   r  rO  r-   r   r7   r   rR   r   r   r   r   rp  r   r  r   r   rM   r  r   r   r?   r=   r  r  =  s    :L{{E399"!

 2 2 ! !PZ !fp !. ++CD 15,0/3&*#

 !-
 $D>	

 'tn
 d^
 
 
 
 E
r?   r  a  
    Returns:

    Transcription example:

    ```python
    >>> import jax.numpy as jnp
    >>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification
    >>> from datasets import load_dataset

    >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
    >>> model = FlaxWhisperForAudioClassification.from_pretrained(
    ...     "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True
    ... )
    >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True, trust_remote_code=True)

    >>> sample = next(iter(ds))

    >>> inputs = feature_extractor(
    ...     sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np"
    ... )
    >>> input_features = inputs.input_features

    >>> logits = model(input_features).logits

    >>> predicted_class_ids = jnp.argmax(logits).item()
    >>> predicted_label = model.config.id2label[predicted_class_ids]
    >>> predicted_label
    'af_za'
    ```
)S__doc__r+   r   	functoolsr   typingr   r   
flax.linenlinenrP   rR   	jax.numpynumpyr-   flax.core.frozen_dictr   r   r   r	   r
   r   nn_partitioningflax.linen.attentionr   flax.traverse_utilr   r   r   
jax.randomr   generation.flax_logits_processr   modeling_flax_outputsr   r   r   r   r   r   modeling_flax_utilsr   r   r   r   r   utilsr   r   r    r!   configuration_whisperr#   
get_loggerr   logger_CHECKPOINT_FOR_DOC_CONFIG_FOR_DOCr   float_r  r>   WHISPER_START_DOCSTRINGr  r  r  r  rA   r   r   r   r  r  r-  r8  rL  r  r  r  -FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRINGr  r  +FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRINGr   r?   r=   <module>r     sx       "  
  > > 6 6 > ;   Q   u t 0 
		H	% , ! 14

 
_syy 
_ 2# J# *# D\*299 \*@5bii 5p;
		 ;
~Ubii UpK
		 K
\[
 [
|@
 @
F<		 <~A
!4 A
H
 g%1 %	% -/BDZ\k lB
		 B
J HJabm*D m cm`1 -* ')ADq)q !'5HWf
9
bii 9
x SUlm:
(B :
 n:
z/ +@ %'?Bm'm !%3O^mr?   