
    sgZ                      d Z ddlmZ ddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZmZmZmZ ddlm Z  ddl!m"Z"  ejF                  e$      Z%dZ&e G d de             Z'e G d de             Z( G d de      Z)dZ*dZ+ ee*       G d de)             Z, ede*       G d de)e             Z- ede*       G d de)e             Z.y)zTFRAG model implementation.    )annotationsN)	dataclass)ListOptionalTupleUnion   )PretrainedConfig)TFLogitsProcessorList)TFCausalLanguageModelingLossTFModelInputTypeTFPreTrainedModelkeras
shape_listunpack_inputs)ModelOutput%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )	RagConfig)RagRetrieverr   c                      e Zd ZU dZdZded<   dZded<   dZded<   dZded	<   dZ	ded
<   dZ
ded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   y)TFRetrievAugLMMarginOutputa  
    Base class for retriever augmented marginalized models outputs.

    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss.
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
            each vocabulary token.
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
            (see `past_key_values` input) to speed up sequential decoding.
        doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`.
        retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
            Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
            the `doc_scores`.
        retrieved_doc_ids (`tf.Tensor` (int32) of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
            The indexes of the embedded documents retrieved by the retriever.
        context_input_ids (`tf.Tensor`(int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
        context_attention_mask (`tf.Tensor` (int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever.
        question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
            model.
        question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
        question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
            average in the self-attention heads.
        generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
        generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
        generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
            average in the self-attention heads.
        generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
        generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
            average in the self-attention heads.
    Ntf.Tensor | Noneloss	tf.TensorlogitsList[tf.Tensor] | Nonepast_key_values
doc_scoresretrieved_doc_embedsretrieved_doc_idscontext_input_idscontext_attention_mask"question_encoder_last_hidden_stateTuple[tf.Tensor, ...] | Nonequestion_enc_hidden_statesquestion_enc_attentionsgenerator_enc_last_hidden_stategenerator_enc_hidden_statesgenerator_enc_attentionsgenerator_dec_hidden_statesgenerator_dec_attentions)__name__
__module____qualname____doc__r   __annotations__r   r    r!   r"   r#   r$   r%   r&   r(   r)   r*   r+   r,   r-   r.        Z/var/www/html/venv/lib/python3.12/site-packages/transformers/models/rag/modeling_tf_rag.pyr   r   /   s    BH "D
!FI.2O+2#'J '-1*1*.'.*.'./3,3;?&(8??C <C<@9@8<#%5<@D!=D=A:A@D!=D=A:Ar5   r   c                      e Zd ZU dZdZded<   dZded<   dZded<   dZded	<   dZ	ded
<   dZ
ded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   dZded<   y)TFRetrievAugLMOutputa  
    Args:
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
            each vocabulary token.
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
            (see `past_key_values` input) to speed up sequential decoding.
        doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`.
        retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
            Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
            the `doc_scores`.
        retrieved_doc_ids (`tf.Tensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
            The indexes of the embedded documents retrieved by the retriever.
        context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
        context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever.
        question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
            model.
        question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
        question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
            average in the self-attention heads.
        generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
        generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
        generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
            average in the self-attention heads.
        generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
        generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
            average in the self-attention heads.
    Nr   r   r   r    r   r!   r"   r#   r$   r%   r&   r'   r(   r)   r*   r+   r,   r-   r.   )r/   r0   r1   r2   r   r3   r    r!   r"   r#   r$   r%   r&   r(   r)   r*   r+   r,   r-   r.   r4   r5   r6   r8   r8      s    >@ FI.2O+2#'J '-1*1*.'.*.'./3,3;?&(8??C <C<@9@8<#%5<@D!=D=A:A@D!=D=A:Ar5   r8   c                  F    e Zd ZdZeZdZdgZe	 	 	 d	 	 	 	 	 	 	 dd       Z	y)TFRagPreTrainedModela  
    RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
    Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.

    RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
    generator, the encoder and generator are trainable while the retriever is just an indexed dataset.

    ragposition_idsNc                   |j                         D ci c]%  \  }}|j                  d      r|t        d      d |' }}}|j                         D ci c]%  \  }}|j                  d      r|t        d      d |' }	}}|j                         D ]  }
|d|
z   = 
 |	j                         D ]  }
|d|
z   = 
 |j	                  dd      }|R|J d       ddlm} d|vrdd	lm} |j                  |      }||d<    |j                  |g|d
| j                  d|}|	j	                  dd      }|O|J d       ddlm} d|	vrdd	lm} |j                  |      }||	d<    |j                  |fd| j                  d|	}|j                  dd      }|+t        j                  |j                  |j                  fi |} | ||||      S c c}}w c c}}w )a  
        Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
        model checkpoints.

        Params:
            question_encoder_pretrained_model_name_or_path (`str`, *optional*):
                Information necessary to initiate the question encoder. Can be either:

                    - A string with the *shortcut name* of a pretrained model to load from cache or download, e.g.,
                      `google-bert/bert-base-uncased`.
                    - A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g.,
                      `dbmdz/bert-base-german-cased`.
                    - A path to a *directory* containing model weights saved using
                      [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case,
                      `question_encoder_from_pt` should be set to `True`.

            generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
                Information necessary to initiate the generator. Can be either:

                    - A string with the *shortcut name* of a pretrained model to load from cache or download, e.g.,
                      `google-t5/t5-small`.
                    - A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g.,
                      `facebook/bart-base`.
                    - A path to a *directory* containing model weights saved using
                      [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case,
                      `generator_from_pt` should be set to `True`.

            model_args (remaining positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            retriever ([`RagRetriever`], *optional*):
                The retriever to use.
            kwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                `output_attentions=True`).

                - To update the question_encoder configuration, use the prefix *question_encoder_* for each
                  configuration parameter.
                - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
                - To update the parent model configuration, do not use a prefix for each configuration parameter.

                Behaves differently depending on whether a `config` is provided or automatically loaded.

        Example:

        ```python
        >>> from transformers import RagRetriever, TFRagModel

        >>> # initialize a RAG from two pretrained models.
        >>> model = TFRagModel.from_pretrained_question_encoder_generator(
        ...     "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
        ... )
        >>> # alternatively, initialize from pytorch pretrained models can also be done
        >>> model = TFRagModel.from_pretrained_question_encoder_generator(
        ...     "facebook/dpr-question_encoder-single-nq-base",
        ...     "facebook/bart-base",
        ...     generator_from_pt=True,
        ...     question_encoder_from_pt=True,
        ... )

        >>> # saving model after fine-tuning
        >>> model.save_pretrained("./rag")

        >>> # load retriever
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # load fine-tuned model with retriever
        >>> model = TFRagModel.from_pretrained("./rag", retriever=retriever)
        ```question_encoder_N
generator_modelznIf `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined   TFAutoModelconfig)
AutoConfigquestion_encodernameload_weight_prefix	generatorzqIf `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be definedTFAutoModelForSeq2SeqLM)rF   rJ   rD   	retriever)items
startswithlenkeyspopauto.modeling_tf_autorC   auto.configuration_autorE   from_pretrainedrI   rL   getr   'from_question_encoder_generator_configsrD   )cls.question_encoder_pretrained_model_name_or_path'generator_pretrained_model_name_or_pathrM   
model_argskwargsargumentvaluekwargs_question_encoderkwargs_generatorkeyrF   rC   rE   question_encoder_configrJ   rL   generator_configrD   s                      r6   *from_pretrained_question_encoder_generatorz?TFRagPreTrainedModel.from_pretrained_question_encoder_generator   su   f $*<<>#
%""#67 S,-/0%7#
 #
 $*<<>
%""<0 S&()50
 
 +//1 	2C*S01	2#((* 	+C|c)*	+ 366wE#AM M
 <66@*4*D*DEs*t'4K'1:{::>  	 '#&#9#9 
 *  %((d;	:F !F
 H//@#-#=#=>e#f -= *?/??7 #&#9#9 #	I Hd+>FF '')9)9=CF $4	RXdmnnS#

s   *G*G)NNN)rY   strrZ   re   rM   r   returnr   )
r/   r0   r1   r2   r   config_classbase_model_prefix_keys_to_ignore_on_load_missingclassmethodrd   r4   r5   r6   r:   r:      sm     L'6&7# ?C7;"&	Yo8;Yo 25Yo  	Yo 
Yo Yor5   r:   a	  

    RAG is a sequence-to-sequence model which encapsulates two core components: a question encoder and a generator.
    During a forward pass, we encode the input with the question encoder and pass it to the retriever to extract
    relevant context documents. The documents are then prepended to the input. Such contextualized inputs is passed to
    the generator.

    The question encoder can be any *autoencoding* model, preferably [`TFDPRQuestionEncoder`], and the generator can be
    any *seq2seq* model, preferably [`TFBartForConditionalGeneration`].

    The model can be initialized with a [`RagRetriever`] for end-to-end generation or used in combination with the
    outputs of a retriever in multiple steps---see examples for more details. The model is compatible any
    *autoencoding* model as the `question_encoder` and any *seq2seq* model with language model head as the `generator`.
    It has been tested with [`TFDPRQuestionEncoder`] as the `question_encoder` and [`TFBartForConditionalGeneration`]
    as the `generator`.

    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Tensorflow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
    subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
    general usage and behavior.

    The model is in a developing state as it is now fully supports in eager-mode only, and may not be exported in
    SavedModel format.

    Args:
        config ([`RagConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
        question_encoder ([`TFPreTrainedModel`]):
            An encoder model compatible with the faiss index encapsulated by the `retriever`.
        generator ([`TFPreTrainedModel`]):
            A seq2seq model used as the generator in the RAG architecture.
        retriever ([`RagRetriever`]):
            A retriever class encapsulating a faiss index queried to obtain context documents for current inputs.
