
    sg`                        d dl mZ d dlmZmZ d dlmZ d dlZd dl	m
Z d dlmZmZmZ d dlmZmZ ddlmZmZmZ ddlmZmZmZmZ dd	lmZmZ d
dlm Z  dZ!dZ" G d dejF                        Z$ G d dejF                        Z% G d dejF                        Z& G d dejF                        Z' G d dejF                        Z( G d dejF                        Z) G d dejF                        Z* G d dejF                        Z+ G d dejF                        Z, G d  d!ejF                        Z- G d" d#ejF                        Z. G d$ d%ejF                        Z/ G d& d'e      Z0 G d( d)ejF                        Z1 ed*e!       G d+ d,e0             Z2d-Z3 ee2e3        ee2ee .        G d/ d0ejF                        Z4 G d1 d2ejF                        Z5 ed3e!       G d4 d5e0             Z6d6Z7 ee6e7        ee6ee .       y)7    )partial)OptionalTupleN)
FrozenDictfreezeunfreeze)flatten_dictunflatten_dict   )"FlaxBaseModelOutputWithNoAttention,FlaxBaseModelOutputWithPoolingAndNoAttention(FlaxImageClassifierOutputWithNoAttention)ACT2FNFlaxPreTrainedModel append_replace_return_docstringsoverwrite_call_docstring)add_start_docstrings%add_start_docstrings_to_model_forward   )ResNetConfiga  

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
aA  
    Args:
        pixel_values (`jax.numpy.float32` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`AutoImageProcessor.__call__`] for details.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                   4    e Zd ZdZej
                  d        Zy)IdentityzIdentity function.c                     |S N )selfxkwargss      b/var/www/html/venv/lib/python3.12/site-packages/transformers/models/resnet/modeling_flax_resnet.py__call__zIdentity.__call__\   s        N)__name__
__module____qualname____doc__nncompactr    r   r!   r   r   r   Y   s    ZZ r!   r   c                       e Zd ZU eed<   dZeed<   dZeed<   dZee	   ed<   e
j                  Ze
j                  ed<   d	 Zdd
e
j                  dede
j                  fdZy)FlaxResNetConvLayerout_channelsr   kernel_sizer   striderelu
activationdtypec                    t        j                  | j                  | j                  | j                  f| j                  | j                  dz  | j
                  dt         j                  j                  ddd| j
                              | _        t        j                  dd	| j
                  
      | _
        | j                  t        | j                     | _        y t               | _        y )N   F       @fan_outnormal)modedistributionr/   )r+   stridespaddingr/   use_biaskernel_init?h㈵>momentumepsilonr/   )r&   Convr*   r+   r,   r/   initializersvariance_scalingconvolution	BatchNormnormalizationr.   r   r   activation_funcr   s    r   setupzFlaxResNetConvLayer.setuph   s    77))4+;+;<KK$$)**889[ckokuku8v
  \\3TZZX:>//:Uvdoo6[c[er!   r   deterministicreturnc                 p    | j                  |      }| j                  ||      }| j                  |      }|S N)use_running_average)rC   rE   rF   r   r   rI   hidden_states       r   r    zFlaxResNetConvLayer.__call__u   s=    ''*)),M)Z++L9r!   NT)r"   r#   r$   int__annotations__r+   r,   r.   r   strjnpfloat32r/   rH   ndarrayboolr    r   r!   r   r)   r)   a   sc    KFCO &J&{{E399"f#++ d ckk r!   r)   c                       e Zd ZU dZeed<   ej                  Zej                  ed<   d Z	d
dej                  dedej                  fdZy	)FlaxResNetEmbeddingszO
    ResNet Embeddings (stem) composed of a single aggressive convolution.
    configr/   c                     t        | j                  j                  dd| j                  j                  | j                        | _        t        t        j                  ddd      | _        y )N   r1   )r+   r,   r.   r/   )r   r   )r1   r1   )r   r   r]   )window_shaper7   r8   )	r)   rZ   embedding_size
hidden_actr/   embedderr   r&   max_poolrG   s    r   rH   zFlaxResNetEmbeddings.setup   sN    +KK&&{{--**
  &&Zjkr!   pixel_valuesrI   rJ   c                     |j                   d   }|| j                  j                  k7  rt        d      | j	                  ||      }| j                  |      }|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rI   )shaperZ   num_channels
ValueErrorra   rb   )r   rc   rI   rh   	embeddings        r   r    zFlaxResNetEmbeddings.__call__   s\    #))"-4;;333w  MM,mML	MM),	r!   NrP   )r"   r#   r$   r%   r   rR   rT   rU   r/   rH   rV   rW   r    r   r!   r   rY   rY   |   sL     {{E399"	lS[[  QTQ\Q\ r!   rY   c                       e Zd ZU dZeed<   dZeed<   ej                  Z	ej                  ed<   d Z
ddej                  ded	ej                  fd
Zy)FlaxResNetShortCutz
    ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
    downsample the input using `stride=2`.
    r*   r1   r,   r/   c                    t        j                  | j                  d| j                  dt         j                  j                  ddd      | j                        | _        t        j                  dd	| j                  
      | _	        y )Nr]   Fr2   r3   truncated_normal)r5   r6   )r+   r7   r9   r:   r/   r;   r<   r=   )
r&   r@   r*   r,   rA   rB   r/   rC   rD   rE   rG   s    r   rH   zFlaxResNetShortCut.setup   se    77KK889[m8n**
  \\3TZZXr!   r   rI   rJ   c                 N    | j                  |      }| j                  ||      }|S rL   )rC   rE   rN   s       r   r    zFlaxResNetShortCut.__call__   s-    ''*)),M)Zr!   NrP   )r"   r#   r$   r%   rQ   rR   r,   rT   rU   r/   rH   rV   rW   r    r   r!   r   rl   rl      sR    
 FCO{{E399"	Y#++ d ckk r!   rl   c                       e Zd ZU eed<   dZeed<   ej                  Zej                  ed<   d Z	ddej                  dedej                  fd	Zy
)FlaxResNetBasicLayerCollectionr*   r   r,   r/   c                     t        | j                  | j                  | j                        t        | j                  d | j                        g| _        y )Nr,   r/   )r.   r/   )r)   r*   r,   r/   layerrG   s    r   rH   z$FlaxResNetBasicLayerCollection.setup   s;     1 1$++TZZX 1 1d$**U

r!   rO   rI   rJ   c                 <    | j                   D ]  } |||      } |S Nrf   rt   r   rO   rI   rt   s       r   r    z'FlaxResNetBasicLayerCollection.__call__   )    ZZ 	LE ]KL	Lr!   NrP   )r"   r#   r$   rQ   rR   r,   rT   rU   r/   rH   rV   rW   r    r   r!   r   rq   rq      sM    FCO{{E399"
S[[  QTQ\Q\ r!   rq   c                       e Zd ZU dZeed<   eed<   dZeed<   dZee	   ed<   e
j                  Ze
j                  ed<   d	 Zdd
efdZy)FlaxResNetBasicLayerzO
    A classic ResNet's residual layer composed by two `3x3` convolutions.
    in_channelsr*   r   r,   r-   r.   r/   c                 T   | j                   | j                  k7  xs | j                  dk7  }|r,t        | j                  | j                  | j                        nd | _        t        | j                  | j                  | j                        | _        t        | j                     | _
        y )Nr   rs   )r*   r,   r/   )r|   r*   r,   rl   r/   shortcutrq   rt   r   r.   rF   r   should_apply_shortcuts     r   rH   zFlaxResNetBasicLayer.setup   s     $ 0 0D4E4E E YXYIY % t00DJJW 	
 4**;;**


  &doo6r!   rI   c                     |}| j                  ||      }| j                  | j                  ||      }||z  }| j                  |      }|S rv   )rt   r~   rF   r   rO   rI   residuals       r   r    zFlaxResNetBasicLayer.__call__   sU    zz,mzL==$}}X]}KH ++L9r!   NrP   )r"   r#   r$   r%   rQ   rR   r,   r.   r   rS   rT   rU   r/   rH   rW   r    r   r!   r   r{   r{      sO     FCO &J&{{E399"7	D 	r!   r{   c                       e Zd ZU eed<   dZeed<   dZee   ed<   dZ	eed<   e
j                  Ze
j                  ed<   d	 Zdd
e
j                  dede
j                  fdZy)#FlaxResNetBottleNeckLayerCollectionr*   r   r,   r-   r.      	reductionr/   c           	          | j                   | j                  z  }t        |d| j                  d      t        || j                  | j                  d      t        | j                   dd | j                  d      g| _        y )Nr   0)r+   r/   name1)r,   r/   r   2)r+   r.   r/   r   )r*   r   r)   r/   r,   rt   )r   reduces_channelss     r   rH   z)FlaxResNetBottleNeckLayerCollection.setup   sl    ,,>   0atzzX[\ 0DJJ]`a 1 1qTY]YcYcjmn

