
    sg&                      d Z ddlmZ ddlZddlmZ ddlmZmZm	Z	 ddl
ZddlZddlmZ ddlmZ dd	lmZmZmZmZmZ dd
lmZmZ ddlmZmZmZmZ ddl m!Z!m"Z"m#Z#m$Z$  ejJ                  e&      Z'dZ(dZ)e G d de             Z*e G d de             Z+ G d dejX                  jZ                        Z. G d dejX                  jZ                        Z/ G d dejX                  jZ                        Z0 G d dejX                  jZ                        Z1 G d dejX                  jZ                        Z2 G d dejX                  jZ                        Z3 G d  d!ejX                  jZ                        Z4 G d" d#ejX                  jZ                        Z5 G d$ d%ejX                  jZ                        Z6 G d& d'ejX                  jZ                        Z7 G d( d)ejX                  jZ                        Z8 G d* d+ejX                  jZ                        Z9 G d, d-ejX                  jZ                        Z: G d. d/ejX                  jZ                        Z; G d0 d1ejX                  jZ                        Z< G d2 d3e      Z=d4Z>d5Z? ed6d7e>       G d8 d9e=             Z@y):z
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
discrepancy, the original file should be regarded as the 'reference' version.
    )annotationsN)	dataclass)OptionalTupleUnion   )ACT2FN)TFBaseModelOutput)TFModelInputTypeTFPreTrainedModelkeras
shape_listunpack_inputs)flattenfunctional_layernorm)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )	SamConfigSamMaskDecoderConfigSamPromptEncoderConfigSamVisionConfigr   zfacebook/sam-vit-hugec                  J    e Zd ZU dZdZded<   dZded<   dZded<   dZded	<   y)
TFSamVisionEncoderOutputa  
    Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
    layer to the pooler_output.

    Args:
        image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Ntf.Tensor | Noneimage_embeds	tf.Tensorlast_hidden_stateTuple[tf.Tensor, ...] | Nonehidden_states
attentions)	__name__
__module____qualname____doc__r   __annotations__r    r"   r#        Z/var/www/html/venv/lib/python3.12/site-packages/transformers/models/sam/modeling_tf_sam.pyr   r   +   s5    , &*L")#'y'26M/6/3J,3r*   r   c                  X    e Zd ZU dZdZded<   dZded<   dZded<   dZded<   dZ	ded	<   y)
TFSamImageSegmentationOutputa  
    Base class for Segment-Anything model's output

    Args:
        iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
            The iou scores of the predicted masks.
        pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
            The predicted low resolutions masks. Needs to be post-processed by the processor
        vision_hidden_states  (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
        vision_attentions  (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr   
iou_scores
pred_masksr!   vision_hidden_statesvision_attentionsmask_decoder_attentions)
r$   r%   r&   r'   r.   r(   r/   r0   r1   r2   r)   r*   r+   r-   r-   I   sA    6 !J	  J	 9=6=6:3:<@9@r*   r-   c                  0     e Zd ZdZ fdZd ZddZ xZS )TFSamPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                   t        |   di | |j                  |j                  }}|j                  |j
                  }}t        |t        j                  j                        r|n||f}t        |t        j                  j                        r|n||f}|d   |d   z  |d   |d   z  z  }|| _        || _        || _        || _
        t        j                  j                  |||d      | _        y )Nr   r   
projectionkernel_sizestridesnamer)   )super__init__
image_size
patch_sizenum_channelshidden_size
isinstancecollectionsabcIterablenum_patchesr   layersConv2Dr6   )	selfconfigkwargsr=   r>   r?   r@   rE   	__class__s	           r+   r<   zTFSamPatchEmbeddings.__init__t   s    "6"!'!2!2F4E4EJ
$*$7$79K9Kk#-j+//:R:R#SZZdfpYq
#-j+//:R:R#SZZdfpYq
!!}
15*Q-:VW=:XY$$(&,,--Z, . 
r*   c                V   t        |      \  }}}}|| j                  k7  rt        d      || j                  d   k7  s|| j                  d   k7  r2t        d| d| d| j                  d    d| j                  d    d	      | j	                  t        j                  |g d	            }|S )
NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r      r   r   perm)r   r?   
ValueErrorr=   r6   tf	transpose)rH   pixel_values
batch_sizer?   heightwidth
embeddingss          r+   callzTFSamPatchEmbeddings.call   s    2<\2J/
L&%4,,,w  T__Q''5DOOA4F+F$VHAeW4KDOO\]L^K__`aeapapqras`ttvw  __R\\,\%RS
r*   c                   | j                   ry d| _         t        | dd       \t        j                  | j                  j
                        5  | j                  j                  d d d | j                  g       d d d        y y # 1 sw Y   y xY w)NTr6   )builtgetattrrS   
name_scoper6   r:   buildr?   rH   input_shapes     r+   r_   zTFSamPatchEmbeddings.build   s}    ::
4t,8t334 M%%tT49J9J&KLM M 9M Ms   *A??BN)r$   r%   r&   r'   r<   rZ   r_   __classcell__rK   s   @r+   r4   r4   m   s    
 Mr*   r4   c                  .     e Zd Z fdZddZddZ xZS )TFSamMLPBlockc                "   t        |   di | t        j                  j	                  |j
                  d      | _        t        j                  j	                  |j                  d      | _        t        |j                     | _        || _        y )Nlin1r:   lin2r)   )r;   r<   r   rF   Densemlp_dimrh   r@   rj   r	   
hidden_actactrI   rH   rI   rJ   rK   s      r+   r<   zTFSamMLPBlock.__init__   sl    "6"LL&&v~~F&C	LL&&v'9'9&G	&++,r*   c                l    | j                  |      }| j                  |      }| j                  |      }|S rb   )rh   rn   rj   rH   r"   s     r+   rZ   zTFSamMLPBlock.call   s2    		-0/		-0r*   c                "   | j                   ry d| _         t        | dd       dt        j                  | j                  j
                        5  | j                  j                  d d | j                  j                  g       d d d        t        | dd       et        j                  | j                  j
                        5  | j                  j                  d d | j                  j                  g       d d d        y y # 1 sw Y   |xY w# 1 sw Y   y xY w)NTrh   rj   )r\   r]   rS   r^   rh   r:   r_   rI   r@   rj   rl   r`   s     r+   r_   zTFSamMLPBlock.build   s    ::