a  
    Args:
        input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.
        attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`TFRagModel`]) model during decoding.
        decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        past_key_values (`tuple(tuple(tf.Tensor))`):
            Tuple consists of two elements: `encoder_outputs` of the RAG model (see `encoder_outputs`) and
            `past_key_values` of the underlying generator. Can be used to speed up decoding. `past_key_values` are used
            in the ([`RagTokenForGeneration`]) model during decoding.
        doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever.

            If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the
            forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask
            (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when
            *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the question
            encoder `input_ids` by the retriever.

            If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the
            forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        use_cache (`bool`, *optional*, defaults to `True`):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        output_retrieved(`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`TFRetrievAugLMOutput`] instead of a plain tuple.
        n_docs (`int`, *optional*, defaults to `config.n_docs``)
            Number of documents to retrieve and/or number of documents for which to generate an answer.
c                       e Zd ZdZ	 	 	 	 	 d	 	 	 	 	 	 	 	 	 d fdZd	dZe ee       e	e
e      	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d
	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 dd                     ZddZ xZS )
TFRagModeltf_rag_model_1c                ^   |||J d       |,t        j                  |j                  |j                  fi |}n-t        || j                        sJ d| d| j                          t        	|   |fi | |#ddlm} |j                  |j                  d      }|7ddlm} ||n| j                  }|j                  |j                  d	|d
z         }|| _        | j                  5t        |t              sJ dt!        | j                         d       || _        || _
        || _        y )NzQEither a configuration or an question_encoder and a generator has to be provided.zconfig: z has to be of type rA   rB   rF   )rH   rK   rJ   z