r!   rO   rI   rJ   c                 <    | j                   D ]  } |||      } |S rv   rw   rx   s       r   r    z,FlaxResNetBottleNeckLayerCollection.__call__   ry   r!   NrP   )r"   r#   r$   rQ   rR   r,   r.   r   rS   r   rT   rU   r/   rH   rV   rW   r    r   r!   r   r   r      se    FCO &J&Is{{E399"
S[[  QTQ\Q\ r!   r   c                       e Zd ZU dZeed<   eed<   dZeed<   dZee	   ed<   dZ
eed	<   ej                  Zej                  ed
<   d Zddej                  dedej                  fdZy)FlaxResNetBottleNeckLayera$  
    A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the
    input by a factor of `reduction` in order to make the second `3x3` convolution faster. The last `1x1` convolution
    remaps the reduced features to `out_channels`.
    r|   r*   r   r,   r-   r.   r   r   r/   c                    | j                   | j                  k7  xs | j                  dk7  }|r,t        | j                  | j                  | j                        nd | _        t        | j                  | j                  | j                  | j                  | j                        | _	        t        | j                     | _        y )Nr   rs   )r,   r.   r   r/   )r|   r*   r,   rl   r/   r~   r   r.   r   rt   r   rF   r   s     r   rH   zFlaxResNetBottleNeckLayer.setup  s     $ 0 0D4E4E E YXYIY % t00DJJW 	 9;;nn**

  &doo6r!   rO   rI   rJ   c                     |}| j                   | j                  ||      }| j                  ||      }||z  }| j                  |      }|S rv   )r~   rt   rF   r   s       r   r    z"FlaxResNetBottleNeckLayer.__call__!  sS    ==$}}X]}KHzz,> ++L9r!   NrP   )r"   r#   r$   r%   rQ   rR   r,   r.   r   rS   r   rT   rU   r/   rH   rV   rW   r    r   r!   r   r   r     sr     FCO &J&Is{{E399"7$S[[  QTQ\Q\ r!   r   c                       e Zd ZU dZeed<   eed<   eed<   dZeed<   dZeed<   e	j                  Ze	j                  ed<   d	 Zdd