4&2tyy~~. G		tT[[-D-D EFG4&2tyy~~. C		tT[[-@-@ ABC C 3G GC Cs   3C9<3D9DDr"   r   returnr   rb   r$   r%   r&   r<   rZ   r_   rc   rd   s   @r+   rf   rf      s    	Cr*   rf   c                  6     e Zd ZdZd fd	Z fdZddZ xZS )TFSamLayerNormaA  LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
    c                    t        |   di | || _        || _        || _        | j                  dvrt        d| j                         y )N)channels_lastchannels_firstzUnsupported data format: r)   )r;   r<   epsdata_formatnormalized_shapeNotImplementedError)rH   r}   r{   r|   rJ   rK   s        r+   r<   zTFSamLayerNorm.__init__   sY    "6"& 0#FF%(A$BRBRAS&TUU Gr*   c                    | j                  | j                  dd      | _        | j                  | j                  dd      | _        t        |   |       y )Nonesweightshapeinitializerr:   zerosbias)
add_weightr}   r   r   r;   r_   rH   ra   rK   s     r+   r_   zTFSamLayerNorm.build   sI    ooD,A,Av\doeOO$*?*?W[aOb	k"r*   c                    | j                   dk(  r0t        || j                  | j                  | j                  d      }|S | j                   dk(  r.t        || j                  | j                  | j                  d      }|S )Nry   )r   r   epsilonaxisrz   r   )r|   r   r   r   r{   )rH   xs     r+   rZ   zTFSamLayerNorm.call   sq    .$Qt{{TXT\T\cefA  !11$Qt{{TXT\T\cdeAr*   )gư>ry   )r   r   rt   r   )r$   r%   r&   r'   r<   r_   rZ   rc   rd   s   @r+   rw   rw      s    
V#
r*   rw   c                  D     e Zd ZdZd fd	ZddZd	dZd
dZddZ xZ	S )TFSamAttentionz
    SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
    values.
    c                ~   t        |   di | |j                  | _        ||j                  n|}|j                  |z  | _        |j
                  | _        | j                  |j
                  z  dk7  rt        d      t        j                  j                  | j                  d      | _
        t        j                  j                  | j                  d      | _        t        j                  j                  | j                  d      | _        t        j                  j                  | j                  d      | _        y )	Nr   z,num_attention_heads must divide hidden_size.q_projri   k_projv_projout_projr)   )r;   r<   r@   attention_downsample_rateinternal_dimnum_attention_headsrR   r   rF   rk   r   r   r   r   )rH   rI   downsample_raterJ   rK   s       r+   r<   zTFSamAttention.__init__   s    "6"!-->M>U&::[j"../A#)#=#= v999Q>KLLll(():):(Jll(():):(Jll(():):(J**4+;+;**Mr*   c                    t        |      \  }}}}||z  }t        j                  |||z  |||f      }t        j                  |g d      S )Nr   rO   r   r   rP   r   rS   reshaperT   )rH   r"   r   batchpoint_batch_sizen_tokenschannel
c_per_heads           r+   _separate_headszTFSamAttention._separate_heads   sX    5?5N27 33


E$44h@SU_`
 ||M==r*   c                    t        |      \  }}}}t        j                  |g d      }t        j                  ||t        j                  d|g      z  ||||z  f      S )Nr   rP   r   )r   rS   rT   r   
reduce_max)rH   r"   r   r   n_headsr   r   s          r+   _recombine_headszTFSamAttention._recombine_heads   sa    /9-/H,w*]FzzbmmQ(8$9::<LhX_blXlm
 	
r*   c                   | j                  |      }| j                  |      }| j                  |      }t        |      d   }| j	                  || j
                        }| j	                  || j
                        }| j	                  || j
                        }t        |      \  }}}}t        j                  |t        j                  |g d            }|t        j                  j                  t        |            z  }t        j                  j                  |d      }t        j                  ||      }| j                  ||      }| j                  |      }|S )Nr   r   r   r   rO   rP   r   r   )r   r   r   r   r   r   rS   matmulrT   mathsqrtfloatnnsoftmaxr   r   )	rH   querykeyvaluer   _r   attnouts	            r+   rZ   zTFSamAttention.call   s#   E"kk#E"%e,Q/$$UD,D,DE""3(@(@A$$UD,D,DE )/1ayy2<<,7
 bggll5#455uu}}T}+ iie$##C)9:mmC 
r*   c                   | j                   ry d| _         t        | dd       Zt        j                  | j                  j
                        5  | j                  j                  d d | j                  g       d d d        t        | dd       Zt        j                  | j                  j
                        5  | j                  j                  d d | j                  g       d d d        t        | dd       Zt        j                  | j                  j
                        5  | j                  j                  d d | j                  g       d d d        t        | dd       [t        j                  | j                  j
                        5  | j                  j                  d d | j                  g       d d d        y y # 1 sw Y   AxY w# 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   y xY w)NTr   r   r   r   )r\   r]   rS   r^   r   r:   r_   r@   r   r   r   r   r`   s     r+   r_   zTFSamAttention.build  s   ::
44(4t{{//0 B!!4t/?/?"@AB44(4t{{//0 B!!4t/?/?"@AB44(4t{{//0 B!!4t/?/?"@AB4T*6t}}112 E##T41B1B$CDE E 7B BB BB BE Es0   )F32)G )G )G3F= G	GG!rb   )r"   r   r   intrt   r   )r"   r   r   r   rt   r   )r   r   r   r   r   r   rt   r   )
r$   r%   r&   r'   r<   r   r   rZ   r_   rc   rd   s   @r+   r   r      s#    
N >
6Er*   r   c                  H     e Zd Zdd fdZ	 d	 	 	 	 	 	 	 	 	 ddZddZ xZS )	TFSamTwoWayAttentionBlockc                   t        |   di | |j                  | _        |j                  | _        t	        |dd      | _        t        j                  j                  | j                  d      | _	        t	        ||d      | _
        t        j                  j                  | j                  d      | _        t        |d	      | _        t        j                  j                  | j                  d
      | _        t        j                  j                  | j                  d      | _        t	        ||d      | _        || _        y)a  
        A transformer block with four layers:
            (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
            sparse inputs (4) cross attention of dense inputs -> sparse inputs

        Arguments:
            config (`SamMaskDecoderConfig`):
                The configuration file used to instantiate the block
            attention_downsample_rate (*optionalk*, int, defaults to 2):
                The downsample ratio of the block used to reduce the inner dim of the attention.
            skip_first_layer_pe (*optional*, bool, defaults to `False`):
                Whether or not to skip the addition of the query_point_embedding on the first layer.
        r   	self_attn)r   r:   layer_norm1r   r:   cross_attn_token_to_imagelayer_norm2mlpri   layer_norm3layer_norm4cross_attn_image_to_tokenNr)   )r;   r<   r@   layer_norm_epsr   r   r   rF   LayerNormalizationr   r   r   rf   r   r   r   r   skip_first_layer_pe)rH   rI   r   r   rJ   rK   s        r+   r<   z"TFSamTwoWayAttentionBlock.__init__#  s    	"6"!--$33'T <<::4CVCV]j:k)7$=D_*