/generatorrG   z`self.retriever` is of type z&, but should be of type `RagRetriever`)r   rW   rD   
isinstancerg   super__init__rS   rC   from_configrF   rL   rI   rJ   rM   r   type)
selfrD   rF   rJ   rM   rI   r\   rC   rL   	__class__s
            r6   rq   zTFRagModel.__init__  sl    !(Y-B	_^	_ 
 >FF '')9)9=CF fd&7&78sHVHL_`d`q`q_r:ss8*6*#;*66v7N7NUg6hG7I7U!3[_[r[r/;;  {GY\hGh < I #>>%< k-d4>>.B-CCijk  'DN 0"r5   c                    || _         y N)rM   rt   rM   s     r6   set_retrieverzTFRagModel.set_retriever   s	    "r5   output_typerg   c                   d|vsJ d       ||n| j                   j                  }| j                  duxr |du xs
 |	du xs |du xr |du }|d|rF| j                  ||d|      }|d   }| j                  ||j	                         | j
                  j                   j                  |d      }|d	   |d
   |d   |d   f\  }}	}}t        j                  |t        j                        }t        j                  |	t        j                        }	t        j                  |t        j                        }t        j                  |t        j                        }t        j                  t        j                  t        j                  |d      |d      d      }n|J d       |	J d       |J d       |J d       |j                  d   |z  dk(  sJ d| d|j                  d    d       |t        j                  ||d      }|t        j                  ||d      }| j                  ||	|||||
d|	      }|sd}d}d}d}d}nj                   }|j"                  }|r|sd}d}	d}d}t%        |j&                  ||j(                  ||	|||j*                  |j,                  |j.                  |j0                  |j2                        S )aa  
        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, TFRagModel
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = TFRagModel.from_pretrained("facebook/rag-token-base", retriever=retriever, from_pt=True)

        >>> input_dict = tokenizer.prepare_seq2seq_batch(
        ...     "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
        ... )
        >>> input_ids = input_dict["input_ids"]
        >>> outputs = model(input_ids)
        ```decoder_cached_states8Please use past_key_values to cache intermediate outputsNT)attention_maskreturn_dicttrainingr   tfprefixn_docsreturn_tensorsr$   r%   r"   doc_idsr   axistranspose_bzMake sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.z^Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function.M The first dimension of `context_input_ids` should be a multiple of `n_docs`=	, but is .)r   encoder_outputsdecoder_input_idsdecoder_attention_maskr    	use_cacher   r   rw   )r   r!   r    r$   r%   r"   r#   r&   r(   r)   r*   r+   r,   r-   r.   )rD   r   rM   rF   numpyrJ   r   r   castint32float32squeezematmulexpand_dimsshaperepeathidden_states
attentionsr8   r   r    encoder_last_hidden_stateencoder_hidden_statesencoder_attentionsdecoder_hidden_statesdecoder_attentions)rt   	input_idsr   r   r   r   r    r!   r$   r%   r   output_attentionsoutput_hidden_statesoutput_retrievedr   r   r   r\   has_to_retrievequestion_enc_outputsr&   retriever_outputsr"   r#   gen_outputsr(   r)   s                              r6   callzTFRagModel.call#  s   \ $61	FE	F1 "-4;;3E3E NN$& ("d*b.D.LbPZ^bPb(4' 	 "'+'<'<n$Ya (= ($ 6J62 %)NN6<<>>>0077!#' %3 %! &&9:%&>?%&<=%i0	fb!#9;OQb %'GG,=rxx$H!)+1G)R&')ww/CRZZ'P$$&GG,=rxx$H!  ZZII'IPQR,$(
 
 )4 P4 .9 T9 "- J- "	lk	l"   #f,2 	
[\b[c d!''*+1.	
2 ( "		*;V! L!-%'YY/EvTU%V"nn1+/#9+ % 

 15.)-&&*##'  $)=)K)K&&:&E&E#&6 '%)"#'  $#%%!'77/#9!5//Q'A$;,7,Q,Q(3(I(I%0%C%C(3(I(I%0%C%C
 	
r5   c                   | j                   ry d| _         t        j                  | j                  j                        5  | j                  j                  d        d d d        t        j                  | j                  j                        5  | j                  j                  d        d d d        y # 1 sw Y   WxY w# 1 sw Y   y xY w)NT)builtr   
name_scoperJ   rH   buildrF   rt   input_shapes     r6   r   zTFRagModel.build  s    ::
]]4>>../ 	'NN  &	']]400556 	.!!''-	. 	.	' 	'	. 	.s   B0B<0B9<C)NNNNN)
rD   Optional[PretrainedConfig]rF   Optional[TFPreTrainedModel]rJ   r   rM   Optional[RagRetriever]rI   zOptional[str]rM   r   )NNNNNNNNNNNNNNNF)"r   TFModelInputType | Noner   np.ndarray | tf.Tensor | Noner   r   r   r   r   r   r    1Tuple[Tuple[Union[np.ndarray, tf.Tensor]]] | Noner!   r   r$   r   r%   r   r   bool | Noner   r   r   r   r   r   r   
int | Noner   r   r   boolrf   r8   rw   )r/   r0   r1   rI   rq   ry   r   r   RAG_FORWARD_INPUTS_DOCSTRINGr   r8   _CONFIG_FOR_DOCr   r   __classcell__ru   s   @r6   rl   rl     s}   ) .28<15,0,0*#**# 6*# /	*#
 **# **#X# *+GH+?o^ .28<9=;?@DMQ48;?@D!%)-,0(,!#'#k
*k
 6k
 7	k

 9k
 !>k
 Kk
 2k
 9k
 !>k
 k
 'k
 *k
 &k
 k
  !!k
" #k
& 
'k
 _ I k
Z.r5   rl   zr
    A TF RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
    c            	          e Zd ZdZ	 	 	 	 d	 	 	 	 	 	 	 d fdZddZ	 	 	 	 	 	 ddZed        Zed        Z	ed        Z
edd	       Zdd
Ze ee       eee      	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 dd                     Zddddddd e       f	 	 	 ddZd Zd ZddZddZddZddZ xZS )TFRagTokenForGenerationz!tf_rag_token_for_generation_1/ragNc                    |||J d       |+t        j                  |j                  |j                  fi |}t        |   |       t        ||||| j                  d      | _        y NzHEither a configuration or an encoder and a generator has to be provided.r;   )rD   rF   rJ   rM   rI   rH   r   rW   rD   rp   rq   rl   rI   r;   rt   rD   rF   rJ   rM   r\   ru   s         r6   rq   z TFRagTokenForGeneration.__init__       !(Y-B	VU	V 
 >FF '')9)9=CF 	  -#66
