a
    öDfNT                     @   s   d dl Z d dlZd dlZd dlmZ ddlmZ	 dd Z
G dd dZG dd	 d	Zd
d ZG dd dZdd Zdd Zdd Zdd Zdd Zdd Ze Zg dZG dd dZG dd dZdS )    N   )interpreterc                 C   s   t }| d dkr*t| dd  }||S |j|j|j|j|j|j|j	|j
|j|j|j|j|j|j|j|j|j|jd}||  S )Nr   *   )Zfp8e4nvZfp8e5Zfp8e4b15Z
fp8e4b15x4Zfp16Zbf16Zfp32Zfp64i1i8Zi16Zi32Zi64u8u16u32Zu64B)tl	str_to_typointer_typeZ
float8e4nvZfloat8e5Zfloat8e4b15Zfloat8e4b15x4float16Zbfloat16float32float64int1int8int16int32int64uint8uint16uint32uint64)namelanguagetyZtys r   g/nfs/NAS7/SABIOD/METHODE/ermites/ermites_venv/lib/python3.9/site-packages/triton/runtime/interpreter.pyr      s0    
r   c                   @   s   e Zd Zdd Zdd ZdS )TensorHandlec                 C   s   || _ || _d S N)datadtype)selfr"   r#   r   r   r   __init__)   s    zTensorHandle.__init__c                 C   s   t | j S r!   )boolr"   allr$   r   r   r   __bool__-   s    zTensorHandle.__bool__N)__name__
__module____qualname__r%   r)   r   r   r   r   r    '   s   r    c                   @   s   e Zd Zdd Zdd ZdS )BlockPointerHandlec                 C   s(   || _ || _|| _|| _|| _|| _d S r!   )baseshapestridesoffsetstensor_shapeorderr$   r.   r/   r0   r1   r2   r3   r   r   r   r%   3   s    zBlockPointerHandle.__init__c           
      C   s   | j jj}|jd }| j}t| j j| j}tj| jt	d}t
t|D ]~}dgt| }|| ||< | j| jt||  |}	|||	 | j| j tj }||v rHt||	| j| jk }qHt|| j j}||fS )N   r#   r   )r.   r#   
element_typrimitive_bitwidthr2   npbroadcast_tor"   Zonesr&   rangelenr1   arangereshaper0   astyper   logical_andr/   r    )
r$   boundary_checkdtype_ttZn_bytesr2   ptrsmasksZdimZ
bcast_dimsoffr   r   r   materialize_pointers;   s    

  z'BlockPointerHandle.materialize_pointersN)r*   r+   r,   r%   rF   r   r   r   r   r-   1   s   r-   c                    s    fdd}|S )Nc                    s    fdd}|S )Nc                     s$   | i |}t |j | i |S r!   )r    r"   )argskwargsret)compute_ret_tyfnr   r   wrappedP   s    z*wrap_ret.<locals>.wrapper.<locals>.wrappedr   )rK   rL   rJ   )rK   r   wrapperN   s    zwrap_ret.<locals>.wrapperr   )rJ   rN   r   rM   r   wrap_retL   s    rO   c                   @   s:  e Zd ZddddZdd Zdd Zd	d
 Zdd Zdd Zdd Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zdd  Zd!d" Zd#d$ Zd%d& Zd'd( Zd)d* Zd+d, Zd-d. Zd/d. Zd0d. Zd1d. Zd2d. Zd3d. Zd4d. Zd5d6 Zd7d8 Z d9d: Z!d;d. Z"d<d. Z#d=d. Z$d>d. Z%d?d. Z&d@d. Z'dAd. Z(dBd. Z)dCd. Z*dDd. Z+dEd. Z,dFd. Z-dGd. Z.dHd. Z/dId. Z0dJd. Z1dKd. Z2dLd. Z3dMd. Z4dNd. Z5dOd. Z6dPd. Z7dQd. Z8dRd. Z9dSd. Z:dTd. Z;dUd. Z<dVd. Z=dWd. Z>dXd. Z?dYd. Z@dZd. ZAd[d. ZBd\d. ZCd]d. ZDd^d. ZEd_d. ZFd`d. ZGdad. ZHdbd. ZIdcd. ZJddd. ZKded. ZLdfd. ZMdgd. ZNdhd. ZOdidj ZPdkd. ZQdldm ZRdnd. ZSdod. ZTdpd. ZUdqd. ZVdrd. ZWdsd. ZXdtd. ZYdud. ZZdvd. Z[dwd. Z\dxdy ZZdzd{ Z]d|d} Z^d~d Z_dd Z`dd Zadd Zbdd Zcdd Zddd Zedd ZfdS )BuilderNreturnc                 C   s
   d | _ d S r!   )archr(   r   r   r   r%   [   s    zBuilder.__init__c                 C   sF   || j d k sJ || j d k s$J || j d k s6J |||f| _d S )Nr   r   r   )grid_dimgrid_idx)r$   xyzr   r   r   set_grid_idx_   s    zBuilder.set_grid_idxc                 C   s   |||f| _ d S r!   )rT   )r$   ZnxnyZnzr   r   r   set_grid_dime   s    zBuilder.set_grid_dimc                 C   s   t |tjrttjS tjttjtjttjtjttjtj	ttj	tj