& !<<::4CVCV]j:k e4 <<::4CVCV]j:k <<::4CVCV]j:k)7$=D_*
& $7 r*   c                   | j                   r| j                  |||      }n||z   }| j                  |||      }||z   }| j                  |      }||z   }||z   }| j                  |||      }||z   }| j	                  |      }| j                  |      }	||	z   }| j                  |      }||z   }||z   }| j                  |||      }||z   }| j                  |      }||f}
|r|
|fz   }
|
S |
dz   }
|
S )Nr   r   r   rb   )	r   r   r   r   r   r   r   r   r   )rH   querieskeysquery_point_embeddingkey_point_embeddingoutput_attentionsr   attn_outr   mlp_outoutputss              r+   rZ   zTFSamTwoWayAttentionBlock.callH  sE    ##nn7wnOG33E~~EuG~LH(G""7+ //((113d1SH$""7+ ((7#G#""7+ //((11g1Vh%D/+G  'Gr*   c                
   | j                   ry d| _         t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       [t        j                  | j                  j
                        5  | j                  j                  d d d | j                  g       d d d        t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       [t        j                  | j                  j
                        5  | j                  j                  d d d | j                  g       d d d        t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       [t        j                  | j                  j
                        5  | j                  j                  d d d | j                  g       d d d        t        | dd       [t        j                  | j                  j
                        5  | j                  j                  d d d | j                  g       d d d        t        | d	d       Nt        j                  | j                  j
                        5  | j                  j                  d        d d d        y y # 1 sw Y   xY w# 1 sw Y   _xY w# 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   jxY w# 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   y xY w)
NTr   r   r   r   r   r   r   r   )r\   r]   rS   r^   r   r:   r_   r   r@   r   r   r   r   r   r   r`   s     r+   r_   zTFSamTwoWayAttentionBlock.buildy  s   ::
4d+7t~~223 +$$T*+4-9t//445 M  &&dD$:J:J'KLM44d;Gt==BBC ;..44T:;4-9t//445 M  &&dD$:J:J'KLM4%1txx}}- %t$%4-9t//445 M  &&dD$:J:J'KLM4-9t//445 M  &&dD$:J:J'KLM44d;Gt==BBC ;..44T:; ; H)+ +M M; ;M M% %M MM M; ;s`   L%*L,L9'*MM)*M *M-9M9L),L69MMM M*-M69N)rO   F)r   r   r   boolF)
r   r   r   r   r   r   r   r   r   r   rb   ru   rd   s   @r+   r   r   "  sI    #7V #(// /  )	/
 '/  /b;r*   r   c                  R     e Zd Zd fdZ	 	 	 d	 	 	 	 	 	 	 	 	 	 	 	 	 ddZddZ xZS )TFSamTwoWayTransformerc           	     t   t        |   di | || _        |j                  | _        g | _        t        | j                        D ]/  }| j                  j                  t        ||dk(  d|              1 t        |d      | _	        t        j                  j                  |j                  d      | _        y )	Nr   	layers_._)r   r:   final_attn_token_to_imageri   layer_norm_final_attnr   r)   )r;   r<   rI   num_hidden_layersrF   rangeappendr   r   r   r   r   r   r   )rH   rI   rJ   irK   s       r+   r<   zTFSamTwoWayTransformer.__init__  s    "6"!'!9!9t--. 	vAKK8VW[\V\fopqordstu	v *8E`)a&%*\\%D%D))0G &E &
"r*   c                *   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }d}|t	        d      t        j                  t        |d      d      d d d f   }t        j                  t        |d      d      d d d f   }|}|}	| j                  D ]  }
 |
||	|||      \  }}	}|s||fz   } ||z   }|	|z   }| j                  |||	      }||z   }| j                  |      }||	|fS )Nr)   z&You have to specify an image_embeddingrO   )r   rO   r   rP   )r   r   r   r   r   r   )rI   r   output_hidden_statesuse_return_dictrR   rS   rT   r   rF   r   r   )rH   point_embeddingsimage_embeddingsimage_positional_embeddingsr   r   return_dictall_attentionsr   r   layerattention_outputsr   r   r   s                  r+   rZ   zTFSamTwoWayTransformer.call  sb    2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B]#EFF<<0@!(D9UVWY]V]^&(ll7;VXY3Z\e&fghjngn&o# # [[ 
	GE/4&6$?"30,GT, !!/3D2F!F
	G **00113d1SH$,,W5n,,r*   c                   | j                   ry d| _         t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       et        j                  | j                  j
                        5  | j                  j                  d d d | j                  j                  g       d d d        | j                  D ];  }t        j                  |j
                        5  |j                  d        d d d        = y # 1 sw Y   xY w# 1 sw Y   `xY w# 1 sw Y   `xY w)NTr   r   )r\   r]   rS   r^   r   r:   r_   r   rI   r@   rF   rH   ra   r   s      r+   r_   zTFSamTwoWayTransformer.build  s   ::
44d;Gt==BBC ;..44T:;40$7Ct99>>? ^**00$dDKKD[D[1\]^[[ 	"Euzz* "D!" "	"; ;^ ^" "s$   D,%4D8E,D58EE	rI   r   NNN)r   r   r   r   r   r   r   Optional[bool]r   r   r   r   rt   zUnion[Tuple, TFBaseModelOutput]rb   ru   rd   s   @r+   r   r     sb    
( -1/3&*0-#0- $0- &/	0-
 *0- -0- $0- 
)0-d"r*   r   c                  D     e Zd Z	 d	 	 	 	 	 	 	 	 	 d fdZd ZddZ xZS )TFSamFeedForwardc           	        t        |   di | || _        t        j                  j                         | _        t        j                  j                  ||fd      | _        t        j                  j                  ||fd      | _	        t        |dz
        D cg c](  }t        j                  j                  ||fd|       * c}| _        || _        || _        || _        y c c}w )Nproj_in)ra   r:   proj_outrO   r   r)   )r;   r<   
num_layersr   rF   ReLU
activationrk   r   r   r   sigmoid_output
hidden_dim	input_dim)	rH   r   r   
output_dimr   r   rJ   r   rK   s	           r+   r<   zTFSamFeedForward.__init__  s     	"6"$,,++-||))*9,U^)_**:J=Wa*b :>*
 LLz
}YWXVY?[
 -$"
s   -C&c                    | j                  |      }| j                  |      }| j                  D ]  }| j                   ||            } | j                  |      }| j                  rt        j                  |      }|S rb   )r   r   rF   r   r   rS   sigmoid)rH   r"   r   s      r+   rZ   zTFSamFeedForward.call  ss    ]36[[ 	BE OOE-,@AM	B m4JJ}5Mr*   c                   | j                   ry d| _         t        | dd       Zt        j                  | j                  j
                        5  | j                  j                  d d | j                  g       d d d        t        | dd       Zt        j                  | j                  j
                        5  | j                  j                  d d | j                  g       d d d        t        | dd       X| j                  D ]H  }t        j                  |j
                        5  |j                  d d | j                  g       d d d        J y y # 1 sw Y   xY w# 1 sw Y   {xY w# 1 sw Y   nxY w)NTr   r   rF   )r\   r]   rS   r^   r   r:   r_   r   r   r   rF   r   s      r+   r_   zTFSamFeedForward.build  s1   ::
4D)5t||001 A""D$#?@A4T*6t}}112 C##T4$ABC44(4 ?]]5::. ?KKtT__ =>? ?? 5A AC C? ?s$   )E	2)EE!	EE!E*	r   )
r   r   r   r   r   r   r   r   r   r   rb   ru   rd   s   @r+   r   r     s=    hm##*-#;>#LO#ae# 	?r*   r   c                  N     e Zd Zd fdZddZ	 d	 	 	 	 	 	 	 	 	 	 	 	 	 ddZ xZS )TFSamMaskDecoderc           
     :   t        |   di | |j                  | _        |j                  | _        |j                  dz   | _        t        |d      | _        t        j                  j                  | j                  dz  dddd      | _
        t        j                  j                  | j                  d	z  ddd
d      | _        t        | j                  dz  dd      | _        t        j                  j                   | _        g }t%        | j                        D ]:  }|t'        | j                  | j                  | j                  d	z  dd|       gz  }< || _        t'        | j                  |j*                  | j                  |j,                  d      | _        y )Nr   transformerri      rO   upscale_conv1rz   )r8   r9   r:   r|      upscale_conv2upscale_layer_norm)r|   r:   r   zoutput_hypernetworks_mlps_._iou_prediction_headr)   )r;   r<   r@   num_multimask_outputsnum_mask_tokensr   r   r   rF   Conv2DTransposer  r  rw   r  rS   r   gelur   r   r   output_hypernetworks_mlpsiou_head_hidden_dimiou_head_depthr  )rH   rI   rJ   	mlps_listr   rK   s        r+   r<   zTFSamMaskDecoder.__init__  s   "6"!--%+%A%A"%;;a?1&}M"\\99!q!/_o : 
 #\\99!q!/_o : 
 #1!/?FZ#
 %%**	t++, 		A $$$$$$)7s; I		 *3&#3&&  !!&$
 r*   c                   | j                   ry d| _         | j                  d| j                  fdd      | _        | j                  | j                  | j                  fdd      | _        t        | dd       Mt        j                  | j                  j                        5  | j                  j                  d        d d d        t        | dd       [t        j                  | j                  j                        5  | j                  j                  d | j                  d d g       d d d        t        | dd       ^t        j                  | j                  j                        5  | j                  j                  d | j                  d	z  d d g       d d d        t        | d
d       Mt        j                  | j                  j                        5  | j                  j                  d        d d d        t        | dd       Mt        j                  | j                  j                        5  | j                  j                  d        d d d        | j                   D ];  }t        j                  |j                        5  |j                  d        d d d        = y # 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   #xY w# 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   xY w)NTr   ziou_token.weight)r   r:   	trainablezmask_tokens.weightr   r  r  r   r  r  )r\   r   r@   	iou_tokenr  mask_tokensr]   rS   r^   r   r:   r_   r  r  r  r  r
  )rH   ra   r   s      r+   r_   zTFSamMaskDecoder.build?  so   ::
43C3C/DK]imn??'')9)9:AUae + 
 4-9t//445 -  &&t,-4$/;t11667 O""(($0@0@$)MNO4$/;t11667 T""(($0@0@A0EtT)RST4-t4@t66;;< 4''--d344.5At77<<= 5((..t4511 	 Csxx(  		$   	 - -O OT T4 45 5   sH   J9*J"!-J/J<&K8KJ"J,/J9<KKK	c           
        t        |      \  }}}	}