r5   c                &    || j                   _        y rw   r;   rM   rx   s     r6   ry   z%TFRagTokenForGeneration.set_retriever      &r5   c           
     4    ||d d dd f   }d ||||||d|d	S )NT)	r   r   r!   r%   r   r    r   do_marginalizer   r4   )	rt   r   r    r   r   r   r!   r   r\   s	            r6   prepare_inputs_for_generationz5TFRagTokenForGeneration.prepare_inputs_for_generation  sB     & 1!RS& 9 .$&4!2.""

 
	
r5   c                .    | j                   j                  S rw   r   rt   s    r6   rM   z!TFRagTokenForGeneration.retriever"      xx!!!r5   c                .    | j                   j                  S rw   r;   rJ   r   s    r6   rJ   z!TFRagTokenForGeneration.generator&  r   r5   c                .    | j                   j                  S rw   r;   rF   r   s    r6   rF   z(TFRagTokenForGeneration.question_encoder*      xx(((r5   c                N    fd}t         j                  j                  ||       S )a,  
        RAG-specific `_gather_beams`: gathers the beam slices indexed by beam_indices into new beam array. If the
        nested tensor has a shape mismatch with the beam indices, then it means it is the cache. In that case, isolates
        and takes care of the extra dimension for ndocs.
        c                   | j                   d   j                   d   k7  }|rW| j                   d   j                   d   z  }j                   d   }t        j                  | |d|g| j                   dd        } t        j                  | dd      }|r+t        j                  |z  dg|j                   dd        }|S )Nr   r   rA   r   )paramsindicesr   
batch_dimsr	   )r   r   reshapegather)tensoris_rag_cacher   
batch_sizegathered_tensorbeam_indicess        r6   	gather_fnz8TFRagTokenForGeneration._gather_beams.<locals>.gather_fn6  s    !<<?l.@.@.CCLaL,>,>q,AA)//2
FZV,WfllSTSUFV,WX iiv|RS`abO"$**_zF?RTV>sYhYnYnopoqYr>s"t""r5   )r   nestmap_structure)nestedr   
batch_axisr   s    `  r6   _gather_beamsz%TFRagTokenForGeneration._gather_beams.  s!    	#  ww$$Y77r5   c                   ||n| j                   j                  }t        j                  j	                  |d      }t        j
                  ||j                  d   |z  |d|j                  d   g      }t        j                  j	                  |d      }t        j                  |d      }t        j                  |d      }||z   }t        j                  |d      S )Nr   r   r   r   )	rD   r   r   nnlog_softmaxr   r   r   reduce_logsumexp)rt   
seq_logitsr!   r   seq_logprobsdoc_logprobslog_prob_sums          r6   marginalizez#TFRagTokenForGeneration.marginalizeH  s    !-4;;3E3E uu(("(=zz,1A1A!1D1NPVXZ\f\l\lmo\p0qruu((!(<~~l<~~l<#l2""<a88r5   rz   c                >   d|vsJ d       |r|n| j                   j                  }|r|n| j                   j                  }|||}d}
| j                  |||||||	|||
|||||      }d}|j                  }|C|J | j                  |j                  |j                  ||| j                   j                  |      }|r| j                  ||j                  |      }t        di d|d|d	|j                  d
|j                  d|j                  d|j                  d|j                  d|j                  d|j                  d|j                   d|j"                  d|j$                  d|j&                  d|j(                  d|j*                  d|j,                  S )a  
        do_marginalize (`bool`, *optional*):
            If `True`, the logits are marginalized over all documents by making use of
            `torch.nn.functional.log_softmax`.
        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the cross entropy classification loss according to Rag-Token model formulation See
            https://arxiv.org/pdf/2005.11401.pdf Section 2.1 for details about Rag-Token formulation. Indices should be
            in `[0, ..., config.vocab_size - 1]`.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `tf.Tensor.sum`
            operation.
        kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
            Legacy dictionary, which is required so that model can use *generate()* function.

        Returns:

        Example:

        ```python
        >>> import tensorflow as tf
        >>> from transformers import AutoTokenizer, RagRetriever, TFRagTokenForGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True)

        >>> input_dict = tokenizer.prepare_seq2seq_batch(
        ...     "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
        ... )
        >>> outputs = model(input_dict, output_retrieved=True)

        >>> # or use retriever separately
        >>> # 1. Encode
        >>> input_ids = input_dict["input_ids"]
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf")
        >>> doc_scores = tf.squeeze(
        ...     tf.matmul(
        ...         tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True
        ...     ),
        ...     axis=1,
        ... )
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     inputs=None,
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=input_dict["labels"],
        ... )

        >>> # or directly generate
        >>> generated = model.generate(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ... )
        >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
        ```r}   r~   NFr   r   r   r   r$   r%   r!   r    r   r   r   r   r   r   reduce_lossepsilonr   r   r   r    r!   r$   r%   r"   r#   r&   r(   r)   r*   r+   r,   r-   r.   r4   )rD   r   r   r;   r   get_nllr!   label_smoothingr   r   r    r$   r%   r"   r#   r&   r(   r)   r*   r+   r,   r-   r.   )rt   r   r   r   r   r   r    r!   r$   r%   r   r   r   r   r   r   labelsr   r   r   r\   outputsr   r   s                           r6   r   zTFRagTokenForGeneration.callT  s   v $61	FE	F1 ,:t{{?Y?Y%0kdkk6M6M ($*!I(()+/#9/#9!+/!5-  