ttj
tjttjtjttjtjttjtjttjtjttjtjttji}|| S r!   )
isinstancer   r   r9   r#   r   r   r   r   r   r   r   r   r   r   r   )r$   Ztt_dtypeZnp_typesr   r   r   np_dtypeh   s    zBuilder.np_dtypec                 C   s   t jS r!   )r   r   r(   r   r   r   get_half_ty{   s    zBuilder.get_half_tyc                 C   s   t jS r!   )r   r   r(   r   r   r   get_float_ty~   s    zBuilder.get_float_tyc                 C   s   t jS r!   )r   r   r(   r   r   r   get_int64_ty   s    zBuilder.get_int64_tyc                 C   s   t ||S r!   )r   r   )r$   Zelt_tyZ
addr_spacer   r   r   
get_ptr_ty   s    zBuilder.get_ptr_tyc                 C   s   t ||S r!   )r   tensor)r$   r#   r/   r   r   r   get_block_ty   s    zBuilder.get_block_tyc                 C   s   t tj|gtjdtjS Nr6   )r    r9   arrayr   r   r$   valuer   r   r   	get_int32   s    zBuilder.get_int32c                 C   s   t tj|gtjdtjS rd   )r    r9   re   r   r   rf   r   r   r   	get_int64   s    zBuilder.get_int64c                 C   s   t tj|gtjdtjS rd   )r    r9   re   r   r   rf   r   r   r   get_fp16   s    zBuilder.get_fp16c                 C   s   t tj|gtjdtjS rd   )r    r9   re   r   r   rf   r   r   r   get_fp32   s    zBuilder.get_fp32c                 C   s   t tjdg| |d|S Nr   r6   )r    r9   re   r]   )r$   typer   r   r   get_null_value   s    zBuilder.get_null_valuec                 C   s.   | j d usJ ttj| j | gtjdtjS rd   )rU   r    r9   re   r   r   r$   axisr   r   r   create_get_program_id   s    zBuilder.create_get_program_idc                 C   s    t tj| j| gtjdtjS rd   )r    r9   re   rT   r   r   ro   r   r   r   create_get_num_programs   s    zBuilder.create_get_num_programsc                 C   s0   t tj|jtdtj}d }| ||||||S rd   )r    r9   	ones_liker"   r&   r   r   create_masked_load)r$   ptr_0_1is_volatilemaskotherr   r   r   create_load   s    zBuilder.create_loadc                 C   s*   t tj|jtdtj}| |||d d S rd   )r    r9   rs   r"   r&   r   r   create_masked_store)r$   ru   valrv   rw   ry   r   r   r   create_store   s    zBuilder.create_storec           
      C   sP   |j j}| |}|d u r0ttj|j|d|}t|j|j|j|}	t|	|S rd   )	r#   r7   r]   r    r9   rs   r"   _interpreterload)