t        j                  j                  dt        j                  |      d         }t        j
                  | j                  | j                  gd      }t        j                  |d d d d f   ||ddg      }t        |      d   dk7  rt        j
                  ||fd      }n|}t        j                  || j                  j                        }||z   }t        j                  ||d      }t        j                  ||d      }| j                  ||||      \  }}}|d d d d dd d f   }|d d d d dd| j                  z   d d f   }t        j                  |d      }t        j                  |||z  ||	|
g      }| j!                  |      }| j#                  | j%                  |            }| j#                  | j'                  |            }g }t)        | j                        D ]*  }| j*                  |   }| ||d d d d |d d f         gz  }, t        j,                  |d      }t        |      \  }}}	}
t        j                  |||||	|
z  g      }t        j                  ||z  ||d|	|
g      }| j/                  |      }|rt1        dd       }nt1        dd      }|d d d d |d d d d f   }|d d d d |f   }||f}|r||fz   }|S |d	z   }|S )
Nr   r   r   rO   )r   r   r   r   r   rP   r   rb   )r   rS   r   maximumr   concatr  r  tilecastdtyperepeatr   r  rT   r   r  r   r  r  r   r
  stackr  slice)rH   r   r   sparse_prompt_embeddingsdense_prompt_embeddingsmultimask_outputr   rV   r?   rW   rX   r   output_tokenstokensr   point_embeddingr#   iou_token_outmask_tokens_outupscaled_embeddinghyper_in_listr   current_mlphyper_inr   masksiou_pred
mask_slicer   s                                r+   rZ   zTFSamMaskDecoder.call[  s&    3==M2N/
L&%77??1bhh7O.PQR.ST		4>>43C3C"D1M$a-(:7GA*N
 ./2a7YY/GHqQF"F7764>>+?+?@+.EE99%57GaP&(ii0KM]de&f#8<8H8H--(C/	 9I 9
5): (1a
3)!QQ9M9M5M0NPQ*QR<<(8|L::&6FV9VXdflns8tu!//0@A!__T-D-DEW-XY!__T-?-?@R-STt++, 	HA88;Kk/!Q1**EFGGM	H 88M2)34F)G&<ZZ-=|VV[^ \
 

8&88:GWY[]cej:kl++M:q$Jq!JaJ1,-Aq*,-(#-G  'Gr*   r   rb   )r   r   r   r   r  r   r  r   r  r   r   r   rt   Tuple[tf.Tensor, tf.Tensor]r$   r%   r&   r<   r_   rZ   rc   rd   s   @r+   r   r     se    (
T F -1J#J &/J #,	J
 "+J J *J 
%Jr*   r   c                  0     e Zd Z fdZ fdZddZ xZS )TFSamPositionalEmbeddingc                Z    t        |   di | |j                  dz  | _        || _        y )NrO   r)   )r;   r<   r@   scalerI   ro   s      r+   r<   z!TFSamPositionalEmbedding.__init__  s,    "6"''1,
r*   c                    | j                  dd| j                  j                  ft        j                  j                  d| j                        d      | _        t        | %  |       y )Npositional_embeddingrO           meanstddevFr:   r   r   r  )
r   rI   num_pos_featsr   initializersRandomNormalr/  r1  r;   r_   r   s     r+   r_   zTFSamPositionalEmbedding.build  s\    $(OO'dkk//0**77S7T	 %4 %
! 	k"r*   c           
        t        j                  |      }|t        j                  t        j                  |dddddddf   t         j                        |d   z  t        j                  |dddddddf   t         j                        |d   z  gd      }d|z  dz
  }t        j                  || j
                  j                        }t        j                  || j
                        }dt        j                  z  |z  }t        j                  t        j                  |      t        j                  |      gd      S )z8Positionally encode points that are normalized to [0,1].Nr   r   r   r   rO   )rS   identityr  r  float32r1  r  r   nppir  sincos)rH   input_coordsra   coordinatess       r+   rZ   zTFSamPositionalEmbedding.call  s    kk,/"((GGK1a
3RZZ@;q>QGGK1a
3RZZ@;q>Q K +o)ggk4+D+D+J+JKiiT-F-FG"%%i+-yy"&&-rvvk/BC"MMr*   rb   r+  rd   s   @r+   r-  r-    s    
#Nr*   r-  c                  .     e Zd Zd fdZd ZddZ xZS )TFSamMaskEmbeddingc                V   t        |   di | |j                  dz  | _        t        |j                     | _        t        j                  j                  | j                  ddd      | _	        t        j                  j                  |j                  ddd      | _
        t        j                  j                  |j                  dd      | _        t        | j                  |j                  d	
      | _        t        | j                  dz  |j                  d
      | _        || _        y )Nr   rO   conv1r7   conv2r   conv3)r8   r:   r   ri   r   r)   )r;   r<   mask_input_channelsr	   rm   r   r   rF   rG   rF  rG  r@   rH  rw   r   r   r   rI   ro   s      r+   r<   zTFSamMaskEmbedding.__init__  s    "6"#)#=#=#B  !2!23\\(()A)AqZ[bi(j