$ $000<<""'33   D %%fg.@.@&IF) 


 $33
 ))	

 &77
 $+#A#A
 ")!=!=
 &77
 07/Y/Y
 (/'I'I
 %,$C$C
 -4,S,S
 )0(K(K
 &-%E%E
 )0(K(K
  &-%E%E!
 	
r5   c	                   | j                   t        j                         j                  di |	}
n| j                  j
                  | j                  7|4| j                  ||      d   }| j                  ||j                         j                  t        j                        | j                  j                  j                  d      }|d   |d   |d   }}}t        j                  |t        j                         }t        j                  |t        j                         }t        j                  |t        j                        }t        j"                  t        j$                  |d	      |d
      }t        j&                  |d	      }|j(                  d   z  dk(  sJ d d|j(                  d    d       |j(                  d   z  | j*                  j                  j-                         } |||j.                  j0                  d
      }t        j2                  j4                  z  dft        j                  j6                  t        j                               }|d   }d fd	} ||j4                        } ||j4                        |d<   t        j8                  |j4                  d	      }||
d<   ||
d<   ||
d<   |
d<   | j;                  t        j(                  |      d   |      }j4                  dk(  rb | j<                  d|j>                  j@                  jB                  |j.                  j0                  jD                  jF                  d	|
S j4                  dkD  rЉj4                  jH                  k  r&tK        dj4                   djH                   d      fd} ||      } ||
d         |
d<    ||
d   d         |
d   d<    | jL                  d|j>                  j@                  jB                  |j.                  j0                  jD                  jF                  d	|
S tK        dj4                         )!a|  
        Implements TFRAG token decoding.

        Args:
            input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            logits_processor (`TFLogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and a
                model's config. If a logit processor is passed that is already created with the arguments or a model's
                config an error is thrown.
            kwargs (`Dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model.

        Return:
            `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
            second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
            due to the `eos_token_id`.
        r   r   r   r   r$   r%   r"   r   r   Tr   r   r   r   )r   r   r   r   r   last_hidden_statec                    | j                   dd }df|z   }t        j                  | |      } |f|z   }t        j                  | |      } |z  z  f|z   }t        j                  | |      S )z
            Broadcast tensor with `num_beams` replica, with correct order Input: tensor of shape (batch_size*n_docs ,
            d) Output: tensor of shape (batch_size*num_beams*n_docs , d)
            r   N)r   r   r   broadcast_to)r   	num_beamsd_shape_list	new_shaper   r   s       r6   extend_enc_outputz;TFRagTokenForGeneration.generate.<locals>.extend_enc_outputj  s     "<<+L $Q/,>IZZ	2F $Y7,FI__VY7F $i/&8:\II::fi00r5   )r   r!   r   r   r   r   )generation_configinput_ids_seq_lengthlogits_processor)	r   
max_lengthpad_token_ideos_token_idr  r   r   output_scoresreturn_dict_in_generatezwBeam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences, got z and z (respectivelly)c                j    t        |       }t        j                  | dj                  g|dd z         S )zFUnflattens the first, flat batch*beam dimension of a non-scalar array.r   r   N)r   r   r   r   )r   r   r  s     r6   unflatten_beam_dimz<TFRagTokenForGeneration.generate.<locals>.unflatten_beam_dim  s7    "6*zz&2/@/J/J*KeTUTVi*WXXr5   uH   `num_beams` has to be an integer strictly superior to 0 (≥ 1), but is r4   rw   )'r  copydeepcopyupdaterD   r   rM   rF   r   astypenpr   rJ   r   r   r   r   r   r   r   r   r;   get_encoderr   r   fillr   decoder_start_token_idr   _get_logits_processorgreedy_searchr  r  r  r	  r
  num_return_sequences
ValueErrorbeam_search)rt   r   r   r$   r%   r!   r   r  r  r\   model_kwargsquestion_hidden_statesoutr"   encoderr   r   r   r  pre_processorr  r   s         ``             @r6   generatez TFRagTokenForGeneration.generate  s   F $ $ 6 6 MM*;</(//9&9 "-4;;3E3E >>%*;*C%)%:%:9Uc%:%def%g"..&,,.55bjjA~~,,33# ! C '(,-*+ 8L5 !#(9288 D%'WW-CRXX%N"#%77+?#L  5A>@TbfJ JQ7J!''*V39 	
[\b[c d!''*+1.	
9
 ',,Q/69
(($$002!'1/AA!2!G!G
 GG+555q9GG%<<bhhG
 ,,?@	1, "33IUfUpUp!q/@):)D)D0
+, YYz+<+F+FQO
 &0\"*9&')?%&!'X22/!#*;!<R!@- 3 
 &&!+%4%% +,77.;;.;;!."3"E"E%6%K%K/==(9(Q(Q   ((1, **->-S-SS 22C2M2M1N O)>>??OQ Y
 !33D E-?M]@^-_L)*CU./0CDDL*+,?@ $4## +,77.;;.;;!."3"E"E%6%K%K/==(9(Q(Q   Z[l[v[vZwx r5   c                J    | j                   j                  j                         S rw   )r;   rJ   get_input_embeddingsr   s    r6   r!  z,TFRagTokenForGeneration.get_input_embeddings  s    xx!!6688r5   c                J    | j                   j                  j                         S rw   )r;   rJ   get_output_embeddingsr   s    r6   r#  z-TFRagTokenForGeneration.get_output_embeddings  s    xx!!7799r5   c           
        |)| j                   j                  j                  }|J d       | j                   j                  j                  }|J d       t	        j
                  t        |      d   dft	        j                  ||j                              }t	        j                  ||ddddf   gd      }t	        j                  |dk(  t	        j
                  t        |      t	        j                  ||j                              |      }t        j                  j                  |t	        j                  d|j                              }t	        j                  |g      5  t	        j                  |      }ddd       |S # 1 sw Y   |S xY w)zCShift input ids one token to the right, and pad with start_token_idNzself.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as generator, see Bart docs for more informationz1self.model.config.pad_token_id has to be defined.r   r   r   i)rJ   rD   r  r  r   r  r   r   dtypeconcatwhere	debuggingassert_greater_equalcontrol_dependenciesidentity)rt   r   start_token_idr  start_tokensshifted_input_idsassert_gte0s          r6   shift_tokens_rightz*TFRagTokenForGeneration.shift_tokens_right  sb    !!^^22IIN!- A-
 ~~,,99'\)\\'ww
9 5a 8!<bggnV_VeVe>fgII|Yq#2#v5F&GL HH%GGJ01277<3YZ
 ll778I277STVgVmVmKno $$k]3 	? ",= >	? ! 	? ! s   !FFc           
        ||n| j                   j                  }t        j                  |d d dd f   t        j                  |j
                  d   dgt        j                  | j                   j                  j                  |j                              gd      }| j                  |||      }| j                  ||d|      }|S )Nr   r   r   T)from_logitsr   )rD   r   r   r&  r  r   r   rJ   r  r%  r   hf_compute_loss)	rt   r   r!   targetr   r   r   rag_logprobsr   s	            r6   r   zTFRagTokenForGeneration.get_nll  s    !-4;;3E3E AqrE]BGGV\\!_a$8"''$++BWBWBdBdflfrfr:stu
 ''
JG##FLdXc#dr5   c                   t         j                  j                  dt         j                  j                  j                        }|du r<d}t        j                  ||d|z
        }t
        j                  j                  |      }|}t        j                  |d      }	t        j                  |	| j                  j                  j                        }
t        j                  t        j                  |d|j                  d	   f      |
      }t        j                  |	|
      } |||      }t        j                   |d
       }t        j                   |      }||j                  d   z  }d|z
  |z  ||z  z   }|S )z(CrossEntropyLoss that ignores pad tokensT)r2  	reductionFg&.>r   )clip_value_minclip_value_max)r   r   rA   r         ?)r   lossesSparseCategoricalCrossentropy	ReductionSUMr   clip_by_valuemathlogr   	not_equalrD   rJ   r  boolean_maskr   
reduce_sum)rt   r   y_predsmooth_epsilonr2  r   loss_fnepsr   melted_labelsactive_lossreduced_logitsnll_losssmooth_losseps_ir   s                   r6   r3  z'TFRagTokenForGeneration.hf_compute_loss  s=    ,,<<ll,,00 = 

 %C%%fSQRUXQXYFWW[[(F

651ll=$++2G2G2T2TUFRa<Q)RT_`<6>2}}^"==mmK0!5!5b!99n$05;3FFr5   c                    | j                   ry d| _         t        | dd       Nt        j                  | j                  j
                        5  | j                  j                  d        d d d        y y # 1 sw Y   y xY wNTr;   r   getattrr   r   r;   rH   r   r   s     r6   r   zTFRagTokenForGeneration.build  e    ::