r$   rC   ry   rz   cache_modifiereviction_policyrx   rB   Zdtype_nprI   r   r   r   rt      s    
zBuilder.create_masked_loadc                 C   s   t |j|j|jS r!   )r   storer"   )r$   rC   rg   ry   r   r   r   r   r   r|      s    zBuilder.create_masked_storec                 C   s*   t |tjr|j}t|j| ||S r!   )r\   r   rb   r#   r    r"   r?   r]   r$   srcdst_typer   r   r   	cast_impl   s    zBuilder.cast_implc                 C   s   |  ||S r!   r   r   r   r   r   <lambda>       zBuilder.<lambda>c                 C   s   |  ||S r!   r   r   r   r   r   r      r   c                 C   s   |  ||S r!   r   r   r   r   r   r      r   c                 C   s   |  ||S r!   r   r   r   r   r   r      r   c                 C   s   |  ||S r!   r   r   r   r   r   r      r   c                 C   s   |  ||S r!   r   r   r   r   r   r      r   c                 C   s   |  ||S r!   r   )r$   r   r   	is_signedr   r   r   r      r   c                 C   s   dsJ d S )Nzfloat8 not NotImplemented yetr   r   r   r   r   create_fp_to_fp   s    zBuilder.create_fp_to_fpc                 C   s   t |j| ||S r!   )r    r"   viewr]   r   r   r   r   create_bitcast   s    zBuilder.create_bitcastc                 C   s   t ||j|j|jS r!   r    r"   r#   )r$   lhsrhsopr   r   r   	binary_op   s    zBuilder.binary_opc                 C   s   |  ||tjS r!   r   r9   addr$   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   multiplyr   r   r   r   r      r   c                 C   s   |  ||tjS r!   )r   r9   divider   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   	remainderr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   subtractr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   Zfloor_divider   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   )r   r9   Z
left_shiftr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   Zright_shiftr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   Zminimumr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   maximumr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   Z
less_equalr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   Zlessr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   Zgreater_equalr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   Zgreaterr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   equalr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r9   	not_equalr   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   r   r   r   r   r   r      r   c                 C   s   |  ||tjS r!   )r   r9   Zbitwise_andr   r   r   r   r      r   c                 C   s   |  ||tjS r!   )r   r9   Zbitwise_xorr   r   r   r   r      r   c                 C   s   |  ||tjS r!   )r   r9   Z
bitwise_orr   r   r   r   r      r   c                 C   s   t ||j|j|j|jS r!   r   )r$   r   r   rz   r   r   r   r   
ternary_op   s    zBuilder.ternary_opc                 C   s   |  |||tjS r!   )r   r9   where)r$   Zcondr   r   r   r   r   r     r   c                 C   s   t ||j|jS r!   r   )r$   argr   r   r   r   unary_op  s    zBuilder.unary_opc                 C   s   |  |tjS r!   )r   r9   expr$   r   r   r   r   r     r   c                 C   s   |  |tjS r!   )r   r9   cosr   r   r   r   r     r   c                 C   s   |  |tjS r!   )r   r9   sinr   r   r   r   r   	  r   c                 C   s   |  |tjS r!   )r   r9   logr   r   r   r   r   
  r   c                 C   s   |  |tjS r!   )r   r9   sqrtr   r   r   r   r     r   c                 C   s   |  |tjS r!   r   r9   absr   r   r   r   r     r   c                 C   s   |  |tjS r!   r   r   r   r   r   r     r   c                 C   s   |  ||tjS r!   )r   r9   dotr   r   r   r   r     r   c                 C   s   t |j||jS r!   )r    r"   r>   r#   )r$   r   r/   ZallowReorderr   r   r   r     r   c                 C   s   |  |tjS r!   )r   r9   Z	transposer   r   r   r   r     r   c                 C   s   t t|j|j|j |jS r!   )r    r9   r   r"   r#   )r$   abdZ