\\(()C)CQR\]dk(l
\\((););QX(Y
)$*B*BFDYDY`mn)$*B*BQ*FH]H]dqrr*   c                P   t        j                  |d      }| j                  |      }| j                  |      }| j	                  |      }| j                  |      }| j                  |      }| j	                  |      }| j                  |      }t        j                  |d      }|S )NrN   rP   r   r   r   rO   )rS   rT   rF  r   r   rG  r   rH  )rH   r'  r"   dense_embeddingss       r+   rZ   zTFSamMaskEmbedding.call  s    U6

5)((76

=1((76::m4<<(8|Lr*   c                `   | j                   ry d| _         t        j                  d      5  | j                  j	                  g d       d d d        t        j                  d      5  | j
                  j	                  d d d | j                  g       d d d        t        j                  d      5  | j                  j	                  d d d | j                  dz  g       d d d        t        j                  d      5  | j                  j	                  d d d | j                  g       d d d        t        j                  d      5  | j                  j	                  d d d | j                  dz  g       d d d        y # 1 sw Y   -xY w# 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   yxY w# 1 sw Y   y xY w)	NTrF  )NNNr   rG  rH  r   r   r   )
r\   rS   r^   rF  r_   rG  rI  rH  r   r   r`   s     r+   r_   zTFSamMaskEmbedding.build  st   ::
]]7# 	4JJ23	4]]7# 	KJJdD$0H0HIJ	K]]7# 	OJJdD$0H0H10LMN	O]]=) 	Q""D$d6N6N#OP	Q]]=) 	U""D$d6N6NQR6R#ST	U 	U	4 	4	K 	K	O 	O	Q 	Q	U 	Us;   E3%*F ,-F6*F=-F$3E= F	FF!$F-rI   r   rb   ru   rd   s   @r+   rD  rD    s    	 Ur*   rD  c                  X     e Zd Zd fdZddZddZd	dZ	 	 	 	 	 	 	 	 	 	 	 	 d
dZ xZS )TFSamPromptEncoderc                   t        |   di | || _        t        |d      | _        d | _        |j                  |j                  f| _        |j                  | _        g | _	        |j                  | _
        d | _        || _        y )N
mask_embedri   r)   )r;   r<   shared_embeddingrD  rR  no_mask_embedimage_embedding_sizer=   input_image_sizepoint_embedr@   not_a_point_embedrI   )rH   rI   shared_patch_embeddingrJ   rK   s       r+   r<   zTFSamPromptEncoder.__init__  s}    "6" 6,V,G!%+%@%@&B]B]$^! & 1 1!--!%r*   c                
   | j                  dd| j                  ft        j                  j	                  dd      d      | _        t        | j                  j                        D cg c]F  }| j                  d| d	d| j                  ft        j                  j	                  dd      d      H c}| _	        | j                  d
d| j                  ft        j                  j	                  dd      d      | _
        t        j                  d      5  | j                  j                  d | j                  j                  | j                  j                   | j                  j                   f       d d d        | j"                  ry d| _        t%        | dd       Nt        j                  | j                  j&                        5  | j                  j                  d        d d d        y y c c}w # 1 sw Y   ~xY w# 1 sw Y   y xY w)Nzno_mask_embed.weightr   r2  g{Gz?r3  Tr6  zpoint_embed_._z.weightznot_a_point_embed.weightrR  )r   r@   r   r8  r9  rT  r   rI   num_point_embeddingsrW  rX  rS   r^   rR  r_   rI  r=   r\   r]   r:   )rH   ra   r   s      r+   r_   zTFSamPromptEncoder.build  s   !__'d&&'**77S7N	 - 
 4;;;;<
  OO%aS0$**+!..;;T;R	  
 "&+d&&'**77S7N	 "1 "
 ]]<( 	OO!!t{{668N8NPTP[P[PfPfg	 ::
4t,8t334 ,%%d+, , 91
	 	, ,s    'AG(AG-G9-G69Hc                p   |dz   }|rt        |      d   t        |      d   dt        |      d   f}t        |      d   t        |      d   df}t        j                  ||j                        }t        j                  ||j                         }t        j
                  ||gd      }t        j
                  ||gd      }| j                  | j                  f}| j                  ||      }	t        j                  |d   dk(  | j                  d   |	      }	t        j                  |d   d	k7  |	t        j                  |	            }	t        j                  |dk(  d
d
d
d
d
d
d
f   |	| j                  d   z   |	      }	t        j                  |dk(  d
d
d
d
d
d
d
f   |	| j                  d   z   |	      }	|	S )zEmbeds point prompts.      ?r   r   r   r  rO   r   ).NiN)r   rS   r   r  r   r  rV  rS  whererX  
zeros_likerW  )
rH   pointslabelspadtarget_point_shapetarget_labels_shapepadding_pointpadding_labelra   r   s
             r+   _embed_pointsz TFSamPromptEncoder._embed_points*  s   #",V"4Q"7F9KA9NPQS]^dSefhSi!j#-f#5a#8*V:LQ:OQR"SHH%7v||LMWW%8MMMYY6Q?FYY6Q?F,,d.C.CD//D((6)#4#:D<R<RST<UWfg((9$MM/*

 ((q[!Q4-(/D<L<LQ<O*OQ`
 ((q[!Q4-(/D<L<LQ<O*OQ`
 r*   c                   |dz   }t        |      dd \  }}t        j                  |||ddf      }| j                  | j                  f}| j	                  ||      }|t        j
                  t        j                  t        |      d         dddddf   dk(  | j                  d   d   | j                  d   d         z  }|S )zEmbeds box prompts.r]  NrO   r   r   )r   rS   r   rV  rS  r_  r   rW  )rH   boxesrV   nb_boxescoordsra   corner_embeddings          r+   _embed_boxeszTFSamPromptEncoder._embed_boxesF  s    )%0!4
HEJ!Q#?@,,d.C.CD00EBHHHHZ 01!45dD!T6IJaOQ"Q"
 	

  r*   c                   d}||t        |      dd \  }}|t        d      | j                  |||du       }t        j                  ||d| j
                  f|j                        }t        j                  ||gd      }|=t        |      d   }| j                  |      }	||	}nt        j                  ||	gd      }|| j                  |      }
nY| j                  d   }
t        j                  |
d      }
t        j                  |
|d	| j                  d   | j                  d	   f      }
|/t        j                  |dd	| j
                  f|
j                        }||
fS )
al  
        Embeds different types of prompts, returning both sparse and dense embeddings.

        Args:
            points (`tf.Tensor`, *optional*):
                point coordinates and labels to embed.
            boxes (`tf.Tensor`, *optional*):
                boxes to embed
            masks (`tf.Tensor`, *optional*):
                masks to embed
        NrO   z5If points are provided, labels must also be provided.)rc  r   r^  r   )r   r   r   r   r   )r   rR   rh  rS   r   r@   r  r  rn  rR  rT  r   r  rU  )rH   rV   input_pointsinput_labelsinput_boxesinput_maskssparse_embeddingsr   r   box_embeddingsrL  s              r+   rZ   zTFSamPromptEncoder.callT  s   & !#+5l+CBQ+G(J(# !XYY#11,S^bfSf1h "-q$2B2BCK[KaKa! !#		+<>N*OVW X"#K03J!..{;N ($2!$&II/@..QXY$Z!"#{;#11!4!zz*:MJ!ww :q$2K2KA2NPTPiPijkPl"m  $ "*aD<L<L)MUeUkUk l "222r*   rN  rb   )ra  r   rb  r   rc  r   rt   r   )rj  r   rt   r   )rV   zOptional[int]rp  z%Optional[Tuple[tf.Tensor, tf.Tensor]]rq  r   rr  r   rs  r   rt   r*  )	r$   r%   r&   r<   r_   rh  rn  rZ   rc   rd   s   @r+   rP  rP    sW    !,F8 /3!/3 </3 '	/3
 &/3 &/3 
%/3r*   rP  c                  `     e Zd ZdZ fdZddZddZ	 	 	 	 	 	 	 	 	 	 	 	 	 	 d	dZd
ddZ xZ	S )TFSamVisionAttentionz=Multi-head Attention block with relative position embeddings.c                ~   t        |   d	i | |dk(  r2|j                  |j                  z  |j                  |j                  z  fn||f}|| _        |j
                  | _        |j                  |j
                  z  }|| _        |dz  | _        |j                  | _
        t        j                  j                  |j                  dz  |j                  d      | _        t        j                  j                  |j                  d      | _        |j"                  | _        | j"                  r|t%        d      || _        y )
Nr   g      r   qkv)use_biasr:   projri   zBInput size must be provided if using relative positional encoding.r)   )r;   r<   r=   r>   
input_sizer   r@   head_dimr/  attention_dropoutdropoutr   rF   rk   qkv_biasry  r{  use_rel_posrR   rI   )rH   rI   window_sizerJ   r|  r}  rK   s         r+   r<   zTFSamVisionAttention.__init__  s    "6" a &"3"33V5F5F&J[J[5[\{+ 	
 %#)#=#= %%)C)CC t^
//<<%%f&8&81&<v]b%cLL&&v'9'9&G	!--! !effr*   c                   | j                   p| j                  d| j                   d   z  dz
  | j                  fdd      | _        | j                  d| j                   d   z  dz
  | j                  fdd      | _        | j
                  ry d| _        t        | d	d       dt        j                  | j                  j                        5  | j                  j                  d d | j                  j                  g       d d d        t        | d
d       et        j                  | j                  j                        5  | j                  j                  d d | j                  j                  g       d d d        y y # 1 sw Y   |xY w# 1 sw Y   y xY w)NrO   r   r   r   	rel_pos_hr   	rel_pos_wTry  r{  )r|  r   r}  r  r  r\   r]   rS   r^   ry  r:   r_   rI   r@   r{  r`   s     r+   r_   zTFSamVisionAttention.build  sX   ??&!__4??1--14==Aw]h - DN "__4??1--14==Aw]h - DN ::
4%1txx}}- FdDKK,C,CDEF4&2tyy~~. G		tT[[-D-D EFG G 3F FG Gs   3E583F5E>F
c                   t        dt        ||      z  dz
        }|j                  d   |k7  rnt        j                  j                  t        j                  |d|j                  d   df      ||j                  d   fd      }t        j                  |d|f      }n|}t        j                  t        j                  |t        j                        d      t        ||z  d      z  }t        j                  t        j                  |t        j                        d      t        ||z  d      z  }||z
  |dz
  t        ||z  d      z  z   }t        j                  |t        j                  |t        j                              S )	a  
        Get relative positional embeddings according to the relative positions of
            query and key sizes.

        Args:
            q_size (int):
                size of the query.
            k_size (int):
                size of key k.
            rel_pos (`tf.Tensor`):
                relative position embeddings (L, channel).

        Returns:
            Extracted positional embeddings according to relative positions.
        rO   r   r   r   bilinear)sizemethodr^  g      ?)r   maxr   rS   imageresizer   expand_dimsr   r<  gatherr  int32)	rH   q_sizek_sizerel_posmax_rel_distrel_pos_resizedq_coordsk_coordsrelative_coordss	            r+   get_rel_posz TFSamVisionAttention.get_rel_pos  s=     1s6622Q67==|+ hhoo