4%1txx}}- %t$% % 2% %   A11A:NNNNrD   r   rF   r   rJ   r   rM   r   r   )NNNNNN)r   rw   NNNNNNNNNNNNNNNNNNF)(r   r   r   r   r   r   r   r   r   r   r    r   r!   r   r$   r   r%   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rf   r   r   r   r   r   )F        N)rY  TF)r/   r0   r1   rI   rq   ry   r   propertyrM   rJ   rF   staticmethodr   r   r   r   r   r   r   r   r   r   r  r!  r#  r0  r   r3  r   r   r   s   @r6   r   r     s\    = .28<15,0
*
 6
 /	

 *
:' 
6 " " " " ) ) 8 82
9 *+GH+ETcd .28<;?@D9=MQ48;?@D!%)-,0(,!&*04#'#')V
*V
 6V
 9	V

 !>V
 7V
 KV
 2V
 9V
 !>V
 V
 'V
 *V
 &V
 V
  $!V
" .#V
$ !%V
& !'V
( )V
, 
$-V
 e I V
t .2+/#.0S*S )Sj9:!B<%r5   r   zx
    A TF RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
    c                  p    e Zd ZdZ	 	 	 	 d	 	 	 	 	 	 	 d fdZddZed        Zed        Zed        Z	e
 ee       eee      	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 dd                     Z	 dd	Z	 	 	 	 	 	 	 	 	 d	 	 	 dd
Zed        ZddZ xZS )TFRagSequenceForGenerationz$tf_rag_sequence_for_generation_1/ragc                    |||J d       |+t        j                  |j                  |j                  fi |}t        |   |       t        ||||| j                  d      | _        y r   r   r   s         r6   rq   z#TFRagSequenceForGeneration.__init__*  r   r5   c                &    || j                   _        y rw   r   rx   s     r6   ry   z(TFRagSequenceForGeneration.set_retrieverG  r   r5   c                .    | j                   j                  S rw   r   r   s    r6   rM   z$TFRagSequenceForGeneration.retrieverJ  r   r5   c                .    | j                   j                  S rw   r   r   s    r6   rJ   z$TFRagSequenceForGeneration.generatorN  r   r5   c                .    | j                   j                  S rw   r   r   s    r6   rF   z+TFRagSequenceForGeneration.question_encoderR  r   r5   rz   c                   d|vsJ d       |r|n| j                   j                  }|r|n| j                   j                  }|||}d}
| j                  |||||||	|||
|||||      }d}|?| j	                  |j
                  |j                  ||| j                   j                  |      }t        di d|d|j
                  d	|j                  d
