@@ -3031,6 +3031,54 @@ static Index ad_var_memop_remap(Index index, bool input) {
30313031
30323032// ==========================================================================
30333033
3034+ uint64_t ad_var_tile (Index index, uint32_t count) {
3035+ JitVar result = JitVar::steal (jit_var_tile (jit_index (index), count));
3036+
3037+ if (likely (is_detached (index)))
3038+ return result.release ();
3039+ else {
3040+ VarInfo info = jit_set_backend (jit_index (index));
3041+
3042+ JitVar input_size_var = JitVar::steal (jit_var_literal (info.backend , VarType::UInt32, &info.size , 1 , 0 ));
3043+ JitVar offset = JitVar::steal (jit_var_counter (info.backend , result.size ()));
3044+ offset = JitVar::steal (jit_var_mod (offset.index (), input_size_var.index ()));
3045+
3046+ uint64_t one_u64 = 1 ;
3047+ JitVar mask = JitVar::steal (jit_var_literal (info.backend , VarType::Bool, &one_u64, 1 , 0 ));
3048+
3049+ return ad_var_new (" tile" , std::move (result),
3050+ SpecialArg (index, new Gather (
3051+ GenericArray<uint32_t >::borrow (offset.index ()),
3052+ JitMask::borrow (mask.index ()),
3053+ ReduceMode::Auto)));
3054+ }
3055+ }
3056+
3057+ uint64_t ad_var_repeat (Index index, uint32_t count, size_t max_size) {
3058+ JitVar result = JitVar::steal (jit_var_repeat (jit_index (index), count, max_size));
3059+
3060+ if (likely (is_detached (index)))
3061+ return result.release ();
3062+ else {
3063+ VarInfo info = jit_set_backend (jit_index (index));
3064+
3065+ JitVar offset = JitVar::steal (jit_var_counter (info.backend , result.size ()));
3066+ JitVar divisor = JitVar::steal (jit_var_literal (info.backend , VarType::UInt32, &count, 1 , 0 ));
3067+ offset = JitVar::steal (jit_var_div (offset.index (), divisor.index ()));
3068+
3069+ uint64_t one_u64 = 1 ;
3070+ JitVar mask = JitVar::steal (jit_var_literal (info.backend , VarType::Bool, &one_u64, 1 , 0 ));
3071+
3072+ return ad_var_new (" repeat" , std::move (result),
3073+ SpecialArg (index, new Gather (
3074+ GenericArray<uint32_t >::borrow (offset.index ()),
3075+ JitMask::borrow (mask.index ()),
3076+ ReduceMode::Auto)));
3077+ }
3078+ }
3079+
3080+ // ==========================================================================
3081+
30343082uint64_t ad_var_gather (Index source, JitIndex offset, JitIndex mask, ReduceMode mode) {
30353083 JitVar result = JitVar::steal (jit_var_gather (jit_index (source), offset, mask));
30363084
0 commit comments