7Qa(8"$=>"GMM!$45! . O
 !jj2|:LMO%O >>"((6"DaH3vX^`cKdd>>"((6"DaH3vX^`cKdd#h.6A:Vf_VYAZ2ZZyy"''/288*LMMr*   c                   |\  }}|\  }	}
| j                  ||	|      }| j                  ||
|      }t        |      \  }}}t        j                  |||||f      }t        j                  d||      }t        j                  d||      }t        j                  |||||	|
f      }|t        j
                  |d      z   t        j
                  |d      z   }t        j                  ||||z  |	|
z  f      }|S )a  
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            attn (`tf.Tensor`):
                attention map.
            query (`tf.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`tf.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`tf.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            attn (`tf.Tensor`):
                attention map with added relative positional embeddings.
        zbhwc,hkc->bhwkzbhwc,wkc->bhwkr   r   )r  r   rS   r   einsumr  )rH   r   r   r  r  r  r  query_heightquery_width
key_height	key_widthrelative_position_heightrelative_position_widthrV   r   dimreshaped_queryrel_hrel_ws                      r+   add_decomposed_rel_posz+TFSamVisionAttention.add_decomposed_rel_pos  s    > %+!k &
I#'#3#3L*i#X "&"2"2;	9"U'.
AsEJkSV+WX		*N<TU		*N<STzz$\;
T] ^_bnnU44r~~eRT7UUzz$\K-GV_I_ `ar*   c           	        t        |      \  }}}}t        j                  | j                  |      |||z  d| j                  df      }t        j
                  |d      }t        j                  t        j                  |d|| j                  z  ||z  df      d      \  }	}
}t        j                  |	| j                  z  |
d      }| j                  r.| j                  ||	| j                  | j                  ||f||f      }t        j                  j                  |d      }|r,t        j                  j                  || j                  	      }n|}t        j                  ||z  || j                  ||df      }t        j
                  |d
      }t        j                  ||||| j                   j"                  f      }| j%                  |      }|r||f}|S |d f}|S )Nr   r   )rO   r   r   r   r   rP   r   r   T)transpose_b)rate)r   rO   r   r   r   )r   rS   r   ry  r   rT   unstackr   r/  r  r  r  r  r   r   r  rI   r@   r{  )rH   r"   r   trainingrV   rW   rX   r   ry  r   r   r   attn_weights
attn_probsattn_outputr   s                   r+   rZ   zTFSamVisionAttention.call  s   '1-'@$
FE1jj-0:v~qRVRjRjln2opll3_5JJJJsQ
T-E-E EvPU~WYZ[bc
sE yy!3SdK66eT^^T^^fe_W]_dVeL uu}}\};|$,,GJ%Jjje!3j$BZBZ\bdikm5noll;_Ejjz65$++JaJa.bcii,"L1G  #D)Gr*   rb   )r  r   r  r   r  r   rt   r   )r   r   r   r   r  r   r  r   r  Tuple[int, int]r  r  rt   r   FFrs   )
r$   r%   r&   r'   r<   r_   r  r  rZ   rc   rd   s   @r+   rw  rw    sl    G0G("NH++ + 	+
 +  +  + 
+Z" "r*   rw  c                  f     e Zd Z fdZddZ	 	 	 	 	 	 	 	 	 	 ddZ	 	 d	 	 	 	 	 	 	 d	dZd
dZ xZS )TFSamVisionLayerc                J   t        |   di | t        j                  j	                  |j
                  d      | _        t        ||d      | _        t        j                  j	                  |j
                  d      | _	        t        |d      | _        || _        || _        y )Nr   r   r   ri   r   r   r)   )r;   r<   r   rF   r   r   r   rw  r   r   rf   r   r  rI   )rH   rI   r  rJ   rK   s       r+   r<   zTFSamVisionLayer.__init__,  s    "6" <<::6CXCX_l:m(6J	 <<::6CXCX_l:m e4&r*   c           	     l   t        |      \  }}}}|||z  z
  |z  }|||z  z
  |z  }|dkD  s|dkD  r"t        j                  |ddgd|gd|gddgg      }||z   ||z   }
}	t        j                  |||	|z  ||
|z  ||g      }t        j                  t        j                  |g d      d|||g      }||	|
ffS )Nr   r   r   r   rO   r      rP   r   )r   rS   rc  r   rT   )rH   r"   r  rV   rW   rX   r   pad_hpad_w
pad_height	pad_widthwindowss               r+   window_partitionz!TFSamVisionLayer.window_partition5  s    -7-F*
FE7v33{Bu{22kA19	FF=Aq6Au:5zTUWXSY2Z[M &I


{2KkAY[fhop
 **LL-?@2{T_ahBi
 Y///r*   c           	     *   |\  }}|\  }}t        |      d   ||z  |z  |z  z  }	t        j                  ||	||z  ||z  ||dg      }
t        j                  t        j                  |
g d      |	||dg      }
||kD  s||kD  r|
d d d |d |d d f   }
|
S )Nr   r   r  rP   r   )rH   r  r  padding_shapeoriginal_shaper  r  rW   rX   rV   r"   s              r+   window_unpartitionz#TFSamVisionLayer.window_unpartitionG  s     !.
I&(+
Y0F+0UYd0de


j*";Y+=UWbdoqst
 

LL-?@:z[dfhBi
 )e"3)!WfWfufa*?@Mr*   c                   |}| j                  |      }| j                  dkD  r=|j                  d   |j                  d   }}| j                  || j                        \  }}| j	                  |||      \  }}| j                  dkD  r | j                  || j                  f      }||z   }| j                  |      }	|| j                  |	      z   }|f}
|r|
|fz  }
|
S )Nr   r   rO   )r"   r   r  )r   r  r   r  r   r  r   r   )rH   r"   r   r  residualrW   rX   r  r  layernorm_outputr   s              r+   rZ   zTFSamVisionLayer.callX  s    !((7a)//2M4G4G4JEF+/+@+@PTP`P`+a(M=&*ii'/ '0 '
#|
 a 33M4CSCSUbekmrdstM =0++M:%1A(BB "&Gr*   c                   | j                   ry d| _         t        | dd       et        j                  | j                  j
                        5  | j                  j                  d d d | j                  j                  g       d d d        t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       et        j                  | j                  j
                        5  | j                  j                  d d d | j                  j                  g       d d d        t        | dd       Nt        j                  | j                  j
                        5  | j                  j                  d        d d d        y y # 1 sw Y   2xY w# 1 sw Y   xY w# 1 sw Y   ~xY w# 1 sw Y   y xY w)NTr   r   r   r   )r\   r]   rS   r^   r   r:   r_   rI   r@   r   r   r   r`   s     r+   r_   zTFSamVisionLayer.buildw  s{   ::
4-9t//445 T  &&dD$++:Q:Q'RST4&2tyy~~. &		%&4-9t//445 T  &&dD$++:Q:Q'RST4%1txx}}- %t$% % 2T T& &T T% %0   4F/=F<4G	G/F9<GGG)r"   r   r  r   rt   z!Tuple[tf.Tensor, Tuple[int, int]])
r  r   r  r   r  r  r  r  rt   r   r  )r"   r   r   r   r  r   rt   zTuple[tf.Tensor]rb   )	r$   r%   r&   r<   r  r  rZ   r_   rc   rd   s   @r+   r  r  +  sr    0$ /2CRds	( -2#(	  * !	
 
>%r*   r  c                  .     e Zd Zd fdZd ZddZ xZS )TFSamVisionNeckc                l   t        |   di | || _        t        j                  j                  |j                  ddd      | _        t        |j                  d      | _	        t        j                  j                  |j                  dddd	
      | _
        t        |j                  d      | _        y )Nr   FrF  )r8   rz  r:   r   ri   r   samerG  )r8   paddingrz  r:   r   r)   )r;   r<   rI   r   rF   rG   output_channelsrF  rw   r   rG  r   ro   s      r+   r<   zTFSamVisionNeck.__init__  s    "6"\\((""	 ) 

 *&*@*@}U\\(("" ) 

 *&*@*@}Ur*   c                    | j                  |      }| j                  |      }| j                  |      }| j                  |      }t	        j
                  |g d      }|S )NrK  rP   )rF  r   rG  r   rS   rT   rq   s     r+   rZ   zTFSamVisionNeck.call  sT    