|j                  d|j                  d|j                  d|j                  d|j                  d|j                  d|j                  d|j                   d|j"                  d|j$                  d|j&                  d|j(                  d|j*                  S )a  
        exclude_bos_score (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
            the loss.
        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the cross entropy classification loss according to Rag-Sequence model formulation See
            https://arxiv.org/pdf/2005.11401.pdf Section 2.1 for details about Rag-Sequence formulation. Indices should
            be in `[0, ..., config.vocab_size - 1]`.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `tf.Tensor.sum`
            operation.
        kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
            Legacy dictionary, which is required so that model can use *generate()* function.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, TFRagSequenceForGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = TFRagSequenceForGeneration.from_pretrained(
        ...     "facebook/rag-sequence-nq", retriever=retriever, from_pt=True
        ... )

        >>> input_dict = tokenizer.prepare_seq2seq_batch(
        ...     "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
        ... )
        >>> outputs = model(input_dict, output_retrieved=True)

        >>> # or use retriever separately
        >>> # 1. Encode
        >>> input_ids = input_dict["input_ids"]
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf")
        >>> doc_scores = tf.squeeze(
        ...     tf.matmul(
        ...         tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True
        ...     ),
        ...     axis=1,
        ... )
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     inputs=None,
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=input_dict["labels"],
        ... )

        >>> # or directly generate
        >>> generated = model.generate(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ... )
        >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
        ```r}   r~   NFr   r   r   r   r!   r    r$   r%   r"   r#   r&   r(   r)   r*   r+   r,   r-   r.   r4   )rD   exclude_bos_scorer   r;   r   r   r!   r   r   r    r$   r%   r"   r#   r&   r(   r)   r*   r+   r,   r-   r.   )rt   r   r   r   r   r   r    r!   r$   r%   r   r   r   r   r   rd  r   r   r   r   r\   r   r   s                          r6   r   zTFRagSequenceForGeneration.callV  s   x $61	FE	F1 2C-HeHe%0kdkk6M6M ($*!I(()+/#9/#9!+/!5-  
$ <<""'33   D * 

>>
 ))
 $33	

 &77
 $+#A#A
 ")!=!=
 &77
 07/Y/Y
 (/'I'I
 %,$C$C
 -4,S,S
 )0(K(K
 &-%E%E
 )0(K(K
  &-%E%E!
 	
r5   c           
     6    t        j                  d d dd f   t        j                  j                  d   dgt        j                   j
                  j                  j                  j                              gd       j
                  j                  xs   j
                  j                  j                  }||n j
                  j                  }t        j                  t        j                  d d df   |            }	|d uxr |	}
 fd}t         j                  j                  |d      }t        j                  ||j                  d   |z  |d|j                  d   f      }t         j                  j                  |d      }t        j                   |d      }t        j                   |d      }|d d d d d dd d f   }|d d d d ddd d f   }|d d d d dd d d f   }t        j                  |||z   |gd      }t        j                   d      t        j                   d      t        j"                  |d      t%        j                        t%        |j                        k(  sJ d } ||      }t        j&                  |dd	
      } |||      \  }}|r&|
r$t        j&                  |d d d d dd f   d      }nt        j&                  |d      }t        j&                  |d      }t         j(                  j+                  |d      }t         j(                  j+                  |d      }| }| }|r*t        j&                  |      }t        j&                  |      }||j                  d   z  }d|z
  |z  ||z  z   }|S )Nr   r   r   c                   t        j                  t        j                  j                  j                  j
                  j                              }t        j                  |      r.t        j                  |d|       } t        j                  |d|      }t        j                  | d      t        j                  |d      fS )NrY  r   r   )
r   equalr   rD   rJ   r  r%  
reduce_anyr'  r   )ll
smooth_objpad_maskrt   r4  s      r6   
_mask_padsz6TFRagSequenceForGeneration.get_nll.<locals>._mask_pads  s    xx0E0E0R0RTZT`T`(abH}}X&XXhR0XXhZ@
::br*BJJz,KKKr5   r   rA   c                    d }t        j                  | d| j                  d   f      }|j                  }t        j                  |d      } |||      }t        j                  ||      S )Nc                   t        j                  t        j                  t        j                  |      d   |j                        |d d df   gd      }t        j
                  | |      }t        j                  |d      S )Nr   )r%  r   r   )r   stackranger   r%  	gather_ndr   )r4  	id_tensoridxresults       r6   gather2dzJTFRagSequenceForGeneration.get_nll.<locals>.torch_gather.<locals>.gather2d  sc    hh))<Q)?y WYbcdfgcgYhiprsfc2~~f266r5   r   )r   r   )r   r   r   )paramrr  ru  r4  target_shapert  s         r6   torch_gatherz8TFRagSequenceForGeneration.get_nll.<locals>.torch_gather  s\    7
 ZZEKKO'<=F$??L

9g6Ifi0F::fl33r5   )rr  T)r   keepdimsr:  )r   r&  r  r   r   rD   rJ   r  r%  bos_token_idr   
reduce_allrg  r   r   r   r   r   rP   rD  r@  r   )rt   r   r!   r4  r   r   rd  r   rz  equal_bos_token_id_alluse_bosrl  r   r   first_token_scoressecond_token_scores	remainderr5  rx  ri  rj  rL  rM  rN  r   s   `  `                     r6   r   z"TFRagSequenceForGeneration.get_nll  s'    AqrE]BGGV\\!_a$8"''$++BWBWBdBdflfrfr:stu
 {{//U4;;3H3H3U3U!-4;;3E3E!#rxxq!tl/S!Td*E/E	L uu(("(=zz:++A.&8&"jFVFVWYFZ[
 uu((!(<~~l<~~l< *!QA+6*1a1a<8 Aqr1-	yy"46IL6XZc!dklm Q/R06626<< C(:(:$;;;;	4 ,&9]]<b4H
#B
3J r!Q(|!4Br*B]]:A6
WW%%bq%1WW--jq-A
3!k}}X.H--4K,,,R00g)EK,??r5   c
                   |	|	n| j                   j                  }	||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|	|J d       | j
                  ]|[| j                  ||      d   }| j                  ||j                         | j                  j                   j                  |	d      d   }g }||
d<   ||
d	<   d|
d
<   ||j                  d   n|j                  d   |	z  }t        |      D ]  }|||	z  |dz   |	z   } | j                  j                  |fi |
}|r`t        j                  t        |D ci c]*  }t!        |j                         j#                               |, c}j%                                     }|j                  d   }|*t        j&                  |||dz    |df      } | ||d      }n|J d       |J d       t        j&                  ||df      }|||	z  |dz   |	z   }t        j&                  ||df      }|||dz   ddf   }t        j&                  ||df      } | d||||d      }t        j(                  j+                  |d    |      d   }|j-                  t        j.                  ||              | j1                  || j                   j                  j2                        S c c}w )a  
        Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
        for more information on how to set other generate input parameters

        Args:
            input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for
                tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention
                masks?](../glossary#attention-mask)
            context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
                retriever.
            context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever. If the model has is not initialized with a `retriever` or `input_ids` is not given,
                `context_input_ids` and `context_attention_mask` have to be provided to the forward pass. They are
                returned by [`~RagRetriever.__call__`].
            doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` or
                `input_ids` is not given, `doc_scores` has to be provided to the forward pass. `doc_scores` are
                returned by [`~RagRetriever.__call__`].
            do_deduplication (`bool`, *optional*):
                Whether or not to deduplicate the generations from different context documents for a given input. Has
                to be set to `False` if used while training with distributed backend.
            num_return_sequences(`int`, *optional*, defaults to 1):
                The number of independently computed returned sequences for each element in the batch. Note that this
                is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
                where we set `num_return_sequences` to `num_beams`.
            num_beams (`int`, *optional*, defaults to 1):
                Number of beams for beam search. 1 means no beam search.
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            kwargs (`Dict[str, Any]`, *optional*):
                Additional kwargs will be passed to [`~generation.GenerationMixin.generate`]

        Return:
            `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
            second dimension (sequence length) is either equal to `max_length` or shorter if all batches finished early
            due to the `eos_token_id`.
        Nz= At least one of input_ids or context_input_ids must be givenr   r   r   r   r$   r   r  r   r   T)r   rd  zMake sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.)r   r$   r%   r!   r   rd  r   )k)r  )rD   r   do_deduplicationr  r   rM   rF   r   rJ   r   r   rp  r  r   ro  listre   tolistvaluestiler@  top_kappendr   _cat_and_padr  )rt   r   r   r$   r%   r!   r  r  r   r   r  num_doc_return_sequencesr  hyposr   indexgenerator_input_idsoutput_sequencesr  num_candidatesnew_input_idsr   individual_input_idsindividual_attention_maskindividual_doc_scorestop_cand_indss                             r6   r  z#TFRagSequenceForGeneration.generateA  sz   t "-4;;3E3E/?/K+QUQ\Q\QmQm$8$D $++JjJj 	! "+!6IDKK<Q<Q	 !%6%B	KJ	KB >>%*;*C%)%:%:9Uc%:%def%g" $&,,.~~,,33# !/ ! "!# $-[!/8+,)-%&+4+@Y__Q'FWF]F]^_F`djFj