e	j                  dede	j                  fdZy)FlaxResNetStageLayersCollection4
    A ResNet stage composed by stacked layers.
    rZ   r|   r*   r1   r,   depthr/   c                    | j                   j                  dk(  rt        nt        } || j                  | j
                  | j                  | j                   j                  | j                  d      g}t        | j                  dz
        D ]\  }|j                   || j
                  | j
                  | j                   j                  | j                  t        |dz                      ^ || _        y )N
bottleneckr   )r,   r.   r/   r   r   )r.   r/   r   )rZ   
layer_typer   r{   r|   r*   r,   r`   r/   ranger   appendrS   layers)r   rt   r   is       r   rH   z%FlaxResNetStageLayersCollection.setup8  s    -1[[-C-C|-S)Ym   !!{{;;11jj

 tzzA~& 		AMM%%%%#{{55**QU		 r!   r   rI   rJ   c                 @    |}| j                   D ]  } |||      } |S rv   r   )r   r   rI   rO   rt   s        r   r    z(FlaxResNetStageLayersCollection.__call__T  s.    [[ 	LE ]KL	Lr!   NrP   r"   r#   r$   r%   r   rR   rQ   r,   r   rT   rU   r/   rH   rV   rW   r    r   r!   r   r   r   ,  sf     FCOE3N{{E399"8#++ d ckk r!   r   c                       e Zd ZU dZeed<   eed<   eed<   dZeed<   dZeed<   e	j                  Ze	j                  ed<   d	 Zdd
e	j                  dede	j                  fdZy)FlaxResNetStager   rZ   r|   r*   r1   r,   r   r/   c                     t        | j                  | j                  | j                  | j                  | j
                  | j                        | _        y )N)r|   r*   r,   r   r/   )r   rZ   r|   r*   r,   r   r/   r   rG   s    r   rH   zFlaxResNetStage.setupg  s<    5KK((**;;****
r!   r   rI   rJ   c                 (    | j                  ||      S rv   r   )r   r   rI   s      r   r    zFlaxResNetStage.__call__q  s    {{1M{::r!   NrP   r   r   r!   r   r   r   [  sf     FCOE3N{{E399"
;#++ ;d ;ckk ;r!   r   c            	           e Zd ZU eed<   ej                  Zej                  ed<   d Z	 	 d
dej                  de
de
defdZy	)FlaxResNetStageCollectionrZ   r/   c                 v   t        | j                  j                  | j                  j                  dd        }t        | j                  | j                  j                  | j                  j                  d   | j                  j
                  rdnd| j                  j                  d   | j                  d      g}t        t        || j                  j                  dd              D ]K  \  }\  \  }}}|j                  t        | j                  |||| j                  t        |dz                      M || _        y )Nr   r   r1   r   )r,   r   r/   r   )r   r/   r   )ziprZ   hidden_sizesr   r_   downsample_in_first_stagedepthsr/   	enumerater   rS   stages)r   in_out_channelsr   r   r|   r*   r   s          r   rH   zFlaxResNetStageCollection.setupy  s   dkk668P8PQRQS8TU**((+ KKAAqqkk((+jj

 8A_VZVaVaVhVhijikVlAm7n 	3A3+lUMM[,e[_[e[elopqtupulvw	
 r!   rO   output_hidden_statesrI   rJ   c                     |rdnd }| j                   D ]&  }|r||j                  dddd      fz   } |||      }( ||fS )Nr   r   r   r   r1   rf   )r   	transpose)r   rO   r   rI   hidden_statesstage_modules         r   r    z"FlaxResNetStageCollection.__call__  s]     3 KK 	SL# -1G1G1aQR1S0U U'MRL		S ]**r!   N)FTr"   r#   r$   r   rR   rT   rU   r/   rH   rV   rW   r   r    r   r!   r   r   r   u  sV    {{E399"0 &+"	+kk+ #+ 	+
 
,+r!   r   c                       e Zd ZU eed<   ej                  Zej                  ed<   d Z	 	 	 ddej                  de
de
de
def
d	Zy
)FlaxResNetEncoderrZ   r/   c                 P    t        | j                  | j                        | _        y )Nr/   )r   rZ   r/   r   rG   s    r   rH   zFlaxResNetEncoder.setup  s    /4::Nr!   rO   r   return_dictrI   rJ   c                     | j                  |||      \  }}|r||j                  dddd      fz   }|st        d ||fD              S t        ||      S )N)r   rI   r   r   r   r1   c              3   &   K   | ]	  }||  y wr   r   ).0vs     r   	<genexpr>z-FlaxResNetEncoder.__call__.<locals>.<genexpr>  s     SqQ]Ss   )last_hidden_stater   )r   r   tupler   )r   rO   r   r   rI   r   s         r   r    zFlaxResNetEncoder.__call__  st     '+kk/CS` '2 '
#m  )\-C-CAq!Q-O,QQMS\=$ASSS1*'
 	
r!   N)FTTr   r   r!   r   r   r     sd    {{E399"O &+ "
kk
 #
 	

 
 
,
r!   r   c                       e Zd ZU dZeZdZdZdZe	j                  ed<   ddej                  dfd	ed
edej                  def fdZddej&                  j(                  dededefdZ ee      	 	 	 	 ddededee   dee   fd       Z xZS )FlaxResNetPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    resnetrc   Nmodule_class)r      r   r   r   TrZ   seedr/   _do_initc                      | j                   d||d|}|$d|j                  |j                  |j                  f}t        |   ||||||       y )NrZ   r/   r   )input_shaper   r/   r   r   )r   
image_sizerh   super__init__)	r   rZ   r   r   r/   r   r   module	__class__s	           r   r   z"FlaxResNetPreTrainedModel.__init__  sc     #""H&HHf//1B1BFDWDWXK[tSXcklr!   rngr   paramsrJ   c                 X   t        j                  || j                        }d|i}| j                  j	                  ||d      }|dt        t        |            }t        t        |            }| j                  D ]
  }||   ||<    t               | _        t        t        |            S |S )Nr   r   F)r   )rT   zerosr/   r   initr	   r   _missing_keyssetr   r
   )r   r   r   r   rc   rngsrandom_paramsmissing_keys           r   init_weightsz&FlaxResNetPreTrainedModel.init_weights  s    yyDJJ?#((|(O(-)@AM!(6"23F#11 A&3K&@{#A!$D.011  r!   trainr   r   c           	         ||n| j                   j                  }||n| j                   j                  }t        j                  |d      }i }| j
                  j                  ||d   n| j                  d   ||d   n| j                  d   dt        j                  |t        j                        | ||||rdg      S d      S )N)r   r1   r   r   r   batch_stats)r   r   r   F)r   mutable)