=1((7

=1((7]Fr*   c                   | j                   ry d| _         t        | dd       et        j                  | j                  j
                        5  | j                  j                  d d d | j                  j                  g       d d d        t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       et        j                  | j                  j
                        5  | j                  j                  d d d | j                  j                  g       d d d        t        | dd       Nt        j                  | j                  j
                        5  | j                  j                  d        d d d        y y # 1 sw Y   2xY w# 1 sw Y   xY w# 1 sw Y   ~xY w# 1 sw Y   y xY w)NTrF  r   rG  r   )r\   r]   rS   r^   rF  r:   r_   rI   r@   r   rG  r  r   r`   s     r+   r_   zTFSamVisionNeck.build  s   ::
4$'3tzz/ N

  $dDKK4K4K!LMN4-9t//445 -  &&t,-4$'3tzz/ R

  $dDKK4O4O!PQR4-9t//445 -  &&t,- - :N N- -R R- -r  rI   r   rb   ru   rd   s   @r+   r  r    s    V(-r*   r  c                  X     e Zd Zd fdZddZd Z	 	 	 	 	 d	 	 	 	 	 	 	 	 	 	 	 ddZ xZS )	TFSamVisionEncoderc                x   t        |   di | || _        |j                  | _        t	        |d      | _        d | _        g | _        t        |j                        D ]H  }t        |||j                  vr|j                  ndd|       }| j                  j                  |       J t        |d      | _        y )Npatch_embedri   r   r   )r  r:   neckr)   )r;   r<   rI   r=   r4   r  	pos_embedrF   r   r   r  global_attn_indexesr  r   r  r  )rH   rI   rJ   r   r   rK   s        r+   r<   zTFSamVisionEncoder.__init__  s    "6" ++/]Kv//0 	&A$236;U;U2UF..[\ _E
 KKu%	& $F8	r*   c                   | j                   ry d| _         | j                  j                  r| j                  d| j                  j                  | j                  j
                  z  | j                  j                  | j                  j
                  z  | j                  j                  gddd      | _        t        | dd       Mt        j                  | j                  j                        5  | j                  j                  d        d d d        t        | dd       Mt        j                  | j                  j                        5  | j                  j                  d        d d d        | j                  D ];  }t        j                  |j                        5  |j                  d        d d d        = y # 1 sw Y   xY w# 1 sw Y   `xY w# 1 sw Y   `xY w)NTr   r   r  )r   r   r  r:   r  r  )r\   rI   use_abs_posr   r=   r>   r@   r  r]   rS   r^   r  r:   r_   r  rF   r   s      r+   r_   zTFSamVisionEncoder.build  sp   ::
;;""!__KK**dkk.D.DDKK**dkk.D.DDKK++	 $  - 
DN 4-9t//445 -  &&t,-4&2tyy~~. &		%&[[ 	"Euzz* "D!" "	"- -& &" "s$   )F2F>G
2F;>G
G	c                    | j                   S rb   )r  rH   s    r+   get_input_embeddingsz'TFSamVisionEncoder.get_input_embeddings  s    r*   c                &   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|t	        d      | j                  |      }| j                  || j                  z   }|rdnd }|rdnd }t        | j                        D ])  \  }	}
|r||fz   } |
|||      }|d   }|s!||d   fz   }+ |r||fz   }| j                  |      }|s|f}|r||fz   }|r||fz   }|S t        |||      S )Nz You have to specify pixel_valuesr)   )r   r  r   r   )r    r"   r#   )rI   r   r   r   rR   r  r  	enumeraterF   r  r   )rH   rU   r   r   r   r  r"   all_hidden_statesall_self_attentionsr   layer_modulelayer_outputsr   s                r+   rZ   zTFSamVisionEncoder.call  sh    2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B]?@@((6>>%)DNN:M"6BD$5b4(5 		POA|#$58H$H!(J[fnoM)!,M &9]1=M<O&O#		P   1]4D D		-0$&G#!%6$88 !%8$::N'++*
 	
r*   r  rb   )NNNNF)rU   r   r   r   r   r   r   r   r  r   rt   z&Union[Tuple, TFSamVisionEncoderOutput])r$   r%   r&   r<   r_   r  rZ   rc   rd   s   @r+   r  r    sb    9("8 
 *.,0/3&*#(4
&4
 *4
 -	4

 $4
 !4
 
04
r*   r  c                      e Zd ZeZdZdZy)TFSamPreTrainedModelsamrU   N)r$   r%   r&   r   config_classbase_model_prefixmain_input_namer)   r*   r+   r  r  $  s    L$Or*   r  aF  
    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
    subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
    general usage and behavior.

    Parameters:
        config ([`SamConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
a6  
    Args:
        pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
            details.
        input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
            Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
            better results. The points can be obtained by passing a list of list of list to the processor that will
            create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
            dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
            input point), the third dimension is the number of points per segmentation mask (it is possible to pass
            multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
            coordinates of the point. If a different number of points is passed either for each image, or for each
            mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
            computation of the embedding will be skipped for these points using the labels.
        input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
            Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
            official implementation, there are 3 types of labels

            - `1`: the point is a point that contains the object of interest
            - `0`: the point is a point that does not contain the object of interest
            - `-1`: the point corresponds to the background

            We added the label:

            - `-10`: the point is a padding point, thus should be ignored by the prompt encoder

            The padding labels should be automatically done by the processor.
        input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
            Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
            much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
            that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
            the number of boxes per image and the coordinates of the top left and botton right point of the box. In the
            order (`x1`, `y1`, `x2`, `y2`):

            - `x1`: the x coordinate of the top left point of the input box
            - `y1`: the y coordinate of the top left point of the input box
            - `x2`: the x coordinate of the bottom right point of the input box
            - `y2`: the y coordinate of the bottom right point of the input box

        input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
            SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
            generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
            manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).

        image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
            Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
            efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
            method, and then feed them to the `call` method instead of feeding the `pixel_values`.
        multimask_output (`bool`, *optional*):
            In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
            bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
            "best" mask, by specifying `multimask_output=False`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zYSegment Anything Model (SAM) for generating segmentation masks, given an input image and z) optional 2D location and bounding boxes.c                       e Zd ZdgZ fdZd Zd Z	 	 	 d
	 	 	 	 	 ddZ	 	 	 	 d	 	 	 	 	 	 	 ddZe	 e
e      	 	 	 	 	 	 	 	 	 	 	 d	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 dd              ZddZdd	Z xZS )
TFSamModelz4prompt_encoder.shared_embedding.positional_embeddingc                *   t        |   |fi | t        |j                  d      | _        t        |j                  d      | _        t        |j                  | j                  d      | _	        t        |j                  d      | _        || _        y )Nshared_image_embeddingri   vision_encoderprompt_encodermask_decoder)r;   r<   r-  vision_configr  r  r  rP  prompt_encoder_configr  r   mask_decoder_configr  rI   ro   s      r+   r<   zTFSamModel.__init__  s    *6*&>v?S?SZr&s#01E1EL\]0(($*E*EL\
 -V-G-Gn]r*   c                6    | j                   j                         S rb   )r  r  r  s    r+   r  zTFSamModel.get_input_embeddings  s    ""7799r*   c                   | j                   j                  j                  }t        j                  ||f      }t        j
                  j                  |d      dz
  }t        j
                  j                  |d      dz
  }||z  }||z  }| j                  t        j                  ||gd            }t        j                  t        j                  |g d      d      S )Nr   r   r]  r   r   )rO   r   r   rP   )rI   r  rU  rS   r   r   cumsumr  r  r  rT   )rH   r  gridy_embedx_embedr1  s         r+   $get_image_wide_positional_embeddingsz/TFSamModel.get_image_wide_positional_embeddings  s    {{00EEwwd|$''..A..4''..A..4D.D.#::288WgDV]_;`a~~bll+?iPWXYYr*   c                :    | j                  ||||      }|d   }|S )a  
        Returns the image embeddings by passing the pixel values through the vision encoder.

        Args:
            pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
                Input pixel values
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.

        )r   r   r   r   )r  )rH   rU   r   r   r   vision_outputr   s          r+   get_image_embeddingszTFSamModel.get_image_embeddings  s8    * ++/!5#	 , 
 )+r*   c                0    | j                  ||||      }|S )a  
        Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.

        Args:
            input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
                Optional input points for the prompt encoder. The padding of the point is automatically done by the
                processor. `point_batch_size` refers to the number of masks that we want the model to predict per
                point. The model will output `point_batch_size` times 3 masks in total.
            input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
                Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
                processor, or can be fed by the user.
            input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
                Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
                processor. users can also pass manually the input boxes.
            input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
                Optional input masks for the prompt encoder.
        )rp  rq  rr  rs  )r  )rH   rp  rq  rr  rs  prompt_outputs         r+   get_prompt_embeddingsz TFSamModel.get_prompt_embeddings  s-    0 ++%%##	 , 
 r*   c                   ||n| j                   j                  }|	|	n| j                   j                  }	|