:& 4	EE"3EFNeaiSYEY"Z6t~~66#    #%88DVf1gQR#aggi6F6F6H2I12L1g1n1n1p,q#r -33N
 $ "	%%!)(D~WXFY Z}5EY]^-9 T9 "- J-
 (*ww'.!)<($ -C56>UZ]^U^bhTh,i),.GG4MP^`aOb,c)(25EAI3F3I(J%(*0EXYGZ([%"&:+D4+&* GGMMGFO+;@XMYZ[\M LL#3]CDi4	El   T[[5J5J5W5W XXY 2hs   %/K-c                $   t        | D cg c]  }|j                  d    c}      t        | D cg c]  }|j                  d    c}      f}t        j                  ||      }t        j
                  |      }d}| D ]K  }||||j                  d   z   d |j                  d   f   j                  |       ||j                  d   z  }M t        j                  |      }t        j                  || d   d   d   j                        S c c}w c c}w )Nr   r   )
sumr   maxr   r  Variableassignconvert_to_tensorr   r%  )tensorsr  tr  outputinds         r6   r  z'TFRagSequenceForGeneration._cat_and_pad  s     W556QX<YAQWWQZ<Y8ZZ	L1 V$  	A3qwwqz))<QWWQZ<78??B1771:C	 %%f-wwvwqz!}Q/5566 6<Ys
   DD
c                    | j                   ry d| _         t        | dd       Nt        j                  | j                  j
                        5  | j                  j                  d        d d d        y y # 1 sw Y   y xY wrP  rQ  r   s     r6   r   z TFRagSequenceForGeneration.build  rS  rT  rU  rV  r   rW  )(r   r   r   r   r   r   r   r   r   r   r    z4Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]]r!   r   r$   r   r%   r   r   Optional[bool]r   r  r   r  r   r  r   zOptional[int]rd  r  r   r   r   r  r   r  r   r   rf   z3Union[Tuple[tf.Tensor], TFRetrievAugLMMarginOutput])FrY  FN)	NNNNNNNNNrX  rw   )r/   r0   r1   rI   rq   ry   rZ  rM   rJ   rF   r   r   r   r   r   r   r   r   r  r[  r  r   r   r   s   @r6   r]  r]  !  s6    @ .28<15,0
*
 6
 /	

 *
:' " " " " ) ) *+GH+ETcd .28<;?@D9=PT48;?@D$(,0/3+/ $,004&*&*)R
*R
 6R
 9	R

 !>R
 7R
 NR
 2R
 9R
 !>R
 "R
 *R
 -R
 )R
 R
  *!R
" .#R
$ $%R
& $'R
( )R
, 
=-R
 e I R
j osRl .2+/#!LY*LY )LY\ 7 7(%r5   r]  )/r2   
__future__r   r  dataclassesr   typingr   r   r   r   r   r  
tensorflowr   configuration_utilsr
   
generationr   modeling_tf_utilsr   r   r   r   r   r   utilsr   r   r   r   configuration_ragr   retrieval_ragr   
get_loggerr/   loggerr   r   r8   r:   RAG_START_DOCSTRINGr   rl   r   r]  r4   r5   r6   <module>r     sH    " "  ! / /   3 /  l k ( ' 
		H	% TB TB TBn OB; OB OBdho, hoV& R?  D '':;i.% i. <i.X ' 	{%24P {%{%| ' 	C%!57S C%C%r5   