rZ   r   r   rT   r   r   applyr   arrayrU   )r   rc   r   r   r   r   r   s          r   r    z"FlaxResNetPreTrainedModel.__call__  s     %9$D $++JjJj 	 &1%<k$++BYBY}}\<@ {{  .4.@&*dkkRZF[8>8Jvm4PTP[P[\iPj IIl#++6I ',]O ! 
 	
 38 ! 
 	
r!   r   )NFNN)r"   r#   r$   r%   r   config_classbase_model_prefixmain_input_namer   r&   ModulerR   rT   rU   rQ   r/   rW   r   jaxrandomPRNGKeyr   r   r   r   RESNET_INPUTS_DOCSTRINGdictr   r    __classcell__)r   s   @r   r   r     s    
  L $O"L"))"
 %;;mm 	m
 yym m!

 2 2 ! !PZ !fp !$ ++BC /3&*
 
 	

 'tn
 d^
 D
r!   r   c            	       t    e Zd ZU eed<   ej                  Zej                  ed<   d Z	 	 	 d
de	de	de	de
fdZy	)FlaxResNetModulerZ   r/   c                     t        | j                  | j                        | _        t	        | j                  | j                        | _        t        t        j                  d      | _	        y )Nr   )r   r   r   )r8   )
rY   rZ   r/   ra   r   encoderr   r&   avg_poolpoolerrG   s    r   rH   zFlaxResNetModule.setup  sF    ,T[[

K(DJJG KK$
r!   rI   r   r   rJ   c                    ||n| j                   j                  }||n| j                   j                  }| j                  ||      }| j	                  ||||      }|d   }| j                  ||j                  d   |j                  d   f|j                  d   |j                  d   f      j                  dddd      }|j                  dddd      }|s
||f|dd  z   S t        |||j                        S )	Nrf   )r   r   rI   r   r   r1   )r^   r7   r   )r   pooler_outputr   )
rZ   r   use_return_dictra   r   r   rg   r   r   r   )	r   rc   rI   r   r   embedding_outputencoder_outputsr   pooled_outputs	            r   r    zFlaxResNetModule.__call__  s-    %9$D $++JjJj 	 &1%<k$++B]B]==]=S,,!5#'	 ' 
 ,A.+11!46G6M6Ma6PQ&,,Q/1B1H1H1KL $ 
 )Aq!Q
	 	 .771aC%}58KKK;/')77
 	