|
n| j                   j                  }
||t	        d      ||t	        d      |=t        |j                        dk7  r%t	        ddj                  |j                              |=t        |j                        dk7  r%t	        ddj                  |j                              |>|<t        |      d   }t        |      d   }||k7  rt	        d	j                  ||            |tt        j                  |d | j                   j                  j                  | j                   j                  j                  | j                   j                  j                  g      }| j                         }|t        |      d
   nt        |      d
   }t        j                  ||d
      }d }d }|)| j!                  |||	d|      }|d   }|	r|d   }|r|d   }|4|2t        j"                  |d d d d d d d
f   t        j$                        }|X|j                  d
   |j                  d
   k7  r9t	        ddj                  |j                  d
   |j                  d
         ddd      | j'                  t        |      d
   ||||      \  }}| j)                  ||||||      \  }}}|
s||f}|	r||fz   }|r|||fz   }|S t+        |||||      S )Nz9Either pixel_values or image_embeddings must be provided.z>Only one of pixel_values and image_embeddings can be provided.r   zlThe input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.z got {}.r   zMThe input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.r   zQYou should provide as many bounding boxes as input points per box. Got {} and {}.r   r   T)r   r   r   r  r    r"   r#   r^  zNThe batch size of the image embeddings and the input points must be the same. zGot {} and {} respectively.zS if you want to pass multiple points for the same image, make sure that you passed zS input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and zK input_labels of shape (batch_size, point_batch_size, num_points_per_image))rV   rp  rq  rr  rs  )r   r   r  r  r  r   r.   r/   r0   r1   r2   )rI   r   r   r   rR   lenr   formatr   rS   ensure_shaper  r?   r=   r  r  r  	ones_liker  r  r  r-   )rH   rU   rp  rq  rr  rs  r   r  r   r   r   r  rJ   r   box_batch_sizer   rV   r1   r0   vision_outputsrt  rL  low_res_masksiou_predictionsr2   outputs                             r+   rZ   zTFSamModel.call  s   " 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B]$4$<XYY#(8(D]^^#L,>,>(?1(D~!!,"4"45  "s;+<+<'='B_!!+"3"34  #(?),7:'4Q7N>1 gnn(. 
 #??KK--::KK--88KK--88	L '+&O&O&Q#4@4LZ-a0R\]mRnopRq
&(ii0KZ^_&`# ##!00"3%9 ! 1 N  ..AB#'5o'F$ $2<$@!#(<<<Q1aZ(@QL#(8(>(>q(A\EWEWXYEZ(Z`-445E5K5KA5NP\PbPbcdPefee]  /3.A.A!"23A6%%## /B /
++ CGBSBS-(C%6$4-/ CT C
?(? %}5F##7"99 #46M"NNM+&$!5/$;
 	
r*   c                   | j                   j                  rt        j                  |j                        nd }| j                   j
                  rt        j                  |j                        nd }t        |j                  |j                  | j                   j                  r|nd | j                   j
                  r|nd | j                   j
                  r|j                        S d       S )Nr  )rI   r   rS   convert_to_tensorr0   r   r1   r-   r.   r/   r2   )rH   r  hsattnss       r+   serving_outputzTFSamModel.serving_outputY  s    BF++BbBbR!!&"="=>hlBF++B_B_$$V%=%=>ei+(((('+{{'G'GT'+{{'D'De$FJkkFcFcF$B$B
 	

 jn
 	
r*   c                `   | j                   ry d| _         t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       Mt        j                  | j                  j
                        5  | j                  j                  d        d d d        t        | dd       Nt        j                  | j                  j
                        5  | j                  j                  d        d d d        y y # 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   ~xY w# 1 sw Y   y xY w)NTr  r  r  r  )
r\   r]   rS   r^   r  r:   r_   r  r  r  r`   s     r+   r_   zTFSamModel.builde  s`   ::
4148Dt::??@ 8++11$784)40<t22778 0##))$/04)40<t22778 0##))$/04.:t00556 .!!''-. . ;8 80 00 0. .s0   E?%F?FF$?F	FF!$F-r   )r   r   r   r   r   r   )NNNN)rp  r   rq  r   rr  r   rs  r   )NNNNNNTNNNF)rU   zTFModelInputType | Nonerp  r   rq  r   rr  r   rs  r   r   r   r  r   r   bool | Noner   r  r   r  r  r   rt   z/TFSamImageSegmentationOutput | Tuple[tf.Tensor])r  r-   rt   r-   rb   )r$   r%   r&   _keys_to_ignore_on_load_missingr<   r  r  r  r  r   r   SAM_INPUTS_DOCSTRINGrZ   r  r_   rc   rd   s   @r+   r  r  z  s]    (_&_#	:	Z -1/3&*  *  -	 
 $ @ *.)-(,(,& ' &	
 &@ *+?@ 15)-)-(,(,-1!%)-,0#'|
-|
 '|
 '	|

 &|
 &|
 +|
 |
 '|
 *|
 !|
 |
 
9|
 A |
|

.r*   r  )Ar'   
__future__r   rB   dataclassesr   typingr   r   r   numpyr=  
tensorflowrS   activations_tfr	   modeling_tf_outputsr
   modeling_tf_utilsr   r   r   r   r   tf_utilsr   r   utilsr   r   r   r   configuration_samr   r   r   r   
get_loggerr$   logger_CONFIG_FOR_DOC_CHECKPOINT_FOR_DOCr   r-   rF   Layerr4   rf   rw   r   r   r   r   r   r-  rD  rP  rw  r  r  r  r  SAM_START_DOCSTRINGr  r  r)   r*   r+   <module>r#     sQ  
 #  ! ) )   $ 4 f f 5 f f g g 
		H	%-  4{ 4 4:  A;  A  AF*M5<<-- *MZCELL&& C4U\\'' 6PEU\\'' PEfr; 2 2 r;jN"U\\// N"b)?u||)) )?XQu||)) Qh#Nu||11 #NL'U++ 'UTK3++ K3\b5<<-- bJ[%u||)) [%|--ell(( --`h
++ h
V%, %  = @ _/
u.% u.
u.r*   