allow_tf32ZmaxNumImpreciseAccr   r   r   
create_dot  s    zBuilder.create_dotc                 C   s   t tj||tjdtjS rd   )r    r9   r=   r   r   )r$   startstopr   r   r   create_make_range  s    zBuilder.create_make_rangec                 C   s.   |j j}t|j|jd |jtj  |j S )Nr5   )r#   r7   r    r"   r8   r?   r9   r   )r$   ru   offsetrB   r   r   r   create_addptr  s    zBuilder.create_addptrc           
      C   s2   | |\}}|d u sJ d }	| |||	|||S r!   )rF   rt   )
r$   ru   rA   Zpadding_optionr   r   rx   rC   rD   rz   r   r   r   create_tensor_pointer_load   s    z"Builder.create_tensor_pointer_loadc                 C   s    | |\}}| |||||S r!   )rF   r|   )r$   ru   rg   rA   r   r   rC   rD   r   r   r   create_tensor_pointer_store'  s    z#Builder.create_tensor_pointer_storec                 C   s   t t|j||jS r!   )r    r9   Zexpand_dimsr"   r#   )r$   r   rp   r   r   r   create_expand_dims+  s    zBuilder.create_expand_dimsc                 C   s   t t|j||jS r!   )r    r9   r:   r"   r#   r$   r   r/   r   r   r   create_broadcast.  s    zBuilder.create_broadcastc                 C   s   t |jtj|S r!   )r    r"   r?   r9   r   )r$   r}   Zdst_tyr   r   r   create_int_to_ptr1  s    zBuilder.create_int_to_ptrc                 C   s&   t tj||jd | |jd|jS rl   )r    r9   fullr"   r]   r#   r   r   r   r   create_splat:  s    zBuilder.create_splatc                 C   s   t |||t|||S r!   )r-   r9   re   r4   r   r   r   create_make_block_ptrg  s    zBuilder.create_make_block_ptrc                 C   sd   t |jt |ksJ t|j|j|j|j|j|j}tt |D ]}|j|  j	|| j	7  _	q@|S r!   )
r<   r1   r-   r.   r/   r0   r2   r3   r;   r"   )r$   ru   r1   rI   ir   r   r   create_advancej  s
    zBuilder.create_advance)gr*   r+   r,   r%   rY   r[   r]   r^   r_   r`   ra   rc   rh   ri   rj   rk   rn   rq   rr   r{   r~   rt   r|   r   Zcreate_si_to_fpZcreate_ui_to_fpZcreate_fp_to_siZcreate_fp_to_uiZcreate_fp_extZcreate_fp_truncZcreate_int_castr   r   r   Zcreate_faddZcreate_fmulZcreate_fdivZcreate_fremZcreate_fsubZ
create_mulZcreate_sdivZcreate_udivZcreate_sremZcreate_uremZ
create_addZ
create_subZ
create_shlZcreate_lshrZcreate_ashrZcreate_minsiZcreate_minuiZcreate_minfZcreate_maxsiZcreate_maxuiZcreate_maxfZcreate_icmpSLEZcreate_icmpSLTZcreate_icmpSGEZcreate_icmpSGTZcreate_icmpULEZcreate_icmpULTZcreate_icmpUGEZcreate_icmpUGTZcreate_icmpEQZcreate_icmpNEZcreate_fcmpOLTZcreate_fcmpOGTZcreate_fcmpOLEZcreate_fcmpOGEZcreate_fcmpOEQZcreate_fcmpONEZcreate_fcmpULTZcreate_fcmpUGTZcreate_fcmpULEZcreate_fcmpUGEZcreate_fcmpUEQZcreate_fcmpUNEZ
create_andZ
create_xorZ	create_orr   Zcreate_selectr   Z
create_expZ
create_cosZ
create_sinZ
create_logZcreate_sqrtZcreate_fabsZcreate_iabsr   Zcreate_reshapeZcreate_transr   r   r   r   r   r   r   r   r   r   r   r   r   r   rP   Y   s   	-rP   c                    s"   |d fdd
}t | || d S )N)memberc                    s$   | |i dd |  D d iS )Nc                 S   s   i | ]\}}|d kr||qS )_builderr   .0kvr   r   r   
<dictcomp>t  s   z0patch_attr.<locals>.<lambda>.<locals>.<dictcomp>r   )items)r   rG   rH   builderr   r   r   s  s   zpatch_attr.<locals>.<lambda>)setattr)objr   r   r   
new_memberr   r   r   
patch_attrr  s    r   c                 C   sZ   t | D ]"\}}tj|r
t| ||| q
dd | _dd | _dd | _dd | _	d S )Nc                 S   s   t | jjS r!   )inthandler"   r(   r   r   r   r   ~  r   z$_patch_lang_tensor.<locals>.<lambda>c                 S   s   dS )NTr   r(   r   r   r   r     r   c                 S   s   t | jjS r!   )strr   r"   r(   r   r   r   r     r   c                 S   s   | j j|S r!   )r   r"   __getitem__)r$   Zslicesr   r   r   r     r   )
inspect
getmembersr   core
is_builtinr   	__index__r)   __str__r   )rb   r   r   r   r   r   r   _patch_lang_tensorz  s    


r   c                 C   s@   t | D ]"\}}tj|r
t| ||| q
dd }|| _d S )Nc                 S   sP   |j j}tjtjd}|| | jj|d}t| j	|j
}tjt|| j	|S )N)r   Z_sum_combine)rp   )rK   r*   r9   maxsumr   r"   r   Z
block_typer#   r/   r   rb   r    )inputrp   Z
combine_fnrK   mappingrI   ret_typer   r   r   _new_reduce  s    z%_patch_lang_core.<locals>._new_reduce)r   r   r   r   r   r   reduce)langr   r   r   r   r   r   r   _patch_lang_core  s
    
r   c                    sn   | j }ddddddd  fdd	}d
d }t|D ]2\}}| v rXt|||| q6t|||| q6d S )Nr   ZarccosZarcsinexp2log2r   )r   acosasinr   r   r   c                    s    fdd}|S )Nc                     sd   | d j }| d j}dd | D } dd | D }tt  | i |}tjt|||}|S )Nr   c                 S   s   g | ]}|j jqS r   r   r"   r   r   r   r   r   
<listcomp>  r   zF_patch_lang_math.<locals>.make_numpy.<locals>.impl.<locals>.<listcomp>c                 S   s   i | ]\}}||j jqS r   r   r   r   r   r   r     r   zF_patch_lang_math.<locals>.make_numpy.<locals>.impl.<locals>.<dictcomp>)	rm   r#   r   getattrr9   r   r   rb   r    )rG   rH   r   Z	ret_dtyperI   )r   r   r   r   impl  s    

z2_patch_lang_math.<locals>.make_numpy.<locals>.implr   )r   r   r   r   r   
make_numpy  s    	z$_patch_lang_math.<locals>.make_numpyc                    s    fdd}|S )Nc                     s   t d  d  dd S )N
zU not supported in interpreter mode: no known numpy implementation.
If you think that z in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
)NotImplementedError)rG   rH   r   r   r   fallback  s
    z9_patch_lang_math.<locals>.make_fallback.<locals>.fallbackr   )r   r   r   r   r   make_fallback  s    z'_patch_lang_math.<locals>.make_fallback)mathr   r   r   )r   r   r   r   r   r   r   r   r   r   _patch_lang_math  s    	r   c                 C   s   t | trNttjjjtjjj| }t	t
j| gt
jd|}t||S t| drttjjjtjjj| }t	t
j|  gt
jd|}t||S | S )Nr6   data_ptr)r\   r   r   tritonZruntimejitZJITFunctionZ_type_ofZ_key_ofr    r9   re   r   r   rb   hasattrr   r   )r   r   r   r   r   r   _implicit_cvt  s    
 
 r  c                 C   s   t | tjr| jS | S r!   )r\   r  ZTensorWrapperr.   )rb   r   r   r   _unwrap  s    r  )Z	num_warpsZ
num_stagesZnum_ctasZenable_warp_specializationZenable_fp_fusionc                   @   s$   e Zd Zdd Zdd Zdd ZdS )GridExecutorc                    sN   ddl m || _|| _|| _fdd|j D   fdd|D | _d S )Nr   _normalize_tyc                    s   i | ]\}}| |qS r   r   )r   r   r   r  r   r   r     r   z)GridExecutor.__init__.<locals>.<dictcomp>c                    s   g | ]}  |d kr|qS )Z	constexpr)get)r   r   )__annotations__r   r   r     r   z)GridExecutor.__init__.<locals>.<listcomp>)r  r  rK   	arg_namesgridr
  r   
constexprs)r$   rK   r  r  r   )r
  r  r   r%     s    zGridExecutor.__init__c                 C   s^   dd | j j D }t|dks*J dtt|d d| t|d | t|d | d S )Nc                 S   s"   g | ]\}}|t t jfv r|qS r   r   r   r   _rg   r   r   r   r     r   z,GridExecutor._patch_lang.<locals>.<listcomp>r   :triton.language must be visible from within jit'd functionr   rb   )rK   __globals__r   r<   r   r   r   r   r$   r   r   r   r   r   _patch_lang  s
    zGridExecutor._patch_langc                    s2  dd |D }dd |  D } t tj jg|R i |} fdd|  D }t jrn |n j}t|dksJ |ddt|   }tj	|  t
|d D ]F}t
|d	 D ]4}t
|d
 D ]"}t|||  jf i | qqqt||D ],\}	}
t|	dr t|	|
|	j q d S )Nc                 S   s&   g | ]}t |d rt| n|qS )r   )r  r  cpur   r   r   r   r     r   z)GridExecutor.__call__.<locals>.<listcomp>c                 S   s   i | ]\}}|t vr||qS r   RESERVED_KWSr   r   r   r   r     r   z)GridExecutor.__call__.<locals>.<dictcomp>c                    s(   i | ] \}}|| j v r|nt|qS r   )r  r  )r   r   r   r(   r   r   r     r      )r   r   r   r   r   )r   r  r   r   getcallargsrK   callabler  r<   r[   r;   rY   zipr  r  Zcopy_toZdevice)r$   Zargs_devrH   Zargs_hstrG   r  rV   rW   rX   Zarg_devZarg_hstr   r(   r   __call__  s"    

zGridExecutor.__call__N)r*   r+   r,   r%   r  r  r   r   r   r   r    s   	r  c                   @   s2   e Zd Zdd ZddddZdd Zd	d
 ZdS )InterpretedFunctionc                 C   sP   dd | j j D }t|dks*J dtt|d d| t|d | d S )Nc                 S   s"   g | ]\}}|t t jfv r|qS r   r  r  r   r   r   r   	  r   z3InterpretedFunction._patch_lang.<locals>.<listcomp>r   r  r   rb   )rK   r  r   r<   r   r   r   r  r   r   r   r    s    zInterpretedFunction._patch_langNrQ   c                    s<   | _  fdd}| _t|}dd |j D  _d S )Nc                     s4   |d }dd |  D }t j j|| i |S )Nr  c                 S   s$   i | ]\}}|t d g vr||qS )r  r  r   r   r   r   r     r   z=InterpretedFunction.__init__.<locals>.run.<locals>.<dictcomp>)r   r  rK   r  )rG   rH   r  r(   r   r   run  s    z)InterpretedFunction.__init__.<locals>.runc                 S   s   g | ]
}|j qS r   r   )r   r   r   r   r   r     r   z0InterpretedFunction.__init__.<locals>.<listcomp>)rK   r  r   	signature
parametersvaluesr  )r$   rK   r  r   r   r(   r   r%     s
    
zInterpretedFunction.__init__c                 C   s   t | j| j|S r!   )r  rK   r  )r$   r  r   r   r   r     s    zInterpretedFunction.__getitem__c                 O   s   |  t | j|i |S r!   )r  r   rK   )r$   rG   rH   r   r   r   r    s    
zInterpretedFunction.__call__)r*   r+   r,   r  r%   r   r  r   r   r   r   r    s   r  )r   numpyr9   r  Ztriton.languager   r   Z_C.libtriton.tritonr   r   r   r    r-   rO   rP   r   r   r   r   r  r  r   r  r  r  r   r   r   r   <module>   s*   
  
+,