r!   N)TFT)r"   r#   r$   r   rR   rT   rU   r/   rH   rW   r   r    r   r!   r   r   r   	  sW    {{E399"
 #%* &
 &
 #	&

 &
 
6&
r!   r   zOThe bare ResNet model outputting raw features without any specific head on top.c                       e Zd ZeZy)FlaxResNetModelN)r"   r#   r$   r   r   r   r!   r   r   r   @  s	    
 $Lr!   r   an  
    Returns:

    Examples:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxResNetModel
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)
    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
    >>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50")
    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
)output_typer   c                       e Zd ZU eed<   ej                  Zej                  ed<   d Zdej                  dej                  fdZ
y)FlaxResNetClassifierCollectionrZ   r/   c                 z    t        j                  | j                  j                  | j                  d      | _        y )Nr   )r/   r   )r&   DenserZ   
num_labelsr/   
classifierrG   s    r   rH   z$FlaxResNetClassifierCollection.setupf  s%    ((4;;#9#9RUVr!   r   rJ   c                 $    | j                  |      S r   )r   )r   r   s     r   r    z'FlaxResNetClassifierCollection.__call__i  s    q!!r!   N)r"   r#   r$   r   rR   rT   rU   r/   rH   rV   r    r   r!   r   r   r   b  s;    {{E399"W"#++ "#++ "r!   r   c                   j    e Zd ZU eed<   ej                  Zej                  ed<   d Z	 	 	 	 dde	fdZ
y)&FlaxResNetForImageClassificationModulerZ   r/   c                     t        | j                  | j                        | _        | j                  j                  dkD  r't        | j                  | j                        | _        y t               | _        y )Nr   r   r   )r   rZ   r/   r   r   r   r   r   rG   s    r   rH   z,FlaxResNetForImageClassificationModule.setupq  sL    &dkkL;;!!A%<T[[PTPZPZ[DO&jDOr!   NrI   c                    ||n| j                   j                  }| j                  ||||      }|r|j                  n|d   }| j	                  |d d d d ddf         }|s|f|dd  z   }|S t        ||j                        S )N)rI   r   r   r   r   r1   )logitsr   )rZ   r   r   r   r   r   r   )	r   rc   rI   r   r   outputsr   r  outputs	            r   r    z/FlaxResNetForImageClassificationModule.__call__y  s     &1%<k$++B]B]++'!5#	  
 2=--'!*q!Qz!:;Y,FM7vU\UjUjkkr!   )NTNN)r"   r#   r$   r   rR   rT   rU   r/   rH   rW   r    r   r!   r   r  r  m  s>    {{E399") "!l lr!   r  z
    ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    c                       e Zd ZeZy) FlaxResNetForImageClassificationN)r"   r#   r$   r  r   r   r!   r   r
  r
    s	     :Lr!   r
  a]  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification
    >>> from PIL import Image
    >>> import jax
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
    >>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits

    >>> # model predicts one of the 1000 ImageNet classes
    >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
    >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
    ```
)8	functoolsr   typingr   r   
flax.linenlinenr&   r   	jax.numpynumpyrT   flax.core.frozen_dictr   r   r   flax.traverse_utilr	   r
   modeling_flax_outputsr   r   r   modeling_flax_utilsr   r   r   r   utilsr   r   configuration_resnetr   RESNET_START_DOCSTRINGr   r   r   r)   rY   rl   rq   r{   r   r   r   r   r   r   r   r   r   FLAX_VISION_MODEL_DOCSTRINGr   r  r
  FLAX_VISION_CLASSIF_DOCSTRINGr   r!   r   <module>r     s     "  
  > > ; 
  Q .! H
 ryy ")) 6299 < 6RYY ""299 "J")) ,(		 (V,bii ,^;bii ;4'+		 '+T
		 
<I
 3 I
X4
ryy 4
n U$/ $	$ ( *E F  !M\h
"RYY "$lRYY $lN  :'@ ::! 6 9;X Y  $2Ziur!   