diff --git "a/debug/debug-phase4.py" "b/debug/debug-phase4.py" new file mode 100644--- /dev/null +++ "b/debug/debug-phase4.py" @@ -0,0 +1,10895 @@ +# from tvm.script import ir as I +# from tvm.script import tir as T +# from tvm.script import relax as R + +@I.ir_module +class Module: + I.module_attrs({"external_mods": [metadata["runtime.Module"][0], metadata["runtime.Module"][1], metadata["runtime.Module"][2], metadata["runtime.Module"][3], metadata["runtime.Module"][4], metadata["runtime.Module"][5], metadata["runtime.Module"][6], metadata["runtime.Module"][7], metadata["runtime.Module"][8], metadata["runtime.Module"][9], metadata["runtime.Module"][10], metadata["runtime.Module"][11], metadata["runtime.Module"][12], metadata["runtime.Module"][13], metadata["runtime.Module"][14]]}) + @T.prim_func(private=True) + def NT_matmul(layer_norm356: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_q_proj_weight5: T.Buffer((T.int64(1280), T.int64(1280)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + NT_matmul_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + NT_matmul_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + model_decoder_layers_0_self_attn_q_proj_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(1280)), "float16", scope="local") + layer_norm356_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_0 in T.serial(T.int64(3), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_3 in T.vectorized(T.int64(1)): + with T.block("layer_norm356_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) + T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(1280)) + T.reads(layer_norm356[v0, v1, v2]) + T.writes(layer_norm356_shared[v0, v1, v2]) + layer_norm356_shared[v0, v1, v2] = layer_norm356[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): + with T.block("model_decoder_layers_0_self_attn_q_proj_weight5_local"): + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) + v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) + T.reads(model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1]) + T.writes(model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1]) + model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1] = model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) + T.reads(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) + T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_2_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads() + T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] + for ax1_fused_2 in range(T.int64(1)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) + T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul[T.int64(0), T.int64(0), v0]) + with T.init(): + NT_matmul[T.int64(0), T.int64(0), v0] = T.float16(0) + NT_matmul[T.int64(0), T.int64(0), v0] = NT_matmul[T.int64(0), T.int64(0), v0] + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + + @T.prim_func(private=True) + def NT_matmul3(layer_norm452: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_embed_tokens_weight5: T.Buffer((T.int64(51866), T.int64(1280)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(51866)), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + NT_matmul_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(51866)), scope="local") + NT_matmul_rf_local_1 = T.alloc_buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(51866)), scope="local") + model_decoder_embed_tokens_weight5_local = T.alloc_buffer((T.int64(51866), T.int64(1280)), "float16", scope="local") + layer_norm452_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(12967), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_0 in T.serial(T.int64(5), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax2_2 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_3 in T.vectorized(T.int64(1)): + with T.block("layer_norm452_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(256) + ax2_1 * T.int64(64) + ax2_2 + ax2_3) + T.reads(layer_norm452[v0, v1, v2]) + T.writes(layer_norm452_shared[v0, v1, v2]) + layer_norm452_shared[v0, v1, v2] = layer_norm452[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init < T.int64(51866)) + T.reads() + T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float32(0) + for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_fused_0 in range(T.int64(2)): + for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): + with T.block("model_decoder_embed_tokens_weight5_local"): + v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1) + v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 < T.int64(51866)) + T.reads(model_decoder_embed_tokens_weight5[v0, v1]) + T.writes(model_decoder_embed_tokens_weight5_local[v0, v1]) + model_decoder_embed_tokens_weight5_local[v0, v1] = model_decoder_embed_tokens_weight5[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_fused_u_fused_2, vax1_fused_u_fused_0 = T.axis.remap("RR", [ax1_fused_u_fused_2, ax1_fused_u_fused_0]) + T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2 < T.int64(51866)) + T.reads(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm452_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused], model_decoder_embed_tokens_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) + T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + T.Cast("float32", layer_norm452_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) * T.Cast("float32", model_decoder_embed_tokens_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_2_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) + v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + (T.Mul(T.int64(0), T.int64(4)) + ax2_fused_0_ax2_fused_1_fused % T.int64(4) + (ax2_fused_2_0 + ax2_fused_2_1)) < T.int64(51866)) + T.reads() + T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float32(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + (T.Mul(T.int64(0), T.int64(4)) + ax2_fused_0_ax2_fused_1_fused % T.int64(4) + (ax2_fused_2_0 + ax2_fused_2_1)) < T.int64(51866)) + T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] + for ax1_fused_2 in range(T.int64(1)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) + v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) + T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + (T.Mul(T.int64(0), T.int64(4)) + ax1_fused_0_ax1_fused_1_fused % T.int64(4) + ax1_fused_2) < T.int64(51866)) + T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul[T.int64(0), T.int64(0), v0]) + with T.init(): + NT_matmul[T.int64(0), T.int64(0), v0] = T.float32(0) + NT_matmul[T.int64(0), T.int64(0), v0] = NT_matmul[T.int64(0), T.int64(0), v0] + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + + @T.prim_func(private=True) + def add(var_reshape708: T.handle, var_reshape709: T.handle, var_T_add: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + reshape708 = T.match_buffer(var_reshape708, (batch_size, T.int64(1), T.int64(1280)), "float16") + reshape709 = T.match_buffer(var_reshape709, (batch_size, T.int64(1), T.int64(1280)), "float16") + T_add = T.match_buffer(var_T_add, (batch_size, T.int64(1), T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_add"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) + T.reads(reshape708[v0, T.int64(0), v1], reshape709[v0, T.int64(0), v1]) + T.writes(T_add[v0, T.int64(0), v1]) + T_add[v0, T.int64(0), v1] = reshape708[v0, T.int64(0), v1] + reshape709[v0, T.int64(0), v1] + + @T.prim_func(private=True) + def add4(var_add: T.handle, var_lv610: T.handle, var_T_add: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + add = T.match_buffer(var_add, (batch_size, T.int64(1500), T.int64(1280)), "float16") + lv610 = T.match_buffer(var_lv610, (batch_size, T.int64(1500), T.int64(1280)), "float16") + T_add = T.match_buffer(var_T_add, (batch_size, T.int64(1500), T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_add"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) + v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) + v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) + T.reads(add[v0, v1, v2], lv610[v0, v1, v2]) + T.writes(T_add[v0, v1, v2]) + T_add[v0, v1, v2] = add[v0, v1, v2] + lv610[v0, v1, v2] + + @T.prim_func(private=True) + def add5(var_reshape385: T.handle, var_reshape386: T.handle, var_T_add: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + reshape385 = T.match_buffer(var_reshape385, (T.int64(1), seq_len, T.int64(1280)), "float16") + reshape386 = T.match_buffer(var_reshape386, (T.int64(1), seq_len, T.int64(1280)), "float16") + T_add = T.match_buffer(var_T_add, (T.int64(1), seq_len, T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_add"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < seq_len * T.int64(1280)) + T.reads(reshape385[T.int64(0), v0, v1], reshape386[T.int64(0), v0, v1]) + T.writes(T_add[T.int64(0), v0, v1]) + T_add[T.int64(0), v0, v1] = reshape385[T.int64(0), v0, v1] + reshape386[T.int64(0), v0, v1] + + @T.prim_func + def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size)) + num_seq = T.int32(is_size_var=True) + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") + # with T.block("root"): + for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"): + for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("block"): + vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size) + vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size) + T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size) + T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv]) + T.writes(logits[seq_ids[vs], vv]) + logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-3.4028234663852886e+38)) + + @T.prim_func + def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size)) + num_token = T.int32(is_size_var=True) + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + logit_bias = T.match_buffer(var_logit_bias, (num_token,)) + # with T.block("root"): + for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"): + for p1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp]) + T.writes(logits[pos2seq_id[vp], token_ids[vp]]) + logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp] + + @T.prim_func + def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size)) + num_seq = T.int32(is_size_var=True) + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + num_token = T.int32(is_size_var=True) + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") + penalties = T.match_buffer(var_penalties, (num_seq, 3)) + # with T.block("root"): + for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"): + for p1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp]) + T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1]) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > T.float32(0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2]) + + @T.prim_func(private=True) + def argsort_thrust(var_probs: T.handle, var_lv: T.handle, var_topk_gpu_v1: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int64(), T.int64() + data_buf = T.match_buffer(var_probs, (batch_size, vocab_size), align=8) + workspace_buf = T.match_buffer(var_lv, (T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12),), "uint8", align=8) + indices_buf = T.match_buffer(var_topk_gpu_v1, (batch_size, vocab_size), "int32", align=8) + # with T.block("root"): + value_buf = T.alloc_buffer((batch_size, vocab_size), align=8) + with T.block("topk_gpu"): + T.reads() + T.writes() + T.call_packed("tvm.contrib.thrust.sort", T.tvm_stack_make_array(data_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.tvm_stack_make_array(value_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.tvm_stack_make_array(indices_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, 0, T.int64(0)), 0, T.tvm_stack_make_array(workspace_buf.data, T.tvm_stack_make_shape(T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12)), 0, 1, T.uint8(0), T.int64(0))) + + @T.prim_func + def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + B = T.int32(is_size_var=True) + Q = T.match_buffer(Q_handle, (B, 20, 64), "float16") + max_num_pages = T.int32(is_size_var=True) + pages = T.match_buffer(pages_handle, (max_num_pages, 2, 20, 16, 64), "float16") + page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) + nnz_pages = T.int32(is_size_var=True) + page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) + length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1) + k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) + q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) + output = T.match_buffer(output_handle, (B, 20, 64), "float16") + lse = T.match_buffer(lse_handle, (B, 20)) + # with T.block("root"): + sm_scale: T.float32 = T.float32(0.18033688011112042) + for bx in T.thread_binding(B, thread="blockIdx.x"): + for fused_by_bz in T.thread_binding(20, thread="blockIdx.y"): + for ty in T.thread_binding(1, thread="threadIdx.y"): + for tx in T.thread_binding(16, thread="threadIdx.x"): + for tz in T.thread_binding(32, thread="threadIdx.z"): + with T.block("attn"): + T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, fused_by_bz // 20 + ty + fused_by_bz % 20, tx * 4 - 32:tx * 4 - 32 + 68]) + T.writes(output[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty]) + Q_local = T.alloc_buffer((4,), "float16", scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + K_smem = T.alloc_buffer((64, 64), "float16", scope="shared") + V_smem = T.alloc_buffer((64, 64), "float16", scope="shared") + O_allreduce = T.alloc_buffer((32, 1, 64), scope="shared") + md_allreduce = T.alloc_buffer((32, 1, 2), scope="shared") + S_reduce_local = T.alloc_buffer((1,), scope="local") + t0 = T.alloc_buffer((1,), scope="local") + S_local = T.alloc_buffer((2,), scope="local") + QK_local = T.alloc_buffer((4,), scope="local") + V_local = T.alloc_buffer((4,), "float16", scope="local") + m_prev = T.alloc_buffer((1,), scope="local") + d_prev = T.alloc_buffer((1,), scope="local") + other_m = T.alloc_buffer((1,), scope="local") + other_d = T.alloc_buffer((1,), scope="local") + exp_mprev = T.alloc_buffer((1,), scope="local") + exp_otherm = T.alloc_buffer((1,), scope="local") + other_o = T.alloc_buffer((4,), scope="local") + st_m = T.alloc_buffer((1,), scope="local") + st_d = T.alloc_buffer((1,), scope="local") + O_local = T.alloc_buffer((4,), scope="local") + by: T.int32 = fused_by_bz % 20 + bz: T.int32 = fused_by_bz // 20 + batch_idx: T.int32 = bx + cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] + cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] + kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0) + st_m[0] = T.float32(-50000) + st_d[0] = T.float32(1) + for vec in T.vectorized(4): + O_local[vec] = T.float32(0) + for vec in T.vectorized(4): + Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, Q[bx, by + bz + ty, tx * 4 + vec + 32] * T.float16(-1), Q[bx, by + bz + ty, tx * 4 + vec - 32]))), Q[bx, by + bz + ty, tx * 4 + vec]) + for iterator in range((kv_chunk_len[0] + 63) // 64): + tile_start_s: T.int32 = (tz + ty) * 2 + tile_start_g: T.int32 = (iterator * 32 + tz + ty) * 2 + for j in range(2): + with T.block("KV_load"): + T.reads() + T.writes() + row_g: T.int32 = tile_start_g + j + if row_g < kv_chunk_len[0]: + seq_offset: T.int32 = row_g + page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] + page_offset: T.int32 = seq_offset % 16 + for vec in T.vectorized(4): + K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, pages[page_no, 0, by, page_offset, tx * 4 + vec + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, tx * 4 + vec - 32]))), pages[page_no, 0, by, page_offset, tx * 4 + vec]) + V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] + else: + for vec in T.vectorized(4): + K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) + V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) + T.tvm_storage_sync("shared") + m_prev[0] = st_m[0] + for j in range(2): + for vec in T.vectorized(4): + QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale + S_reduce_local[0] = T.float32(0) + for vec in T.unroll(4): + S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] + with T.block("block_cross_thread"): + T.reads(S_reduce_local[0]) + T.writes(t0[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) + S_local[j] = T.float32(-50000) + if (iterator * 32 + tz) * 2 + j < kv_chunk_len[0]: + S_local[j] = t0[0] + st_m[0] = T.max(st_m[0], S_local[j]) + o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) + st_d[0] = st_d[0] * o_scale + for j in range(2): + S_local[j] = T.exp2(S_local[j] - st_m[0]) + st_d[0] = st_d[0] + S_local[j] + for j in T.vectorized(4): + O_local[j] = O_local[j] * o_scale + for j in range(2): + for vec in T.vectorized(4): + V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec] + for vec in T.vectorized(4): + O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] + for vec in T.vectorized(4): + O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] + md_allreduce[tz, ty, 0] = st_m[0] + md_allreduce[tz, ty, 1] = st_d[0] + T.tvm_storage_sync("shared") + st_m[0] = T.float32(-50000) + st_d[0] = T.float32(1) + for vec in T.vectorized(4): + O_local[vec] = T.float32(0) + for j in range(32): + m_prev[0] = st_m[0] + d_prev[0] = st_d[0] + other_m[0] = md_allreduce[j, ty, 0] + other_d[0] = md_allreduce[j, ty, 1] + for vec in T.vectorized(4): + other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] + st_m[0] = T.max(st_m[0], other_m[0]) + st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) + exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) + exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) + for vec in T.vectorized(4): + O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] + for vec in T.vectorized(4): + O_local[vec] = O_local[vec] / st_d[0] + for vec in T.vectorized(4): + output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) + lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) + + @T.prim_func + def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + B = T.int32(is_size_var=True) + Q = T.match_buffer(Q_handle, (B, 20, 64), "float16") + max_num_pages = T.int32(is_size_var=True) + pages = T.match_buffer(pages_handle, (max_num_pages, 2, 20, 16, 64), "float16") + page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) + nnz_pages = T.int32(is_size_var=True) + page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) + length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1) + k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) + q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) + output = T.match_buffer(output_handle, (B, 20, 64), "float16") + lse = T.match_buffer(lse_handle, (B, 20)) + # with T.block("root"): + sm_scale: T.float32 = T.float32(0.18033688011112042) + for bx in T.thread_binding(B, thread="blockIdx.x"): + for fused_by_bz in T.thread_binding(20, thread="blockIdx.y"): + for ty in T.thread_binding(1, thread="threadIdx.y"): + for tx in T.thread_binding(16, thread="threadIdx.x"): + for tz in T.thread_binding(32, thread="threadIdx.z"): + with T.block("attn"): + T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, fused_by_bz // 20 + ty + fused_by_bz % 20, tx * 4 - 32:tx * 4 - 32 + 68]) + T.writes(output[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty]) + Q_local = T.alloc_buffer((4,), "float16", scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + K_smem = T.alloc_buffer((64, 64), "float16", scope="shared") + V_smem = T.alloc_buffer((64, 64), "float16", scope="shared") + O_allreduce = T.alloc_buffer((32, 1, 64), scope="shared") + md_allreduce = T.alloc_buffer((32, 1, 2), scope="shared") + S_reduce_local = T.alloc_buffer((1,), scope="local") + t0 = T.alloc_buffer((1,), scope="local") + S_local = T.alloc_buffer((2,), scope="local") + QK_local = T.alloc_buffer((4,), scope="local") + V_local = T.alloc_buffer((4,), "float16", scope="local") + m_prev = T.alloc_buffer((1,), scope="local") + d_prev = T.alloc_buffer((1,), scope="local") + other_m = T.alloc_buffer((1,), scope="local") + other_d = T.alloc_buffer((1,), scope="local") + exp_mprev = T.alloc_buffer((1,), scope="local") + exp_otherm = T.alloc_buffer((1,), scope="local") + other_o = T.alloc_buffer((4,), scope="local") + st_m = T.alloc_buffer((1,), scope="local") + st_d = T.alloc_buffer((1,), scope="local") + O_local = T.alloc_buffer((4,), scope="local") + by: T.int32 = fused_by_bz % 20 + bz: T.int32 = fused_by_bz // 20 + batch_idx: T.int32 = bx + cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] + cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] + kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0) + st_m[0] = T.float32(-50000) + st_d[0] = T.float32(1) + for vec in T.vectorized(4): + O_local[vec] = T.float32(0) + for vec in T.vectorized(4): + Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, Q[bx, by + bz + ty, tx * 4 + vec + 32] * T.float16(-1), Q[bx, by + bz + ty, tx * 4 + vec - 32]))), Q[bx, by + bz + ty, tx * 4 + vec]) + for iterator in range((kv_chunk_len[0] + 63) // 64): + tile_start_s: T.int32 = (tz + ty) * 2 + tile_start_g: T.int32 = (iterator * 32 + tz + ty) * 2 + for j in range(2): + with T.block("KV_load"): + T.reads() + T.writes() + row_g: T.int32 = tile_start_g + j + if row_g < kv_chunk_len[0]: + seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx]) + page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] + page_offset: T.int32 = seq_offset % 16 + for vec in T.vectorized(4): + K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, pages[page_no, 0, by, page_offset, tx * 4 + vec + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, tx * 4 + vec - 32]))), pages[page_no, 0, by, page_offset, tx * 4 + vec]) + V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] + else: + for vec in T.vectorized(4): + K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) + V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) + T.tvm_storage_sync("shared") + m_prev[0] = st_m[0] + for j in range(2): + for vec in T.vectorized(4): + QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale + S_reduce_local[0] = T.float32(0) + for vec in T.unroll(4): + S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] + with T.block("block_cross_thread"): + T.reads(S_reduce_local[0]) + T.writes(t0[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) + S_local[j] = T.float32(-50000) + if (iterator * 32 + tz) * 2 + j < kv_chunk_len[0]: + S_local[j] = t0[0] + st_m[0] = T.max(st_m[0], S_local[j]) + o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) + st_d[0] = st_d[0] * o_scale + for j in range(2): + S_local[j] = T.exp2(S_local[j] - st_m[0]) + st_d[0] = st_d[0] + S_local[j] + for j in T.vectorized(4): + O_local[j] = O_local[j] * o_scale + for j in range(2): + for vec in T.vectorized(4): + V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec] + for vec in T.vectorized(4): + O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] + for vec in T.vectorized(4): + O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] + md_allreduce[tz, ty, 0] = st_m[0] + md_allreduce[tz, ty, 1] = st_d[0] + T.tvm_storage_sync("shared") + st_m[0] = T.float32(-50000) + st_d[0] = T.float32(1) + for vec in T.vectorized(4): + O_local[vec] = T.float32(0) + for j in range(32): + m_prev[0] = st_m[0] + d_prev[0] = st_d[0] + other_m[0] = md_allreduce[j, ty, 0] + other_d[0] = md_allreduce[j, ty, 1] + for vec in T.vectorized(4): + other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] + st_m[0] = T.max(st_m[0], other_m[0]) + st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) + exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) + exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) + for vec in T.vectorized(4): + O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] + for vec in T.vectorized(4): + O_local[vec] = O_local[vec] / st_d[0] + for vec in T.vectorized(4): + output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) + lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) + + @T.prim_func + def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + total_len = T.int32(is_size_var=True) + q = T.match_buffer(var_q, (total_len, 20, 64), "float16") + batch_size = T.int32(is_size_var=True) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) + max_num_pages = T.int32(is_size_var=True) + pages = T.match_buffer(var_pages, (max_num_pages, 2, 20, 16, 64), "float16") + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) + nnz_pages = T.int32(is_size_var=True) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) + length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) + output = T.match_buffer(var_output, (total_len, 20, 64), "float16") + lse = T.match_buffer(var_lse, (total_len, 20)) + # with T.block("root"): + for lbx in T.thread_binding(16, thread="blockIdx.x"): + for lby in T.thread_binding(20, thread="blockIdx.y"): + for lty in T.thread_binding(4, thread="threadIdx.y"): + for ltx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = T.alloc_buffer((1,), "int32", scope="local") + batch_idx = T.alloc_buffer((1,), "int32", scope="local") + batch_tiles = T.alloc_buffer((1,), "int32", scope="local") + batch_rows = T.alloc_buffer((1,), "int32", scope="local") + iterator = T.alloc_buffer((1,), "int32", scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") + K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + S_smem = T.alloc_buffer((32, 16), scope="shared") + S_local = T.alloc_buffer((32, 16), scope="local") + O_local = T.alloc_buffer((32, 64), scope="local") + m_smem = T.alloc_buffer((32,), scope="shared") + m_prev_smem = T.alloc_buffer((32,), scope="shared") + d_smem = T.alloc_buffer((32,), scope="shared") + m_new = T.alloc_buffer((1,), scope="local") + m_prev = T.alloc_buffer((1,), scope="local") + d_new = T.alloc_buffer((1,), scope="local") + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = q_indptr[1] - q_indptr[0] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] = tile_id[0] - batch_tiles[0] + batch_idx[0] = batch_idx[0] + 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * 32 + q_indptr_val: T.int32 = q_indptr[b_idx] + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0) + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + m_smem[row] = T.float32(-50000) + d_smem[row] = T.float32(1) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = T.float32(0) + T.tvm_storage_sync("shared") + for li_lj_fused_0 in range(4): + for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for li_lj_fused_3 in T.vectorized(4): + with T.block("Q_load"): + i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) + j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = q_indptr_val + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]))), q[cur_L, cur_H_qo, j]) + else: + Q_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + for iterator_1 in range((kv_chunk_len[0] + 15) // 16): + L_kv_start: T.int32 = iterator_1 * 16 + for lz_ly_fused_0 in range(2): + for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for lz_ly_fused_3 in T.vectorized(4): + with T.block("K_load"): + i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) + j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32 = cur_L + page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] + page_offset: T.int32 = seq_offset % 16 + K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, pages[page_no, 0, by, page_offset, j + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, j - 32]))), pages[page_no, 0, by, page_offset, j]) + else: + K_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + for lz_ly_fused_0 in range(2): + for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for lz_ly_fused_3 in T.vectorized(4): + with T.block("V_load"): + i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) + j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32 = cur_L + page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] + page_offset: T.int32 = seq_offset % 16 + V_smem[i, j] = pages[page_no, 1, by, page_offset, j] + else: + V_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) + T.writes(S_local[0:32, 0:16]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(2, 2): + with T.block("S_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) + j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) + T.reads() + T.writes(S_local[i, j]) + S_local[i, j] = T.float32(0) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): + with T.block("S_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + k = T.axis.reduce(64, lk_0 * 8 + lk_1) + T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) + T.writes(S_local[i, j]) + S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.18033688011112042) + T.tvm_storage_sync("shared") + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(2, 2): + with T.block("S_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + T.reads(S_local[i, j]) + T.writes(S_smem[i, j]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update1"): + T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) + T.writes(m_prev[i], m_new[i], d_new[i]) + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + row_: T.int32 = LH_start + row + for j in range(16): + if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + with T.block("update"): + T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) + T.writes(S_smem[row, 0:16]) + for j in range(16): + if row < 32: + row_: T.int32 = LH_start + row + if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update"): + T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) + T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) + for j in range(16): + d_new[i] = d_new[i] + S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) + T.writes(O_local[0:32, 0:64]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(4, 4): + with T.block("O_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) + j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): + with T.block("O_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + k = T.axis.reduce(16, lk_0 * 8 + lk_1) + T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) + T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) + for li_0 in range(1): + for li_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_2 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("lse_store"): + i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) + T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) + T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) + T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + tile_id[0] = tile_id[0] + 16 + + @T.prim_func + def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + total_len = T.int32(is_size_var=True) + q = T.match_buffer(var_q, (total_len, 20, 64), "float16") + batch_size = T.int32(is_size_var=True) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) + max_num_pages = T.int32(is_size_var=True) + pages = T.match_buffer(var_pages, (max_num_pages, 2, 20, 16, 64), "float16") + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) + nnz_pages = T.int32(is_size_var=True) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) + length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) + output = T.match_buffer(var_output, (total_len, 20, 64), "float16") + lse = T.match_buffer(var_lse, (total_len, 20)) + # with T.block("root"): + for lbx in T.thread_binding(16, thread="blockIdx.x"): + for lby in T.thread_binding(20, thread="blockIdx.y"): + for lty in T.thread_binding(4, thread="threadIdx.y"): + for ltx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = T.alloc_buffer((1,), "int32", scope="local") + batch_idx = T.alloc_buffer((1,), "int32", scope="local") + batch_tiles = T.alloc_buffer((1,), "int32", scope="local") + batch_rows = T.alloc_buffer((1,), "int32", scope="local") + iterator = T.alloc_buffer((1,), "int32", scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") + K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + S_smem = T.alloc_buffer((32, 16), scope="shared") + S_local = T.alloc_buffer((32, 16), scope="local") + O_local = T.alloc_buffer((32, 64), scope="local") + m_smem = T.alloc_buffer((32,), scope="shared") + m_prev_smem = T.alloc_buffer((32,), scope="shared") + d_smem = T.alloc_buffer((32,), scope="shared") + m_new = T.alloc_buffer((1,), scope="local") + m_prev = T.alloc_buffer((1,), scope="local") + d_new = T.alloc_buffer((1,), scope="local") + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = q_indptr[1] - q_indptr[0] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] = tile_id[0] - batch_tiles[0] + batch_idx[0] = batch_idx[0] + 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * 32 + q_indptr_val: T.int32 = q_indptr[b_idx] + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0) + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + m_smem[row] = T.float32(-50000) + d_smem[row] = T.float32(1) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = T.float32(0) + T.tvm_storage_sync("shared") + for li_lj_fused_0 in range(4): + for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for li_lj_fused_3 in T.vectorized(4): + with T.block("Q_load"): + i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) + j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = q_indptr_val + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]))), q[cur_L, cur_H_qo, j]) + else: + Q_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + for iterator_1 in range((kv_chunk_len[0] + 15) // 16): + L_kv_start: T.int32 = iterator_1 * 16 + for lz_ly_fused_0 in range(2): + for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for lz_ly_fused_3 in T.vectorized(4): + with T.block("K_load"): + i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) + j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) + page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] + page_offset: T.int32 = seq_offset % 16 + K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, pages[page_no, 0, by, page_offset, j + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, j - 32]))), pages[page_no, 0, by, page_offset, j]) + else: + K_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + for lz_ly_fused_0 in range(2): + for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for lz_ly_fused_3 in T.vectorized(4): + with T.block("V_load"): + i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) + j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) + page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] + page_offset: T.int32 = seq_offset % 16 + V_smem[i, j] = pages[page_no, 1, by, page_offset, j] + else: + V_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) + T.writes(S_local[0:32, 0:16]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(2, 2): + with T.block("S_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) + j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) + T.reads() + T.writes(S_local[i, j]) + S_local[i, j] = T.float32(0) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): + with T.block("S_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + k = T.axis.reduce(64, lk_0 * 8 + lk_1) + T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) + T.writes(S_local[i, j]) + S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.18033688011112042) + T.tvm_storage_sync("shared") + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(2, 2): + with T.block("S_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + T.reads(S_local[i, j]) + T.writes(S_smem[i, j]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update1"): + T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) + T.writes(m_prev[i], m_new[i], d_new[i]) + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + row_: T.int32 = LH_start + row + for j in range(16): + if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + with T.block("update"): + T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) + T.writes(S_smem[row, 0:16]) + for j in range(16): + if row < 32: + row_: T.int32 = LH_start + row + if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update"): + T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) + T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) + for j in range(16): + d_new[i] = d_new[i] + S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) + T.writes(O_local[0:32, 0:64]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(4, 4): + with T.block("O_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) + j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): + with T.block("O_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + k = T.axis.reduce(16, lk_0 * 8 + lk_1) + T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) + T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) + for li_0 in range(1): + for li_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_2 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("lse_store"): + i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) + T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) + T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) + T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + tile_id[0] = tile_id[0] + 16 + + @T.prim_func + def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + qo_len = T.int32(is_size_var=True) + q = T.match_buffer(var_q, (qo_len, 20, 64), "float16") + batch_size = T.int32(is_size_var=True) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) + kv_len = T.int32(is_size_var=True) + k = T.match_buffer(var_k, (kv_len, 20, 64), "float16") + v = T.match_buffer(var_v, (kv_len, 20, 64), "float16") + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) + output = T.match_buffer(var_output, (qo_len, 20, 64), "float16") + lse = T.match_buffer(var_lse, (qo_len, 20)) + # with T.block("root"): + for lbx in T.thread_binding(16, thread="blockIdx.x"): + for lby in T.thread_binding(20, thread="blockIdx.y"): + for lty in T.thread_binding(4, thread="threadIdx.y"): + for ltx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = T.alloc_buffer((1,), "int32", scope="local") + batch_idx = T.alloc_buffer((1,), "int32", scope="local") + batch_tiles = T.alloc_buffer((1,), "int32", scope="local") + batch_rows = T.alloc_buffer((1,), "int32", scope="local") + iterator = T.alloc_buffer((1,), "int32", scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") + K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + S_smem = T.alloc_buffer((32, 16), scope="shared") + S_local = T.alloc_buffer((32, 16), scope="local") + O_local = T.alloc_buffer((32, 64), scope="local") + m_smem = T.alloc_buffer((32,), scope="shared") + m_prev_smem = T.alloc_buffer((32,), scope="shared") + d_smem = T.alloc_buffer((32,), scope="shared") + m_new = T.alloc_buffer((1,), scope="local") + m_prev = T.alloc_buffer((1,), scope="local") + d_new = T.alloc_buffer((1,), scope="local") + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = q_indptr[1] - q_indptr[0] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] = tile_id[0] - batch_tiles[0] + batch_idx[0] = batch_idx[0] + 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + q_indptr_val: T.int32 = q_indptr[b_idx] + LH_start: T.int32 = tile_id[0] * 32 + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + m_smem[row] = T.float32(-50000) + d_smem[row] = T.float32(1) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = T.float32(0) + T.tvm_storage_sync("shared") + for li_lj_fused_0 in range(4): + for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for li_lj_fused_3 in T.vectorized(4): + with T.block("Q_load"): + i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) + j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = q_indptr_val + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]))), q[cur_L, cur_H_qo, j]) + else: + Q_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + for iterator_1 in range((kv_chunk_len[0] + 15) // 16): + L_kv_start: T.int32 = iterator_1 * 16 + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz_ly_fused_0 in range(2): + for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for lz_ly_fused_3 in T.vectorized(4): + with T.block("K_load"): + i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) + j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = L_kv_start + i + if cur_L < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, k[L_kv_base + cur_L, by, j + 32] * T.float16(-1), k[L_kv_base + cur_L, by, j - 32]))), k[L_kv_base + cur_L, by, j]) + else: + K_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + for lz_ly_fused_0 in range(2): + for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for lz_ly_fused_3 in T.vectorized(4): + with T.block("V_load"): + i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) + j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = L_kv_start + i + if cur_L < kv_chunk_len[0]: + V_smem[i, j] = v[L_kv_base + cur_L, by, j] + else: + V_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) + T.writes(S_local[0:32, 0:16]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(2, 2): + with T.block("S_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) + j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) + T.reads() + T.writes(S_local[i, j]) + S_local[i, j] = T.float32(0) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): + with T.block("S_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + k_1 = T.axis.reduce(64, lk_0 * 8 + lk_1) + T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1]) + T.writes(S_local[i, j]) + S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.18033688011112042) + T.tvm_storage_sync("shared") + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(2, 2): + with T.block("S_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + T.reads(S_local[i, j]) + T.writes(S_smem[i, j]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update1"): + T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) + T.writes(m_prev[i], m_new[i], d_new[i]) + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + row_: T.int32 = LH_start + row + for j in range(16): + if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + with T.block("update"): + T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) + T.writes(S_smem[row, 0:16]) + for j in range(16): + if row < 32: + row_: T.int32 = LH_start + row + if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update"): + T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) + T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) + for j in range(16): + d_new[i] = d_new[i] + S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) + T.writes(O_local[0:32, 0:64]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(4, 4): + with T.block("O_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) + j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): + with T.block("O_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + k_1 = T.axis.reduce(16, lk_0 * 8 + lk_1) + T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) + T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) + for li_0 in range(1): + for li_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_2 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("lse_store"): + i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) + T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) + T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) + T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + tile_id[0] = tile_id[0] + 16 + + @T.prim_func + def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + qo_len = T.int32(is_size_var=True) + q = T.match_buffer(var_q, (qo_len, 20, 64), "float16") + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) + kv_len = T.int32(is_size_var=True) + k = T.match_buffer(var_k, (kv_len, 20, 64), "float16") + v = T.match_buffer(var_v, (kv_len, 20, 64), "float16") + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) + mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1) + tree_size = T.int32(is_size_var=True) + mask = T.match_buffer(var_mask, (tree_size,), "int32", offset_factor=1) + output = T.match_buffer(var_output, (qo_len, 20, 64), "float16") + lse = T.match_buffer(var_lse, (qo_len, 20)) + # with T.block("root"): + for lbx in T.thread_binding(16, thread="blockIdx.x"): + for lby in T.thread_binding(20, thread="blockIdx.y"): + for lty in T.thread_binding(4, thread="threadIdx.y"): + for ltx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = T.alloc_buffer((1,), "int32", scope="local") + batch_idx = T.alloc_buffer((1,), "int32", scope="local") + batch_tiles = T.alloc_buffer((1,), "int32", scope="local") + batch_rows = T.alloc_buffer((1,), "int32", scope="local") + iterator = T.alloc_buffer((1,), "int32", scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") + K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") + S_smem = T.alloc_buffer((32, 16), scope="shared") + S_local = T.alloc_buffer((32, 16), scope="local") + O_local = T.alloc_buffer((32, 64), scope="local") + m_smem = T.alloc_buffer((32,), scope="shared") + m_prev_smem = T.alloc_buffer((32,), scope="shared") + d_smem = T.alloc_buffer((32,), scope="shared") + m_new = T.alloc_buffer((1,), scope="local") + m_prev = T.alloc_buffer((1,), scope="local") + d_new = T.alloc_buffer((1,), scope="local") + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = q_indptr[1] - q_indptr[0] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] = tile_id[0] - batch_tiles[0] + batch_idx[0] = batch_idx[0] + 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] + batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * 32 + q_indptr_val: T.int32 = q_indptr[b_idx] + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + m_smem[row] = T.float32(-50000) + d_smem[row] = T.float32(1) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = T.float32(0) + T.tvm_storage_sync("shared") + for li_lj_fused_0 in range(4): + for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for li_lj_fused_3 in T.vectorized(4): + with T.block("Q_load"): + i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) + j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = q_indptr_val + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * q[cur_L, cur_H_qo, j] + T.Cast("float16", T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]), q[cur_L, cur_H_qo, j]) + else: + Q_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + for iterator_1 in range((kv_chunk_len[0] + 15) // 16): + L_kv_start: T.int32 = iterator_1 * 16 + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz_ly_fused_0 in range(2): + for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for lz_ly_fused_3 in T.vectorized(4): + with T.block("KV_load"): + i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) + j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) + T.reads() + T.writes() + cur_L: T.int32 = L_kv_base + L_kv_start + i + if L_kv_start + i < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * k[cur_L, by, j] + T.Cast("float16", T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * T.if_then_else(j < 32, k[cur_L, by, j + 32] * T.float16(-1), k[cur_L, by, j - 32]), k[cur_L, by, j]) + V_smem[i, j] = v[cur_L, by, j] + else: + K_smem[i, j] = T.float16(0) + V_smem[i, j] = T.float16(0) + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) + T.writes(S_local[0:32, 0:16]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(2, 2): + with T.block("S_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) + j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) + T.reads() + T.writes(S_local[i, j]) + S_local[i, j] = T.float32(0) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): + with T.block("S_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + k_1 = T.axis.reduce(64, lk_0 * 8 + lk_1) + T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1]) + T.writes(S_local[i, j]) + S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.18033688011112042) + T.tvm_storage_sync("shared") + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(2, 2): + with T.block("S_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) + j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) + T.reads(S_local[i, j]) + T.writes(S_smem[i, j]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update1"): + T.reads(m_smem[row], kv_chunk_len[0], mask[mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start:mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start + 16], mn_indptr[b_idx], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) + T.writes(m_prev[i], m_new[i], d_new[i]) + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + row_: T.int32 = LH_start + row + for j in range(16): + if L_kv_start + j < kv_chunk_len[0] and mask[mn_indptr[b_idx] + row_ * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + (L_kv_start + j)] == 1: + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + with T.block("update"): + T.reads(kv_chunk_len[0], mask[mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start:mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start + 16], mn_indptr[b_idx], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) + T.writes(S_smem[row, 0:16]) + for j in range(16): + if row < 32: + row_: T.int32 = LH_start + row + if L_kv_start + j < kv_chunk_len[0] and mask[mn_indptr[b_idx] + row_ * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + (L_kv_start + j)] == 1: + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) + for i in range(1): + row: T.int32 = i * 32 * 4 + ty * 32 + tx + if row < 32: + with T.block("update"): + T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) + T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) + for j in range(16): + d_new[i] = d_new[i] + S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + with T.block(""): + T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) + T.writes(O_local[0:32, 0:64]) + for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): + for li_1_init, lj_1_init in T.grid(4, 4): + with T.block("O_gemm_init"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) + j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) + T.reads() + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): + with T.block("O_gemm_update"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + k_1 = T.axis.reduce(16, lk_0 * 8 + lk_1) + T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) + T.writes(O_local[i, j]) + O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) + for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for li_1, lj_1 in T.grid(4, 4): + with T.block("O_store"): + i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) + j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) + T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) + T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) + for li_0 in range(1): + for li_1 in T.thread_binding(4, thread="threadIdx.y"): + for li_2 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("lse_store"): + i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) + T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) + T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) + T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) + cur_H_qo: T.int32 = by + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + tile_id[0] = tile_id[0] + 16 + + @T.prim_func(private=True) + def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + num_nodes, vocab_size = T.int32(is_size_var=True), T.int64() + draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size)) + draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32") + model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size)) + token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32") + token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32") + uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,)) + nbatch = T.int32(is_size_var=True) + token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32") + # with T.block("root"): + child_ptr = T.alloc_buffer((1,), "int32", scope="local") + parent_ptr = T.alloc_buffer((1,), "int32", scope="local") + child_token = T.alloc_buffer((1,), "int32", scope="local") + done = T.alloc_buffer((1,), "bool", scope="local") + psum = T.alloc_buffer((1,), scope="local") + t0 = T.alloc_buffer((1,), scope="local") + model_prob_local = T.alloc_buffer((1,), scope="local") + draft_prob_local = T.alloc_buffer((1,), scope="local") + p_child = T.alloc_buffer((1,), scope="local") + q_child = T.alloc_buffer((1,), scope="local") + uniform_sample = T.alloc_buffer((1,), scope="local") + pred_shared = T.alloc_buffer((1,), "bool", scope="shared") + pred_local = T.alloc_buffer((1,), "bool", scope="local") + for _bx in T.thread_binding(nbatch, thread="blockIdx.x"): + for _tx in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("CTA"): + b, tx = T.axis.remap("SS", [_bx, _tx]) + T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]]) + T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b]) + parent_ptr[0] = token_tree_parent_ptr[b] + child_ptr[0] = token_tree_first_child[parent_ptr[0]] + done[0] = T.bool(False) + while not done[0]: + T.tvm_storage_sync("shared") + if child_ptr[0] == -1: + done[0] = T.bool(True) + T.tvm_storage_sync("shared") + else: + if tx == 0: + child_token[0] = draft_tokens[child_ptr[0]] + p_child[0] = model_probs[parent_ptr[0], child_token[0]] + q_child[0] = draft_probs[child_ptr[0], child_token[0]] + uniform_sample[0] = uniform_samples[child_ptr[0]] + pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0] + T.tvm_storage_sync("shared") + pred_local[0] = pred_shared[0] + if pred_local[0]: + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + psum[0] = T.float32(0) + for i in range((vocab_size + T.int64(1023)) // T.int64(1024)): + if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] + draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0)) + psum[0] = psum[0] + model_prob_local[0] + with T.block("block_cross_thread"): + T.reads(psum[0]) + T.writes(t0[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx) + if t0[0] < T.float32(9.9999999999999995e-08): + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + for i in range((vocab_size + T.int64(1023)) // T.int64(1024)): + if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] + draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0)) + model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0] + child_ptr[0] = token_tree_next_sibling[child_ptr[0]] + if tx == 0: + token_tree_parent_ptr[b] = parent_ptr[0] + + @T.prim_func + def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size)) + temperature = T.match_buffer(var_temperature, (batch_size,)) + num_chunks = T.int64(is_size_var=True) + chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) + chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) + # with T.block("root"): + temp_max_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared") + temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared") + for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("max"): + v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0) + v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) + v2 = T.axis.reduce(T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1) + T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2]) + T.writes(temp_max_shared[v0, v1]) + with T.init(): + temp_max_shared[v0, v1] = T.float32(-3.4028234663852886e+38) + temp_max_shared[v0, v1] = T.max(temp_max_shared[v0, v1], T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-3.4028234663852886e+38))) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("sum_exp"): + v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0) + v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) + v2 = T.axis.reduce(T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1) + T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_max_shared[v0, v1]) + T.writes(temp_sum_shared[v0, v1]) + with T.init(): + temp_sum_shared[v0, v1] = T.float32(0) + temp_sum_shared[v0, v1] = temp_sum_shared[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-3.4028234663852886e+38)) - temp_max_shared[v0, v1]), T.Cast("float32", T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-3.4028234663852886e+38)) == temp_max_shared[v0, v1])), T.float32(0)) + for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("log"): + v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks) + v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks) + v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(256) + ax2_1) + T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1)) + T.reads(temperature[v0], temp_sum_shared[v0, v1], temp_max_shared[v0, v1]) + T.writes(chunked_sum[v0, v1], chunked_max[v0, v1]) + chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum_shared[v0, v1]), temp_sum_shared[v0, v1]) + chunked_max[v0, v1] = temp_max_shared[v0, v1] + + @T.prim_func + def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + num_pages = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, 20, 16, 64), "float16") + copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1) + total_copy_length = T.int32() + copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1) + with T.block("root"): + T.reads() + T.writes() + for bhd_o in T.thread_binding((batch_size * 1280 + 1023) // 1024, thread="blockIdx.x"): + for bhd_i in T.thread_binding(1024, thread="threadIdx.x"): + b: T.int32 = (bhd_o * 1024 + bhd_i) // 1280 + h: T.int32 = (bhd_o * 1024 + bhd_i) // 64 % 20 + d: T.int32 = (bhd_o * 1024 + bhd_i) % 64 + if bhd_o * 1024 + bhd_i < batch_size * 20 * 64: + for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]): + src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] + dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] + pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d] + pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d] + + @T.prim_func(private=True) + def concatenate(var_reshape710: T.handle, var_reshape711: T.handle, var_reshape712: T.handle, var_T_concat: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + reshape710 = T.match_buffer(var_reshape710, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") + reshape711 = T.match_buffer(var_reshape711, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") + reshape712 = T.match_buffer(var_reshape712, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") + T_concat = T.match_buffer(var_T_concat, (batch_size, T.int64(1), T.int64(60), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_concat"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) + v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(3840)) + T.reads(reshape712[v0, T.int64(0), v1 + T.int64(-40), v2], reshape711[v0, T.int64(0), v1 + T.int64(-20), v2], reshape710[v0, T.int64(0), v1, v2]) + T.writes(T_concat[v0, T.int64(0), v1, v2]) + T_concat[v0, T.int64(0), v1, v2] = T.if_then_else(T.int64(40) <= v1, reshape712[v0, T.int64(0), v1 - T.int64(40), v2], T.if_then_else(T.int64(20) <= v1, reshape711[v0, T.int64(0), v1 + T.int64(-20), v2], reshape710[v0, T.int64(0), v1, v2])) + + @T.prim_func(private=True) + def concatenate1(var_reshape387: T.handle, var_reshape388: T.handle, var_reshape389: T.handle, var_T_concat: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + reshape387 = T.match_buffer(var_reshape387, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") + reshape388 = T.match_buffer(var_reshape388, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") + reshape389 = T.match_buffer(var_reshape389, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") + T_concat = T.match_buffer(var_T_concat, (T.int64(1), seq_len, T.int64(60), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_concat"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) + v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(3840)) + T.reads(reshape389[T.int64(0), v0, v1 + T.int64(-40), v2], reshape388[T.int64(0), v0, v1 + T.int64(-20), v2], reshape387[T.int64(0), v0, v1, v2]) + T.writes(T_concat[T.int64(0), v0, v1, v2]) + T_concat[T.int64(0), v0, v1, v2] = T.if_then_else(T.int64(40) <= v1, reshape389[T.int64(0), v0, v1 - T.int64(40), v2], T.if_then_else(T.int64(20) <= v1, reshape388[T.int64(0), v0, v1 + T.int64(-20), v2], reshape387[T.int64(0), v0, v1, v2])) + + @T.prim_func + def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + num_pages, page_size = T.int32(), T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, 20, page_size, 64), "float16") + # with T.block("root"): + for b in T.thread_binding((copy_length * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for t in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("copy"): + vh = T.axis.spatial(20, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(64)))) + vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(64)) // T.int64(64)) + vd = T.axis.spatial(64, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(64))) + T.reads(pages[src_page_id, 0:2, vh, vp, vd]) + T.writes(pages[tgt_page_id, 0:2, vh, vp, vd]) + pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] + pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] + + @T.prim_func(private=True) + def cumsum(var_sorted_probs: T.handle, var_lv1: T.handle, var_exclusive_scan_thrust: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int64(), T.int64() + data_buf = T.match_buffer(var_sorted_probs, (batch_size, vocab_size), align=8) + workspace_buf = T.match_buffer(var_lv1, (T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12),), "uint8", align=8) + output_buf = T.match_buffer(var_exclusive_scan_thrust, (batch_size, vocab_size), align=8) + with T.block("exclusive_scan_thrust"): + T.reads() + T.writes() + T.call_packed("tvm.contrib.thrust.sum_scan", T.tvm_stack_make_array(data_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.tvm_stack_make_array(output_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.bool(False), T.tvm_stack_make_array(workspace_buf.data, T.tvm_stack_make_shape(T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12)), 0, 1, T.uint8(0), T.int64(0))) + + @T.prim_func + def full(var_result: T.handle, value: T.int32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + batch_size = T.int32(is_size_var=True) + result = T.match_buffer(var_result, (batch_size, 1), "int32") + # with T.block("root"): + for ax0_fused_0 in T.thread_binding((batch_size + 1023) // 1024, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("block"): + v0 = T.axis.spatial(batch_size, ax0_fused_0 * 1024 + ax0_fused_1) + T.where(ax0_fused_0 * 1024 + ax0_fused_1 < batch_size) + T.reads() + T.writes(result[v0, 0]) + result[v0, 0] = value + + @T.prim_func(private=True) + def fused_NT_matmul1_add8_gelu2(layer_norm358: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_fc1_weight5: T.Buffer((T.int64(5120), T.int64(1280)), "float16"), model_decoder_layers_0_fc1_bias5: T.Buffer((T.int64(5120),), "float16"), T_multiply_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="local") + NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="local") + NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="local") + model_decoder_layers_0_fc1_weight5_local = T.alloc_buffer((T.int64(5120), T.int64(1280)), "float16", scope="local") + layer_norm358_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(1280), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_0 in T.serial(T.int64(5), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax2_2 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_3 in T.vectorized(T.int64(1)): + with T.block("layer_norm358_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(256) + ax2_1 * T.int64(64) + ax2_2 + ax2_3) + T.reads(layer_norm358[v0, v1, v2]) + T.writes(layer_norm358_shared[v0, v1, v2]) + layer_norm358_shared[v0, v1, v2] = layer_norm358[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_fused_0 in range(T.int64(2)): + for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): + with T.block("model_decoder_layers_0_fc1_weight5_local"): + v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1) + v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) + T.reads(model_decoder_layers_0_fc1_weight5[v0, v1]) + T.writes(model_decoder_layers_0_fc1_weight5_local[v0, v1]) + model_decoder_layers_0_fc1_weight5_local[v0, v1] = model_decoder_layers_0_fc1_weight5[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_fused_u_fused_2, vax1_fused_u_fused_0 = T.axis.remap("RR", [ax1_fused_u_fused_2, ax1_fused_u_fused_0]) + T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm358_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused], model_decoder_layers_0_fc1_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + layer_norm358_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused] * model_decoder_layers_0_fc1_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused] + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_2_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) + v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads() + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] + for ax1_fused_2 in range(T.int64(1)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) + v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + with T.init(): + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax0_fused_2 in range(T.int64(1)): + with T.block("T_multiply_2"): + v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) + T.reads(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_fc1_bias5[v0]) + T.writes(T_multiply_intermediate[T.int64(0), T.int64(0), v0]) + T_multiply_intermediate[T.int64(0), T.int64(0), v0] = (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_fc1_bias5[v0]) * (T.float16(0.5) + T.Cast("float16", T.erf(T.Cast("float32", (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_fc1_bias5[v0]) * T.float16(0.70710678118654757)))) * T.float16(0.5)) + + @T.prim_func(private=True) + def fused_NT_matmul2_add7_add6(gelu130: T.Buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16"), model_decoder_layers_0_fc2_weight5: T.Buffer((T.int64(1280), T.int64(5120)), "float16"), model_decoder_layers_0_fc2_bias5: T.Buffer((T.int64(1280),), "float16"), add1227: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_add_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + model_decoder_layers_0_fc2_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(5120)), "float16", scope="local") + gelu130_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_0 in T.serial(T.int64(5), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_3 in T.vectorized(T.int64(2)): + with T.block("gelu130_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(5120), ax2_0 * T.int64(1024) + ax2_1 * T.int64(64) + ax2_2 * T.int64(2) + ax2_3) + T.reads(gelu130[v0, v1, v2]) + T.writes(gelu130_shared[v0, v1, v2]) + gelu130_shared[v0, v1, v2] = gelu130[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_fused_u_fused_0 in T.serial(T.int64(20), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): + with T.block("model_decoder_layers_0_fc2_weight5_local"): + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) + v1 = T.axis.spatial(T.int64(5120), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) + T.reads(model_decoder_layers_0_fc2_weight5[v0, v1]) + T.writes(model_decoder_layers_0_fc2_weight5_local[v0, v1]) + model_decoder_layers_0_fc2_weight5_local[v0, v1] = model_decoder_layers_0_fc2_weight5[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) + T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], gelu130_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_fc2_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + gelu130_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_fc2_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_2_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads() + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] + for ax1_fused_2 in range(T.int64(1)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + with T.init(): + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0_fused_2 in range(T.int64(1)): + with T.block("T_add_1"): + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) + T.reads(add1227[T.int64(0), T.int64(0), v0], NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_fc2_bias5[v0]) + T.writes(T_add_intermediate_1[T.int64(0), T.int64(0), v0]) + T_add_intermediate_1[T.int64(0), T.int64(0), v0] = add1227[T.int64(0), T.int64(0), v0] + (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_fc2_bias5[v0]) + + @T.prim_func(private=True) + def fused_NT_matmul_add7(layer_norm356: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_q_proj_weight5: T.Buffer((T.int64(1280), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_q_proj_bias5: T.Buffer((T.int64(1280),), "float16"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + model_decoder_layers_0_self_attn_q_proj_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(1280)), "float16", scope="local") + layer_norm356_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_0 in T.serial(T.int64(3), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_3 in T.vectorized(T.int64(1)): + with T.block("layer_norm356_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) + T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(1280)) + T.reads(layer_norm356[v0, v1, v2]) + T.writes(layer_norm356_shared[v0, v1, v2]) + layer_norm356_shared[v0, v1, v2] = layer_norm356[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): + with T.block("model_decoder_layers_0_self_attn_q_proj_weight5_local"): + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) + v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) + T.reads(model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1]) + T.writes(model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1]) + model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1] = model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) + T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_2_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads() + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] + for ax1_fused_2 in range(T.int64(1)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + with T.init(): + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0_fused_2 in range(T.int64(1)): + with T.block("T_add"): + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) + T.reads(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_self_attn_q_proj_bias5[v0]) + T.writes(T_add_intermediate[T.int64(0), T.int64(0), v0]) + T_add_intermediate[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_self_attn_q_proj_bias5[v0] + + @T.prim_func(private=True) + def fused_NT_matmul_add7_add6(reshape1361: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_out_proj_weight5: T.Buffer((T.int64(1280), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_out_proj_bias5: T.Buffer((T.int64(1280),), "float16"), add1220: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_add_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") + model_decoder_layers_0_self_attn_out_proj_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(1280)), "float16", scope="local") + reshape1361_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_0 in T.serial(T.int64(3), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_3 in T.vectorized(T.int64(1)): + with T.block("reshape1361_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) + T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(1280)) + T.reads(reshape1361[v0, v1, v2]) + T.writes(reshape1361_shared[v0, v1, v2]) + reshape1361_shared[v0, v1, v2] = reshape1361[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): + with T.block("model_decoder_layers_0_self_attn_out_proj_weight5_local"): + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) + v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) + T.reads(model_decoder_layers_0_self_attn_out_proj_weight5[v0, v1]) + T.writes(model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, v1]) + model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, v1] = model_decoder_layers_0_self_attn_out_proj_weight5[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) + T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], reshape1361_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) + T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + reshape1361_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_2_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads() + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] + for ax1_fused_2 in range(T.int64(1)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) + T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + with T.init(): + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) + NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0_fused_2 in range(T.int64(1)): + with T.block("T_add_1"): + v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) + T.reads(add1220[T.int64(0), T.int64(0), v0], NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_self_attn_out_proj_bias5[v0]) + T.writes(T_add_intermediate_1[T.int64(0), T.int64(0), v0]) + T_add_intermediate_1[T.int64(0), T.int64(0), v0] = add1220[T.int64(0), T.int64(0), v0] + (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_self_attn_out_proj_bias5[v0]) + + @T.prim_func(private=True) + def fused_add4_maximum_minimum(p_add4: T.handle, p_lv611: T.handle, p_output0: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + add4 = T.match_buffer(p_add4, (batch_size, T.int64(1500), T.int64(1280)), "float16") + lv611 = T.match_buffer(p_lv611, (batch_size, T.int64(1500), T.int64(1280)), "float16") + T_minimum_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1500), T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_minimum"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) + v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) + v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) + T.reads(add4[v0, v1, v2], lv611[v0, v1, v2]) + T.writes(T_minimum_intermediate[v0, v1, v2]) + T_minimum_intermediate[v0, v1, v2] = T.min(T.max(add4[v0, v1, v2] + lv611[v0, v1, v2], T.float16(-65504)), T.float16(65504)) + + @T.prim_func(private=True) + def fused_conv1d1_add2_gelu1(p_gelu: T.handle, model_encoder_conv2_weight: T.Buffer((T.int64(1280), T.int64(1280), T.int64(3)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1280), T.int64(1)), "float16"), p_output0: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + gelu = T.match_buffer(p_gelu, (batch_size, T.int64(1280), T.int64(3000)), "float16") + T_multiply_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1280), T.int64(1500)), "float16") + # with T.block("root"): + conv1d_ncw_intermediate_shared = T.alloc_buffer((batch_size, T.int64(1280), T.int64(1500)), "float16", scope="shared") + for ax0_ax1_ax2_fused in T.thread_binding(batch_size * T.int64(1920000), thread="blockIdx.x"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + for ax3_ax4_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax3_ax4_fused_0 in T.serial(T.int64(15), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("conv1d_ncw"): + v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(1920000) + ax0) + v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(1920000) // T.int64(1500) + ax1) + v2 = T.axis.spatial(T.int64(1500), ax0_ax1_ax2_fused % T.int64(1500) + ax2) + v3 = T.axis.reduce(T.int64(1280), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) // T.int64(3)) + v4 = T.axis.reduce(T.int64(3), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) % T.int64(3)) + T.reads(gelu[v0, v3, v2 * T.int64(2) + v4 - T.int64(1)], model_encoder_conv2_weight[v1, v3, v4]) + T.writes(conv1d_ncw_intermediate_shared[v0, v1, v2]) + with T.init(): + conv1d_ncw_intermediate_shared[v0, v1, v2] = T.float16(0) + conv1d_ncw_intermediate_shared[v0, v1, v2] = conv1d_ncw_intermediate_shared[v0, v1, v2] + T.if_then_else(T.int64(1) <= v2 * T.int64(2) + v4 and v2 * T.int64(2) + v4 < T.int64(3001), gelu[v0, v3, v2 * T.int64(2) + v4 - T.int64(1)], T.float16(0)) * model_encoder_conv2_weight[v1, v3, v4] + for ax3 in range(T.int64(1)): + for ax4_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax4_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_multiply_2"): + v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(1920000)) + v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(1920000) // T.int64(1500)) + v2 = T.axis.spatial(T.int64(1500), ax0_ax1_ax2_fused % T.int64(1500)) + v3 = T.axis.spatial(T.int64(1), ax3) + v4 = T.axis.spatial(T.int64(1), ax4_0 * T.int64(256) + ax4_1) + T.where(ax4_0 * T.int64(256) + ax4_1 < T.int64(1)) + T.reads(conv1d_ncw_intermediate_shared[v0, v1, v2], lv3[T.int64(0), v1, T.int64(0)]) + T.writes(T_multiply_intermediate[v0, v1, v2]) + T_multiply_intermediate[v0, v1, v2] = (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv3[T.int64(0), v1, T.int64(0)]) * (T.float16(0.5) + T.Cast("float16", T.erf(T.Cast("float32", (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv3[T.int64(0), v1, T.int64(0)]) * T.float16(0.70710678118654757)))) * T.float16(0.5)) + + @T.prim_func(private=True) + def fused_conv1d_add1_gelu(p_input_features: T.handle, model_encoder_conv1_weight: T.Buffer((T.int64(1280), T.int64(128), T.int64(3)), "float16"), lv1: T.Buffer((T.int64(1), T.int64(1280), T.int64(1)), "float16"), p_output0: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + input_features = T.match_buffer(p_input_features, (batch_size, T.int64(128), T.int64(3000)), "float16") + T_multiply_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1280), T.int64(3000)), "float16") + # with T.block("root"): + conv1d_ncw_intermediate_shared = T.alloc_buffer((batch_size, T.int64(1280), T.int64(3000)), "float16", scope="shared") + for ax0_ax1_ax2_fused in T.thread_binding(batch_size * T.int64(3840000), thread="blockIdx.x"): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + for ax3_ax4_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax3_ax4_fused_0 in T.serial(T.int64(2), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("conv1d_ncw"): + v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(3840000) + ax0) + v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(3840000) // T.int64(3000) + ax1) + v2 = T.axis.spatial(T.int64(3000), ax0_ax1_ax2_fused % T.int64(3000) + ax2) + v3 = T.axis.reduce(T.int64(128), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) // T.int64(3)) + v4 = T.axis.reduce(T.int64(3), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) % T.int64(3)) + T.where(ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1 < T.int64(384)) + T.reads(input_features[v0, v3, v2 + v4 - T.int64(1)], model_encoder_conv1_weight[v1, v3, v4]) + T.writes(conv1d_ncw_intermediate_shared[v0, v1, v2]) + with T.init(): + conv1d_ncw_intermediate_shared[v0, v1, v2] = T.float16(0) + conv1d_ncw_intermediate_shared[v0, v1, v2] = conv1d_ncw_intermediate_shared[v0, v1, v2] + T.if_then_else(T.int64(1) <= v2 + v4 and v2 + v4 < T.int64(3001), input_features[v0, v3, v2 + v4 - T.int64(1)], T.float16(0)) * model_encoder_conv1_weight[v1, v3, v4] + for ax3 in range(T.int64(1)): + for ax4_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax4_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_multiply_2"): + v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(3840000)) + v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(3840000) // T.int64(3000)) + v2 = T.axis.spatial(T.int64(3000), ax0_ax1_ax2_fused % T.int64(3000)) + v3 = T.axis.spatial(T.int64(1), ax3) + v4 = T.axis.spatial(T.int64(1), ax4_0 * T.int64(256) + ax4_1) + T.where(ax4_0 * T.int64(256) + ax4_1 < T.int64(1)) + T.reads(conv1d_ncw_intermediate_shared[v0, v1, v2], lv1[T.int64(0), v1, T.int64(0)]) + T.writes(T_multiply_intermediate[v0, v1, v2]) + T_multiply_intermediate[v0, v1, v2] = (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv1[T.int64(0), v1, T.int64(0)]) * (T.float16(0.5) + T.Cast("float16", T.erf(T.Cast("float32", (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv1[T.int64(0), v1, T.int64(0)]) * T.float16(0.70710678118654757)))) * T.float16(0.5)) + + @T.prim_func(private=True) + def fused_reshape20_reshape20_add6(take7: T.Buffer((T.int64(1), T.int64(1280)), "float16"), take8: T.Buffer((T.int64(1), T.int64(1280)), "float16"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_add"): + v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) + T.reads(take7[T.int64(0), v0], take8[T.int64(0), v0]) + T.writes(T_add_intermediate[T.int64(0), T.int64(0), v0]) + T_add_intermediate[T.int64(0), T.int64(0), v0] = take7[T.int64(0), v0] + take8[T.int64(0), v0] + + @T.prim_func(private=True) + def fused_reshape21_reshape21_reshape21_concatenate2_reshape22(add1221: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), lv1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), add1222: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_reshape_intermediate_1_2_3: T.Buffer((T.int64(1), T.int64(60), T.int64(64)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding(T.int64(4), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape_3"): + v0 = T.axis.spatial(T.int64(60), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(64)) + v1 = T.axis.spatial(T.int64(64), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(64)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(3840)) + T.reads(add1222[T.int64(0), T.int64(0), (v0 - T.int64(40)) * T.int64(64) + v1], lv1[T.int64(0), T.int64(0), (v0 + T.int64(-20)) * T.int64(64) + v1], add1221[T.int64(0), T.int64(0), v0 * T.int64(64) + v1]) + T.writes(T_reshape_intermediate_1_2_3[T.int64(0), v0, v1]) + T_reshape_intermediate_1_2_3[T.int64(0), v0, v1] = T.if_then_else(T.int64(40) <= v0, add1222[T.int64(0), T.int64(0), (v0 - T.int64(40)) * T.int64(64) + v1], T.if_then_else(T.int64(20) <= v0, lv1[T.int64(0), T.int64(0), (v0 + T.int64(-20)) * T.int64(64) + v1], add1221[T.int64(0), T.int64(0), v0 * T.int64(64) + v1])) + + @T.prim_func(private=True) + def fused_reshape21_reshape25(add1225: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(20), T.int64(64)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape_1"): + v0 = T.axis.spatial(T.int64(20), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(64)) + v1 = T.axis.spatial(T.int64(64), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(64)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(1280)) + T.reads(add1225[T.int64(0), T.int64(0), v0 * T.int64(64) + v1]) + T.writes(T_reshape_intermediate_1[T.int64(0), v0, v1]) + T_reshape_intermediate_1[T.int64(0), v0, v1] = add1225[T.int64(0), T.int64(0), v0 * T.int64(64) + v1] + + @T.prim_func(private=True) + def fused_reshape23_reshape24(lv265: T.Buffer((T.int64(1), T.int64(20), T.int64(64)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape_1"): + v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) + T.reads(lv265[T.int64(0), v0 // T.int64(64), v0 % T.int64(64)]) + T.writes(T_reshape_intermediate_1[T.int64(0), T.int64(0), v0]) + T_reshape_intermediate_1[T.int64(0), T.int64(0), v0] = lv265[T.int64(0), v0 // T.int64(64), v0 % T.int64(64)] + + @T.prim_func(private=True) + def fused_reshape9(packed_params_1: T.Buffer((T.int64(1280),), "float16"), T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1280), T.int64(1)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) + T.reads(packed_params_1[v0]) + T.writes(T_reshape_intermediate[T.int64(0), v0, T.int64(0)]) + T_reshape_intermediate[T.int64(0), v0, T.int64(0)] = packed_params_1[v0] + + @T.prim_func + def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, 60, 64), "float16") + position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1) + q = T.match_buffer(var_q, (seq_len, 20, 64), "float16") + k = T.match_buffer(var_k, (seq_len, 20, 64), "float16") + v = T.match_buffer(var_v, (seq_len, 20, 64), "float16") + # with T.block("root"): + for iters_0_iters_1_iters_2_fused_0 in T.thread_binding((seq_len * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for iters_0_iters_1_iters_2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("llama_fused_rope"): + s = T.axis.spatial(seq_len, (iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1) // T.int64(3840)) + h = T.axis.spatial(60, T.Cast("int32", (iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1) % T.int64(3840) // T.int64(64))) + d = T.axis.spatial(64, T.Cast("int32", (iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1) % T.int64(64))) + T.where(iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1 < seq_len * T.int64(3840)) + T.reads(position_map[s], qkv[s, h, d - 32:d - 32 + 65]) + T.writes(q[s, h, d], k[s, h - 20, d], v[s, h - 40, d]) + if h < 20: + q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Cast("float16", T.cos(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", qkv[s, h, d]) + T.sin(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1), qkv[s, h, d - 32]))), qkv[s, h, d]) + else: + if h < 40: + k[s, h - 20, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Cast("float16", T.cos(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", qkv[s, h, d]) + T.sin(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1), qkv[s, h, d - 32]))), qkv[s, h, d]) + else: + v[s, h - 40, d] = qkv[s, h, d] + + @T.prim_func(private=True) + def fused_transpose_add3(packed_params_4: T.Buffer((T.int64(1500), T.int64(1280)), "float16"), p_gelu1: T.handle, p_output0: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + gelu1 = T.match_buffer(p_gelu1, (batch_size, T.int64(1280), T.int64(1500)), "float16") + T_add_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1500), T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_add"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) + v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) + v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) + T.reads(gelu1[v0, v2, v1], packed_params_4[v1, v2]) + T.writes(T_add_intermediate[v0, v1, v2]) + T_add_intermediate[v0, v1, v2] = gelu1[v0, v2, v1] + packed_params_4[v1, v2] + + @T.prim_func + def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + m, n = T.int32(is_size_var=True), T.int32(is_size_var=True) + src = T.match_buffer(var_src, (m, n)) + batch_size = T.int32(is_size_var=True) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (batch_size, n)) + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("gather_2d"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n) + v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n) + T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n) + T.reads(src[indices[v0], v1], indices[v0]) + T.writes(dst[v0, v1]) + dst[v0, v1] = src[indices[v0], v1] + + @T.prim_func(private=True) + def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + batch, vocab_size = T.int64(), T.int64() + cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) + indices = T.match_buffer(B, (batch, vocab_size), "int32") + renorm_prob = T.match_buffer(C, (batch, 1)) + out_batch = T.int64() + usample = T.match_buffer(D, (out_batch, 1)) + sample_indices = T.match_buffer(E, (out_batch, 1), "int32") + output_index = T.match_buffer(F, (out_batch, 1), "int32") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((out_batch * vocab_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_get_index_from_sorted"): + v0 = T.axis.spatial(out_batch, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size * out_batch) // vocab_size) + v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < out_batch * vocab_size) + T.reads(usample[v0, T.int64(0)], cumsum_sorted[sample_indices[v0, T.int64(0)], v1 - T.int64(1):v1 - T.int64(1) + T.int64(2)], sample_indices[v0, T.int64(0)], renorm_prob[sample_indices[v0, T.int64(0)], 0], indices[sample_indices[v0, T.int64(0)], T.min(T.int64(0), v1):T.min(T.int64(0), v1) + (v1 + T.int64(1))]) + T.writes(output_index[v0, 0]) + if usample[v0, T.int64(0)] < cumsum_sorted[sample_indices[v0, T.int64(0)], v1] / renorm_prob[sample_indices[v0, T.int64(0)], 0] or v1 + T.int64(1) == vocab_size: + if v1 == T.int64(0): + output_index[v0, 0] = indices[sample_indices[v0, T.int64(0)], 0] + else: + if usample[v0, T.int64(0)] >= cumsum_sorted[sample_indices[v0, T.int64(0)], v1 - T.int64(1)] / renorm_prob[sample_indices[v0, T.int64(0)], 0]: + output_index[v0, 0] = indices[sample_indices[v0, T.int64(0)], v1] + + @T.prim_func(private=True) + def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + batch, vocab_size = T.int64(), T.int64() + cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) + top_p = T.match_buffer(B, (batch, 1)) + top_k = T.match_buffer(C, (batch, 1), "int32") + renorm_prob = T.match_buffer(D, (batch, 1)) + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch * vocab_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_get_renorm_prob"): + v0 = T.axis.spatial(batch, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size * batch) // vocab_size) + v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch * vocab_size) + T.reads(cumsum_sorted[v0, T.min(T.min(T.int64(0), v1), v1 + T.int64(1)):T.min(T.min(T.int64(0), v1), v1 + T.int64(1)) + (v1 + T.int64(2))], top_p[v0, 0], top_k[v0, 0]) + T.writes(renorm_prob[v0, 0]) + if not (cumsum_sorted[v0, 0] < top_p[v0, 0] and top_k[v0, 0] > 1): + renorm_prob[v0, 0] = cumsum_sorted[v0, 0] + else: + if cumsum_sorted[v0, v1] < top_p[v0, 0] and v1 + T.int64(1) < T.Cast("int64", top_k[v0, 0]): + if v1 + T.int64(1) == vocab_size: + renorm_prob[v0, 0] = cumsum_sorted[v0, v1] + else: + if not (cumsum_sorted[v0, v1 + T.int64(1)] < top_p[v0, 0] and v1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v0, 0])): + renorm_prob[v0, 0] = cumsum_sorted[v0, v1 + T.int64(1)] + + @T.prim_func(private=True) + def index(var_layer_norm355: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + layer_norm355 = T.match_buffer(var_layer_norm355, (T.int64(1), seq_len, T.int64(1280)), "float16") + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("index"): + v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) + T.reads(layer_norm355[T.int64(0), seq_len - T.int64(1), v0]) + T.writes(index[T.int64(0), T.int64(0), v0]) + index[T.int64(0), T.int64(0), v0] = layer_norm355[T.int64(0), seq_len - T.int64(1), v0] + + @T.prim_func(private=True) + def layer_norm(var_add578: T.handle, model_decoder_layers_0_self_attn_layer_norm_weight3: T.Buffer((T.int64(1280),), "float16"), model_decoder_layers_0_self_attn_layer_norm_bias3: T.Buffer((T.int64(1280),), "float16"), var_T_layer_norm: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + add578 = T.match_buffer(var_add578, (batch_size, T.int64(1), T.int64(1280)), "float16") + T_layer_norm = T.match_buffer(var_T_layer_norm, (batch_size, T.int64(1), T.int64(1280)), "float16") + # with T.block("root"): + add578_red_temp_v0_shared = T.alloc_buffer((batch_size, T.int64(1)), scope="shared") + add578_red_temp_v1_shared = T.alloc_buffer((batch_size, T.int64(1)), scope="shared") + for ax0_fused in T.thread_binding(batch_size, thread="blockIdx.x"): + for ax0 in range(T.int64(1)): + for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("add578_red_temp"): + v0 = T.axis.spatial(batch_size, ax0_fused + ax0) + v1 = T.axis.reduce(T.int64(1280), ax1_fused_0 * T.int64(256) + ax1_fused_1) + T.reads(add578[v0, T.int64(0), v1]) + T.writes(add578_red_temp_v0_shared[v0, T.int64(0)], add578_red_temp_v1_shared[v0, T.int64(0)]) + with T.init(): + add578_red_temp_v0_shared[v0, T.int64(0)] = T.float32(0) + add578_red_temp_v1_shared[v0, T.int64(0)] = T.float32(0) + v_add578_red_temp_v0: T.float32 = add578_red_temp_v0_shared[v0, T.int64(0)] + T.Cast("float32", add578[v0, T.int64(0), v1]) + v_add578_red_temp_v1: T.float32 = add578_red_temp_v1_shared[v0, T.int64(0)] + T.Cast("float32", add578[v0, T.int64(0), v1]) * T.Cast("float32", add578[v0, T.int64(0), v1]) + add578_red_temp_v0_shared[v0, T.int64(0)] = v_add578_red_temp_v0 + add578_red_temp_v1_shared[v0, T.int64(0)] = v_add578_red_temp_v1 + for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_layer_norm"): + v0 = T.axis.spatial(batch_size, ax0_fused) + v1 = T.axis.spatial(T.int64(1280), ax1_0 * T.int64(256) + ax1_1) + T.reads(add578[v0, T.int64(0), v1], add578_red_temp_v0_shared[v0, T.int64(0)], add578_red_temp_v1_shared[v0, T.int64(0)], model_decoder_layers_0_self_attn_layer_norm_weight3[v1], model_decoder_layers_0_self_attn_layer_norm_bias3[v1]) + T.writes(T_layer_norm[v0, T.int64(0), v1]) + T_layer_norm[v0, T.int64(0), v1] = T.Cast("float16", (T.Cast("float32", add578[v0, T.int64(0), v1]) - add578_red_temp_v0_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004)) * T.rsqrt(add578_red_temp_v1_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004) - add578_red_temp_v0_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004) * (add578_red_temp_v0_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_decoder_layers_0_self_attn_layer_norm_weight3[v1] + model_decoder_layers_0_self_attn_layer_norm_bias3[v1] + + @T.prim_func(private=True) + def layer_norm1(var_add: T.handle, model_encoder_layers_0_self_attn_layer_norm_weight: T.Buffer((T.int64(1280),), "float16"), model_encoder_layers_0_self_attn_layer_norm_bias: T.Buffer((T.int64(1280),), "float16"), var_T_layer_norm: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + add = T.match_buffer(var_add, (batch_size, T.int64(1500), T.int64(1280)), "float16") + T_layer_norm = T.match_buffer(var_T_layer_norm, (batch_size, T.int64(1500), T.int64(1280)), "float16") + # with T.block("root"): + add_red_temp_v0_shared = T.alloc_buffer((batch_size, T.int64(1500)), scope="shared") + add_red_temp_v1_shared = T.alloc_buffer((batch_size, T.int64(1500)), scope="shared") + for ax0_ax1_fused in T.thread_binding(batch_size * T.int64(1500), thread="blockIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("add_red_temp"): + v0 = T.axis.spatial(batch_size, ax0_ax1_fused // T.int64(1500) + ax0) + v1 = T.axis.spatial(T.int64(1500), ax0_ax1_fused % T.int64(1500) + ax1) + v2 = T.axis.reduce(T.int64(1280), ax2_fused_0 * T.int64(256) + ax2_fused_1) + T.reads(add[v0, v1, v2]) + T.writes(add_red_temp_v0_shared[v0, v1], add_red_temp_v1_shared[v0, v1]) + with T.init(): + add_red_temp_v0_shared[v0, v1] = T.float32(0) + add_red_temp_v1_shared[v0, v1] = T.float32(0) + v_add_red_temp_v0: T.float32 = add_red_temp_v0_shared[v0, v1] + T.Cast("float32", add[v0, v1, v2]) + v_add_red_temp_v1: T.float32 = add_red_temp_v1_shared[v0, v1] + T.Cast("float32", add[v0, v1, v2]) * T.Cast("float32", add[v0, v1, v2]) + add_red_temp_v0_shared[v0, v1] = v_add_red_temp_v0 + add_red_temp_v1_shared[v0, v1] = v_add_red_temp_v1 + for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_layer_norm"): + v0 = T.axis.spatial(batch_size, ax0_ax1_fused // T.int64(1500)) + v1 = T.axis.spatial(T.int64(1500), ax0_ax1_fused % T.int64(1500)) + v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(256) + ax2_1) + T.reads(add[v0, v1, v2], add_red_temp_v0_shared[v0, v1], add_red_temp_v1_shared[v0, v1], model_encoder_layers_0_self_attn_layer_norm_weight[v2], model_encoder_layers_0_self_attn_layer_norm_bias[v2]) + T.writes(T_layer_norm[v0, v1, v2]) + T_layer_norm[v0, v1, v2] = T.Cast("float16", (T.Cast("float32", add[v0, v1, v2]) - add_red_temp_v0_shared[v0, v1] * T.float32(0.00078125000000000004)) * T.rsqrt(add_red_temp_v1_shared[v0, v1] * T.float32(0.00078125000000000004) - add_red_temp_v0_shared[v0, v1] * T.float32(0.00078125000000000004) * (add_red_temp_v0_shared[v0, v1] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_encoder_layers_0_self_attn_layer_norm_weight[v2] + model_encoder_layers_0_self_attn_layer_norm_bias[v2] + + @T.prim_func(private=True) + def layer_norm2(var_add257: T.handle, model_decoder_layers_0_self_attn_layer_norm_weight2: T.Buffer((T.int64(1280),), "float16"), model_decoder_layers_0_self_attn_layer_norm_bias2: T.Buffer((T.int64(1280),), "float16"), var_T_layer_norm: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + add257 = T.match_buffer(var_add257, (T.int64(1), seq_len, T.int64(1280)), "float16") + T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), seq_len, T.int64(1280)), "float16") + # with T.block("root"): + add257_red_temp_v0_shared = T.alloc_buffer((T.int64(1), seq_len), scope="shared") + add257_red_temp_v1_shared = T.alloc_buffer((T.int64(1), seq_len), scope="shared") + for ax0_fused in T.thread_binding(seq_len, thread="blockIdx.x"): + for ax0 in range(T.int64(1)): + for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("add257_red_temp"): + v0 = T.axis.spatial(seq_len, ax0_fused + ax0) + v1 = T.axis.reduce(T.int64(1280), ax1_fused_0 * T.int64(256) + ax1_fused_1) + T.reads(add257[T.int64(0), v0, v1]) + T.writes(add257_red_temp_v0_shared[T.int64(0), v0], add257_red_temp_v1_shared[T.int64(0), v0]) + with T.init(): + add257_red_temp_v0_shared[T.int64(0), v0] = T.float32(0) + add257_red_temp_v1_shared[T.int64(0), v0] = T.float32(0) + v_add257_red_temp_v0: T.float32 = add257_red_temp_v0_shared[T.int64(0), v0] + T.Cast("float32", add257[T.int64(0), v0, v1]) + v_add257_red_temp_v1: T.float32 = add257_red_temp_v1_shared[T.int64(0), v0] + T.Cast("float32", add257[T.int64(0), v0, v1]) * T.Cast("float32", add257[T.int64(0), v0, v1]) + add257_red_temp_v0_shared[T.int64(0), v0] = v_add257_red_temp_v0 + add257_red_temp_v1_shared[T.int64(0), v0] = v_add257_red_temp_v1 + for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_layer_norm"): + v0 = T.axis.spatial(seq_len, ax0_fused) + v1 = T.axis.spatial(T.int64(1280), ax1_0 * T.int64(256) + ax1_1) + T.reads(add257[T.int64(0), v0, v1], add257_red_temp_v0_shared[T.int64(0), v0], add257_red_temp_v1_shared[T.int64(0), v0], model_decoder_layers_0_self_attn_layer_norm_weight2[v1], model_decoder_layers_0_self_attn_layer_norm_bias2[v1]) + T.writes(T_layer_norm[T.int64(0), v0, v1]) + T_layer_norm[T.int64(0), v0, v1] = T.Cast("float16", (T.Cast("float32", add257[T.int64(0), v0, v1]) - add257_red_temp_v0_shared[T.int64(0), v0] * T.float32(0.00078125000000000004)) * T.rsqrt(add257_red_temp_v1_shared[T.int64(0), v0] * T.float32(0.00078125000000000004) - add257_red_temp_v0_shared[T.int64(0), v0] * T.float32(0.00078125000000000004) * (add257_red_temp_v0_shared[T.int64(0), v0] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_decoder_layers_0_self_attn_layer_norm_weight2[v1] + model_decoder_layers_0_self_attn_layer_norm_bias2[v1] + + @T.prim_func(private=True) + def layer_norm3(add1220: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_layer_norm_weight5: T.Buffer((T.int64(1280),), "float16"), model_decoder_layers_0_self_attn_layer_norm_bias5: T.Buffer((T.int64(1280),), "float16"), T_layer_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + add1220_red_temp_v0_shared = T.alloc_buffer((T.int64(1), T.int64(1)), scope="shared") + add1220_red_temp_v1_shared = T.alloc_buffer((T.int64(1), T.int64(1)), scope="shared") + for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for ax0 in range(T.int64(1)): + for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("add1220_red_temp"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.reduce(T.int64(1280), ax1_fused_0 * T.int64(256) + ax1_fused_1) + T.reads(add1220[T.int64(0), T.int64(0), v1]) + T.writes(add1220_red_temp_v0_shared[T.int64(0), T.int64(0)], add1220_red_temp_v1_shared[T.int64(0), T.int64(0)]) + with T.init(): + add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] = T.float32(0) + add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] = T.float32(0) + v_add1220_red_temp_v0: T.float32 = add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] + T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) + v_add1220_red_temp_v1: T.float32 = add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] + T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) * T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) + add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] = v_add1220_red_temp_v0 + add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] = v_add1220_red_temp_v1 + for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_layer_norm"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(1280), ax1_0 * T.int64(256) + ax1_1) + T.reads(add1220[T.int64(0), T.int64(0), v1], add1220_red_temp_v0_shared[T.int64(0), T.int64(0)], add1220_red_temp_v1_shared[T.int64(0), T.int64(0)], model_decoder_layers_0_self_attn_layer_norm_weight5[v1], model_decoder_layers_0_self_attn_layer_norm_bias5[v1]) + T.writes(T_layer_norm[T.int64(0), T.int64(0), v1]) + T_layer_norm[T.int64(0), T.int64(0), v1] = T.Cast("float16", (T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) - add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004)) * T.rsqrt(add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004) - add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004) * (add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_decoder_layers_0_self_attn_layer_norm_weight5[v1] + model_decoder_layers_0_self_attn_layer_norm_bias5[v1] + + @T.prim_func + def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True) + V = T.match_buffer(v, (N, H, D), "float16") + S = T.match_buffer(s, (N, H)) + V_other = T.match_buffer(v_other, (N, H, D), "float16") + S_other = T.match_buffer(s_other, (N, H)) + # with T.block("root"): + for bx in T.thread_binding(N, thread="blockIdx.x"): + for by in T.thread_binding(1, thread="blockIdx.y"): + for ty in T.thread_binding(20, thread="threadIdx.y"): + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block("merge"): + T.reads(S[bx, ty + by * 20], S_other[bx, ty + by * 20], V[bx, ty + by * 20, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 20, tx * 4:tx * 4 + 4]) + T.writes(V[bx, ty + by * 20, tx * 4:tx * 4 + 4], S[bx, ty + by * 20]) + s_val = T.alloc_buffer((1,), scope="local") + s_other_val = T.alloc_buffer((1,), scope="local") + s_max = T.alloc_buffer((1,), scope="local") + scale = T.alloc_buffer((1,), scope="local") + other_scale = T.alloc_buffer((1,), scope="local") + v_vec = T.alloc_buffer((4,), "float16", scope="local") + v_other_vec = T.alloc_buffer((4,), "float16", scope="local") + s_val[0] = S[bx, ty + by * 20] + s_other_val[0] = S_other[bx, ty + by * 20] + s_max[0] = T.max(s_val[0], s_other_val[0]) + s_val[0] = T.exp2(s_val[0] - s_max[0]) + s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) + scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) + other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) + for vec in T.vectorized(4): + v_vec[vec] = V[bx, ty + by * 20, tx * 4 + vec] + for vec in T.vectorized(4): + v_other_vec[vec] = V_other[bx, ty + by * 20, tx * 4 + vec] + for vec in range(4): + v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0]) + for vec in T.vectorized(4): + V[bx, ty + by * 20, tx * 4 + vec] = v_vec[vec] + S[bx, ty + by * 20] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] + + @T.prim_func + def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + n, vocab_size = T.int64(), T.int64() + prob = T.match_buffer(var_prob, (n, vocab_size)) + batch_size = T.int64() + uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1)) + row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int32") + token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int32") + # with T.block("root"): + aggregate = T.alloc_buffer((), scope="local") + sample_id_local = T.alloc_buffer((), "int32", scope="local") + step_iter = T.alloc_buffer((), "int32", scope="local") + for bx in T.thread_binding(batch_size, thread="blockIdx.x"): + row_idx: T.int32 = row_indices[bx, 0] + for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + u: T.float32 = uniform_samples[bx, 0] + aggregate[()] = T.Cast("float32", 0) + step_iter[()] = 0 + while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)): + with T.block(""): + T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()]) + T.writes(sample_id_local[()], aggregate[()]) + prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local") + cumsum = T.alloc_buffer((T.int64(512),), scope="shared") + greater_than_u = T.alloc_buffer((T.int64(4),), "bool", scope="local") + mask = T.alloc_buffer((T.int64(4),), "bool", scope="local") + valid = T.alloc_buffer((T.int64(4),), "bool", scope="local") + indices = T.alloc_buffer((T.int64(4),), "int32", scope="local") + step_aggregate = T.alloc_buffer((), scope="local") + for v in T.unroll(T.int64(4)): + idx: T.int64 = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v + prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0)) + prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0), prob_local, T.Cast("float32", 0)) + valid[v] = prob_local > T.float32(0) and idx < vocab_size + with T.block(""): + T.reads(prob_gt_threshold[T.int64(0):T.int64(4)]) + T.writes(step_aggregate[()]) + local_sum = T.alloc_buffer((), scope="local") + shared_buf = T.alloc_buffer((T.int64(128),), scope="shared") + idx: T.int64 = ty * T.int64(32) + tx + local_sum[()] = T.Cast("float32", 0) + for i in T.unroll(T.int64(4)): + local_sum[()] = local_sum[()] + prob_gt_threshold[i] + shared_buf[idx] = local_sum[()] + for i in T.unroll(T.int64(7)): + if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0): + shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(T.int64(1), i)] + step_aggregate[()] = shared_buf[0] + if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)): + for i in T.unroll(T.int64(1), T.int64(4)): + prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - T.int64(1)] + for i in T.vectorized(T.int64(4)): + cumsum[ty * T.int64(128) + tx * T.int64(4) + i] = prob_gt_threshold[i] + for i in T.unroll(T.int64(5)): + for j in T.vectorized(T.int64(4)): + idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + if tx >= T.shift_left(T.int64(1), i): + cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)] + for i in T.unroll(T.int64(1), T.int64(4)): + for j in T.vectorized(T.int64(4)): + if ty == T.int64(0): + idx: T.int64 = i * T.int64(128) + tx * T.int64(4) + cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)] + for v in T.unroll(T.int64(4)): + greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07) + with T.block(""): + T.reads(greater_than_u[T.int64(0):T.int64(4)]) + T.writes(mask[T.int64(0):T.int64(4)]) + shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared") + tx_idx: T.int64 = ty * T.int64(32) + tx + shared_buf[tx_idx] = greater_than_u[T.int64(3)] + mask[0] = T.if_then_else(tx_idx != T.int64(0), T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - T.int64(1)]), greater_than_u[0]) + for i in T.unroll(T.int64(1), T.int64(4)): + mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - T.int64(1)]) + for v in T.unroll(T.int64(4)): + mask[v] = mask[v] and valid[v] + indices[v] = T.Cast("int32", T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v) + with T.block(""): + T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)]) + T.writes(sample_id_local[()]) + local_sum = T.alloc_buffer((), "int32", scope="local") + shared_buf = T.alloc_buffer((T.int64(128),), "int32", scope="shared") + idx: T.int64 = ty * T.int64(32) + tx + local_sum[()] = T.Cast("int32", vocab_size - T.int64(1)) + for i in T.unroll(T.int64(4)): + if mask[i]: + local_sum[()] = T.min(local_sum[()], indices[i]) + shared_buf[idx] = local_sum[()] + for i in T.unroll(T.int64(7)): + if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0): + shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(T.int64(1), i)]) + sample_id_local[()] = shared_buf[0] + aggregate[()] = aggregate[()] + step_aggregate[()] + step_iter[()] = step_iter[()] + 1 + if tx == T.int64(0) and ty == T.int64(0): + token_ids[bx, 0] = sample_id_local[()] + + @T.prim_func(private=True) + def reshape(var_lv: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + lv = T.match_buffer(var_lv, (batch_size, T.int64(1500), T.int64(1280)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_ax3_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) // T.int64(1920000)) + v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1920000) // T.int64(1280)) + v2 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1280) // T.int64(64)) + v3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(64)) + T.reads(lv[v0, v1, v2 * T.int64(64) + v3]) + T.writes(T_reshape[v0, v1, v2, v3]) + T_reshape[v0, v1, v2, v3] = lv[v0, v1, v2 * T.int64(64) + v3] + + @T.prim_func(private=True) + def reshape1(var_reshape256: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + reshape256 = T.match_buffer(var_reshape256, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size * T.int64(1500), T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size * T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.reads(reshape256[v0 // T.int64(1500), v0 % T.int64(1500), v1, v2]) + T.writes(T_reshape[v0, v1, v2]) + T_reshape[v0, v1, v2] = reshape256[v0 // T.int64(1500), v0 % T.int64(1500), v1, v2] + + @T.prim_func(private=True) + def reshape10(var_lv4: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + lv4 = T.match_buffer(var_lv4, (batch_size * T.int64(1500), T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_ax3_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) // T.int64(1920000)) + v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1920000) // T.int64(1280)) + v2 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1280) // T.int64(64)) + v3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(64)) + T.reads(lv4[v0 * T.int64(1500) + v1, v2, v3]) + T.writes(T_reshape[v0, v1, v2, v3]) + T_reshape[v0, v1, v2, v3] = lv4[v0 * T.int64(1500) + v1, v2, v3] + + @T.prim_func(private=True) + def reshape11(var_reshape6: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + reshape6 = T.match_buffer(var_reshape6, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1500), T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) + v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) + v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) + T.reads(reshape6[v0, v1, v2 // T.int64(64), v2 % T.int64(64)]) + T.writes(T_reshape[v0, v1, v2]) + T_reshape[v0, v1, v2] = reshape6[v0, v1, v2 // T.int64(64), v2 % T.int64(64)] + + @T.prim_func(private=True) + def reshape12(var_input_ids: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + input_ids = T.match_buffer(var_input_ids, (T.int64(1), seq_len), "int32") + T_reshape = T.match_buffer(var_T_reshape, (seq_len,), "int32") + # with T.block("root"): + for ax0_fused_0 in T.thread_binding((seq_len + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(seq_len, ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < seq_len) + T.reads(input_ids[T.int64(0), v0]) + T.writes(T_reshape[v0]) + T_reshape[v0] = input_ids[T.int64(0), v0] + + @T.prim_func(private=True) + def reshape13(var_take: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + take = T.match_buffer(var_take, (seq_len, T.int64(1280)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < seq_len * T.int64(1280)) + T.reads(take[v0, v1]) + T.writes(T_reshape[T.int64(0), v0, v1]) + T_reshape[T.int64(0), v0, v1] = take[v0, v1] + + @T.prim_func(private=True) + def reshape14(var_lv416: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + lv416 = T.match_buffer(var_lv416, (T.int64(1), seq_len, T.int64(1280)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(1280)) + T.reads(lv416[T.int64(0), v0, v1 * T.int64(64) + v2]) + T.writes(T_reshape[T.int64(0), v0, v1, v2]) + T_reshape[T.int64(0), v0, v1, v2] = lv416[T.int64(0), v0, v1 * T.int64(64) + v2] + + @T.prim_func(private=True) + def reshape15(var_concat: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + concat = T.match_buffer(var_concat, (T.int64(1), seq_len, T.int64(60), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(60), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) + v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(3840)) + T.reads(concat[T.int64(0), v0, v1, v2]) + T.writes(T_reshape[v0, v1, v2]) + T_reshape[v0, v1, v2] = concat[T.int64(0), v0, v1, v2] + + @T.prim_func(private=True) + def reshape16(var_lv69: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + lv69 = T.match_buffer(var_lv69, (seq_len, T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(1280)) + T.reads(lv69[v0, v1, v2]) + T.writes(T_reshape[T.int64(0), v0, v1, v2]) + T_reshape[T.int64(0), v0, v1, v2] = lv69[v0, v1, v2] + + @T.prim_func(private=True) + def reshape17(var_reshape391: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + reshape391 = T.match_buffer(var_reshape391, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < seq_len * T.int64(1280)) + T.reads(reshape391[T.int64(0), v0, v1 // T.int64(64), v1 % T.int64(64)]) + T.writes(T_reshape[T.int64(0), v0, v1]) + T_reshape[T.int64(0), v0, v1] = reshape391[T.int64(0), v0, v1 // T.int64(64), v1 % T.int64(64)] + + @T.prim_func(private=True) + def reshape18(var_reshape393: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + reshape393 = T.match_buffer(var_reshape393, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(1280)) + T.reads(reshape393[T.int64(0), v0, v1, v2]) + T.writes(T_reshape[v0, v1, v2]) + T_reshape[v0, v1, v2] = reshape393[T.int64(0), v0, v1, v2] + + @T.prim_func(private=True) + def reshape19(input_ids: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1)) + T.reads(input_ids[T.int64(0), T.int64(0)]) + T.writes(T_reshape[T.int64(0)]) + T_reshape[T.int64(0)] = input_ids[T.int64(0), T.int64(0)] + + @T.prim_func(private=True) + def reshape2(var_input_ids: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + input_ids = T.match_buffer(var_input_ids, (batch_size, T.int64(1)), "int32") + T_reshape = T.match_buffer(var_T_reshape, (batch_size,), "int32") + # with T.block("root"): + for ax0_fused_0 in T.thread_binding((batch_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < batch_size) + T.reads(input_ids[v0, T.int64(0)]) + T.writes(T_reshape[v0]) + T_reshape[v0] = input_ids[v0, T.int64(0)] + + @T.prim_func(private=True) + def reshape3(var_take3: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + take3 = T.match_buffer(var_take3, (batch_size, T.int64(1280)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) + T.reads(take3[v0, v1]) + T.writes(T_reshape[v0, T.int64(0), v1]) + T_reshape[v0, T.int64(0), v1] = take3[v0, v1] + + @T.prim_func(private=True) + def reshape4(var_lv224: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + lv224 = T.match_buffer(var_lv224, (batch_size, T.int64(1), T.int64(1280)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(1280)) + T.reads(lv224[v0, T.int64(0), v1 * T.int64(64) + v2]) + T.writes(T_reshape[v0, T.int64(0), v1, v2]) + T_reshape[v0, T.int64(0), v1, v2] = lv224[v0, T.int64(0), v1 * T.int64(64) + v2] + + @T.prim_func(private=True) + def reshape5(var_concat32: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + concat32 = T.match_buffer(var_concat32, (batch_size, T.int64(1), T.int64(60), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(60), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) + v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(3840)) + T.reads(concat32[v0, T.int64(0), v1, v2]) + T.writes(T_reshape[v0, v1, v2]) + T_reshape[v0, v1, v2] = concat32[v0, T.int64(0), v1, v2] + + @T.prim_func(private=True) + def reshape6(var_lv134: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + lv134 = T.match_buffer(var_lv134, (batch_size, T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(1280)) + T.reads(lv134[v0, v1, v2]) + T.writes(T_reshape[v0, T.int64(0), v1, v2]) + T_reshape[v0, T.int64(0), v1, v2] = lv134[v0, v1, v2] + + @T.prim_func(private=True) + def reshape7(var_reshape714: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + reshape714 = T.match_buffer(var_reshape714, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) + T.reads(reshape714[v0, T.int64(0), v1 // T.int64(64), v1 % T.int64(64)]) + T.writes(T_reshape[v0, T.int64(0), v1]) + T_reshape[v0, T.int64(0), v1] = reshape714[v0, T.int64(0), v1 // T.int64(64), v1 % T.int64(64)] + + @T.prim_func(private=True) + def reshape8(var_reshape716: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + reshape716 = T.match_buffer(var_reshape716, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") + T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(20), T.int64(64)), "float16") + # with T.block("root"): + for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_reshape"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) + v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) + T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(1280)) + T.reads(reshape716[v0, T.int64(0), v1, v2]) + T.writes(T_reshape[v0, v1, v2]) + T_reshape[v0, v1, v2] = reshape716[v0, T.int64(0), v1, v2] + + @T.prim_func + def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) + batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) + unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size)) + sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32") + num_samples = T.int32(is_size_var=True) + sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32") + sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32") + num_positions = T.int32(is_size_var=True) + top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32") + sampled_values = T.match_buffer(var_sampled_values, (num_samples,)) + top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,)) + top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32") + # with T.block("root"): + for ax0_fused_0 in T.thread_binding((num_positions + num_samples + 1023) // 1024, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("block"): + v0 = T.axis.spatial(num_positions + num_samples, ax0_fused_0 * 1024 + ax0_fused_1) + T.where(ax0_fused_0 * 1024 + ax0_fused_1 < num_positions + num_samples) + T.reads(top_prob_offsets[v0], sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], unsorted_probs[T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]):T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]) + (T.max(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions]) + 1 - T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions])), T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]):T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]) + (T.max(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]))], sample_indices[v0 + (0 - num_positions)], sampling_results[v0 + (0 - num_positions)]) + T.writes(top_prob_indices[v0], top_prob_probs[v0], sampled_values[v0 + (0 - num_positions)]) + if v0 < num_positions: + row: T.int32 = top_prob_offsets[v0] // vocab_size + col: T.int32 = top_prob_offsets[v0] % vocab_size + top_prob_indices[v0] = sorted_indices[row, col] + top_prob_probs[v0] = unsorted_probs[row, sorted_indices[row, col]] + else: + vj: T.int32 = v0 - num_positions + sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]] + + @T.prim_func + def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True) + src = T.match_buffer(var_src, (batch_size, n)) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + m = T.int32(is_size_var=True) + dst = T.match_buffer(var_dst, (m, n)) + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("scatter_2d"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n) + v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n) + T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n) + T.reads(src[v0, v1], indices[v0]) + T.writes(dst[indices[v0], v1]) + dst[indices[v0], v1] = src[v0, v1] + + @T.prim_func + def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size)) + temperature = T.match_buffer(var_temperature, (batch_size,)) + num_chunks = T.int64(is_size_var=True) + chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) + chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) + softmax = T.match_buffer(var_softmax, (batch_size, vocab_size)) + # with T.block("root"): + temp_max_shared = T.alloc_buffer((batch_size,), scope="shared") + temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared") + for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): + with T.block("max"): + v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) + v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) + T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) + T.reads(chunked_max[v0, v1]) + T.writes(temp_max_shared[v0]) + with T.init(): + temp_max_shared[v0] = T.float32(-3.4028234663852886e+38) + temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1]) + for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): + with T.block("sum_exp"): + v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) + v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) + T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) + T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0]) + T.writes(temp_sum_shared[v0]) + with T.init(): + temp_sum_shared[v0] = T.float32(0) + temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1]) + for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): + for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"): + for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("log_pad"): + v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) + v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks) + v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2) + T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0]) + T.writes(softmax[v0, v1 * T.int64(4096) + v2]) + if v1 * T.int64(4096) + v2 < vocab_size: + softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0]) + + @T.prim_func(private=True) + def take(model_decoder_embed_tokens_weight3: T.Buffer((T.int64(51866), T.int64(1280)), "float16"), var_reshape707: T.handle, var_T_take: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + reshape707 = T.match_buffer(var_reshape707, (batch_size,), "int32") + T_take = T.match_buffer(var_T_take, (batch_size, T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_take"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) + T.reads(model_decoder_embed_tokens_weight3[reshape707[v0], v1], reshape707[v0]) + T.writes(T_take[v0, v1]) + T_take[v0, v1] = model_decoder_embed_tokens_weight3[reshape707[v0], v1] + + @T.prim_func(private=True) + def take1(model_decoder_embed_positions_weight3: T.Buffer((T.int64(448), T.int64(1280)), "float16"), var_lv133: T.handle, var_T_take: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + lv133 = T.match_buffer(var_lv133, (batch_size,), "int32") + T_take = T.match_buffer(var_T_take, (batch_size, T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_take"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) + T.reads(model_decoder_embed_positions_weight3[lv133[v0], v1], lv133[v0]) + T.writes(T_take[v0, v1]) + T_take[v0, v1] = model_decoder_embed_positions_weight3[lv133[v0], v1] + + @T.prim_func(private=True) + def take2(var_layer_norm161: T.handle, var_logit_positions: T.handle, var_T_take: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + seq_len = T.int64() + layer_norm161 = T.match_buffer(var_layer_norm161, (T.int64(1), seq_len, T.int64(1280)), "float16") + batch_size = T.int64() + logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32") + T_take = T.match_buffer(var_T_take, (T.int64(1), batch_size, T.int64(1280)), "float16") + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_take"): + v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) + v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) + T.reads(layer_norm161[T.int64(0), logit_positions[v0], v1], logit_positions[v0]) + T.writes(T_take[T.int64(0), v0, v1]) + T_take[T.int64(0), v0, v1] = layer_norm161[T.int64(0), logit_positions[v0], v1] + + @T.prim_func(private=True) + def take3(model_decoder_embed_tokens_weight5: T.Buffer((T.int64(51866), T.int64(1280)), "float16"), reshape1353: T.Buffer((T.int64(1),), "int32"), T_take: T.Buffer((T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_take"): + v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) + T.reads(model_decoder_embed_tokens_weight5[reshape1353[T.int64(0)], v0], reshape1353[T.int64(0)]) + T.writes(T_take[T.int64(0), v0]) + T_take[T.int64(0), v0] = model_decoder_embed_tokens_weight5[reshape1353[T.int64(0)], v0] + + @T.prim_func(private=True) + def take4(model_decoder_embed_positions_weight5: T.Buffer((T.int64(448), T.int64(1280)), "float16"), lv264: T.Buffer((T.int64(1),), "int32"), T_take: T.Buffer((T.int64(1), T.int64(1280)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("T_take"): + v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) + T.reads(model_decoder_embed_positions_weight5[lv264[T.int64(0)], v0], lv264[T.int64(0)]) + T.writes(T_take[T.int64(0), v0]) + T_take[T.int64(0), v0] = model_decoder_embed_positions_weight5[lv264[T.int64(0)], v0] + + @T.prim_func(private=True) + def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int64(), T.int64() + probs = T.match_buffer(var_probs, (batch_size, vocab_size)) + lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32") + batch_size_1, vocab_size_1 = T.int64(), T.int64() + take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1)) + # with T.block("root"): + for ax0_ax1_fused_0 in T.thread_binding((batch_size_1 * vocab_size_1 + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("take_sorted_probs"): + v0 = T.axis.spatial(batch_size_1, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size_1 * batch_size_1) // vocab_size_1) + v1 = T.axis.spatial(vocab_size_1, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size_1) + T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size_1 * vocab_size_1) + T.reads(probs[v0, lv1[v0, v1]], lv1[v0, v1]) + T.writes(take_sorted_probs[v0, v1]) + take_sorted_probs[v0, v1] = probs[v0, lv1[v0, v1]] + + @T.prim_func + def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + num_pages, page_size = T.int64(), T.int64(is_size_var=True) + pages = T.match_buffer(var_pages, (num_pages, 2, 20, page_size, 64), "float16") + seqlen = T.int64(is_size_var=True) + position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1) + k_data = T.match_buffer(var_k_data, (32, seqlen, 20, 64), "float16") + v_data = T.match_buffer(var_v_data, (32, seqlen, 20, 64), "float16") + # with T.block("root"): + for p_h_d_fused_0 in T.thread_binding((seqlen * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for p_h_d_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("copy0"): + vp = T.axis.spatial(seqlen, (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) // T.int64(1280)) + vh = T.axis.spatial(20, T.Cast("int32", (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) % T.int64(1280) // T.int64(64))) + vd = T.axis.spatial(64, T.Cast("int32", (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) % T.int64(64))) + T.where(p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1 < seqlen * T.int64(1280)) + T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd]) + T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) + position: T.int32 = position_map[vp] + k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd] + v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd] + + @T.prim_func + def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + num_pages = T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, 20, 16, 64), "float16") + ntoken = T.int64(is_size_var=True) + k_data = T.match_buffer(var_k_data, (ntoken, 20, 64), "float16") + v_data = T.match_buffer(var_v_data, (ntoken, 20, 64), "float16") + position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1) + # with T.block("root"): + for global_pos_h_f_fused_0 in T.thread_binding((ntoken * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): + for global_pos_h_f_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + if position_map[(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(1280)] != -1: + with T.block("k_transpose_append"): + vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(1280)) + vh = T.axis.spatial(20, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(1280) // T.int64(64))) + vf = T.axis.spatial(64, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(64))) + T.where(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1 < ntoken * T.int64(1280)) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] + pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(1280)) + vh = T.axis.spatial(20, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(1280) // T.int64(64))) + vf = T.axis.spatial(64, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(64))) + T.where(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1 < ntoken * T.int64(1280)) + T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] + pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf] + + @T.prim_func(private=True) + def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + B, N = T.int32(), T.int32() + prob = T.match_buffer(var_prob, (B, N)) + top_p_arr = T.match_buffer(var_top_p_arr, (B,)) + init_pivots = T.match_buffer(var_init_pivots, (B, 3)) + final_pivot = T.match_buffer(var_final_pivot, (B,)) + final_lsum = T.match_buffer(var_final_lsum, (B,)) + # with T.block("root"): + pivot = T.alloc_buffer((3,), scope="local") + top_p = T.alloc_buffer((1,), scope="local") + L = T.alloc_buffer((1,), scope="shared") + R_1 = T.alloc_buffer((1,), scope="shared") + L_local = T.alloc_buffer((1,), scope="local") + R_local = T.alloc_buffer((1,), scope="local") + q = T.alloc_buffer((1,), scope="local") + lsum = T.alloc_buffer((3,), scope="local") + lmin_broadcast = T.alloc_buffer((1,), scope="shared") + lmin_broadcast_local = T.alloc_buffer((1,), scope="local") + lmin = T.alloc_buffer((3,), scope="local") + cmin = T.alloc_buffer((3,), "int32", scope="local") + total_sum = T.alloc_buffer((1,), scope="local") + it = T.alloc_buffer((1,), "int32", scope="local") + es_local = T.alloc_buffer((1,), "bool", scope="local") + es = T.alloc_buffer((1,), "bool", scope="shared") + find_pivot_local = T.alloc_buffer((1,), "bool", scope="local") + find_pivot = T.alloc_buffer((1,), "bool", scope="shared") + total_sum_reduce = T.alloc_buffer((1,), scope="local") + lsum_reduce = T.alloc_buffer((1,), scope="local") + lmin_reduce = T.alloc_buffer((1,), scope="local") + cmin_reduce = T.alloc_buffer((1,), "int32", scope="local") + for _bx in T.thread_binding(B, thread="blockIdx.x"): + for _tx in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("CTA"): + b, tx = T.axis.remap("SS", [_bx, _tx]) + T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0]) + T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0]) + top_p[0] = top_p_arr[b] + if tx == 0: + L[0] = T.float32(1) - top_p[0] + R_1[0] = T.float32(9.9999999999999995e-08) + find_pivot[0] = T.bool(False) + T.tvm_storage_sync("shared") + L_local[0] = L[0] + R_local[0] = R_1[0] + for i in T.unroll(3): + pivot[i] = init_pivots[b, i] + find_pivot_local[0] = T.bool(False) + if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08): + if tx == 0: + final_lsum[b] = T.float32(1) + final_pivot[b] = T.float32(0) + find_pivot_local[0] = T.bool(True) + while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]): + T.tvm_storage_sync("shared") + for pidx in T.unroll(3): + lsum[pidx] = T.float32(0) + lmin[pidx] = T.float32(3.4028234663852886e+38) + cmin[pidx] = 0 + total_sum[0] = T.float32(0) + it[0] = 0 + es_local[0] = T.bool(False) + while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]: + q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0)) + total_sum[0] = total_sum[0] + q[0] + for pidx in T.unroll(3): + if q[0] >= pivot[pidx]: + lsum[pidx] = lsum[pidx] + q[0] + if lmin[pidx] > q[0]: + lmin[pidx] = q[0] + cmin[pidx] = 1 + else: + if lmin[pidx] == q[0]: + cmin[pidx] = cmin[pidx] + 1 + it[0] = it[0] + 1 + if it[0] % 32 == 0: + with T.block("block_cross_thread"): + T.reads(total_sum[0]) + T.writes(total_sum_reduce[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx) + if tx == 0: + es[0] = T.float32(1) - total_sum_reduce[0] < pivot[2] + T.tvm_storage_sync("shared") + es_local[0] = es[0] + T.tvm_storage_sync("shared") + for pidx in range(3): + with T.block("block_cross_thread"): + T.reads(lsum[pidx]) + T.writes(lsum_reduce[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx) + with T.block("block_cross_thread"): + T.reads(lmin[pidx]) + T.writes(lmin_reduce[0]) + T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx) + if tx == 0: + lmin_broadcast[0] = lmin_reduce[0] + T.tvm_storage_sync("shared") + lmin_broadcast_local[0] = lmin_broadcast[0] + if lmin[pidx] > lmin_broadcast_local[0]: + cmin[pidx] = 0 + if tx == 0: + lsum[pidx] = lsum_reduce[0] + lmin[pidx] = lmin_reduce[0] + with T.block("block_cross_thread"): + T.reads(cmin[pidx]) + T.writes(cmin_reduce[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx) + if tx == 0: + cmin[pidx] = cmin_reduce[0] + T.tvm_storage_sync("shared") + if tx == 0: + it[0] = 0 + while it[0] < 3 and not find_pivot_local[0]: + if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]: + find_pivot[0] = T.bool(True) + find_pivot_local[0] = T.bool(True) + final_pivot[b] = pivot[it[0]] + final_lsum[b] = lsum[it[0]] + else: + if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]: + R_1[0] = pivot[it[0]] + final_lsum[b] = lsum[it[0]] + else: + if lsum[it[0]] < top_p[0]: + L[0] = pivot[it[0]] + it[0] = it[0] + 1 + T.tvm_storage_sync("shared") + L_local[0] = L[0] + R_local[0] = R_1[0] + find_pivot_local[0] = find_pivot[0] + for pidx in T.unroll(3): + pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4) + if tx == 0: + if not find_pivot_local[0]: + final_pivot[b] = R_local[0] + if R_local[0] == T.float32(9.9999999999999995e-08): + final_lsum[b] = lsum[2] + + @T.prim_func(private=True) + def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle): + T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + B, N = T.int32(), T.int32() + prob = T.match_buffer(var_prob, (B, N)) + final_pivot = T.match_buffer(var_final_pivot, (B,)) + final_lsum = T.match_buffer(var_final_lsum, (B,)) + renorm_prob = T.match_buffer(var_renorm_prob, (B, N)) + # with T.block("root"): + pivot = T.alloc_buffer((1,), scope="local") + lsum = T.alloc_buffer((1,), scope="local") + for _by in T.thread_binding(B, thread="blockIdx.y"): + for _bx in T.thread_binding((B + 511) // B, thread="blockIdx.x"): + for _tx in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("CTA"): + by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx]) + T.reads(final_pivot[by], final_lsum[by], prob[by, T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx:T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx + (T.Select(0 <= (B + 511) // B, (N - 1) // ((B + 511) // B * 1024) * ((B + 511) // B), 0 - (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + 1)], pivot[0], lsum[0]) + T.writes(pivot[0], lsum[0], renorm_prob[by, T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx:T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx + (T.Select(0 <= (B + 511) // B, (N - 1) // ((B + 511) // B * 1024) * ((B + 511) // B), 0 - (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + 1)]) + pivot[0] = final_pivot[by] + lsum[0] = final_lsum[by] + for i in range(((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024)): + if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N: + renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0)) + + @R.function + def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")): + batch_size = T.int64() + vocab_size = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) + cls = Module + with R.dataflow(): + lv: R.Tensor((8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12,), dtype="uint8") = R.builtin.alloc_tensor(R.shape([8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12]), R.dtype("uint8"), R.prim_value(0), R.str("global")) + lv1 = R.call_tir(cls.argsort_thrust, (probs, lv), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="int32")) + lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) + gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1 + R.output(gv1) + return gv1 + + @R.function + def batch_compute_cross_attn_kv(encoder_hidden_states: R.Tensor(("batch_size", 1500, 1280), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Object: + batch_size = T.int64() + R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + with R.dataflow(): + model_decoder_layers_0_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[498] + model_decoder_layers_0_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[499] + model_decoder_layers_0_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[500] + model_decoder_layers_1_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[522] + model_decoder_layers_1_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[523] + model_decoder_layers_1_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[524] + model_decoder_layers_2_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[546] + model_decoder_layers_2_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[547] + model_decoder_layers_2_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[548] + model_decoder_layers_3_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[570] + model_decoder_layers_3_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[571] + model_decoder_layers_3_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[572] + model_decoder_layers_4_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[594] + model_decoder_layers_4_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[595] + model_decoder_layers_4_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[596] + model_decoder_layers_5_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[618] + model_decoder_layers_5_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[619] + model_decoder_layers_5_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[620] + model_decoder_layers_6_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[642] + model_decoder_layers_6_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[643] + model_decoder_layers_6_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[644] + model_decoder_layers_7_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[666] + model_decoder_layers_7_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[667] + model_decoder_layers_7_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[668] + model_decoder_layers_8_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[690] + model_decoder_layers_8_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[691] + model_decoder_layers_8_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[692] + model_decoder_layers_9_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[714] + model_decoder_layers_9_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[715] + model_decoder_layers_9_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[716] + model_decoder_layers_10_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[738] + model_decoder_layers_10_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[739] + model_decoder_layers_10_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[740] + model_decoder_layers_11_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[762] + model_decoder_layers_11_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[763] + model_decoder_layers_11_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[764] + model_decoder_layers_12_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[786] + model_decoder_layers_12_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[787] + model_decoder_layers_12_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[788] + model_decoder_layers_13_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[810] + model_decoder_layers_13_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[811] + model_decoder_layers_13_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[812] + model_decoder_layers_14_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[834] + model_decoder_layers_14_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[835] + model_decoder_layers_14_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[836] + model_decoder_layers_15_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[858] + model_decoder_layers_15_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[859] + model_decoder_layers_15_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[860] + model_decoder_layers_16_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[882] + model_decoder_layers_16_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[883] + model_decoder_layers_16_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[884] + model_decoder_layers_17_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[906] + model_decoder_layers_17_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[907] + model_decoder_layers_17_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[908] + model_decoder_layers_18_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[930] + model_decoder_layers_18_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[931] + model_decoder_layers_18_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[932] + model_decoder_layers_19_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[954] + model_decoder_layers_19_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[955] + model_decoder_layers_19_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[956] + model_decoder_layers_20_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[978] + model_decoder_layers_20_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[979] + model_decoder_layers_20_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[980] + model_decoder_layers_21_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1002] + model_decoder_layers_21_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1003] + model_decoder_layers_21_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1004] + model_decoder_layers_22_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1026] + model_decoder_layers_22_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1027] + model_decoder_layers_22_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1028] + model_decoder_layers_23_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1050] + model_decoder_layers_23_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1051] + model_decoder_layers_23_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1052] + model_decoder_layers_24_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1074] + model_decoder_layers_24_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1075] + model_decoder_layers_24_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1076] + model_decoder_layers_25_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1098] + model_decoder_layers_25_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1099] + model_decoder_layers_25_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1100] + model_decoder_layers_26_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1122] + model_decoder_layers_26_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1123] + model_decoder_layers_26_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1124] + model_decoder_layers_27_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1146] + model_decoder_layers_27_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1147] + model_decoder_layers_27_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1148] + model_decoder_layers_28_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1170] + model_decoder_layers_28_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1171] + model_decoder_layers_28_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1172] + model_decoder_layers_29_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1194] + model_decoder_layers_29_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1195] + model_decoder_layers_29_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1196] + model_decoder_layers_30_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1218] + model_decoder_layers_30_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1219] + model_decoder_layers_30_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1220] + model_decoder_layers_31_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1242] + model_decoder_layers_31_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1243] + model_decoder_layers_31_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1244] + lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_0_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape256 = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_0_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_0_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape257 = R.call_tir(cls.reshape, (lv_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape258 = R.call_tir(cls.reshape1, (reshape256,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape259 = R.call_tir(cls.reshape1, (reshape257,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv36: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", paged_kv_cache, R.prim_value(0), reshape258, reshape259, sinfo_args=(R.Object,)) + lv1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_1_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape260 = R.call_tir(cls.reshape, (lv1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv1_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_1_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_1_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape261 = R.call_tir(cls.reshape, (lv1_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape262 = R.call_tir(cls.reshape1, (reshape260,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape263 = R.call_tir(cls.reshape1, (reshape261,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv37: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv36, R.prim_value(1), reshape262, reshape263, sinfo_args=(R.Object,)) + lv2 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_2_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape264 = R.call_tir(cls.reshape, (lv2,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv2_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_2_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_2_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape265 = R.call_tir(cls.reshape, (lv2_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape266 = R.call_tir(cls.reshape1, (reshape264,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape267 = R.call_tir(cls.reshape1, (reshape265,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv38: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv37, R.prim_value(2), reshape266, reshape267, sinfo_args=(R.Object,)) + lv3 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_3_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape268 = R.call_tir(cls.reshape, (lv3,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv3_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_3_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_3_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape269 = R.call_tir(cls.reshape, (lv3_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape270 = R.call_tir(cls.reshape1, (reshape268,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape271 = R.call_tir(cls.reshape1, (reshape269,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv39: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv38, R.prim_value(3), reshape270, reshape271, sinfo_args=(R.Object,)) + lv4 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_4_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape272 = R.call_tir(cls.reshape, (lv4,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv4_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_4_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_4_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape273 = R.call_tir(cls.reshape, (lv4_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape274 = R.call_tir(cls.reshape1, (reshape272,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape275 = R.call_tir(cls.reshape1, (reshape273,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv40: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv39, R.prim_value(4), reshape274, reshape275, sinfo_args=(R.Object,)) + lv5 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_5_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape276 = R.call_tir(cls.reshape, (lv5,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv5_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_5_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_5_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape277 = R.call_tir(cls.reshape, (lv5_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape278 = R.call_tir(cls.reshape1, (reshape276,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape279 = R.call_tir(cls.reshape1, (reshape277,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv41: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv40, R.prim_value(5), reshape278, reshape279, sinfo_args=(R.Object,)) + lv6 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_6_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape280 = R.call_tir(cls.reshape, (lv6,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv6_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_6_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_6_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape281 = R.call_tir(cls.reshape, (lv6_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape282 = R.call_tir(cls.reshape1, (reshape280,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape283 = R.call_tir(cls.reshape1, (reshape281,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv42: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv41, R.prim_value(6), reshape282, reshape283, sinfo_args=(R.Object,)) + lv7 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_7_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape284 = R.call_tir(cls.reshape, (lv7,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv7_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_7_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_7_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape285 = R.call_tir(cls.reshape, (lv7_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape286 = R.call_tir(cls.reshape1, (reshape284,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape287 = R.call_tir(cls.reshape1, (reshape285,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv43: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv42, R.prim_value(7), reshape286, reshape287, sinfo_args=(R.Object,)) + lv8 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_8_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape288 = R.call_tir(cls.reshape, (lv8,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv8_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_8_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_8_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape289 = R.call_tir(cls.reshape, (lv8_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape290 = R.call_tir(cls.reshape1, (reshape288,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape291 = R.call_tir(cls.reshape1, (reshape289,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv44: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv43, R.prim_value(8), reshape290, reshape291, sinfo_args=(R.Object,)) + lv9 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_9_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape292 = R.call_tir(cls.reshape, (lv9,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv9_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_9_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_9_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape293 = R.call_tir(cls.reshape, (lv9_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape294 = R.call_tir(cls.reshape1, (reshape292,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape295 = R.call_tir(cls.reshape1, (reshape293,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv45: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv44, R.prim_value(9), reshape294, reshape295, sinfo_args=(R.Object,)) + lv10 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_10_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape296 = R.call_tir(cls.reshape, (lv10,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv10_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_10_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_10_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape297 = R.call_tir(cls.reshape, (lv10_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape298 = R.call_tir(cls.reshape1, (reshape296,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape299 = R.call_tir(cls.reshape1, (reshape297,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv46: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv45, R.prim_value(10), reshape298, reshape299, sinfo_args=(R.Object,)) + lv11 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_11_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape300 = R.call_tir(cls.reshape, (lv11,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv11_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_11_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_11_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape301 = R.call_tir(cls.reshape, (lv11_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape302 = R.call_tir(cls.reshape1, (reshape300,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape303 = R.call_tir(cls.reshape1, (reshape301,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv47: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv46, R.prim_value(11), reshape302, reshape303, sinfo_args=(R.Object,)) + lv12 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_12_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape304 = R.call_tir(cls.reshape, (lv12,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv12_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_12_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_12_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape305 = R.call_tir(cls.reshape, (lv12_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape306 = R.call_tir(cls.reshape1, (reshape304,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape307 = R.call_tir(cls.reshape1, (reshape305,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv48: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv47, R.prim_value(12), reshape306, reshape307, sinfo_args=(R.Object,)) + lv13 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_13_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape308 = R.call_tir(cls.reshape, (lv13,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv13_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_13_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_13_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape309 = R.call_tir(cls.reshape, (lv13_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape310 = R.call_tir(cls.reshape1, (reshape308,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape311 = R.call_tir(cls.reshape1, (reshape309,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv49: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv48, R.prim_value(13), reshape310, reshape311, sinfo_args=(R.Object,)) + lv14 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_14_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape312 = R.call_tir(cls.reshape, (lv14,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv14_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_14_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_14_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape313 = R.call_tir(cls.reshape, (lv14_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape314 = R.call_tir(cls.reshape1, (reshape312,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape315 = R.call_tir(cls.reshape1, (reshape313,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv50: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv49, R.prim_value(14), reshape314, reshape315, sinfo_args=(R.Object,)) + lv15 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_15_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape316 = R.call_tir(cls.reshape, (lv15,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv15_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_15_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_15_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape317 = R.call_tir(cls.reshape, (lv15_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape318 = R.call_tir(cls.reshape1, (reshape316,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape319 = R.call_tir(cls.reshape1, (reshape317,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv51: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv50, R.prim_value(15), reshape318, reshape319, sinfo_args=(R.Object,)) + lv16 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_16_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape320 = R.call_tir(cls.reshape, (lv16,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv16_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_16_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_16_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape321 = R.call_tir(cls.reshape, (lv16_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape322 = R.call_tir(cls.reshape1, (reshape320,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape323 = R.call_tir(cls.reshape1, (reshape321,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv52: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv51, R.prim_value(16), reshape322, reshape323, sinfo_args=(R.Object,)) + lv17 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_17_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape324 = R.call_tir(cls.reshape, (lv17,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv17_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_17_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_17_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape325 = R.call_tir(cls.reshape, (lv17_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape326 = R.call_tir(cls.reshape1, (reshape324,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape327 = R.call_tir(cls.reshape1, (reshape325,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv53: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv52, R.prim_value(17), reshape326, reshape327, sinfo_args=(R.Object,)) + lv18 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_18_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape328 = R.call_tir(cls.reshape, (lv18,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv18_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_18_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_18_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape329 = R.call_tir(cls.reshape, (lv18_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape330 = R.call_tir(cls.reshape1, (reshape328,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape331 = R.call_tir(cls.reshape1, (reshape329,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv54: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv53, R.prim_value(18), reshape330, reshape331, sinfo_args=(R.Object,)) + lv19 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_19_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape332 = R.call_tir(cls.reshape, (lv19,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv19_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_19_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_19_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape333 = R.call_tir(cls.reshape, (lv19_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape334 = R.call_tir(cls.reshape1, (reshape332,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape335 = R.call_tir(cls.reshape1, (reshape333,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv55: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv54, R.prim_value(19), reshape334, reshape335, sinfo_args=(R.Object,)) + lv20 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_20_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape336 = R.call_tir(cls.reshape, (lv20,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv20_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_20_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_20_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape337 = R.call_tir(cls.reshape, (lv20_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape338 = R.call_tir(cls.reshape1, (reshape336,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape339 = R.call_tir(cls.reshape1, (reshape337,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv56: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv55, R.prim_value(20), reshape338, reshape339, sinfo_args=(R.Object,)) + lv21 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_21_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape340 = R.call_tir(cls.reshape, (lv21,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv21_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_21_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_21_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape341 = R.call_tir(cls.reshape, (lv21_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape342 = R.call_tir(cls.reshape1, (reshape340,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape343 = R.call_tir(cls.reshape1, (reshape341,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv57: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv56, R.prim_value(21), reshape342, reshape343, sinfo_args=(R.Object,)) + lv22 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_22_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape344 = R.call_tir(cls.reshape, (lv22,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv22_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_22_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_22_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape345 = R.call_tir(cls.reshape, (lv22_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape346 = R.call_tir(cls.reshape1, (reshape344,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape347 = R.call_tir(cls.reshape1, (reshape345,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv58: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv57, R.prim_value(22), reshape346, reshape347, sinfo_args=(R.Object,)) + lv23 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_23_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape348 = R.call_tir(cls.reshape, (lv23,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv23_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_23_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_23_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape349 = R.call_tir(cls.reshape, (lv23_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape350 = R.call_tir(cls.reshape1, (reshape348,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape351 = R.call_tir(cls.reshape1, (reshape349,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv59: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv58, R.prim_value(23), reshape350, reshape351, sinfo_args=(R.Object,)) + lv24 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_24_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape352 = R.call_tir(cls.reshape, (lv24,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv24_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_24_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_24_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape353 = R.call_tir(cls.reshape, (lv24_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape354 = R.call_tir(cls.reshape1, (reshape352,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape355 = R.call_tir(cls.reshape1, (reshape353,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv60: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv59, R.prim_value(24), reshape354, reshape355, sinfo_args=(R.Object,)) + lv25 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_25_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape356 = R.call_tir(cls.reshape, (lv25,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv25_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_25_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_25_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape357 = R.call_tir(cls.reshape, (lv25_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape358 = R.call_tir(cls.reshape1, (reshape356,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape359 = R.call_tir(cls.reshape1, (reshape357,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv61: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv60, R.prim_value(25), reshape358, reshape359, sinfo_args=(R.Object,)) + lv26 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_26_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape360 = R.call_tir(cls.reshape, (lv26,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv26_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_26_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_26_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape361 = R.call_tir(cls.reshape, (lv26_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape362 = R.call_tir(cls.reshape1, (reshape360,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape363 = R.call_tir(cls.reshape1, (reshape361,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv62: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv61, R.prim_value(26), reshape362, reshape363, sinfo_args=(R.Object,)) + lv27 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_27_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape364 = R.call_tir(cls.reshape, (lv27,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv27_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_27_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_27_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape365 = R.call_tir(cls.reshape, (lv27_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape366 = R.call_tir(cls.reshape1, (reshape364,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape367 = R.call_tir(cls.reshape1, (reshape365,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv63: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv62, R.prim_value(27), reshape366, reshape367, sinfo_args=(R.Object,)) + lv28 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_28_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape368 = R.call_tir(cls.reshape, (lv28,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv28_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_28_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_28_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape369 = R.call_tir(cls.reshape, (lv28_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape370 = R.call_tir(cls.reshape1, (reshape368,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape371 = R.call_tir(cls.reshape1, (reshape369,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv64: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv63, R.prim_value(28), reshape370, reshape371, sinfo_args=(R.Object,)) + lv29 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_29_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape372 = R.call_tir(cls.reshape, (lv29,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv29_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_29_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_29_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape373 = R.call_tir(cls.reshape, (lv29_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape374 = R.call_tir(cls.reshape1, (reshape372,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape375 = R.call_tir(cls.reshape1, (reshape373,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv65: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv64, R.prim_value(29), reshape374, reshape375, sinfo_args=(R.Object,)) + lv30 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_30_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape376 = R.call_tir(cls.reshape, (lv30,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv30_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_30_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_30_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape377 = R.call_tir(cls.reshape, (lv30_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape378 = R.call_tir(cls.reshape1, (reshape376,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape379 = R.call_tir(cls.reshape1, (reshape377,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv66: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv65, R.prim_value(30), reshape378, reshape379, sinfo_args=(R.Object,)) + lv31 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_31_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape380 = R.call_tir(cls.reshape, (lv31,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv31_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_31_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_31_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape381 = R.call_tir(cls.reshape, (lv31_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape382 = R.call_tir(cls.reshape1, (reshape380,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape383 = R.call_tir(cls.reshape1, (reshape381,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + gv1: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv66, R.prim_value(31), reshape382, reshape383, sinfo_args=(R.Object,)) + R.output(gv1) + return gv1 + + @R.function + def batch_decode(input_ids: R.Tensor(("batch_size", 1), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor(("batch_size", 1, 51866), dtype="float32"): + batch_size = T.int64() + R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + with R.dataflow(): + model_decoder_embed_tokens_weight3: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] + model_decoder_embed_positions_weight3: R.Tensor((448, 1280), dtype="float16") = packed_params[488] + model_decoder_layers_0_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] + model_decoder_layers_0_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] + model_decoder_layers_0_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[491] + model_decoder_layers_0_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] + model_decoder_layers_0_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[493] + model_decoder_layers_0_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] + model_decoder_layers_0_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[495] + model_decoder_layers_0_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[496] + model_decoder_layers_0_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[497] + model_decoder_layers_0_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] + model_decoder_layers_0_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[502] + model_decoder_layers_0_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] + model_decoder_layers_0_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[504] + model_decoder_layers_0_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[505] + model_decoder_layers_0_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[506] + model_decoder_layers_0_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] + model_decoder_layers_0_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[508] + model_decoder_layers_0_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] + model_decoder_layers_0_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[510] + model_decoder_layers_0_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[511] + model_decoder_layers_0_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[512] + model_decoder_layers_1_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] + model_decoder_layers_1_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] + model_decoder_layers_1_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[515] + model_decoder_layers_1_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] + model_decoder_layers_1_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[517] + model_decoder_layers_1_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] + model_decoder_layers_1_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[519] + model_decoder_layers_1_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[520] + model_decoder_layers_1_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[521] + model_decoder_layers_1_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] + model_decoder_layers_1_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[526] + model_decoder_layers_1_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] + model_decoder_layers_1_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[528] + model_decoder_layers_1_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[529] + model_decoder_layers_1_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[530] + model_decoder_layers_1_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] + model_decoder_layers_1_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[532] + model_decoder_layers_1_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] + model_decoder_layers_1_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[534] + model_decoder_layers_1_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[535] + model_decoder_layers_1_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[536] + model_decoder_layers_2_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] + model_decoder_layers_2_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] + model_decoder_layers_2_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[539] + model_decoder_layers_2_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] + model_decoder_layers_2_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[541] + model_decoder_layers_2_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] + model_decoder_layers_2_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[543] + model_decoder_layers_2_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[544] + model_decoder_layers_2_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[545] + model_decoder_layers_2_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] + model_decoder_layers_2_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[550] + model_decoder_layers_2_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] + model_decoder_layers_2_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[552] + model_decoder_layers_2_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[553] + model_decoder_layers_2_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[554] + model_decoder_layers_2_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] + model_decoder_layers_2_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[556] + model_decoder_layers_2_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] + model_decoder_layers_2_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[558] + model_decoder_layers_2_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[559] + model_decoder_layers_2_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[560] + model_decoder_layers_3_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] + model_decoder_layers_3_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] + model_decoder_layers_3_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[563] + model_decoder_layers_3_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] + model_decoder_layers_3_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[565] + model_decoder_layers_3_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] + model_decoder_layers_3_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[567] + model_decoder_layers_3_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[568] + model_decoder_layers_3_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[569] + model_decoder_layers_3_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] + model_decoder_layers_3_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[574] + model_decoder_layers_3_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] + model_decoder_layers_3_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[576] + model_decoder_layers_3_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[577] + model_decoder_layers_3_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[578] + model_decoder_layers_3_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] + model_decoder_layers_3_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[580] + model_decoder_layers_3_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] + model_decoder_layers_3_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[582] + model_decoder_layers_3_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[583] + model_decoder_layers_3_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[584] + model_decoder_layers_4_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] + model_decoder_layers_4_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] + model_decoder_layers_4_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[587] + model_decoder_layers_4_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] + model_decoder_layers_4_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[589] + model_decoder_layers_4_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] + model_decoder_layers_4_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[591] + model_decoder_layers_4_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[592] + model_decoder_layers_4_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[593] + model_decoder_layers_4_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] + model_decoder_layers_4_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[598] + model_decoder_layers_4_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] + model_decoder_layers_4_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[600] + model_decoder_layers_4_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[601] + model_decoder_layers_4_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[602] + model_decoder_layers_4_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] + model_decoder_layers_4_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[604] + model_decoder_layers_4_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] + model_decoder_layers_4_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[606] + model_decoder_layers_4_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[607] + model_decoder_layers_4_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[608] + model_decoder_layers_5_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] + model_decoder_layers_5_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] + model_decoder_layers_5_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[611] + model_decoder_layers_5_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] + model_decoder_layers_5_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[613] + model_decoder_layers_5_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] + model_decoder_layers_5_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[615] + model_decoder_layers_5_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[616] + model_decoder_layers_5_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[617] + model_decoder_layers_5_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] + model_decoder_layers_5_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[622] + model_decoder_layers_5_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] + model_decoder_layers_5_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[624] + model_decoder_layers_5_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[625] + model_decoder_layers_5_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[626] + model_decoder_layers_5_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] + model_decoder_layers_5_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[628] + model_decoder_layers_5_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] + model_decoder_layers_5_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[630] + model_decoder_layers_5_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[631] + model_decoder_layers_5_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[632] + model_decoder_layers_6_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] + model_decoder_layers_6_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] + model_decoder_layers_6_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[635] + model_decoder_layers_6_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] + model_decoder_layers_6_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[637] + model_decoder_layers_6_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] + model_decoder_layers_6_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[639] + model_decoder_layers_6_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[640] + model_decoder_layers_6_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[641] + model_decoder_layers_6_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] + model_decoder_layers_6_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[646] + model_decoder_layers_6_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] + model_decoder_layers_6_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[648] + model_decoder_layers_6_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[649] + model_decoder_layers_6_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[650] + model_decoder_layers_6_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] + model_decoder_layers_6_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[652] + model_decoder_layers_6_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] + model_decoder_layers_6_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[654] + model_decoder_layers_6_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[655] + model_decoder_layers_6_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[656] + model_decoder_layers_7_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] + model_decoder_layers_7_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] + model_decoder_layers_7_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[659] + model_decoder_layers_7_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] + model_decoder_layers_7_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[661] + model_decoder_layers_7_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] + model_decoder_layers_7_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[663] + model_decoder_layers_7_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[664] + model_decoder_layers_7_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[665] + model_decoder_layers_7_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] + model_decoder_layers_7_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[670] + model_decoder_layers_7_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] + model_decoder_layers_7_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[672] + model_decoder_layers_7_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[673] + model_decoder_layers_7_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[674] + model_decoder_layers_7_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] + model_decoder_layers_7_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[676] + model_decoder_layers_7_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] + model_decoder_layers_7_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[678] + model_decoder_layers_7_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[679] + model_decoder_layers_7_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[680] + model_decoder_layers_8_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] + model_decoder_layers_8_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] + model_decoder_layers_8_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[683] + model_decoder_layers_8_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] + model_decoder_layers_8_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[685] + model_decoder_layers_8_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] + model_decoder_layers_8_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[687] + model_decoder_layers_8_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[688] + model_decoder_layers_8_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[689] + model_decoder_layers_8_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] + model_decoder_layers_8_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[694] + model_decoder_layers_8_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] + model_decoder_layers_8_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[696] + model_decoder_layers_8_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[697] + model_decoder_layers_8_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[698] + model_decoder_layers_8_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] + model_decoder_layers_8_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[700] + model_decoder_layers_8_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] + model_decoder_layers_8_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[702] + model_decoder_layers_8_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[703] + model_decoder_layers_8_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[704] + model_decoder_layers_9_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] + model_decoder_layers_9_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] + model_decoder_layers_9_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[707] + model_decoder_layers_9_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] + model_decoder_layers_9_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[709] + model_decoder_layers_9_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] + model_decoder_layers_9_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[711] + model_decoder_layers_9_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[712] + model_decoder_layers_9_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[713] + model_decoder_layers_9_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] + model_decoder_layers_9_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[718] + model_decoder_layers_9_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] + model_decoder_layers_9_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[720] + model_decoder_layers_9_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[721] + model_decoder_layers_9_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[722] + model_decoder_layers_9_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] + model_decoder_layers_9_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[724] + model_decoder_layers_9_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] + model_decoder_layers_9_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[726] + model_decoder_layers_9_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[727] + model_decoder_layers_9_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[728] + model_decoder_layers_10_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] + model_decoder_layers_10_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] + model_decoder_layers_10_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[731] + model_decoder_layers_10_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] + model_decoder_layers_10_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[733] + model_decoder_layers_10_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] + model_decoder_layers_10_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[735] + model_decoder_layers_10_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[736] + model_decoder_layers_10_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[737] + model_decoder_layers_10_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] + model_decoder_layers_10_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[742] + model_decoder_layers_10_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] + model_decoder_layers_10_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[744] + model_decoder_layers_10_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[745] + model_decoder_layers_10_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[746] + model_decoder_layers_10_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] + model_decoder_layers_10_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[748] + model_decoder_layers_10_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] + model_decoder_layers_10_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[750] + model_decoder_layers_10_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[751] + model_decoder_layers_10_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[752] + model_decoder_layers_11_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] + model_decoder_layers_11_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] + model_decoder_layers_11_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[755] + model_decoder_layers_11_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] + model_decoder_layers_11_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[757] + model_decoder_layers_11_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] + model_decoder_layers_11_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[759] + model_decoder_layers_11_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[760] + model_decoder_layers_11_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[761] + model_decoder_layers_11_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] + model_decoder_layers_11_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[766] + model_decoder_layers_11_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] + model_decoder_layers_11_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[768] + model_decoder_layers_11_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[769] + model_decoder_layers_11_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[770] + model_decoder_layers_11_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] + model_decoder_layers_11_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[772] + model_decoder_layers_11_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] + model_decoder_layers_11_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[774] + model_decoder_layers_11_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[775] + model_decoder_layers_11_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[776] + model_decoder_layers_12_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] + model_decoder_layers_12_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] + model_decoder_layers_12_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[779] + model_decoder_layers_12_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] + model_decoder_layers_12_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[781] + model_decoder_layers_12_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] + model_decoder_layers_12_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[783] + model_decoder_layers_12_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[784] + model_decoder_layers_12_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[785] + model_decoder_layers_12_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] + model_decoder_layers_12_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[790] + model_decoder_layers_12_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] + model_decoder_layers_12_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[792] + model_decoder_layers_12_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[793] + model_decoder_layers_12_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[794] + model_decoder_layers_12_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] + model_decoder_layers_12_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[796] + model_decoder_layers_12_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] + model_decoder_layers_12_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[798] + model_decoder_layers_12_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[799] + model_decoder_layers_12_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[800] + model_decoder_layers_13_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] + model_decoder_layers_13_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] + model_decoder_layers_13_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[803] + model_decoder_layers_13_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] + model_decoder_layers_13_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[805] + model_decoder_layers_13_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] + model_decoder_layers_13_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[807] + model_decoder_layers_13_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[808] + model_decoder_layers_13_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[809] + model_decoder_layers_13_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] + model_decoder_layers_13_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[814] + model_decoder_layers_13_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] + model_decoder_layers_13_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[816] + model_decoder_layers_13_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[817] + model_decoder_layers_13_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[818] + model_decoder_layers_13_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] + model_decoder_layers_13_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[820] + model_decoder_layers_13_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] + model_decoder_layers_13_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[822] + model_decoder_layers_13_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[823] + model_decoder_layers_13_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[824] + model_decoder_layers_14_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] + model_decoder_layers_14_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] + model_decoder_layers_14_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[827] + model_decoder_layers_14_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] + model_decoder_layers_14_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[829] + model_decoder_layers_14_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] + model_decoder_layers_14_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[831] + model_decoder_layers_14_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[832] + model_decoder_layers_14_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[833] + model_decoder_layers_14_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] + model_decoder_layers_14_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[838] + model_decoder_layers_14_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] + model_decoder_layers_14_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[840] + model_decoder_layers_14_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[841] + model_decoder_layers_14_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[842] + model_decoder_layers_14_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] + model_decoder_layers_14_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[844] + model_decoder_layers_14_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] + model_decoder_layers_14_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[846] + model_decoder_layers_14_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[847] + model_decoder_layers_14_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[848] + model_decoder_layers_15_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] + model_decoder_layers_15_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] + model_decoder_layers_15_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[851] + model_decoder_layers_15_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] + model_decoder_layers_15_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[853] + model_decoder_layers_15_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] + model_decoder_layers_15_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[855] + model_decoder_layers_15_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[856] + model_decoder_layers_15_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[857] + model_decoder_layers_15_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] + model_decoder_layers_15_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[862] + model_decoder_layers_15_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] + model_decoder_layers_15_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[864] + model_decoder_layers_15_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[865] + model_decoder_layers_15_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[866] + model_decoder_layers_15_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] + model_decoder_layers_15_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[868] + model_decoder_layers_15_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] + model_decoder_layers_15_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[870] + model_decoder_layers_15_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[871] + model_decoder_layers_15_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[872] + model_decoder_layers_16_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] + model_decoder_layers_16_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] + model_decoder_layers_16_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[875] + model_decoder_layers_16_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] + model_decoder_layers_16_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[877] + model_decoder_layers_16_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] + model_decoder_layers_16_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[879] + model_decoder_layers_16_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[880] + model_decoder_layers_16_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[881] + model_decoder_layers_16_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] + model_decoder_layers_16_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[886] + model_decoder_layers_16_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] + model_decoder_layers_16_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[888] + model_decoder_layers_16_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[889] + model_decoder_layers_16_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[890] + model_decoder_layers_16_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] + model_decoder_layers_16_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[892] + model_decoder_layers_16_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] + model_decoder_layers_16_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[894] + model_decoder_layers_16_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[895] + model_decoder_layers_16_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[896] + model_decoder_layers_17_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] + model_decoder_layers_17_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] + model_decoder_layers_17_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[899] + model_decoder_layers_17_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] + model_decoder_layers_17_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[901] + model_decoder_layers_17_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] + model_decoder_layers_17_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[903] + model_decoder_layers_17_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[904] + model_decoder_layers_17_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[905] + model_decoder_layers_17_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] + model_decoder_layers_17_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[910] + model_decoder_layers_17_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] + model_decoder_layers_17_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[912] + model_decoder_layers_17_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[913] + model_decoder_layers_17_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[914] + model_decoder_layers_17_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] + model_decoder_layers_17_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[916] + model_decoder_layers_17_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] + model_decoder_layers_17_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[918] + model_decoder_layers_17_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[919] + model_decoder_layers_17_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[920] + model_decoder_layers_18_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] + model_decoder_layers_18_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] + model_decoder_layers_18_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[923] + model_decoder_layers_18_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] + model_decoder_layers_18_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[925] + model_decoder_layers_18_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] + model_decoder_layers_18_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[927] + model_decoder_layers_18_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[928] + model_decoder_layers_18_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[929] + model_decoder_layers_18_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] + model_decoder_layers_18_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[934] + model_decoder_layers_18_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] + model_decoder_layers_18_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[936] + model_decoder_layers_18_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[937] + model_decoder_layers_18_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[938] + model_decoder_layers_18_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] + model_decoder_layers_18_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[940] + model_decoder_layers_18_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] + model_decoder_layers_18_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[942] + model_decoder_layers_18_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[943] + model_decoder_layers_18_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[944] + model_decoder_layers_19_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] + model_decoder_layers_19_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] + model_decoder_layers_19_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[947] + model_decoder_layers_19_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] + model_decoder_layers_19_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[949] + model_decoder_layers_19_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] + model_decoder_layers_19_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[951] + model_decoder_layers_19_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[952] + model_decoder_layers_19_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[953] + model_decoder_layers_19_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] + model_decoder_layers_19_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[958] + model_decoder_layers_19_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] + model_decoder_layers_19_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[960] + model_decoder_layers_19_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[961] + model_decoder_layers_19_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[962] + model_decoder_layers_19_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] + model_decoder_layers_19_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[964] + model_decoder_layers_19_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] + model_decoder_layers_19_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[966] + model_decoder_layers_19_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[967] + model_decoder_layers_19_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[968] + model_decoder_layers_20_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] + model_decoder_layers_20_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] + model_decoder_layers_20_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[971] + model_decoder_layers_20_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] + model_decoder_layers_20_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[973] + model_decoder_layers_20_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] + model_decoder_layers_20_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[975] + model_decoder_layers_20_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[976] + model_decoder_layers_20_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[977] + model_decoder_layers_20_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] + model_decoder_layers_20_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[982] + model_decoder_layers_20_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] + model_decoder_layers_20_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[984] + model_decoder_layers_20_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[985] + model_decoder_layers_20_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[986] + model_decoder_layers_20_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] + model_decoder_layers_20_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[988] + model_decoder_layers_20_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] + model_decoder_layers_20_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[990] + model_decoder_layers_20_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[991] + model_decoder_layers_20_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[992] + model_decoder_layers_21_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] + model_decoder_layers_21_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] + model_decoder_layers_21_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[995] + model_decoder_layers_21_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] + model_decoder_layers_21_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[997] + model_decoder_layers_21_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] + model_decoder_layers_21_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[999] + model_decoder_layers_21_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1000] + model_decoder_layers_21_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1001] + model_decoder_layers_21_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] + model_decoder_layers_21_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1006] + model_decoder_layers_21_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] + model_decoder_layers_21_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1008] + model_decoder_layers_21_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1009] + model_decoder_layers_21_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1010] + model_decoder_layers_21_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] + model_decoder_layers_21_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1012] + model_decoder_layers_21_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] + model_decoder_layers_21_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1014] + model_decoder_layers_21_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1015] + model_decoder_layers_21_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1016] + model_decoder_layers_22_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] + model_decoder_layers_22_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] + model_decoder_layers_22_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1019] + model_decoder_layers_22_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] + model_decoder_layers_22_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1021] + model_decoder_layers_22_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] + model_decoder_layers_22_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1023] + model_decoder_layers_22_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1024] + model_decoder_layers_22_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1025] + model_decoder_layers_22_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] + model_decoder_layers_22_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1030] + model_decoder_layers_22_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] + model_decoder_layers_22_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1032] + model_decoder_layers_22_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1033] + model_decoder_layers_22_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1034] + model_decoder_layers_22_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] + model_decoder_layers_22_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1036] + model_decoder_layers_22_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] + model_decoder_layers_22_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1038] + model_decoder_layers_22_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1039] + model_decoder_layers_22_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1040] + model_decoder_layers_23_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] + model_decoder_layers_23_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] + model_decoder_layers_23_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1043] + model_decoder_layers_23_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] + model_decoder_layers_23_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1045] + model_decoder_layers_23_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] + model_decoder_layers_23_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1047] + model_decoder_layers_23_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1048] + model_decoder_layers_23_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1049] + model_decoder_layers_23_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] + model_decoder_layers_23_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1054] + model_decoder_layers_23_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] + model_decoder_layers_23_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1056] + model_decoder_layers_23_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1057] + model_decoder_layers_23_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1058] + model_decoder_layers_23_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] + model_decoder_layers_23_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1060] + model_decoder_layers_23_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] + model_decoder_layers_23_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1062] + model_decoder_layers_23_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1063] + model_decoder_layers_23_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1064] + model_decoder_layers_24_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] + model_decoder_layers_24_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] + model_decoder_layers_24_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1067] + model_decoder_layers_24_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] + model_decoder_layers_24_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1069] + model_decoder_layers_24_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] + model_decoder_layers_24_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1071] + model_decoder_layers_24_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1072] + model_decoder_layers_24_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1073] + model_decoder_layers_24_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] + model_decoder_layers_24_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1078] + model_decoder_layers_24_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] + model_decoder_layers_24_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1080] + model_decoder_layers_24_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1081] + model_decoder_layers_24_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1082] + model_decoder_layers_24_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] + model_decoder_layers_24_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1084] + model_decoder_layers_24_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] + model_decoder_layers_24_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1086] + model_decoder_layers_24_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1087] + model_decoder_layers_24_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1088] + model_decoder_layers_25_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] + model_decoder_layers_25_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] + model_decoder_layers_25_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1091] + model_decoder_layers_25_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] + model_decoder_layers_25_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1093] + model_decoder_layers_25_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] + model_decoder_layers_25_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1095] + model_decoder_layers_25_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1096] + model_decoder_layers_25_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1097] + model_decoder_layers_25_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] + model_decoder_layers_25_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1102] + model_decoder_layers_25_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] + model_decoder_layers_25_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1104] + model_decoder_layers_25_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1105] + model_decoder_layers_25_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1106] + model_decoder_layers_25_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] + model_decoder_layers_25_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1108] + model_decoder_layers_25_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] + model_decoder_layers_25_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1110] + model_decoder_layers_25_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1111] + model_decoder_layers_25_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1112] + model_decoder_layers_26_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] + model_decoder_layers_26_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] + model_decoder_layers_26_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1115] + model_decoder_layers_26_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] + model_decoder_layers_26_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1117] + model_decoder_layers_26_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] + model_decoder_layers_26_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1119] + model_decoder_layers_26_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1120] + model_decoder_layers_26_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1121] + model_decoder_layers_26_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] + model_decoder_layers_26_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1126] + model_decoder_layers_26_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] + model_decoder_layers_26_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1128] + model_decoder_layers_26_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1129] + model_decoder_layers_26_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1130] + model_decoder_layers_26_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] + model_decoder_layers_26_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1132] + model_decoder_layers_26_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] + model_decoder_layers_26_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1134] + model_decoder_layers_26_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1135] + model_decoder_layers_26_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1136] + model_decoder_layers_27_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] + model_decoder_layers_27_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] + model_decoder_layers_27_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1139] + model_decoder_layers_27_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] + model_decoder_layers_27_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1141] + model_decoder_layers_27_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] + model_decoder_layers_27_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1143] + model_decoder_layers_27_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1144] + model_decoder_layers_27_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1145] + model_decoder_layers_27_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] + model_decoder_layers_27_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1150] + model_decoder_layers_27_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] + model_decoder_layers_27_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1152] + model_decoder_layers_27_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1153] + model_decoder_layers_27_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1154] + model_decoder_layers_27_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] + model_decoder_layers_27_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1156] + model_decoder_layers_27_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] + model_decoder_layers_27_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1158] + model_decoder_layers_27_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1159] + model_decoder_layers_27_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1160] + model_decoder_layers_28_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] + model_decoder_layers_28_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] + model_decoder_layers_28_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1163] + model_decoder_layers_28_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] + model_decoder_layers_28_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1165] + model_decoder_layers_28_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] + model_decoder_layers_28_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1167] + model_decoder_layers_28_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1168] + model_decoder_layers_28_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1169] + model_decoder_layers_28_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] + model_decoder_layers_28_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1174] + model_decoder_layers_28_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] + model_decoder_layers_28_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1176] + model_decoder_layers_28_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1177] + model_decoder_layers_28_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1178] + model_decoder_layers_28_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] + model_decoder_layers_28_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1180] + model_decoder_layers_28_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] + model_decoder_layers_28_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1182] + model_decoder_layers_28_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1183] + model_decoder_layers_28_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1184] + model_decoder_layers_29_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] + model_decoder_layers_29_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] + model_decoder_layers_29_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1187] + model_decoder_layers_29_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] + model_decoder_layers_29_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1189] + model_decoder_layers_29_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] + model_decoder_layers_29_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1191] + model_decoder_layers_29_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1192] + model_decoder_layers_29_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1193] + model_decoder_layers_29_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] + model_decoder_layers_29_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1198] + model_decoder_layers_29_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] + model_decoder_layers_29_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1200] + model_decoder_layers_29_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1201] + model_decoder_layers_29_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1202] + model_decoder_layers_29_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] + model_decoder_layers_29_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1204] + model_decoder_layers_29_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] + model_decoder_layers_29_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1206] + model_decoder_layers_29_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1207] + model_decoder_layers_29_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1208] + model_decoder_layers_30_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] + model_decoder_layers_30_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] + model_decoder_layers_30_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1211] + model_decoder_layers_30_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] + model_decoder_layers_30_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1213] + model_decoder_layers_30_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] + model_decoder_layers_30_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1215] + model_decoder_layers_30_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1216] + model_decoder_layers_30_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1217] + model_decoder_layers_30_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] + model_decoder_layers_30_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1222] + model_decoder_layers_30_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] + model_decoder_layers_30_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1224] + model_decoder_layers_30_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1225] + model_decoder_layers_30_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1226] + model_decoder_layers_30_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] + model_decoder_layers_30_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1228] + model_decoder_layers_30_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] + model_decoder_layers_30_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1230] + model_decoder_layers_30_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1231] + model_decoder_layers_30_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1232] + model_decoder_layers_31_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] + model_decoder_layers_31_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] + model_decoder_layers_31_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1235] + model_decoder_layers_31_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] + model_decoder_layers_31_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1237] + model_decoder_layers_31_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] + model_decoder_layers_31_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1239] + model_decoder_layers_31_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1240] + model_decoder_layers_31_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1241] + model_decoder_layers_31_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] + model_decoder_layers_31_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1246] + model_decoder_layers_31_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] + model_decoder_layers_31_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1248] + model_decoder_layers_31_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1249] + model_decoder_layers_31_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1250] + model_decoder_layers_31_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] + model_decoder_layers_31_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1252] + model_decoder_layers_31_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] + model_decoder_layers_31_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1254] + model_decoder_layers_31_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1255] + model_decoder_layers_31_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1256] + model_decoder_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1257] + model_decoder_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1258] + reshape707 = R.call_tir(cls.reshape2, (input_ids,), out_sinfo=R.Tensor((batch_size,), dtype="int32")) + take3 = R.call_tir(cls.take, (model_decoder_embed_tokens_weight3, reshape707), out_sinfo=R.Tensor((batch_size, 1280), dtype="float16")) + reshape708 = R.call_tir(cls.reshape3, (take3,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv133: R.Tensor((batch_size,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((batch_size,), dtype="int32"),)) + take4 = R.call_tir(cls.take1, (model_decoder_embed_positions_weight3, lv133), out_sinfo=R.Tensor((batch_size, 1280), dtype="float16")) + reshape709 = R.call_tir(cls.reshape3, (take4,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add578 = R.call_tir(cls.add, (reshape708, reshape709), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm162 = R.call_tir(cls.layer_norm, (add578, model_decoder_layers_0_self_attn_layer_norm_weight3, model_decoder_layers_0_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv224 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_self_attn_q_proj_weight3, layer_norm162, model_decoder_layers_0_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape710 = R.call_tir(cls.reshape4, (lv224,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv65 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_0_self_attn_k_proj_weight3, layer_norm162), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape711 = R.call_tir(cls.reshape4, (lv65,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv225 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_self_attn_v_proj_weight3, layer_norm162, model_decoder_layers_0_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape712 = R.call_tir(cls.reshape4, (lv225,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat32 = R.call_tir(cls.concatenate, (reshape710, reshape711, reshape712), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape713 = R.call_tir(cls.reshape5, (concat32,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv134 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape713), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape714 = R.call_tir(cls.reshape6, (lv134,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape715 = R.call_tir(cls.reshape7, (reshape714,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv226 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_self_attn_out_proj_weight3, reshape715, model_decoder_layers_0_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add582 = R.call_tir(cls.add, (add578, lv226), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm163 = R.call_tir(cls.layer_norm, (add582, model_decoder_layers_0_encoder_attn_layer_norm_weight3, model_decoder_layers_0_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv227 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_encoder_attn_q_proj_weight3, layer_norm163, model_decoder_layers_0_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape716 = R.call_tir(cls.reshape4, (lv227,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape717 = R.call_tir(cls.reshape8, (reshape716,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv135 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape717), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape718 = R.call_tir(cls.reshape6, (lv135,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape719 = R.call_tir(cls.reshape7, (reshape718,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv228 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_encoder_attn_out_proj_weight3, reshape719, model_decoder_layers_0_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add585 = R.call_tir(cls.add, (add582, lv228), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm164 = R.call_tir(cls.layer_norm, (add585, model_decoder_layers_0_final_layer_norm_weight3, model_decoder_layers_0_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv32 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_0_fc1_weight3, layer_norm164, model_decoder_layers_0_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv229 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_0_fc2_weight3, lv32, model_decoder_layers_0_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add588 = R.call_tir(cls.add, (add585, lv229), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm165 = R.call_tir(cls.layer_norm, (add588, model_decoder_layers_1_self_attn_layer_norm_weight3, model_decoder_layers_1_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv230 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_self_attn_q_proj_weight3, layer_norm165, model_decoder_layers_1_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape720 = R.call_tir(cls.reshape4, (lv230,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv66 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_1_self_attn_k_proj_weight3, layer_norm165), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape721 = R.call_tir(cls.reshape4, (lv66,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv231 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_self_attn_v_proj_weight3, layer_norm165, model_decoder_layers_1_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape722 = R.call_tir(cls.reshape4, (lv231,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat33 = R.call_tir(cls.concatenate, (reshape720, reshape721, reshape722), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape723 = R.call_tir(cls.reshape5, (concat33,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv136 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape723), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape724 = R.call_tir(cls.reshape6, (lv136,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape725 = R.call_tir(cls.reshape7, (reshape724,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv232 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_self_attn_out_proj_weight3, reshape725, model_decoder_layers_1_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add592 = R.call_tir(cls.add, (add588, lv232), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm166 = R.call_tir(cls.layer_norm, (add592, model_decoder_layers_1_encoder_attn_layer_norm_weight3, model_decoder_layers_1_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv233 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_encoder_attn_q_proj_weight3, layer_norm166, model_decoder_layers_1_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape726 = R.call_tir(cls.reshape4, (lv233,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape727 = R.call_tir(cls.reshape8, (reshape726,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv137 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape727), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape728 = R.call_tir(cls.reshape6, (lv137,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape729 = R.call_tir(cls.reshape7, (reshape728,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv234 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_encoder_attn_out_proj_weight3, reshape729, model_decoder_layers_1_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add595 = R.call_tir(cls.add, (add592, lv234), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm167 = R.call_tir(cls.layer_norm, (add595, model_decoder_layers_1_final_layer_norm_weight3, model_decoder_layers_1_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv33 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_1_fc1_weight3, layer_norm167, model_decoder_layers_1_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv235 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_1_fc2_weight3, lv33, model_decoder_layers_1_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add598 = R.call_tir(cls.add, (add595, lv235), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm168 = R.call_tir(cls.layer_norm, (add598, model_decoder_layers_2_self_attn_layer_norm_weight3, model_decoder_layers_2_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv236 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_self_attn_q_proj_weight3, layer_norm168, model_decoder_layers_2_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape730 = R.call_tir(cls.reshape4, (lv236,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv67 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_2_self_attn_k_proj_weight3, layer_norm168), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape731 = R.call_tir(cls.reshape4, (lv67,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv237 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_self_attn_v_proj_weight3, layer_norm168, model_decoder_layers_2_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape732 = R.call_tir(cls.reshape4, (lv237,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat34 = R.call_tir(cls.concatenate, (reshape730, reshape731, reshape732), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape733 = R.call_tir(cls.reshape5, (concat34,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv138 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape733), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape734 = R.call_tir(cls.reshape6, (lv138,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape735 = R.call_tir(cls.reshape7, (reshape734,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv238 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_self_attn_out_proj_weight3, reshape735, model_decoder_layers_2_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add602 = R.call_tir(cls.add, (add598, lv238), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm169 = R.call_tir(cls.layer_norm, (add602, model_decoder_layers_2_encoder_attn_layer_norm_weight3, model_decoder_layers_2_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv239 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_encoder_attn_q_proj_weight3, layer_norm169, model_decoder_layers_2_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape736 = R.call_tir(cls.reshape4, (lv239,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape737 = R.call_tir(cls.reshape8, (reshape736,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv139 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape737), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape738 = R.call_tir(cls.reshape6, (lv139,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape739 = R.call_tir(cls.reshape7, (reshape738,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv240 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_encoder_attn_out_proj_weight3, reshape739, model_decoder_layers_2_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add605 = R.call_tir(cls.add, (add602, lv240), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm170 = R.call_tir(cls.layer_norm, (add605, model_decoder_layers_2_final_layer_norm_weight3, model_decoder_layers_2_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv34 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_2_fc1_weight3, layer_norm170, model_decoder_layers_2_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv241 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_2_fc2_weight3, lv34, model_decoder_layers_2_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add608 = R.call_tir(cls.add, (add605, lv241), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm171 = R.call_tir(cls.layer_norm, (add608, model_decoder_layers_3_self_attn_layer_norm_weight3, model_decoder_layers_3_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv242 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_self_attn_q_proj_weight3, layer_norm171, model_decoder_layers_3_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape740 = R.call_tir(cls.reshape4, (lv242,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv68 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_3_self_attn_k_proj_weight3, layer_norm171), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape741 = R.call_tir(cls.reshape4, (lv68,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv243 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_self_attn_v_proj_weight3, layer_norm171, model_decoder_layers_3_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape742 = R.call_tir(cls.reshape4, (lv243,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat35 = R.call_tir(cls.concatenate, (reshape740, reshape741, reshape742), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape743 = R.call_tir(cls.reshape5, (concat35,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv140 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape743), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape744 = R.call_tir(cls.reshape6, (lv140,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape745 = R.call_tir(cls.reshape7, (reshape744,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv244 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_self_attn_out_proj_weight3, reshape745, model_decoder_layers_3_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add612 = R.call_tir(cls.add, (add608, lv244), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm172 = R.call_tir(cls.layer_norm, (add612, model_decoder_layers_3_encoder_attn_layer_norm_weight3, model_decoder_layers_3_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv245 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_encoder_attn_q_proj_weight3, layer_norm172, model_decoder_layers_3_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape746 = R.call_tir(cls.reshape4, (lv245,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape747 = R.call_tir(cls.reshape8, (reshape746,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv141 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape747), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape748 = R.call_tir(cls.reshape6, (lv141,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape749 = R.call_tir(cls.reshape7, (reshape748,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv246 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_encoder_attn_out_proj_weight3, reshape749, model_decoder_layers_3_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add615 = R.call_tir(cls.add, (add612, lv246), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm173 = R.call_tir(cls.layer_norm, (add615, model_decoder_layers_3_final_layer_norm_weight3, model_decoder_layers_3_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv35 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_3_fc1_weight3, layer_norm173, model_decoder_layers_3_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv247 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_3_fc2_weight3, lv35, model_decoder_layers_3_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add618 = R.call_tir(cls.add, (add615, lv247), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm174 = R.call_tir(cls.layer_norm, (add618, model_decoder_layers_4_self_attn_layer_norm_weight3, model_decoder_layers_4_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv248 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_self_attn_q_proj_weight3, layer_norm174, model_decoder_layers_4_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape750 = R.call_tir(cls.reshape4, (lv248,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv69 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_4_self_attn_k_proj_weight3, layer_norm174), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape751 = R.call_tir(cls.reshape4, (lv69,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv249 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_self_attn_v_proj_weight3, layer_norm174, model_decoder_layers_4_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape752 = R.call_tir(cls.reshape4, (lv249,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat36 = R.call_tir(cls.concatenate, (reshape750, reshape751, reshape752), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape753 = R.call_tir(cls.reshape5, (concat36,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv142 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape753), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape754 = R.call_tir(cls.reshape6, (lv142,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape755 = R.call_tir(cls.reshape7, (reshape754,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv250 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_self_attn_out_proj_weight3, reshape755, model_decoder_layers_4_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add622 = R.call_tir(cls.add, (add618, lv250), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm175 = R.call_tir(cls.layer_norm, (add622, model_decoder_layers_4_encoder_attn_layer_norm_weight3, model_decoder_layers_4_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv251 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_encoder_attn_q_proj_weight3, layer_norm175, model_decoder_layers_4_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape756 = R.call_tir(cls.reshape4, (lv251,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape757 = R.call_tir(cls.reshape8, (reshape756,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv143 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape757), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape758 = R.call_tir(cls.reshape6, (lv143,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape759 = R.call_tir(cls.reshape7, (reshape758,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv252 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_encoder_attn_out_proj_weight3, reshape759, model_decoder_layers_4_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add625 = R.call_tir(cls.add, (add622, lv252), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm176 = R.call_tir(cls.layer_norm, (add625, model_decoder_layers_4_final_layer_norm_weight3, model_decoder_layers_4_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv36 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_4_fc1_weight3, layer_norm176, model_decoder_layers_4_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv253 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_4_fc2_weight3, lv36, model_decoder_layers_4_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add628 = R.call_tir(cls.add, (add625, lv253), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm177 = R.call_tir(cls.layer_norm, (add628, model_decoder_layers_5_self_attn_layer_norm_weight3, model_decoder_layers_5_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv254 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_self_attn_q_proj_weight3, layer_norm177, model_decoder_layers_5_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape760 = R.call_tir(cls.reshape4, (lv254,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv70 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_5_self_attn_k_proj_weight3, layer_norm177), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape761 = R.call_tir(cls.reshape4, (lv70,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv255 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_self_attn_v_proj_weight3, layer_norm177, model_decoder_layers_5_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape762 = R.call_tir(cls.reshape4, (lv255,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat37 = R.call_tir(cls.concatenate, (reshape760, reshape761, reshape762), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape763 = R.call_tir(cls.reshape5, (concat37,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv144 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape763), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape764 = R.call_tir(cls.reshape6, (lv144,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape765 = R.call_tir(cls.reshape7, (reshape764,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv256 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_self_attn_out_proj_weight3, reshape765, model_decoder_layers_5_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add632 = R.call_tir(cls.add, (add628, lv256), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm178 = R.call_tir(cls.layer_norm, (add632, model_decoder_layers_5_encoder_attn_layer_norm_weight3, model_decoder_layers_5_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv257 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_encoder_attn_q_proj_weight3, layer_norm178, model_decoder_layers_5_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape766 = R.call_tir(cls.reshape4, (lv257,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape767 = R.call_tir(cls.reshape8, (reshape766,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv145 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape767), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape768 = R.call_tir(cls.reshape6, (lv145,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape769 = R.call_tir(cls.reshape7, (reshape768,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv258 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_encoder_attn_out_proj_weight3, reshape769, model_decoder_layers_5_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add635 = R.call_tir(cls.add, (add632, lv258), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm179 = R.call_tir(cls.layer_norm, (add635, model_decoder_layers_5_final_layer_norm_weight3, model_decoder_layers_5_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv37 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_5_fc1_weight3, layer_norm179, model_decoder_layers_5_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv259 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_5_fc2_weight3, lv37, model_decoder_layers_5_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add638 = R.call_tir(cls.add, (add635, lv259), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm180 = R.call_tir(cls.layer_norm, (add638, model_decoder_layers_6_self_attn_layer_norm_weight3, model_decoder_layers_6_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv260 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_self_attn_q_proj_weight3, layer_norm180, model_decoder_layers_6_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape770 = R.call_tir(cls.reshape4, (lv260,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv71 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_6_self_attn_k_proj_weight3, layer_norm180), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape771 = R.call_tir(cls.reshape4, (lv71,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv261 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_self_attn_v_proj_weight3, layer_norm180, model_decoder_layers_6_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape772 = R.call_tir(cls.reshape4, (lv261,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat38 = R.call_tir(cls.concatenate, (reshape770, reshape771, reshape772), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape773 = R.call_tir(cls.reshape5, (concat38,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv146 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape773), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape774 = R.call_tir(cls.reshape6, (lv146,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape775 = R.call_tir(cls.reshape7, (reshape774,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv262 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_self_attn_out_proj_weight3, reshape775, model_decoder_layers_6_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add642 = R.call_tir(cls.add, (add638, lv262), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm181 = R.call_tir(cls.layer_norm, (add642, model_decoder_layers_6_encoder_attn_layer_norm_weight3, model_decoder_layers_6_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv263 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_encoder_attn_q_proj_weight3, layer_norm181, model_decoder_layers_6_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape776 = R.call_tir(cls.reshape4, (lv263,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape777 = R.call_tir(cls.reshape8, (reshape776,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv147 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape777), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape778 = R.call_tir(cls.reshape6, (lv147,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape779 = R.call_tir(cls.reshape7, (reshape778,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv264 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_encoder_attn_out_proj_weight3, reshape779, model_decoder_layers_6_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add645 = R.call_tir(cls.add, (add642, lv264), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm182 = R.call_tir(cls.layer_norm, (add645, model_decoder_layers_6_final_layer_norm_weight3, model_decoder_layers_6_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv38 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_6_fc1_weight3, layer_norm182, model_decoder_layers_6_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv265 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_6_fc2_weight3, lv38, model_decoder_layers_6_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add648 = R.call_tir(cls.add, (add645, lv265), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm183 = R.call_tir(cls.layer_norm, (add648, model_decoder_layers_7_self_attn_layer_norm_weight3, model_decoder_layers_7_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv266 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_self_attn_q_proj_weight3, layer_norm183, model_decoder_layers_7_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape780 = R.call_tir(cls.reshape4, (lv266,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv72 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_7_self_attn_k_proj_weight3, layer_norm183), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape781 = R.call_tir(cls.reshape4, (lv72,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv267 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_self_attn_v_proj_weight3, layer_norm183, model_decoder_layers_7_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape782 = R.call_tir(cls.reshape4, (lv267,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat39 = R.call_tir(cls.concatenate, (reshape780, reshape781, reshape782), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape783 = R.call_tir(cls.reshape5, (concat39,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv148 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape783), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape784 = R.call_tir(cls.reshape6, (lv148,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape785 = R.call_tir(cls.reshape7, (reshape784,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv268 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_self_attn_out_proj_weight3, reshape785, model_decoder_layers_7_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add652 = R.call_tir(cls.add, (add648, lv268), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm184 = R.call_tir(cls.layer_norm, (add652, model_decoder_layers_7_encoder_attn_layer_norm_weight3, model_decoder_layers_7_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv269 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_encoder_attn_q_proj_weight3, layer_norm184, model_decoder_layers_7_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape786 = R.call_tir(cls.reshape4, (lv269,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape787 = R.call_tir(cls.reshape8, (reshape786,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv149 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape787), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape788 = R.call_tir(cls.reshape6, (lv149,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape789 = R.call_tir(cls.reshape7, (reshape788,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv270 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_encoder_attn_out_proj_weight3, reshape789, model_decoder_layers_7_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add655 = R.call_tir(cls.add, (add652, lv270), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm185 = R.call_tir(cls.layer_norm, (add655, model_decoder_layers_7_final_layer_norm_weight3, model_decoder_layers_7_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv39 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_7_fc1_weight3, layer_norm185, model_decoder_layers_7_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv271 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_7_fc2_weight3, lv39, model_decoder_layers_7_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add658 = R.call_tir(cls.add, (add655, lv271), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm186 = R.call_tir(cls.layer_norm, (add658, model_decoder_layers_8_self_attn_layer_norm_weight3, model_decoder_layers_8_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv272 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_self_attn_q_proj_weight3, layer_norm186, model_decoder_layers_8_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape790 = R.call_tir(cls.reshape4, (lv272,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv73 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_8_self_attn_k_proj_weight3, layer_norm186), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape791 = R.call_tir(cls.reshape4, (lv73,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv273 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_self_attn_v_proj_weight3, layer_norm186, model_decoder_layers_8_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape792 = R.call_tir(cls.reshape4, (lv273,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat40 = R.call_tir(cls.concatenate, (reshape790, reshape791, reshape792), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape793 = R.call_tir(cls.reshape5, (concat40,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv150 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape793), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape794 = R.call_tir(cls.reshape6, (lv150,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape795 = R.call_tir(cls.reshape7, (reshape794,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv274 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_self_attn_out_proj_weight3, reshape795, model_decoder_layers_8_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add662 = R.call_tir(cls.add, (add658, lv274), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm187 = R.call_tir(cls.layer_norm, (add662, model_decoder_layers_8_encoder_attn_layer_norm_weight3, model_decoder_layers_8_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv275 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_encoder_attn_q_proj_weight3, layer_norm187, model_decoder_layers_8_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape796 = R.call_tir(cls.reshape4, (lv275,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape797 = R.call_tir(cls.reshape8, (reshape796,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv151 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape797), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape798 = R.call_tir(cls.reshape6, (lv151,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape799 = R.call_tir(cls.reshape7, (reshape798,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv276 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_encoder_attn_out_proj_weight3, reshape799, model_decoder_layers_8_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add665 = R.call_tir(cls.add, (add662, lv276), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm188 = R.call_tir(cls.layer_norm, (add665, model_decoder_layers_8_final_layer_norm_weight3, model_decoder_layers_8_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv40 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_8_fc1_weight3, layer_norm188, model_decoder_layers_8_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv277 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_8_fc2_weight3, lv40, model_decoder_layers_8_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add668 = R.call_tir(cls.add, (add665, lv277), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm189 = R.call_tir(cls.layer_norm, (add668, model_decoder_layers_9_self_attn_layer_norm_weight3, model_decoder_layers_9_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv278 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_self_attn_q_proj_weight3, layer_norm189, model_decoder_layers_9_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape800 = R.call_tir(cls.reshape4, (lv278,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv74 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_9_self_attn_k_proj_weight3, layer_norm189), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape801 = R.call_tir(cls.reshape4, (lv74,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv279 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_self_attn_v_proj_weight3, layer_norm189, model_decoder_layers_9_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape802 = R.call_tir(cls.reshape4, (lv279,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat41 = R.call_tir(cls.concatenate, (reshape800, reshape801, reshape802), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape803 = R.call_tir(cls.reshape5, (concat41,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv152 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape803), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape804 = R.call_tir(cls.reshape6, (lv152,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape805 = R.call_tir(cls.reshape7, (reshape804,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv280 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_self_attn_out_proj_weight3, reshape805, model_decoder_layers_9_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add672 = R.call_tir(cls.add, (add668, lv280), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm190 = R.call_tir(cls.layer_norm, (add672, model_decoder_layers_9_encoder_attn_layer_norm_weight3, model_decoder_layers_9_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv281 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_encoder_attn_q_proj_weight3, layer_norm190, model_decoder_layers_9_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape806 = R.call_tir(cls.reshape4, (lv281,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape807 = R.call_tir(cls.reshape8, (reshape806,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv153 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape807), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape808 = R.call_tir(cls.reshape6, (lv153,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape809 = R.call_tir(cls.reshape7, (reshape808,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv282 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_encoder_attn_out_proj_weight3, reshape809, model_decoder_layers_9_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add675 = R.call_tir(cls.add, (add672, lv282), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm191 = R.call_tir(cls.layer_norm, (add675, model_decoder_layers_9_final_layer_norm_weight3, model_decoder_layers_9_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv41 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_9_fc1_weight3, layer_norm191, model_decoder_layers_9_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv283 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_9_fc2_weight3, lv41, model_decoder_layers_9_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add678 = R.call_tir(cls.add, (add675, lv283), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm192 = R.call_tir(cls.layer_norm, (add678, model_decoder_layers_10_self_attn_layer_norm_weight3, model_decoder_layers_10_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv284 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_self_attn_q_proj_weight3, layer_norm192, model_decoder_layers_10_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape810 = R.call_tir(cls.reshape4, (lv284,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv75 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_10_self_attn_k_proj_weight3, layer_norm192), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape811 = R.call_tir(cls.reshape4, (lv75,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv285 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_self_attn_v_proj_weight3, layer_norm192, model_decoder_layers_10_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape812 = R.call_tir(cls.reshape4, (lv285,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat42 = R.call_tir(cls.concatenate, (reshape810, reshape811, reshape812), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape813 = R.call_tir(cls.reshape5, (concat42,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv154 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape813), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape814 = R.call_tir(cls.reshape6, (lv154,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape815 = R.call_tir(cls.reshape7, (reshape814,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv286 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_self_attn_out_proj_weight3, reshape815, model_decoder_layers_10_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add682 = R.call_tir(cls.add, (add678, lv286), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm193 = R.call_tir(cls.layer_norm, (add682, model_decoder_layers_10_encoder_attn_layer_norm_weight3, model_decoder_layers_10_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv287 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_encoder_attn_q_proj_weight3, layer_norm193, model_decoder_layers_10_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape816 = R.call_tir(cls.reshape4, (lv287,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape817 = R.call_tir(cls.reshape8, (reshape816,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv155 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape817), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape818 = R.call_tir(cls.reshape6, (lv155,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape819 = R.call_tir(cls.reshape7, (reshape818,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv288 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_encoder_attn_out_proj_weight3, reshape819, model_decoder_layers_10_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add685 = R.call_tir(cls.add, (add682, lv288), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm194 = R.call_tir(cls.layer_norm, (add685, model_decoder_layers_10_final_layer_norm_weight3, model_decoder_layers_10_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv42 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_10_fc1_weight3, layer_norm194, model_decoder_layers_10_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv289 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_10_fc2_weight3, lv42, model_decoder_layers_10_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add688 = R.call_tir(cls.add, (add685, lv289), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm195 = R.call_tir(cls.layer_norm, (add688, model_decoder_layers_11_self_attn_layer_norm_weight3, model_decoder_layers_11_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv290 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_self_attn_q_proj_weight3, layer_norm195, model_decoder_layers_11_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape820 = R.call_tir(cls.reshape4, (lv290,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv76 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_11_self_attn_k_proj_weight3, layer_norm195), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape821 = R.call_tir(cls.reshape4, (lv76,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv291 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_self_attn_v_proj_weight3, layer_norm195, model_decoder_layers_11_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape822 = R.call_tir(cls.reshape4, (lv291,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat43 = R.call_tir(cls.concatenate, (reshape820, reshape821, reshape822), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape823 = R.call_tir(cls.reshape5, (concat43,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv156 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape823), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape824 = R.call_tir(cls.reshape6, (lv156,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape825 = R.call_tir(cls.reshape7, (reshape824,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv292 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_self_attn_out_proj_weight3, reshape825, model_decoder_layers_11_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add692 = R.call_tir(cls.add, (add688, lv292), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm196 = R.call_tir(cls.layer_norm, (add692, model_decoder_layers_11_encoder_attn_layer_norm_weight3, model_decoder_layers_11_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv293 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_encoder_attn_q_proj_weight3, layer_norm196, model_decoder_layers_11_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape826 = R.call_tir(cls.reshape4, (lv293,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape827 = R.call_tir(cls.reshape8, (reshape826,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv157 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape827), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape828 = R.call_tir(cls.reshape6, (lv157,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape829 = R.call_tir(cls.reshape7, (reshape828,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv294 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_encoder_attn_out_proj_weight3, reshape829, model_decoder_layers_11_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add695 = R.call_tir(cls.add, (add692, lv294), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm197 = R.call_tir(cls.layer_norm, (add695, model_decoder_layers_11_final_layer_norm_weight3, model_decoder_layers_11_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv43 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_11_fc1_weight3, layer_norm197, model_decoder_layers_11_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv295 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_11_fc2_weight3, lv43, model_decoder_layers_11_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add698 = R.call_tir(cls.add, (add695, lv295), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm198 = R.call_tir(cls.layer_norm, (add698, model_decoder_layers_12_self_attn_layer_norm_weight3, model_decoder_layers_12_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv296 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_self_attn_q_proj_weight3, layer_norm198, model_decoder_layers_12_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape830 = R.call_tir(cls.reshape4, (lv296,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv77 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_12_self_attn_k_proj_weight3, layer_norm198), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape831 = R.call_tir(cls.reshape4, (lv77,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv297 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_self_attn_v_proj_weight3, layer_norm198, model_decoder_layers_12_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape832 = R.call_tir(cls.reshape4, (lv297,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat44 = R.call_tir(cls.concatenate, (reshape830, reshape831, reshape832), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape833 = R.call_tir(cls.reshape5, (concat44,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv158 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape833), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape834 = R.call_tir(cls.reshape6, (lv158,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape835 = R.call_tir(cls.reshape7, (reshape834,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv298 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_self_attn_out_proj_weight3, reshape835, model_decoder_layers_12_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add702 = R.call_tir(cls.add, (add698, lv298), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm199 = R.call_tir(cls.layer_norm, (add702, model_decoder_layers_12_encoder_attn_layer_norm_weight3, model_decoder_layers_12_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv299 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_encoder_attn_q_proj_weight3, layer_norm199, model_decoder_layers_12_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape836 = R.call_tir(cls.reshape4, (lv299,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape837 = R.call_tir(cls.reshape8, (reshape836,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv159 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape837), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape838 = R.call_tir(cls.reshape6, (lv159,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape839 = R.call_tir(cls.reshape7, (reshape838,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv300 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_encoder_attn_out_proj_weight3, reshape839, model_decoder_layers_12_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add705 = R.call_tir(cls.add, (add702, lv300), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm200 = R.call_tir(cls.layer_norm, (add705, model_decoder_layers_12_final_layer_norm_weight3, model_decoder_layers_12_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv44 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_12_fc1_weight3, layer_norm200, model_decoder_layers_12_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv301 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_12_fc2_weight3, lv44, model_decoder_layers_12_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add708 = R.call_tir(cls.add, (add705, lv301), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm201 = R.call_tir(cls.layer_norm, (add708, model_decoder_layers_13_self_attn_layer_norm_weight3, model_decoder_layers_13_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv302 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_self_attn_q_proj_weight3, layer_norm201, model_decoder_layers_13_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape840 = R.call_tir(cls.reshape4, (lv302,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv78 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_13_self_attn_k_proj_weight3, layer_norm201), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape841 = R.call_tir(cls.reshape4, (lv78,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv303 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_self_attn_v_proj_weight3, layer_norm201, model_decoder_layers_13_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape842 = R.call_tir(cls.reshape4, (lv303,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat45 = R.call_tir(cls.concatenate, (reshape840, reshape841, reshape842), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape843 = R.call_tir(cls.reshape5, (concat45,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv160 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape843), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape844 = R.call_tir(cls.reshape6, (lv160,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape845 = R.call_tir(cls.reshape7, (reshape844,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv304 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_self_attn_out_proj_weight3, reshape845, model_decoder_layers_13_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add712 = R.call_tir(cls.add, (add708, lv304), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm202 = R.call_tir(cls.layer_norm, (add712, model_decoder_layers_13_encoder_attn_layer_norm_weight3, model_decoder_layers_13_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv305 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_encoder_attn_q_proj_weight3, layer_norm202, model_decoder_layers_13_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape846 = R.call_tir(cls.reshape4, (lv305,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape847 = R.call_tir(cls.reshape8, (reshape846,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv161 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape847), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape848 = R.call_tir(cls.reshape6, (lv161,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape849 = R.call_tir(cls.reshape7, (reshape848,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv306 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_encoder_attn_out_proj_weight3, reshape849, model_decoder_layers_13_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add715 = R.call_tir(cls.add, (add712, lv306), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm203 = R.call_tir(cls.layer_norm, (add715, model_decoder_layers_13_final_layer_norm_weight3, model_decoder_layers_13_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv45 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_13_fc1_weight3, layer_norm203, model_decoder_layers_13_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv307 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_13_fc2_weight3, lv45, model_decoder_layers_13_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add718 = R.call_tir(cls.add, (add715, lv307), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm204 = R.call_tir(cls.layer_norm, (add718, model_decoder_layers_14_self_attn_layer_norm_weight3, model_decoder_layers_14_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv308 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_self_attn_q_proj_weight3, layer_norm204, model_decoder_layers_14_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape850 = R.call_tir(cls.reshape4, (lv308,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv79 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_14_self_attn_k_proj_weight3, layer_norm204), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape851 = R.call_tir(cls.reshape4, (lv79,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv309 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_self_attn_v_proj_weight3, layer_norm204, model_decoder_layers_14_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape852 = R.call_tir(cls.reshape4, (lv309,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat46 = R.call_tir(cls.concatenate, (reshape850, reshape851, reshape852), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape853 = R.call_tir(cls.reshape5, (concat46,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv162 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape853), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape854 = R.call_tir(cls.reshape6, (lv162,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape855 = R.call_tir(cls.reshape7, (reshape854,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv310 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_self_attn_out_proj_weight3, reshape855, model_decoder_layers_14_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add722 = R.call_tir(cls.add, (add718, lv310), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm205 = R.call_tir(cls.layer_norm, (add722, model_decoder_layers_14_encoder_attn_layer_norm_weight3, model_decoder_layers_14_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv311 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_encoder_attn_q_proj_weight3, layer_norm205, model_decoder_layers_14_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape856 = R.call_tir(cls.reshape4, (lv311,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape857 = R.call_tir(cls.reshape8, (reshape856,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv163 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape857), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape858 = R.call_tir(cls.reshape6, (lv163,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape859 = R.call_tir(cls.reshape7, (reshape858,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv312 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_encoder_attn_out_proj_weight3, reshape859, model_decoder_layers_14_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add725 = R.call_tir(cls.add, (add722, lv312), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm206 = R.call_tir(cls.layer_norm, (add725, model_decoder_layers_14_final_layer_norm_weight3, model_decoder_layers_14_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv46 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_14_fc1_weight3, layer_norm206, model_decoder_layers_14_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv313 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_14_fc2_weight3, lv46, model_decoder_layers_14_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add728 = R.call_tir(cls.add, (add725, lv313), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm207 = R.call_tir(cls.layer_norm, (add728, model_decoder_layers_15_self_attn_layer_norm_weight3, model_decoder_layers_15_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv314 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_self_attn_q_proj_weight3, layer_norm207, model_decoder_layers_15_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape860 = R.call_tir(cls.reshape4, (lv314,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv80 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_15_self_attn_k_proj_weight3, layer_norm207), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape861 = R.call_tir(cls.reshape4, (lv80,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv315 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_self_attn_v_proj_weight3, layer_norm207, model_decoder_layers_15_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape862 = R.call_tir(cls.reshape4, (lv315,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat47 = R.call_tir(cls.concatenate, (reshape860, reshape861, reshape862), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape863 = R.call_tir(cls.reshape5, (concat47,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv164 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape863), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape864 = R.call_tir(cls.reshape6, (lv164,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape865 = R.call_tir(cls.reshape7, (reshape864,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv316 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_self_attn_out_proj_weight3, reshape865, model_decoder_layers_15_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add732 = R.call_tir(cls.add, (add728, lv316), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm208 = R.call_tir(cls.layer_norm, (add732, model_decoder_layers_15_encoder_attn_layer_norm_weight3, model_decoder_layers_15_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv317 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_encoder_attn_q_proj_weight3, layer_norm208, model_decoder_layers_15_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape866 = R.call_tir(cls.reshape4, (lv317,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape867 = R.call_tir(cls.reshape8, (reshape866,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv165 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape867), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape868 = R.call_tir(cls.reshape6, (lv165,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape869 = R.call_tir(cls.reshape7, (reshape868,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv318 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_encoder_attn_out_proj_weight3, reshape869, model_decoder_layers_15_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add735 = R.call_tir(cls.add, (add732, lv318), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm209 = R.call_tir(cls.layer_norm, (add735, model_decoder_layers_15_final_layer_norm_weight3, model_decoder_layers_15_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv47 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_15_fc1_weight3, layer_norm209, model_decoder_layers_15_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv319 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_15_fc2_weight3, lv47, model_decoder_layers_15_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add738 = R.call_tir(cls.add, (add735, lv319), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm210 = R.call_tir(cls.layer_norm, (add738, model_decoder_layers_16_self_attn_layer_norm_weight3, model_decoder_layers_16_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv320 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_self_attn_q_proj_weight3, layer_norm210, model_decoder_layers_16_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape870 = R.call_tir(cls.reshape4, (lv320,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv81 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_16_self_attn_k_proj_weight3, layer_norm210), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape871 = R.call_tir(cls.reshape4, (lv81,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv321 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_self_attn_v_proj_weight3, layer_norm210, model_decoder_layers_16_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape872 = R.call_tir(cls.reshape4, (lv321,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat48 = R.call_tir(cls.concatenate, (reshape870, reshape871, reshape872), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape873 = R.call_tir(cls.reshape5, (concat48,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv166 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape873), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape874 = R.call_tir(cls.reshape6, (lv166,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape875 = R.call_tir(cls.reshape7, (reshape874,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv322 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_self_attn_out_proj_weight3, reshape875, model_decoder_layers_16_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add742 = R.call_tir(cls.add, (add738, lv322), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm211 = R.call_tir(cls.layer_norm, (add742, model_decoder_layers_16_encoder_attn_layer_norm_weight3, model_decoder_layers_16_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv323 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_encoder_attn_q_proj_weight3, layer_norm211, model_decoder_layers_16_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape876 = R.call_tir(cls.reshape4, (lv323,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape877 = R.call_tir(cls.reshape8, (reshape876,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv167 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape877), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape878 = R.call_tir(cls.reshape6, (lv167,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape879 = R.call_tir(cls.reshape7, (reshape878,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv324 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_encoder_attn_out_proj_weight3, reshape879, model_decoder_layers_16_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add745 = R.call_tir(cls.add, (add742, lv324), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm212 = R.call_tir(cls.layer_norm, (add745, model_decoder_layers_16_final_layer_norm_weight3, model_decoder_layers_16_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv48 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_16_fc1_weight3, layer_norm212, model_decoder_layers_16_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv325 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_16_fc2_weight3, lv48, model_decoder_layers_16_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add748 = R.call_tir(cls.add, (add745, lv325), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm213 = R.call_tir(cls.layer_norm, (add748, model_decoder_layers_17_self_attn_layer_norm_weight3, model_decoder_layers_17_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv326 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_self_attn_q_proj_weight3, layer_norm213, model_decoder_layers_17_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape880 = R.call_tir(cls.reshape4, (lv326,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv82 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_17_self_attn_k_proj_weight3, layer_norm213), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape881 = R.call_tir(cls.reshape4, (lv82,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv327 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_self_attn_v_proj_weight3, layer_norm213, model_decoder_layers_17_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape882 = R.call_tir(cls.reshape4, (lv327,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat49 = R.call_tir(cls.concatenate, (reshape880, reshape881, reshape882), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape883 = R.call_tir(cls.reshape5, (concat49,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv168 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape883), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape884 = R.call_tir(cls.reshape6, (lv168,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape885 = R.call_tir(cls.reshape7, (reshape884,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv328 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_self_attn_out_proj_weight3, reshape885, model_decoder_layers_17_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add752 = R.call_tir(cls.add, (add748, lv328), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm214 = R.call_tir(cls.layer_norm, (add752, model_decoder_layers_17_encoder_attn_layer_norm_weight3, model_decoder_layers_17_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv329 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_encoder_attn_q_proj_weight3, layer_norm214, model_decoder_layers_17_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape886 = R.call_tir(cls.reshape4, (lv329,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape887 = R.call_tir(cls.reshape8, (reshape886,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv169 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape887), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape888 = R.call_tir(cls.reshape6, (lv169,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape889 = R.call_tir(cls.reshape7, (reshape888,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv330 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_encoder_attn_out_proj_weight3, reshape889, model_decoder_layers_17_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add755 = R.call_tir(cls.add, (add752, lv330), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm215 = R.call_tir(cls.layer_norm, (add755, model_decoder_layers_17_final_layer_norm_weight3, model_decoder_layers_17_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv49 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_17_fc1_weight3, layer_norm215, model_decoder_layers_17_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv331 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_17_fc2_weight3, lv49, model_decoder_layers_17_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add758 = R.call_tir(cls.add, (add755, lv331), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm216 = R.call_tir(cls.layer_norm, (add758, model_decoder_layers_18_self_attn_layer_norm_weight3, model_decoder_layers_18_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv332 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_self_attn_q_proj_weight3, layer_norm216, model_decoder_layers_18_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape890 = R.call_tir(cls.reshape4, (lv332,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv83 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_18_self_attn_k_proj_weight3, layer_norm216), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape891 = R.call_tir(cls.reshape4, (lv83,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv333 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_self_attn_v_proj_weight3, layer_norm216, model_decoder_layers_18_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape892 = R.call_tir(cls.reshape4, (lv333,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat50 = R.call_tir(cls.concatenate, (reshape890, reshape891, reshape892), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape893 = R.call_tir(cls.reshape5, (concat50,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv170 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape893), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape894 = R.call_tir(cls.reshape6, (lv170,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape895 = R.call_tir(cls.reshape7, (reshape894,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv334 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_self_attn_out_proj_weight3, reshape895, model_decoder_layers_18_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add762 = R.call_tir(cls.add, (add758, lv334), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm217 = R.call_tir(cls.layer_norm, (add762, model_decoder_layers_18_encoder_attn_layer_norm_weight3, model_decoder_layers_18_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv335 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_encoder_attn_q_proj_weight3, layer_norm217, model_decoder_layers_18_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape896 = R.call_tir(cls.reshape4, (lv335,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape897 = R.call_tir(cls.reshape8, (reshape896,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv171 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape897), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape898 = R.call_tir(cls.reshape6, (lv171,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape899 = R.call_tir(cls.reshape7, (reshape898,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv336 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_encoder_attn_out_proj_weight3, reshape899, model_decoder_layers_18_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add765 = R.call_tir(cls.add, (add762, lv336), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm218 = R.call_tir(cls.layer_norm, (add765, model_decoder_layers_18_final_layer_norm_weight3, model_decoder_layers_18_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv50 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_18_fc1_weight3, layer_norm218, model_decoder_layers_18_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv337 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_18_fc2_weight3, lv50, model_decoder_layers_18_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add768 = R.call_tir(cls.add, (add765, lv337), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm219 = R.call_tir(cls.layer_norm, (add768, model_decoder_layers_19_self_attn_layer_norm_weight3, model_decoder_layers_19_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv338 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_self_attn_q_proj_weight3, layer_norm219, model_decoder_layers_19_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape900 = R.call_tir(cls.reshape4, (lv338,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv84 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_19_self_attn_k_proj_weight3, layer_norm219), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape901 = R.call_tir(cls.reshape4, (lv84,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv339 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_self_attn_v_proj_weight3, layer_norm219, model_decoder_layers_19_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape902 = R.call_tir(cls.reshape4, (lv339,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat51 = R.call_tir(cls.concatenate, (reshape900, reshape901, reshape902), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape903 = R.call_tir(cls.reshape5, (concat51,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv172 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape903), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape904 = R.call_tir(cls.reshape6, (lv172,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape905 = R.call_tir(cls.reshape7, (reshape904,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv340 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_self_attn_out_proj_weight3, reshape905, model_decoder_layers_19_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add772 = R.call_tir(cls.add, (add768, lv340), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm220 = R.call_tir(cls.layer_norm, (add772, model_decoder_layers_19_encoder_attn_layer_norm_weight3, model_decoder_layers_19_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv341 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_encoder_attn_q_proj_weight3, layer_norm220, model_decoder_layers_19_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape906 = R.call_tir(cls.reshape4, (lv341,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape907 = R.call_tir(cls.reshape8, (reshape906,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv173 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape907), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape908 = R.call_tir(cls.reshape6, (lv173,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape909 = R.call_tir(cls.reshape7, (reshape908,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv342 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_encoder_attn_out_proj_weight3, reshape909, model_decoder_layers_19_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add775 = R.call_tir(cls.add, (add772, lv342), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm221 = R.call_tir(cls.layer_norm, (add775, model_decoder_layers_19_final_layer_norm_weight3, model_decoder_layers_19_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv51 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_19_fc1_weight3, layer_norm221, model_decoder_layers_19_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv343 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_19_fc2_weight3, lv51, model_decoder_layers_19_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add778 = R.call_tir(cls.add, (add775, lv343), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm222 = R.call_tir(cls.layer_norm, (add778, model_decoder_layers_20_self_attn_layer_norm_weight3, model_decoder_layers_20_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv344 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_self_attn_q_proj_weight3, layer_norm222, model_decoder_layers_20_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape910 = R.call_tir(cls.reshape4, (lv344,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv85 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_20_self_attn_k_proj_weight3, layer_norm222), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape911 = R.call_tir(cls.reshape4, (lv85,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv345 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_self_attn_v_proj_weight3, layer_norm222, model_decoder_layers_20_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape912 = R.call_tir(cls.reshape4, (lv345,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat52 = R.call_tir(cls.concatenate, (reshape910, reshape911, reshape912), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape913 = R.call_tir(cls.reshape5, (concat52,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv174 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape913), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape914 = R.call_tir(cls.reshape6, (lv174,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape915 = R.call_tir(cls.reshape7, (reshape914,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv346 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_self_attn_out_proj_weight3, reshape915, model_decoder_layers_20_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add782 = R.call_tir(cls.add, (add778, lv346), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm223 = R.call_tir(cls.layer_norm, (add782, model_decoder_layers_20_encoder_attn_layer_norm_weight3, model_decoder_layers_20_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv347 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_encoder_attn_q_proj_weight3, layer_norm223, model_decoder_layers_20_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape916 = R.call_tir(cls.reshape4, (lv347,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape917 = R.call_tir(cls.reshape8, (reshape916,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv175 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape917), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape918 = R.call_tir(cls.reshape6, (lv175,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape919 = R.call_tir(cls.reshape7, (reshape918,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv348 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_encoder_attn_out_proj_weight3, reshape919, model_decoder_layers_20_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add785 = R.call_tir(cls.add, (add782, lv348), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm224 = R.call_tir(cls.layer_norm, (add785, model_decoder_layers_20_final_layer_norm_weight3, model_decoder_layers_20_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv52 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_20_fc1_weight3, layer_norm224, model_decoder_layers_20_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv349 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_20_fc2_weight3, lv52, model_decoder_layers_20_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add788 = R.call_tir(cls.add, (add785, lv349), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm225 = R.call_tir(cls.layer_norm, (add788, model_decoder_layers_21_self_attn_layer_norm_weight3, model_decoder_layers_21_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv350 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_self_attn_q_proj_weight3, layer_norm225, model_decoder_layers_21_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape920 = R.call_tir(cls.reshape4, (lv350,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv86 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_21_self_attn_k_proj_weight3, layer_norm225), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape921 = R.call_tir(cls.reshape4, (lv86,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv351 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_self_attn_v_proj_weight3, layer_norm225, model_decoder_layers_21_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape922 = R.call_tir(cls.reshape4, (lv351,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat53 = R.call_tir(cls.concatenate, (reshape920, reshape921, reshape922), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape923 = R.call_tir(cls.reshape5, (concat53,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv176 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape923), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape924 = R.call_tir(cls.reshape6, (lv176,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape925 = R.call_tir(cls.reshape7, (reshape924,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv352 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_self_attn_out_proj_weight3, reshape925, model_decoder_layers_21_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add792 = R.call_tir(cls.add, (add788, lv352), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm226 = R.call_tir(cls.layer_norm, (add792, model_decoder_layers_21_encoder_attn_layer_norm_weight3, model_decoder_layers_21_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv353 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_encoder_attn_q_proj_weight3, layer_norm226, model_decoder_layers_21_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape926 = R.call_tir(cls.reshape4, (lv353,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape927 = R.call_tir(cls.reshape8, (reshape926,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv177 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape927), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape928 = R.call_tir(cls.reshape6, (lv177,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape929 = R.call_tir(cls.reshape7, (reshape928,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv354 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_encoder_attn_out_proj_weight3, reshape929, model_decoder_layers_21_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add795 = R.call_tir(cls.add, (add792, lv354), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm227 = R.call_tir(cls.layer_norm, (add795, model_decoder_layers_21_final_layer_norm_weight3, model_decoder_layers_21_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv53 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_21_fc1_weight3, layer_norm227, model_decoder_layers_21_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv355 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_21_fc2_weight3, lv53, model_decoder_layers_21_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add798 = R.call_tir(cls.add, (add795, lv355), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm228 = R.call_tir(cls.layer_norm, (add798, model_decoder_layers_22_self_attn_layer_norm_weight3, model_decoder_layers_22_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv356 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_self_attn_q_proj_weight3, layer_norm228, model_decoder_layers_22_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape930 = R.call_tir(cls.reshape4, (lv356,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv87 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_22_self_attn_k_proj_weight3, layer_norm228), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape931 = R.call_tir(cls.reshape4, (lv87,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv357 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_self_attn_v_proj_weight3, layer_norm228, model_decoder_layers_22_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape932 = R.call_tir(cls.reshape4, (lv357,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat54 = R.call_tir(cls.concatenate, (reshape930, reshape931, reshape932), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape933 = R.call_tir(cls.reshape5, (concat54,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv178 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape933), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape934 = R.call_tir(cls.reshape6, (lv178,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape935 = R.call_tir(cls.reshape7, (reshape934,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv358 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_self_attn_out_proj_weight3, reshape935, model_decoder_layers_22_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add802 = R.call_tir(cls.add, (add798, lv358), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm229 = R.call_tir(cls.layer_norm, (add802, model_decoder_layers_22_encoder_attn_layer_norm_weight3, model_decoder_layers_22_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv359 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_encoder_attn_q_proj_weight3, layer_norm229, model_decoder_layers_22_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape936 = R.call_tir(cls.reshape4, (lv359,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape937 = R.call_tir(cls.reshape8, (reshape936,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv179 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape937), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape938 = R.call_tir(cls.reshape6, (lv179,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape939 = R.call_tir(cls.reshape7, (reshape938,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv360 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_encoder_attn_out_proj_weight3, reshape939, model_decoder_layers_22_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add805 = R.call_tir(cls.add, (add802, lv360), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm230 = R.call_tir(cls.layer_norm, (add805, model_decoder_layers_22_final_layer_norm_weight3, model_decoder_layers_22_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv54 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_22_fc1_weight3, layer_norm230, model_decoder_layers_22_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv361 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_22_fc2_weight3, lv54, model_decoder_layers_22_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add808 = R.call_tir(cls.add, (add805, lv361), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm231 = R.call_tir(cls.layer_norm, (add808, model_decoder_layers_23_self_attn_layer_norm_weight3, model_decoder_layers_23_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv362 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_self_attn_q_proj_weight3, layer_norm231, model_decoder_layers_23_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape940 = R.call_tir(cls.reshape4, (lv362,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv88 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_23_self_attn_k_proj_weight3, layer_norm231), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape941 = R.call_tir(cls.reshape4, (lv88,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv363 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_self_attn_v_proj_weight3, layer_norm231, model_decoder_layers_23_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape942 = R.call_tir(cls.reshape4, (lv363,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat55 = R.call_tir(cls.concatenate, (reshape940, reshape941, reshape942), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape943 = R.call_tir(cls.reshape5, (concat55,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv180 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape943), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape944 = R.call_tir(cls.reshape6, (lv180,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape945 = R.call_tir(cls.reshape7, (reshape944,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv364 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_self_attn_out_proj_weight3, reshape945, model_decoder_layers_23_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add812 = R.call_tir(cls.add, (add808, lv364), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm232 = R.call_tir(cls.layer_norm, (add812, model_decoder_layers_23_encoder_attn_layer_norm_weight3, model_decoder_layers_23_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv365 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_encoder_attn_q_proj_weight3, layer_norm232, model_decoder_layers_23_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape946 = R.call_tir(cls.reshape4, (lv365,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape947 = R.call_tir(cls.reshape8, (reshape946,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv181 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape947), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape948 = R.call_tir(cls.reshape6, (lv181,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape949 = R.call_tir(cls.reshape7, (reshape948,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv366 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_encoder_attn_out_proj_weight3, reshape949, model_decoder_layers_23_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add815 = R.call_tir(cls.add, (add812, lv366), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm233 = R.call_tir(cls.layer_norm, (add815, model_decoder_layers_23_final_layer_norm_weight3, model_decoder_layers_23_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv55 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_23_fc1_weight3, layer_norm233, model_decoder_layers_23_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv367 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_23_fc2_weight3, lv55, model_decoder_layers_23_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add818 = R.call_tir(cls.add, (add815, lv367), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm234 = R.call_tir(cls.layer_norm, (add818, model_decoder_layers_24_self_attn_layer_norm_weight3, model_decoder_layers_24_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv368 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_self_attn_q_proj_weight3, layer_norm234, model_decoder_layers_24_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape950 = R.call_tir(cls.reshape4, (lv368,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv89 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_24_self_attn_k_proj_weight3, layer_norm234), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape951 = R.call_tir(cls.reshape4, (lv89,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv369 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_self_attn_v_proj_weight3, layer_norm234, model_decoder_layers_24_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape952 = R.call_tir(cls.reshape4, (lv369,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat56 = R.call_tir(cls.concatenate, (reshape950, reshape951, reshape952), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape953 = R.call_tir(cls.reshape5, (concat56,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv182 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape953), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape954 = R.call_tir(cls.reshape6, (lv182,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape955 = R.call_tir(cls.reshape7, (reshape954,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv370 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_self_attn_out_proj_weight3, reshape955, model_decoder_layers_24_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add822 = R.call_tir(cls.add, (add818, lv370), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm235 = R.call_tir(cls.layer_norm, (add822, model_decoder_layers_24_encoder_attn_layer_norm_weight3, model_decoder_layers_24_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv371 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_encoder_attn_q_proj_weight3, layer_norm235, model_decoder_layers_24_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape956 = R.call_tir(cls.reshape4, (lv371,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape957 = R.call_tir(cls.reshape8, (reshape956,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv183 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape957), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape958 = R.call_tir(cls.reshape6, (lv183,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape959 = R.call_tir(cls.reshape7, (reshape958,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv372 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_encoder_attn_out_proj_weight3, reshape959, model_decoder_layers_24_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add825 = R.call_tir(cls.add, (add822, lv372), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm236 = R.call_tir(cls.layer_norm, (add825, model_decoder_layers_24_final_layer_norm_weight3, model_decoder_layers_24_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv56 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_24_fc1_weight3, layer_norm236, model_decoder_layers_24_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv373 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_24_fc2_weight3, lv56, model_decoder_layers_24_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add828 = R.call_tir(cls.add, (add825, lv373), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm237 = R.call_tir(cls.layer_norm, (add828, model_decoder_layers_25_self_attn_layer_norm_weight3, model_decoder_layers_25_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv374 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_self_attn_q_proj_weight3, layer_norm237, model_decoder_layers_25_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape960 = R.call_tir(cls.reshape4, (lv374,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv90 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_25_self_attn_k_proj_weight3, layer_norm237), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape961 = R.call_tir(cls.reshape4, (lv90,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv375 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_self_attn_v_proj_weight3, layer_norm237, model_decoder_layers_25_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape962 = R.call_tir(cls.reshape4, (lv375,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat57 = R.call_tir(cls.concatenate, (reshape960, reshape961, reshape962), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape963 = R.call_tir(cls.reshape5, (concat57,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv184 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape963), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape964 = R.call_tir(cls.reshape6, (lv184,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape965 = R.call_tir(cls.reshape7, (reshape964,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv376 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_self_attn_out_proj_weight3, reshape965, model_decoder_layers_25_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add832 = R.call_tir(cls.add, (add828, lv376), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm238 = R.call_tir(cls.layer_norm, (add832, model_decoder_layers_25_encoder_attn_layer_norm_weight3, model_decoder_layers_25_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv377 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_encoder_attn_q_proj_weight3, layer_norm238, model_decoder_layers_25_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape966 = R.call_tir(cls.reshape4, (lv377,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape967 = R.call_tir(cls.reshape8, (reshape966,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv185 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape967), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape968 = R.call_tir(cls.reshape6, (lv185,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape969 = R.call_tir(cls.reshape7, (reshape968,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv378 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_encoder_attn_out_proj_weight3, reshape969, model_decoder_layers_25_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add835 = R.call_tir(cls.add, (add832, lv378), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm239 = R.call_tir(cls.layer_norm, (add835, model_decoder_layers_25_final_layer_norm_weight3, model_decoder_layers_25_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv57 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_25_fc1_weight3, layer_norm239, model_decoder_layers_25_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv379 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_25_fc2_weight3, lv57, model_decoder_layers_25_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add838 = R.call_tir(cls.add, (add835, lv379), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm240 = R.call_tir(cls.layer_norm, (add838, model_decoder_layers_26_self_attn_layer_norm_weight3, model_decoder_layers_26_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv380 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_self_attn_q_proj_weight3, layer_norm240, model_decoder_layers_26_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape970 = R.call_tir(cls.reshape4, (lv380,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv91 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_26_self_attn_k_proj_weight3, layer_norm240), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape971 = R.call_tir(cls.reshape4, (lv91,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv381 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_self_attn_v_proj_weight3, layer_norm240, model_decoder_layers_26_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape972 = R.call_tir(cls.reshape4, (lv381,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat58 = R.call_tir(cls.concatenate, (reshape970, reshape971, reshape972), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape973 = R.call_tir(cls.reshape5, (concat58,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv186 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape973), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape974 = R.call_tir(cls.reshape6, (lv186,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape975 = R.call_tir(cls.reshape7, (reshape974,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv382 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_self_attn_out_proj_weight3, reshape975, model_decoder_layers_26_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add842 = R.call_tir(cls.add, (add838, lv382), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm241 = R.call_tir(cls.layer_norm, (add842, model_decoder_layers_26_encoder_attn_layer_norm_weight3, model_decoder_layers_26_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv383 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_encoder_attn_q_proj_weight3, layer_norm241, model_decoder_layers_26_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape976 = R.call_tir(cls.reshape4, (lv383,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape977 = R.call_tir(cls.reshape8, (reshape976,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv187 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape977), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape978 = R.call_tir(cls.reshape6, (lv187,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape979 = R.call_tir(cls.reshape7, (reshape978,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv384 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_encoder_attn_out_proj_weight3, reshape979, model_decoder_layers_26_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add845 = R.call_tir(cls.add, (add842, lv384), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm242 = R.call_tir(cls.layer_norm, (add845, model_decoder_layers_26_final_layer_norm_weight3, model_decoder_layers_26_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv58 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_26_fc1_weight3, layer_norm242, model_decoder_layers_26_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv385 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_26_fc2_weight3, lv58, model_decoder_layers_26_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add848 = R.call_tir(cls.add, (add845, lv385), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm243 = R.call_tir(cls.layer_norm, (add848, model_decoder_layers_27_self_attn_layer_norm_weight3, model_decoder_layers_27_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv386 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_self_attn_q_proj_weight3, layer_norm243, model_decoder_layers_27_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape980 = R.call_tir(cls.reshape4, (lv386,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv92 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_27_self_attn_k_proj_weight3, layer_norm243), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape981 = R.call_tir(cls.reshape4, (lv92,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv387 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_self_attn_v_proj_weight3, layer_norm243, model_decoder_layers_27_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape982 = R.call_tir(cls.reshape4, (lv387,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat59 = R.call_tir(cls.concatenate, (reshape980, reshape981, reshape982), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape983 = R.call_tir(cls.reshape5, (concat59,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv188 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape983), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape984 = R.call_tir(cls.reshape6, (lv188,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape985 = R.call_tir(cls.reshape7, (reshape984,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv388 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_self_attn_out_proj_weight3, reshape985, model_decoder_layers_27_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add852 = R.call_tir(cls.add, (add848, lv388), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm244 = R.call_tir(cls.layer_norm, (add852, model_decoder_layers_27_encoder_attn_layer_norm_weight3, model_decoder_layers_27_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv389 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_encoder_attn_q_proj_weight3, layer_norm244, model_decoder_layers_27_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape986 = R.call_tir(cls.reshape4, (lv389,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape987 = R.call_tir(cls.reshape8, (reshape986,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv189 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape987), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape988 = R.call_tir(cls.reshape6, (lv189,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape989 = R.call_tir(cls.reshape7, (reshape988,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv390 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_encoder_attn_out_proj_weight3, reshape989, model_decoder_layers_27_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add855 = R.call_tir(cls.add, (add852, lv390), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm245 = R.call_tir(cls.layer_norm, (add855, model_decoder_layers_27_final_layer_norm_weight3, model_decoder_layers_27_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv59 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_27_fc1_weight3, layer_norm245, model_decoder_layers_27_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv391 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_27_fc2_weight3, lv59, model_decoder_layers_27_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add858 = R.call_tir(cls.add, (add855, lv391), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm246 = R.call_tir(cls.layer_norm, (add858, model_decoder_layers_28_self_attn_layer_norm_weight3, model_decoder_layers_28_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv392 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_self_attn_q_proj_weight3, layer_norm246, model_decoder_layers_28_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape990 = R.call_tir(cls.reshape4, (lv392,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv93 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_28_self_attn_k_proj_weight3, layer_norm246), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape991 = R.call_tir(cls.reshape4, (lv93,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv393 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_self_attn_v_proj_weight3, layer_norm246, model_decoder_layers_28_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape992 = R.call_tir(cls.reshape4, (lv393,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat60 = R.call_tir(cls.concatenate, (reshape990, reshape991, reshape992), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape993 = R.call_tir(cls.reshape5, (concat60,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv190 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape993), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape994 = R.call_tir(cls.reshape6, (lv190,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape995 = R.call_tir(cls.reshape7, (reshape994,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv394 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_self_attn_out_proj_weight3, reshape995, model_decoder_layers_28_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add862 = R.call_tir(cls.add, (add858, lv394), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm247 = R.call_tir(cls.layer_norm, (add862, model_decoder_layers_28_encoder_attn_layer_norm_weight3, model_decoder_layers_28_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv395 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_encoder_attn_q_proj_weight3, layer_norm247, model_decoder_layers_28_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape996 = R.call_tir(cls.reshape4, (lv395,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape997 = R.call_tir(cls.reshape8, (reshape996,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv191 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape997), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape998 = R.call_tir(cls.reshape6, (lv191,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape999 = R.call_tir(cls.reshape7, (reshape998,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv396 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_encoder_attn_out_proj_weight3, reshape999, model_decoder_layers_28_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add865 = R.call_tir(cls.add, (add862, lv396), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm248 = R.call_tir(cls.layer_norm, (add865, model_decoder_layers_28_final_layer_norm_weight3, model_decoder_layers_28_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv60 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_28_fc1_weight3, layer_norm248, model_decoder_layers_28_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv397 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_28_fc2_weight3, lv60, model_decoder_layers_28_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add868 = R.call_tir(cls.add, (add865, lv397), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm249 = R.call_tir(cls.layer_norm, (add868, model_decoder_layers_29_self_attn_layer_norm_weight3, model_decoder_layers_29_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv398 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_self_attn_q_proj_weight3, layer_norm249, model_decoder_layers_29_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1000 = R.call_tir(cls.reshape4, (lv398,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv94 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_29_self_attn_k_proj_weight3, layer_norm249), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1001 = R.call_tir(cls.reshape4, (lv94,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv399 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_self_attn_v_proj_weight3, layer_norm249, model_decoder_layers_29_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1002 = R.call_tir(cls.reshape4, (lv399,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat61 = R.call_tir(cls.concatenate, (reshape1000, reshape1001, reshape1002), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape1003 = R.call_tir(cls.reshape5, (concat61,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv192 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1003), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape1004 = R.call_tir(cls.reshape6, (lv192,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1005 = R.call_tir(cls.reshape7, (reshape1004,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv400 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_self_attn_out_proj_weight3, reshape1005, model_decoder_layers_29_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add872 = R.call_tir(cls.add, (add868, lv400), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm250 = R.call_tir(cls.layer_norm, (add872, model_decoder_layers_29_encoder_attn_layer_norm_weight3, model_decoder_layers_29_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv401 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_encoder_attn_q_proj_weight3, layer_norm250, model_decoder_layers_29_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1006 = R.call_tir(cls.reshape4, (lv401,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1007 = R.call_tir(cls.reshape8, (reshape1006,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv193 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1007), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape1008 = R.call_tir(cls.reshape6, (lv193,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1009 = R.call_tir(cls.reshape7, (reshape1008,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv402 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_encoder_attn_out_proj_weight3, reshape1009, model_decoder_layers_29_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add875 = R.call_tir(cls.add, (add872, lv402), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm251 = R.call_tir(cls.layer_norm, (add875, model_decoder_layers_29_final_layer_norm_weight3, model_decoder_layers_29_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv61 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_29_fc1_weight3, layer_norm251, model_decoder_layers_29_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv403 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_29_fc2_weight3, lv61, model_decoder_layers_29_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add878 = R.call_tir(cls.add, (add875, lv403), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm252 = R.call_tir(cls.layer_norm, (add878, model_decoder_layers_30_self_attn_layer_norm_weight3, model_decoder_layers_30_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv404 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_self_attn_q_proj_weight3, layer_norm252, model_decoder_layers_30_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1010 = R.call_tir(cls.reshape4, (lv404,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv95 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_30_self_attn_k_proj_weight3, layer_norm252), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1011 = R.call_tir(cls.reshape4, (lv95,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv405 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_self_attn_v_proj_weight3, layer_norm252, model_decoder_layers_30_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1012 = R.call_tir(cls.reshape4, (lv405,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat62 = R.call_tir(cls.concatenate, (reshape1010, reshape1011, reshape1012), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape1013 = R.call_tir(cls.reshape5, (concat62,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv194 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1013), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape1014 = R.call_tir(cls.reshape6, (lv194,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1015 = R.call_tir(cls.reshape7, (reshape1014,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv406 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_self_attn_out_proj_weight3, reshape1015, model_decoder_layers_30_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add882 = R.call_tir(cls.add, (add878, lv406), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm253 = R.call_tir(cls.layer_norm, (add882, model_decoder_layers_30_encoder_attn_layer_norm_weight3, model_decoder_layers_30_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv407 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_encoder_attn_q_proj_weight3, layer_norm253, model_decoder_layers_30_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1016 = R.call_tir(cls.reshape4, (lv407,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1017 = R.call_tir(cls.reshape8, (reshape1016,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv195 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1017), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape1018 = R.call_tir(cls.reshape6, (lv195,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1019 = R.call_tir(cls.reshape7, (reshape1018,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv408 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_encoder_attn_out_proj_weight3, reshape1019, model_decoder_layers_30_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add885 = R.call_tir(cls.add, (add882, lv408), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm254 = R.call_tir(cls.layer_norm, (add885, model_decoder_layers_30_final_layer_norm_weight3, model_decoder_layers_30_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv62 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_30_fc1_weight3, layer_norm254, model_decoder_layers_30_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv409 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_30_fc2_weight3, lv62, model_decoder_layers_30_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add888 = R.call_tir(cls.add, (add885, lv409), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm255 = R.call_tir(cls.layer_norm, (add888, model_decoder_layers_31_self_attn_layer_norm_weight3, model_decoder_layers_31_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv410 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_self_attn_q_proj_weight3, layer_norm255, model_decoder_layers_31_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1020 = R.call_tir(cls.reshape4, (lv410,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv96 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_31_self_attn_k_proj_weight3, layer_norm255), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1021 = R.call_tir(cls.reshape4, (lv96,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + lv411 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_self_attn_v_proj_weight3, layer_norm255, model_decoder_layers_31_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1022 = R.call_tir(cls.reshape4, (lv411,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + concat63 = R.call_tir(cls.concatenate, (reshape1020, reshape1021, reshape1022), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) + reshape1023 = R.call_tir(cls.reshape5, (concat63,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) + lv196 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1023), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape1024 = R.call_tir(cls.reshape6, (lv196,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1025 = R.call_tir(cls.reshape7, (reshape1024,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv412 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_self_attn_out_proj_weight3, reshape1025, model_decoder_layers_31_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add892 = R.call_tir(cls.add, (add888, lv412), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm256 = R.call_tir(cls.layer_norm, (add892, model_decoder_layers_31_encoder_attn_layer_norm_weight3, model_decoder_layers_31_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv413 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_encoder_attn_q_proj_weight3, layer_norm256, model_decoder_layers_31_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + reshape1026 = R.call_tir(cls.reshape4, (lv413,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1027 = R.call_tir(cls.reshape8, (reshape1026,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + lv197 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1027), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) + reshape1028 = R.call_tir(cls.reshape6, (lv197,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) + reshape1029 = R.call_tir(cls.reshape7, (reshape1028,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv414 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_encoder_attn_out_proj_weight3, reshape1029, model_decoder_layers_31_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add895 = R.call_tir(cls.add, (add892, lv414), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm257 = R.call_tir(cls.layer_norm, (add895, model_decoder_layers_31_final_layer_norm_weight3, model_decoder_layers_31_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + lv63 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_31_fc1_weight3, layer_norm257, model_decoder_layers_31_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) + lv415 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_31_fc2_weight3, lv63, model_decoder_layers_31_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + add898 = R.call_tir(cls.add, (add895, lv415), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + layer_norm258 = R.call_tir(cls.layer_norm, (add898, model_decoder_layer_norm_weight3, model_decoder_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) + gv3 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul4_cublas", (model_decoder_embed_tokens_weight3, layer_norm258), out_sinfo=R.Tensor((batch_size, 1, 51866), dtype="float32")) + R.output(gv3) + return gv3 + + @R.function + def batch_encode(input_features: R.Tensor(("batch_size", 128, 3000), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor(("batch_size", 1500, 1280), dtype="float16"): + batch_size = T.int64() + R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + with R.dataflow(): + model_encoder_conv1_weight: R.Tensor((1280, 128, 3), dtype="float16") = packed_params[0] + lv: R.Tensor((1280,), dtype="float16") = packed_params[1] + lv1 = R.call_tir(cls.fused_reshape9, (lv,), out_sinfo=R.Tensor((1, 1280, 1), dtype="float16")) + model_encoder_conv2_weight: R.Tensor((1280, 1280, 3), dtype="float16") = packed_params[2] + lv2: R.Tensor((1280,), dtype="float16") = packed_params[3] + lv3 = R.call_tir(cls.fused_reshape9, (lv2,), out_sinfo=R.Tensor((1, 1280, 1), dtype="float16")) + lv4 = R.call_tir(cls.fused_conv1d_add1_gelu, (input_features, model_encoder_conv1_weight, lv1), out_sinfo=R.Tensor((batch_size, 1280, 3000), dtype="float16")) + lv5 = R.call_tir(cls.fused_conv1d1_add2_gelu1, (lv4, model_encoder_conv2_weight, lv3), out_sinfo=R.Tensor((batch_size, 1280, 1500), dtype="float16")) + lv6: R.Tensor((1500, 1280), dtype="float16") = packed_params[4] + lv7 = R.call_tir(cls.fused_transpose_add3, (lv6, lv5), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + model_encoder_layers_0_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[5] + model_encoder_layers_0_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[6] + model_encoder_layers_0_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[7] + model_encoder_layers_0_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[8] + model_encoder_layers_0_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[9] + model_encoder_layers_0_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[10] + model_encoder_layers_0_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[11] + model_encoder_layers_0_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[12] + model_encoder_layers_0_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[13] + model_encoder_layers_0_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[14] + model_encoder_layers_0_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[15] + model_encoder_layers_0_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[16] + model_encoder_layers_0_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[17] + model_encoder_layers_0_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[18] + model_encoder_layers_0_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[19] + model_encoder_layers_1_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[20] + model_encoder_layers_1_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[21] + model_encoder_layers_1_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[22] + model_encoder_layers_1_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[23] + model_encoder_layers_1_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[24] + model_encoder_layers_1_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[25] + model_encoder_layers_1_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[26] + model_encoder_layers_1_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[27] + model_encoder_layers_1_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[28] + model_encoder_layers_1_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[29] + model_encoder_layers_1_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[30] + model_encoder_layers_1_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[31] + model_encoder_layers_1_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[32] + model_encoder_layers_1_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[33] + model_encoder_layers_1_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[34] + model_encoder_layers_2_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[35] + model_encoder_layers_2_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[36] + model_encoder_layers_2_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[37] + model_encoder_layers_2_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[38] + model_encoder_layers_2_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[39] + model_encoder_layers_2_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[40] + model_encoder_layers_2_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[41] + model_encoder_layers_2_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[42] + model_encoder_layers_2_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[43] + model_encoder_layers_2_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[44] + model_encoder_layers_2_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[45] + model_encoder_layers_2_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[46] + model_encoder_layers_2_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[47] + model_encoder_layers_2_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[48] + model_encoder_layers_2_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[49] + model_encoder_layers_3_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[50] + model_encoder_layers_3_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[51] + model_encoder_layers_3_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[52] + model_encoder_layers_3_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[53] + model_encoder_layers_3_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[54] + model_encoder_layers_3_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[55] + model_encoder_layers_3_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[56] + model_encoder_layers_3_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[57] + model_encoder_layers_3_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[58] + model_encoder_layers_3_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[59] + model_encoder_layers_3_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[60] + model_encoder_layers_3_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[61] + model_encoder_layers_3_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[62] + model_encoder_layers_3_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[63] + model_encoder_layers_3_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[64] + model_encoder_layers_4_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[65] + model_encoder_layers_4_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[66] + model_encoder_layers_4_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[67] + model_encoder_layers_4_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[68] + model_encoder_layers_4_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[69] + model_encoder_layers_4_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[70] + model_encoder_layers_4_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[71] + model_encoder_layers_4_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[72] + model_encoder_layers_4_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[73] + model_encoder_layers_4_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[74] + model_encoder_layers_4_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[75] + model_encoder_layers_4_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[76] + model_encoder_layers_4_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[77] + model_encoder_layers_4_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[78] + model_encoder_layers_4_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[79] + model_encoder_layers_5_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[80] + model_encoder_layers_5_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[81] + model_encoder_layers_5_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[82] + model_encoder_layers_5_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[83] + model_encoder_layers_5_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[84] + model_encoder_layers_5_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[85] + model_encoder_layers_5_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[86] + model_encoder_layers_5_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[87] + model_encoder_layers_5_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[88] + model_encoder_layers_5_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[89] + model_encoder_layers_5_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[90] + model_encoder_layers_5_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[91] + model_encoder_layers_5_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[92] + model_encoder_layers_5_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[93] + model_encoder_layers_5_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[94] + model_encoder_layers_6_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[95] + model_encoder_layers_6_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[96] + model_encoder_layers_6_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[97] + model_encoder_layers_6_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[98] + model_encoder_layers_6_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[99] + model_encoder_layers_6_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[100] + model_encoder_layers_6_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[101] + model_encoder_layers_6_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[102] + model_encoder_layers_6_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[103] + model_encoder_layers_6_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[104] + model_encoder_layers_6_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[105] + model_encoder_layers_6_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[106] + model_encoder_layers_6_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[107] + model_encoder_layers_6_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[108] + model_encoder_layers_6_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[109] + model_encoder_layers_7_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[110] + model_encoder_layers_7_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[111] + model_encoder_layers_7_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[112] + model_encoder_layers_7_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[113] + model_encoder_layers_7_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[114] + model_encoder_layers_7_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[115] + model_encoder_layers_7_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[116] + model_encoder_layers_7_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[117] + model_encoder_layers_7_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[118] + model_encoder_layers_7_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[119] + model_encoder_layers_7_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[120] + model_encoder_layers_7_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[121] + model_encoder_layers_7_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[122] + model_encoder_layers_7_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[123] + model_encoder_layers_7_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[124] + model_encoder_layers_8_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[125] + model_encoder_layers_8_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[126] + model_encoder_layers_8_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[127] + model_encoder_layers_8_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[128] + model_encoder_layers_8_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[129] + model_encoder_layers_8_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[130] + model_encoder_layers_8_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[131] + model_encoder_layers_8_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[132] + model_encoder_layers_8_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[133] + model_encoder_layers_8_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[134] + model_encoder_layers_8_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[135] + model_encoder_layers_8_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[136] + model_encoder_layers_8_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[137] + model_encoder_layers_8_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[138] + model_encoder_layers_8_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[139] + model_encoder_layers_9_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[140] + model_encoder_layers_9_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[141] + model_encoder_layers_9_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[142] + model_encoder_layers_9_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[143] + model_encoder_layers_9_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[144] + model_encoder_layers_9_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[145] + model_encoder_layers_9_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[146] + model_encoder_layers_9_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[147] + model_encoder_layers_9_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[148] + model_encoder_layers_9_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[149] + model_encoder_layers_9_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[150] + model_encoder_layers_9_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[151] + model_encoder_layers_9_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[152] + model_encoder_layers_9_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[153] + model_encoder_layers_9_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[154] + model_encoder_layers_10_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[155] + model_encoder_layers_10_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[156] + model_encoder_layers_10_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[157] + model_encoder_layers_10_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[158] + model_encoder_layers_10_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[159] + model_encoder_layers_10_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[160] + model_encoder_layers_10_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[161] + model_encoder_layers_10_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[162] + model_encoder_layers_10_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[163] + model_encoder_layers_10_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[164] + model_encoder_layers_10_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[165] + model_encoder_layers_10_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[166] + model_encoder_layers_10_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[167] + model_encoder_layers_10_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[168] + model_encoder_layers_10_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[169] + model_encoder_layers_11_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[170] + model_encoder_layers_11_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[171] + model_encoder_layers_11_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[172] + model_encoder_layers_11_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[173] + model_encoder_layers_11_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[174] + model_encoder_layers_11_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[175] + model_encoder_layers_11_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[176] + model_encoder_layers_11_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[177] + model_encoder_layers_11_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[178] + model_encoder_layers_11_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[179] + model_encoder_layers_11_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[180] + model_encoder_layers_11_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[181] + model_encoder_layers_11_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[182] + model_encoder_layers_11_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[183] + model_encoder_layers_11_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[184] + model_encoder_layers_12_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[185] + model_encoder_layers_12_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[186] + model_encoder_layers_12_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[187] + model_encoder_layers_12_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[188] + model_encoder_layers_12_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[189] + model_encoder_layers_12_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[190] + model_encoder_layers_12_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[191] + model_encoder_layers_12_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[192] + model_encoder_layers_12_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[193] + model_encoder_layers_12_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[194] + model_encoder_layers_12_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[195] + model_encoder_layers_12_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[196] + model_encoder_layers_12_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[197] + model_encoder_layers_12_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[198] + model_encoder_layers_12_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[199] + model_encoder_layers_13_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[200] + model_encoder_layers_13_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[201] + model_encoder_layers_13_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[202] + model_encoder_layers_13_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[203] + model_encoder_layers_13_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[204] + model_encoder_layers_13_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[205] + model_encoder_layers_13_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[206] + model_encoder_layers_13_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[207] + model_encoder_layers_13_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[208] + model_encoder_layers_13_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[209] + model_encoder_layers_13_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[210] + model_encoder_layers_13_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[211] + model_encoder_layers_13_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[212] + model_encoder_layers_13_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[213] + model_encoder_layers_13_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[214] + model_encoder_layers_14_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[215] + model_encoder_layers_14_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[216] + model_encoder_layers_14_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[217] + model_encoder_layers_14_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[218] + model_encoder_layers_14_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[219] + model_encoder_layers_14_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[220] + model_encoder_layers_14_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[221] + model_encoder_layers_14_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[222] + model_encoder_layers_14_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[223] + model_encoder_layers_14_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[224] + model_encoder_layers_14_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[225] + model_encoder_layers_14_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[226] + model_encoder_layers_14_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[227] + model_encoder_layers_14_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[228] + model_encoder_layers_14_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[229] + model_encoder_layers_15_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[230] + model_encoder_layers_15_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[231] + model_encoder_layers_15_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[232] + model_encoder_layers_15_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[233] + model_encoder_layers_15_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[234] + model_encoder_layers_15_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[235] + model_encoder_layers_15_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[236] + model_encoder_layers_15_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[237] + model_encoder_layers_15_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[238] + model_encoder_layers_15_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[239] + model_encoder_layers_15_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[240] + model_encoder_layers_15_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[241] + model_encoder_layers_15_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[242] + model_encoder_layers_15_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[243] + model_encoder_layers_15_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[244] + model_encoder_layers_16_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[245] + model_encoder_layers_16_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[246] + model_encoder_layers_16_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[247] + model_encoder_layers_16_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[248] + model_encoder_layers_16_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[249] + model_encoder_layers_16_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[250] + model_encoder_layers_16_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[251] + model_encoder_layers_16_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[252] + model_encoder_layers_16_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[253] + model_encoder_layers_16_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[254] + model_encoder_layers_16_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[255] + model_encoder_layers_16_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[256] + model_encoder_layers_16_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[257] + model_encoder_layers_16_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[258] + model_encoder_layers_16_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[259] + model_encoder_layers_17_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[260] + model_encoder_layers_17_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[261] + model_encoder_layers_17_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[262] + model_encoder_layers_17_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[263] + model_encoder_layers_17_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[264] + model_encoder_layers_17_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[265] + model_encoder_layers_17_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[266] + model_encoder_layers_17_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[267] + model_encoder_layers_17_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[268] + model_encoder_layers_17_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[269] + model_encoder_layers_17_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[270] + model_encoder_layers_17_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[271] + model_encoder_layers_17_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[272] + model_encoder_layers_17_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[273] + model_encoder_layers_17_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[274] + model_encoder_layers_18_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[275] + model_encoder_layers_18_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[276] + model_encoder_layers_18_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[277] + model_encoder_layers_18_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[278] + model_encoder_layers_18_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[279] + model_encoder_layers_18_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[280] + model_encoder_layers_18_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[281] + model_encoder_layers_18_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[282] + model_encoder_layers_18_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[283] + model_encoder_layers_18_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[284] + model_encoder_layers_18_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[285] + model_encoder_layers_18_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[286] + model_encoder_layers_18_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[287] + model_encoder_layers_18_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[288] + model_encoder_layers_18_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[289] + model_encoder_layers_19_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[290] + model_encoder_layers_19_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[291] + model_encoder_layers_19_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[292] + model_encoder_layers_19_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[293] + model_encoder_layers_19_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[294] + model_encoder_layers_19_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[295] + model_encoder_layers_19_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[296] + model_encoder_layers_19_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[297] + model_encoder_layers_19_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[298] + model_encoder_layers_19_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[299] + model_encoder_layers_19_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[300] + model_encoder_layers_19_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[301] + model_encoder_layers_19_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[302] + model_encoder_layers_19_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[303] + model_encoder_layers_19_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[304] + model_encoder_layers_20_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[305] + model_encoder_layers_20_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[306] + model_encoder_layers_20_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[307] + model_encoder_layers_20_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[308] + model_encoder_layers_20_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[309] + model_encoder_layers_20_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[310] + model_encoder_layers_20_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[311] + model_encoder_layers_20_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[312] + model_encoder_layers_20_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[313] + model_encoder_layers_20_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[314] + model_encoder_layers_20_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[315] + model_encoder_layers_20_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[316] + model_encoder_layers_20_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[317] + model_encoder_layers_20_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[318] + model_encoder_layers_20_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[319] + model_encoder_layers_21_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[320] + model_encoder_layers_21_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[321] + model_encoder_layers_21_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[322] + model_encoder_layers_21_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[323] + model_encoder_layers_21_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[324] + model_encoder_layers_21_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[325] + model_encoder_layers_21_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[326] + model_encoder_layers_21_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[327] + model_encoder_layers_21_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[328] + model_encoder_layers_21_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[329] + model_encoder_layers_21_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[330] + model_encoder_layers_21_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[331] + model_encoder_layers_21_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[332] + model_encoder_layers_21_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[333] + model_encoder_layers_21_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[334] + model_encoder_layers_22_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[335] + model_encoder_layers_22_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[336] + model_encoder_layers_22_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[337] + model_encoder_layers_22_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[338] + model_encoder_layers_22_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[339] + model_encoder_layers_22_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[340] + model_encoder_layers_22_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[341] + model_encoder_layers_22_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[342] + model_encoder_layers_22_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[343] + model_encoder_layers_22_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[344] + model_encoder_layers_22_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[345] + model_encoder_layers_22_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[346] + model_encoder_layers_22_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[347] + model_encoder_layers_22_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[348] + model_encoder_layers_22_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[349] + model_encoder_layers_23_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[350] + model_encoder_layers_23_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[351] + model_encoder_layers_23_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[352] + model_encoder_layers_23_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[353] + model_encoder_layers_23_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[354] + model_encoder_layers_23_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[355] + model_encoder_layers_23_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[356] + model_encoder_layers_23_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[357] + model_encoder_layers_23_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[358] + model_encoder_layers_23_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[359] + model_encoder_layers_23_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[360] + model_encoder_layers_23_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[361] + model_encoder_layers_23_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[362] + model_encoder_layers_23_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[363] + model_encoder_layers_23_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[364] + model_encoder_layers_24_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[365] + model_encoder_layers_24_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[366] + model_encoder_layers_24_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[367] + model_encoder_layers_24_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[368] + model_encoder_layers_24_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[369] + model_encoder_layers_24_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[370] + model_encoder_layers_24_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[371] + model_encoder_layers_24_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[372] + model_encoder_layers_24_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[373] + model_encoder_layers_24_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[374] + model_encoder_layers_24_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[375] + model_encoder_layers_24_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[376] + model_encoder_layers_24_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[377] + model_encoder_layers_24_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[378] + model_encoder_layers_24_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[379] + model_encoder_layers_25_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[380] + model_encoder_layers_25_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[381] + model_encoder_layers_25_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[382] + model_encoder_layers_25_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[383] + model_encoder_layers_25_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[384] + model_encoder_layers_25_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[385] + model_encoder_layers_25_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[386] + model_encoder_layers_25_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[387] + model_encoder_layers_25_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[388] + model_encoder_layers_25_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[389] + model_encoder_layers_25_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[390] + model_encoder_layers_25_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[391] + model_encoder_layers_25_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[392] + model_encoder_layers_25_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[393] + model_encoder_layers_25_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[394] + model_encoder_layers_26_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[395] + model_encoder_layers_26_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[396] + model_encoder_layers_26_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[397] + model_encoder_layers_26_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[398] + model_encoder_layers_26_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[399] + model_encoder_layers_26_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[400] + model_encoder_layers_26_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[401] + model_encoder_layers_26_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[402] + model_encoder_layers_26_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[403] + model_encoder_layers_26_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[404] + model_encoder_layers_26_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[405] + model_encoder_layers_26_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[406] + model_encoder_layers_26_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[407] + model_encoder_layers_26_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[408] + model_encoder_layers_26_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[409] + model_encoder_layers_27_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[410] + model_encoder_layers_27_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[411] + model_encoder_layers_27_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[412] + model_encoder_layers_27_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[413] + model_encoder_layers_27_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[414] + model_encoder_layers_27_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[415] + model_encoder_layers_27_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[416] + model_encoder_layers_27_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[417] + model_encoder_layers_27_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[418] + model_encoder_layers_27_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[419] + model_encoder_layers_27_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[420] + model_encoder_layers_27_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[421] + model_encoder_layers_27_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[422] + model_encoder_layers_27_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[423] + model_encoder_layers_27_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[424] + model_encoder_layers_28_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[425] + model_encoder_layers_28_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[426] + model_encoder_layers_28_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[427] + model_encoder_layers_28_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[428] + model_encoder_layers_28_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[429] + model_encoder_layers_28_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[430] + model_encoder_layers_28_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[431] + model_encoder_layers_28_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[432] + model_encoder_layers_28_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[433] + model_encoder_layers_28_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[434] + model_encoder_layers_28_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[435] + model_encoder_layers_28_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[436] + model_encoder_layers_28_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[437] + model_encoder_layers_28_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[438] + model_encoder_layers_28_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[439] + model_encoder_layers_29_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[440] + model_encoder_layers_29_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[441] + model_encoder_layers_29_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[442] + model_encoder_layers_29_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[443] + model_encoder_layers_29_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[444] + model_encoder_layers_29_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[445] + model_encoder_layers_29_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[446] + model_encoder_layers_29_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[447] + model_encoder_layers_29_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[448] + model_encoder_layers_29_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[449] + model_encoder_layers_29_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[450] + model_encoder_layers_29_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[451] + model_encoder_layers_29_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[452] + model_encoder_layers_29_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[453] + model_encoder_layers_29_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[454] + model_encoder_layers_30_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[455] + model_encoder_layers_30_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[456] + model_encoder_layers_30_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[457] + model_encoder_layers_30_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[458] + model_encoder_layers_30_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[459] + model_encoder_layers_30_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[460] + model_encoder_layers_30_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[461] + model_encoder_layers_30_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[462] + model_encoder_layers_30_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[463] + model_encoder_layers_30_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[464] + model_encoder_layers_30_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[465] + model_encoder_layers_30_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[466] + model_encoder_layers_30_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[467] + model_encoder_layers_30_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[468] + model_encoder_layers_30_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[469] + model_encoder_layers_31_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[470] + model_encoder_layers_31_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[471] + model_encoder_layers_31_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[472] + model_encoder_layers_31_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[473] + model_encoder_layers_31_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[474] + model_encoder_layers_31_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[475] + model_encoder_layers_31_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[476] + model_encoder_layers_31_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[477] + model_encoder_layers_31_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[478] + model_encoder_layers_31_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[479] + model_encoder_layers_31_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[480] + model_encoder_layers_31_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[481] + model_encoder_layers_31_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[482] + model_encoder_layers_31_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[483] + model_encoder_layers_31_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[484] + model_encoder_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[485] + model_encoder_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[486] + layer_norm = R.call_tir(cls.layer_norm1, (lv7, model_encoder_layers_0_self_attn_layer_norm_weight, model_encoder_layers_0_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv608 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_0_self_attn_q_proj_weight, layer_norm, model_encoder_layers_0_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape = R.call_tir(cls.reshape, (lv608,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv131 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_0_self_attn_k_proj_weight, layer_norm), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape1 = R.call_tir(cls.reshape, (lv131,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv609 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_0_self_attn_v_proj_weight, layer_norm, model_encoder_layers_0_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape2 = R.call_tir(cls.reshape, (lv609,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape3 = R.call_tir(cls.reshape1, (reshape,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape4 = R.call_tir(cls.reshape1, (reshape1,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape5 = R.call_tir(cls.reshape1, (reshape2,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv4_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape3, reshape4, reshape5), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape6 = R.call_tir(cls.reshape10, (lv4_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape7 = R.call_tir(cls.reshape11, (reshape6,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv610 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_0_self_attn_out_proj_weight, reshape7, model_encoder_layers_0_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add4 = R.call_tir(cls.add4, (lv7, lv610), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm1 = R.call_tir(cls.layer_norm1, (add4, model_encoder_layers_0_final_layer_norm_weight, model_encoder_layers_0_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv96 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_0_fc1_weight, layer_norm1, model_encoder_layers_0_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv611 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_0_fc2_weight, lv96, model_encoder_layers_0_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv8 = R.call_tir(cls.fused_add4_maximum_minimum, (add4, lv611), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm2 = R.call_tir(cls.layer_norm1, (lv8, model_encoder_layers_1_self_attn_layer_norm_weight, model_encoder_layers_1_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv612 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_1_self_attn_q_proj_weight, layer_norm2, model_encoder_layers_1_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape8 = R.call_tir(cls.reshape, (lv612,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv132 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_1_self_attn_k_proj_weight, layer_norm2), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape9 = R.call_tir(cls.reshape, (lv132,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv613 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_1_self_attn_v_proj_weight, layer_norm2, model_encoder_layers_1_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape10 = R.call_tir(cls.reshape, (lv613,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape11 = R.call_tir(cls.reshape1, (reshape8,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape12 = R.call_tir(cls.reshape1, (reshape9,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape13 = R.call_tir(cls.reshape1, (reshape10,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv5_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape11, reshape12, reshape13), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape14 = R.call_tir(cls.reshape10, (lv5_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape15 = R.call_tir(cls.reshape11, (reshape14,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv614 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_1_self_attn_out_proj_weight, reshape15, model_encoder_layers_1_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add11 = R.call_tir(cls.add4, (lv8, lv614), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm3 = R.call_tir(cls.layer_norm1, (add11, model_encoder_layers_1_final_layer_norm_weight, model_encoder_layers_1_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv97 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_1_fc1_weight, layer_norm3, model_encoder_layers_1_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv615 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_1_fc2_weight, lv97, model_encoder_layers_1_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv9 = R.call_tir(cls.fused_add4_maximum_minimum, (add11, lv615), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm4 = R.call_tir(cls.layer_norm1, (lv9, model_encoder_layers_2_self_attn_layer_norm_weight, model_encoder_layers_2_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv616 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_2_self_attn_q_proj_weight, layer_norm4, model_encoder_layers_2_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape16 = R.call_tir(cls.reshape, (lv616,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv133 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_2_self_attn_k_proj_weight, layer_norm4), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape17 = R.call_tir(cls.reshape, (lv133,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv617 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_2_self_attn_v_proj_weight, layer_norm4, model_encoder_layers_2_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape18 = R.call_tir(cls.reshape, (lv617,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape19 = R.call_tir(cls.reshape1, (reshape16,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape20 = R.call_tir(cls.reshape1, (reshape17,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape21 = R.call_tir(cls.reshape1, (reshape18,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv6_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape19, reshape20, reshape21), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape22 = R.call_tir(cls.reshape10, (lv6_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape23 = R.call_tir(cls.reshape11, (reshape22,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv618 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_2_self_attn_out_proj_weight, reshape23, model_encoder_layers_2_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add18 = R.call_tir(cls.add4, (lv9, lv618), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm5 = R.call_tir(cls.layer_norm1, (add18, model_encoder_layers_2_final_layer_norm_weight, model_encoder_layers_2_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv98 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_2_fc1_weight, layer_norm5, model_encoder_layers_2_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv619 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_2_fc2_weight, lv98, model_encoder_layers_2_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv10 = R.call_tir(cls.fused_add4_maximum_minimum, (add18, lv619), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm6 = R.call_tir(cls.layer_norm1, (lv10, model_encoder_layers_3_self_attn_layer_norm_weight, model_encoder_layers_3_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv620 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_3_self_attn_q_proj_weight, layer_norm6, model_encoder_layers_3_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape24 = R.call_tir(cls.reshape, (lv620,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv134 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_3_self_attn_k_proj_weight, layer_norm6), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape25 = R.call_tir(cls.reshape, (lv134,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv621 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_3_self_attn_v_proj_weight, layer_norm6, model_encoder_layers_3_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape26 = R.call_tir(cls.reshape, (lv621,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape27 = R.call_tir(cls.reshape1, (reshape24,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape28 = R.call_tir(cls.reshape1, (reshape25,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape29 = R.call_tir(cls.reshape1, (reshape26,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv7_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape27, reshape28, reshape29), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape30 = R.call_tir(cls.reshape10, (lv7_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape31 = R.call_tir(cls.reshape11, (reshape30,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv622 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_3_self_attn_out_proj_weight, reshape31, model_encoder_layers_3_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add25 = R.call_tir(cls.add4, (lv10, lv622), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm7 = R.call_tir(cls.layer_norm1, (add25, model_encoder_layers_3_final_layer_norm_weight, model_encoder_layers_3_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv99 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_3_fc1_weight, layer_norm7, model_encoder_layers_3_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv623 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_3_fc2_weight, lv99, model_encoder_layers_3_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv11 = R.call_tir(cls.fused_add4_maximum_minimum, (add25, lv623), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm8 = R.call_tir(cls.layer_norm1, (lv11, model_encoder_layers_4_self_attn_layer_norm_weight, model_encoder_layers_4_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv624 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_4_self_attn_q_proj_weight, layer_norm8, model_encoder_layers_4_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape32 = R.call_tir(cls.reshape, (lv624,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv135 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_4_self_attn_k_proj_weight, layer_norm8), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape33 = R.call_tir(cls.reshape, (lv135,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv625 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_4_self_attn_v_proj_weight, layer_norm8, model_encoder_layers_4_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape34 = R.call_tir(cls.reshape, (lv625,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape35 = R.call_tir(cls.reshape1, (reshape32,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape36 = R.call_tir(cls.reshape1, (reshape33,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape37 = R.call_tir(cls.reshape1, (reshape34,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv8_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape35, reshape36, reshape37), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape38 = R.call_tir(cls.reshape10, (lv8_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape39 = R.call_tir(cls.reshape11, (reshape38,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv626 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_4_self_attn_out_proj_weight, reshape39, model_encoder_layers_4_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add32 = R.call_tir(cls.add4, (lv11, lv626), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm9 = R.call_tir(cls.layer_norm1, (add32, model_encoder_layers_4_final_layer_norm_weight, model_encoder_layers_4_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv100 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_4_fc1_weight, layer_norm9, model_encoder_layers_4_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv627 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_4_fc2_weight, lv100, model_encoder_layers_4_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv12 = R.call_tir(cls.fused_add4_maximum_minimum, (add32, lv627), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm10 = R.call_tir(cls.layer_norm1, (lv12, model_encoder_layers_5_self_attn_layer_norm_weight, model_encoder_layers_5_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv628 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_5_self_attn_q_proj_weight, layer_norm10, model_encoder_layers_5_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape40 = R.call_tir(cls.reshape, (lv628,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv136 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_5_self_attn_k_proj_weight, layer_norm10), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape41 = R.call_tir(cls.reshape, (lv136,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv629 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_5_self_attn_v_proj_weight, layer_norm10, model_encoder_layers_5_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape42 = R.call_tir(cls.reshape, (lv629,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape43 = R.call_tir(cls.reshape1, (reshape40,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape44 = R.call_tir(cls.reshape1, (reshape41,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape45 = R.call_tir(cls.reshape1, (reshape42,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv9_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape43, reshape44, reshape45), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape46 = R.call_tir(cls.reshape10, (lv9_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape47 = R.call_tir(cls.reshape11, (reshape46,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv630 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_5_self_attn_out_proj_weight, reshape47, model_encoder_layers_5_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add39 = R.call_tir(cls.add4, (lv12, lv630), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm11 = R.call_tir(cls.layer_norm1, (add39, model_encoder_layers_5_final_layer_norm_weight, model_encoder_layers_5_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv101 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_5_fc1_weight, layer_norm11, model_encoder_layers_5_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv631 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_5_fc2_weight, lv101, model_encoder_layers_5_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv13 = R.call_tir(cls.fused_add4_maximum_minimum, (add39, lv631), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm12 = R.call_tir(cls.layer_norm1, (lv13, model_encoder_layers_6_self_attn_layer_norm_weight, model_encoder_layers_6_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv632 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_6_self_attn_q_proj_weight, layer_norm12, model_encoder_layers_6_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape48 = R.call_tir(cls.reshape, (lv632,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv137 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_6_self_attn_k_proj_weight, layer_norm12), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape49 = R.call_tir(cls.reshape, (lv137,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv633 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_6_self_attn_v_proj_weight, layer_norm12, model_encoder_layers_6_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape50 = R.call_tir(cls.reshape, (lv633,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape51 = R.call_tir(cls.reshape1, (reshape48,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape52 = R.call_tir(cls.reshape1, (reshape49,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape53 = R.call_tir(cls.reshape1, (reshape50,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv10_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape51, reshape52, reshape53), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape54 = R.call_tir(cls.reshape10, (lv10_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape55 = R.call_tir(cls.reshape11, (reshape54,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv634 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_6_self_attn_out_proj_weight, reshape55, model_encoder_layers_6_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add46 = R.call_tir(cls.add4, (lv13, lv634), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm13 = R.call_tir(cls.layer_norm1, (add46, model_encoder_layers_6_final_layer_norm_weight, model_encoder_layers_6_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv102 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_6_fc1_weight, layer_norm13, model_encoder_layers_6_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv635 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_6_fc2_weight, lv102, model_encoder_layers_6_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv14 = R.call_tir(cls.fused_add4_maximum_minimum, (add46, lv635), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm14 = R.call_tir(cls.layer_norm1, (lv14, model_encoder_layers_7_self_attn_layer_norm_weight, model_encoder_layers_7_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv636 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_7_self_attn_q_proj_weight, layer_norm14, model_encoder_layers_7_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape56 = R.call_tir(cls.reshape, (lv636,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv138 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_7_self_attn_k_proj_weight, layer_norm14), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape57 = R.call_tir(cls.reshape, (lv138,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv637 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_7_self_attn_v_proj_weight, layer_norm14, model_encoder_layers_7_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape58 = R.call_tir(cls.reshape, (lv637,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape59 = R.call_tir(cls.reshape1, (reshape56,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape60 = R.call_tir(cls.reshape1, (reshape57,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape61 = R.call_tir(cls.reshape1, (reshape58,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv11_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape59, reshape60, reshape61), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape62 = R.call_tir(cls.reshape10, (lv11_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape63 = R.call_tir(cls.reshape11, (reshape62,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv638 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_7_self_attn_out_proj_weight, reshape63, model_encoder_layers_7_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add53 = R.call_tir(cls.add4, (lv14, lv638), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm15 = R.call_tir(cls.layer_norm1, (add53, model_encoder_layers_7_final_layer_norm_weight, model_encoder_layers_7_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv103 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_7_fc1_weight, layer_norm15, model_encoder_layers_7_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv639 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_7_fc2_weight, lv103, model_encoder_layers_7_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv15 = R.call_tir(cls.fused_add4_maximum_minimum, (add53, lv639), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm16 = R.call_tir(cls.layer_norm1, (lv15, model_encoder_layers_8_self_attn_layer_norm_weight, model_encoder_layers_8_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv640 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_8_self_attn_q_proj_weight, layer_norm16, model_encoder_layers_8_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape64 = R.call_tir(cls.reshape, (lv640,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv139 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_8_self_attn_k_proj_weight, layer_norm16), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape65 = R.call_tir(cls.reshape, (lv139,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv641 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_8_self_attn_v_proj_weight, layer_norm16, model_encoder_layers_8_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape66 = R.call_tir(cls.reshape, (lv641,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape67 = R.call_tir(cls.reshape1, (reshape64,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape68 = R.call_tir(cls.reshape1, (reshape65,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape69 = R.call_tir(cls.reshape1, (reshape66,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv12_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape67, reshape68, reshape69), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape70 = R.call_tir(cls.reshape10, (lv12_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape71 = R.call_tir(cls.reshape11, (reshape70,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv642 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_8_self_attn_out_proj_weight, reshape71, model_encoder_layers_8_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add60 = R.call_tir(cls.add4, (lv15, lv642), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm17 = R.call_tir(cls.layer_norm1, (add60, model_encoder_layers_8_final_layer_norm_weight, model_encoder_layers_8_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv104 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_8_fc1_weight, layer_norm17, model_encoder_layers_8_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv643 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_8_fc2_weight, lv104, model_encoder_layers_8_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv16 = R.call_tir(cls.fused_add4_maximum_minimum, (add60, lv643), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm18 = R.call_tir(cls.layer_norm1, (lv16, model_encoder_layers_9_self_attn_layer_norm_weight, model_encoder_layers_9_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv644 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_9_self_attn_q_proj_weight, layer_norm18, model_encoder_layers_9_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape72 = R.call_tir(cls.reshape, (lv644,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv140 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_9_self_attn_k_proj_weight, layer_norm18), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape73 = R.call_tir(cls.reshape, (lv140,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv645 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_9_self_attn_v_proj_weight, layer_norm18, model_encoder_layers_9_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape74 = R.call_tir(cls.reshape, (lv645,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape75 = R.call_tir(cls.reshape1, (reshape72,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape76 = R.call_tir(cls.reshape1, (reshape73,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape77 = R.call_tir(cls.reshape1, (reshape74,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv13_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape75, reshape76, reshape77), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape78 = R.call_tir(cls.reshape10, (lv13_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape79 = R.call_tir(cls.reshape11, (reshape78,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv646 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_9_self_attn_out_proj_weight, reshape79, model_encoder_layers_9_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add67 = R.call_tir(cls.add4, (lv16, lv646), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm19 = R.call_tir(cls.layer_norm1, (add67, model_encoder_layers_9_final_layer_norm_weight, model_encoder_layers_9_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv105 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_9_fc1_weight, layer_norm19, model_encoder_layers_9_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv647 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_9_fc2_weight, lv105, model_encoder_layers_9_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv17 = R.call_tir(cls.fused_add4_maximum_minimum, (add67, lv647), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm20 = R.call_tir(cls.layer_norm1, (lv17, model_encoder_layers_10_self_attn_layer_norm_weight, model_encoder_layers_10_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv648 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_10_self_attn_q_proj_weight, layer_norm20, model_encoder_layers_10_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape80 = R.call_tir(cls.reshape, (lv648,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv141 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_10_self_attn_k_proj_weight, layer_norm20), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape81 = R.call_tir(cls.reshape, (lv141,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv649 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_10_self_attn_v_proj_weight, layer_norm20, model_encoder_layers_10_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape82 = R.call_tir(cls.reshape, (lv649,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape83 = R.call_tir(cls.reshape1, (reshape80,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape84 = R.call_tir(cls.reshape1, (reshape81,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape85 = R.call_tir(cls.reshape1, (reshape82,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv14_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape83, reshape84, reshape85), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape86 = R.call_tir(cls.reshape10, (lv14_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape87 = R.call_tir(cls.reshape11, (reshape86,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv650 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_10_self_attn_out_proj_weight, reshape87, model_encoder_layers_10_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add74 = R.call_tir(cls.add4, (lv17, lv650), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm21 = R.call_tir(cls.layer_norm1, (add74, model_encoder_layers_10_final_layer_norm_weight, model_encoder_layers_10_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv106 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_10_fc1_weight, layer_norm21, model_encoder_layers_10_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv651 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_10_fc2_weight, lv106, model_encoder_layers_10_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv18 = R.call_tir(cls.fused_add4_maximum_minimum, (add74, lv651), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm22 = R.call_tir(cls.layer_norm1, (lv18, model_encoder_layers_11_self_attn_layer_norm_weight, model_encoder_layers_11_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv652 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_11_self_attn_q_proj_weight, layer_norm22, model_encoder_layers_11_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape88 = R.call_tir(cls.reshape, (lv652,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv142 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_11_self_attn_k_proj_weight, layer_norm22), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape89 = R.call_tir(cls.reshape, (lv142,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv653 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_11_self_attn_v_proj_weight, layer_norm22, model_encoder_layers_11_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape90 = R.call_tir(cls.reshape, (lv653,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape91 = R.call_tir(cls.reshape1, (reshape88,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape92 = R.call_tir(cls.reshape1, (reshape89,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape93 = R.call_tir(cls.reshape1, (reshape90,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv15_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape91, reshape92, reshape93), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape94 = R.call_tir(cls.reshape10, (lv15_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape95 = R.call_tir(cls.reshape11, (reshape94,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv654 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_11_self_attn_out_proj_weight, reshape95, model_encoder_layers_11_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add81 = R.call_tir(cls.add4, (lv18, lv654), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm23 = R.call_tir(cls.layer_norm1, (add81, model_encoder_layers_11_final_layer_norm_weight, model_encoder_layers_11_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv107 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_11_fc1_weight, layer_norm23, model_encoder_layers_11_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv655 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_11_fc2_weight, lv107, model_encoder_layers_11_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv19 = R.call_tir(cls.fused_add4_maximum_minimum, (add81, lv655), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm24 = R.call_tir(cls.layer_norm1, (lv19, model_encoder_layers_12_self_attn_layer_norm_weight, model_encoder_layers_12_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv656 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_12_self_attn_q_proj_weight, layer_norm24, model_encoder_layers_12_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape96 = R.call_tir(cls.reshape, (lv656,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv143 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_12_self_attn_k_proj_weight, layer_norm24), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape97 = R.call_tir(cls.reshape, (lv143,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv657 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_12_self_attn_v_proj_weight, layer_norm24, model_encoder_layers_12_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape98 = R.call_tir(cls.reshape, (lv657,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape99 = R.call_tir(cls.reshape1, (reshape96,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape100 = R.call_tir(cls.reshape1, (reshape97,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape101 = R.call_tir(cls.reshape1, (reshape98,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv16_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape99, reshape100, reshape101), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape102 = R.call_tir(cls.reshape10, (lv16_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape103 = R.call_tir(cls.reshape11, (reshape102,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv658 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_12_self_attn_out_proj_weight, reshape103, model_encoder_layers_12_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add88 = R.call_tir(cls.add4, (lv19, lv658), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm25 = R.call_tir(cls.layer_norm1, (add88, model_encoder_layers_12_final_layer_norm_weight, model_encoder_layers_12_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv108 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_12_fc1_weight, layer_norm25, model_encoder_layers_12_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv659 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_12_fc2_weight, lv108, model_encoder_layers_12_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv20 = R.call_tir(cls.fused_add4_maximum_minimum, (add88, lv659), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm26 = R.call_tir(cls.layer_norm1, (lv20, model_encoder_layers_13_self_attn_layer_norm_weight, model_encoder_layers_13_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv660 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_13_self_attn_q_proj_weight, layer_norm26, model_encoder_layers_13_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape104 = R.call_tir(cls.reshape, (lv660,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv144 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_13_self_attn_k_proj_weight, layer_norm26), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape105 = R.call_tir(cls.reshape, (lv144,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv661 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_13_self_attn_v_proj_weight, layer_norm26, model_encoder_layers_13_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape106 = R.call_tir(cls.reshape, (lv661,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape107 = R.call_tir(cls.reshape1, (reshape104,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape108 = R.call_tir(cls.reshape1, (reshape105,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape109 = R.call_tir(cls.reshape1, (reshape106,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv17_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape107, reshape108, reshape109), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape110 = R.call_tir(cls.reshape10, (lv17_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape111 = R.call_tir(cls.reshape11, (reshape110,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv662 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_13_self_attn_out_proj_weight, reshape111, model_encoder_layers_13_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add95 = R.call_tir(cls.add4, (lv20, lv662), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm27 = R.call_tir(cls.layer_norm1, (add95, model_encoder_layers_13_final_layer_norm_weight, model_encoder_layers_13_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv109 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_13_fc1_weight, layer_norm27, model_encoder_layers_13_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv663 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_13_fc2_weight, lv109, model_encoder_layers_13_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv21 = R.call_tir(cls.fused_add4_maximum_minimum, (add95, lv663), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm28 = R.call_tir(cls.layer_norm1, (lv21, model_encoder_layers_14_self_attn_layer_norm_weight, model_encoder_layers_14_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv664 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_14_self_attn_q_proj_weight, layer_norm28, model_encoder_layers_14_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape112 = R.call_tir(cls.reshape, (lv664,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv145 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_14_self_attn_k_proj_weight, layer_norm28), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape113 = R.call_tir(cls.reshape, (lv145,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv665 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_14_self_attn_v_proj_weight, layer_norm28, model_encoder_layers_14_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape114 = R.call_tir(cls.reshape, (lv665,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape115 = R.call_tir(cls.reshape1, (reshape112,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape116 = R.call_tir(cls.reshape1, (reshape113,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape117 = R.call_tir(cls.reshape1, (reshape114,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv18_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape115, reshape116, reshape117), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape118 = R.call_tir(cls.reshape10, (lv18_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape119 = R.call_tir(cls.reshape11, (reshape118,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv666 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_14_self_attn_out_proj_weight, reshape119, model_encoder_layers_14_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add102 = R.call_tir(cls.add4, (lv21, lv666), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm29 = R.call_tir(cls.layer_norm1, (add102, model_encoder_layers_14_final_layer_norm_weight, model_encoder_layers_14_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv110 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_14_fc1_weight, layer_norm29, model_encoder_layers_14_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv667 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_14_fc2_weight, lv110, model_encoder_layers_14_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv22 = R.call_tir(cls.fused_add4_maximum_minimum, (add102, lv667), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm30 = R.call_tir(cls.layer_norm1, (lv22, model_encoder_layers_15_self_attn_layer_norm_weight, model_encoder_layers_15_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv668 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_15_self_attn_q_proj_weight, layer_norm30, model_encoder_layers_15_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape120 = R.call_tir(cls.reshape, (lv668,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv146 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_15_self_attn_k_proj_weight, layer_norm30), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape121 = R.call_tir(cls.reshape, (lv146,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv669 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_15_self_attn_v_proj_weight, layer_norm30, model_encoder_layers_15_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape122 = R.call_tir(cls.reshape, (lv669,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape123 = R.call_tir(cls.reshape1, (reshape120,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape124 = R.call_tir(cls.reshape1, (reshape121,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape125 = R.call_tir(cls.reshape1, (reshape122,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv19_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape123, reshape124, reshape125), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape126 = R.call_tir(cls.reshape10, (lv19_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape127 = R.call_tir(cls.reshape11, (reshape126,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv670 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_15_self_attn_out_proj_weight, reshape127, model_encoder_layers_15_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add109 = R.call_tir(cls.add4, (lv22, lv670), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm31 = R.call_tir(cls.layer_norm1, (add109, model_encoder_layers_15_final_layer_norm_weight, model_encoder_layers_15_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv111 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_15_fc1_weight, layer_norm31, model_encoder_layers_15_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv671 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_15_fc2_weight, lv111, model_encoder_layers_15_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv23 = R.call_tir(cls.fused_add4_maximum_minimum, (add109, lv671), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm32 = R.call_tir(cls.layer_norm1, (lv23, model_encoder_layers_16_self_attn_layer_norm_weight, model_encoder_layers_16_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv672 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_16_self_attn_q_proj_weight, layer_norm32, model_encoder_layers_16_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape128 = R.call_tir(cls.reshape, (lv672,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv147 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_16_self_attn_k_proj_weight, layer_norm32), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape129 = R.call_tir(cls.reshape, (lv147,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv673 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_16_self_attn_v_proj_weight, layer_norm32, model_encoder_layers_16_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape130 = R.call_tir(cls.reshape, (lv673,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape131 = R.call_tir(cls.reshape1, (reshape128,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape132 = R.call_tir(cls.reshape1, (reshape129,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape133 = R.call_tir(cls.reshape1, (reshape130,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv20_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape131, reshape132, reshape133), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape134 = R.call_tir(cls.reshape10, (lv20_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape135 = R.call_tir(cls.reshape11, (reshape134,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv674 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_16_self_attn_out_proj_weight, reshape135, model_encoder_layers_16_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add116 = R.call_tir(cls.add4, (lv23, lv674), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm33 = R.call_tir(cls.layer_norm1, (add116, model_encoder_layers_16_final_layer_norm_weight, model_encoder_layers_16_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv112 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_16_fc1_weight, layer_norm33, model_encoder_layers_16_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv675 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_16_fc2_weight, lv112, model_encoder_layers_16_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv24 = R.call_tir(cls.fused_add4_maximum_minimum, (add116, lv675), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm34 = R.call_tir(cls.layer_norm1, (lv24, model_encoder_layers_17_self_attn_layer_norm_weight, model_encoder_layers_17_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv676 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_17_self_attn_q_proj_weight, layer_norm34, model_encoder_layers_17_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape136 = R.call_tir(cls.reshape, (lv676,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv148 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_17_self_attn_k_proj_weight, layer_norm34), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape137 = R.call_tir(cls.reshape, (lv148,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv677 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_17_self_attn_v_proj_weight, layer_norm34, model_encoder_layers_17_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape138 = R.call_tir(cls.reshape, (lv677,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape139 = R.call_tir(cls.reshape1, (reshape136,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape140 = R.call_tir(cls.reshape1, (reshape137,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape141 = R.call_tir(cls.reshape1, (reshape138,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv21_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape139, reshape140, reshape141), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape142 = R.call_tir(cls.reshape10, (lv21_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape143 = R.call_tir(cls.reshape11, (reshape142,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv678 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_17_self_attn_out_proj_weight, reshape143, model_encoder_layers_17_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add123 = R.call_tir(cls.add4, (lv24, lv678), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm35 = R.call_tir(cls.layer_norm1, (add123, model_encoder_layers_17_final_layer_norm_weight, model_encoder_layers_17_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv113 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_17_fc1_weight, layer_norm35, model_encoder_layers_17_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv679 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_17_fc2_weight, lv113, model_encoder_layers_17_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv25 = R.call_tir(cls.fused_add4_maximum_minimum, (add123, lv679), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm36 = R.call_tir(cls.layer_norm1, (lv25, model_encoder_layers_18_self_attn_layer_norm_weight, model_encoder_layers_18_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv680 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_18_self_attn_q_proj_weight, layer_norm36, model_encoder_layers_18_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape144 = R.call_tir(cls.reshape, (lv680,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv149 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_18_self_attn_k_proj_weight, layer_norm36), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape145 = R.call_tir(cls.reshape, (lv149,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv681 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_18_self_attn_v_proj_weight, layer_norm36, model_encoder_layers_18_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape146 = R.call_tir(cls.reshape, (lv681,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape147 = R.call_tir(cls.reshape1, (reshape144,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape148 = R.call_tir(cls.reshape1, (reshape145,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape149 = R.call_tir(cls.reshape1, (reshape146,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv22_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape147, reshape148, reshape149), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape150 = R.call_tir(cls.reshape10, (lv22_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape151 = R.call_tir(cls.reshape11, (reshape150,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv682 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_18_self_attn_out_proj_weight, reshape151, model_encoder_layers_18_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add130 = R.call_tir(cls.add4, (lv25, lv682), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm37 = R.call_tir(cls.layer_norm1, (add130, model_encoder_layers_18_final_layer_norm_weight, model_encoder_layers_18_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv114 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_18_fc1_weight, layer_norm37, model_encoder_layers_18_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv683 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_18_fc2_weight, lv114, model_encoder_layers_18_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv26 = R.call_tir(cls.fused_add4_maximum_minimum, (add130, lv683), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm38 = R.call_tir(cls.layer_norm1, (lv26, model_encoder_layers_19_self_attn_layer_norm_weight, model_encoder_layers_19_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv684 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_19_self_attn_q_proj_weight, layer_norm38, model_encoder_layers_19_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape152 = R.call_tir(cls.reshape, (lv684,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv150 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_19_self_attn_k_proj_weight, layer_norm38), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape153 = R.call_tir(cls.reshape, (lv150,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv685 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_19_self_attn_v_proj_weight, layer_norm38, model_encoder_layers_19_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape154 = R.call_tir(cls.reshape, (lv685,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape155 = R.call_tir(cls.reshape1, (reshape152,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape156 = R.call_tir(cls.reshape1, (reshape153,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape157 = R.call_tir(cls.reshape1, (reshape154,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv23_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape155, reshape156, reshape157), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape158 = R.call_tir(cls.reshape10, (lv23_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape159 = R.call_tir(cls.reshape11, (reshape158,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv686 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_19_self_attn_out_proj_weight, reshape159, model_encoder_layers_19_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add137 = R.call_tir(cls.add4, (lv26, lv686), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm39 = R.call_tir(cls.layer_norm1, (add137, model_encoder_layers_19_final_layer_norm_weight, model_encoder_layers_19_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv115 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_19_fc1_weight, layer_norm39, model_encoder_layers_19_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv687 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_19_fc2_weight, lv115, model_encoder_layers_19_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv27 = R.call_tir(cls.fused_add4_maximum_minimum, (add137, lv687), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm40 = R.call_tir(cls.layer_norm1, (lv27, model_encoder_layers_20_self_attn_layer_norm_weight, model_encoder_layers_20_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv688 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_20_self_attn_q_proj_weight, layer_norm40, model_encoder_layers_20_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape160 = R.call_tir(cls.reshape, (lv688,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv151 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_20_self_attn_k_proj_weight, layer_norm40), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape161 = R.call_tir(cls.reshape, (lv151,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv689 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_20_self_attn_v_proj_weight, layer_norm40, model_encoder_layers_20_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape162 = R.call_tir(cls.reshape, (lv689,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape163 = R.call_tir(cls.reshape1, (reshape160,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape164 = R.call_tir(cls.reshape1, (reshape161,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape165 = R.call_tir(cls.reshape1, (reshape162,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv24_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape163, reshape164, reshape165), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape166 = R.call_tir(cls.reshape10, (lv24_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape167 = R.call_tir(cls.reshape11, (reshape166,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv690 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_20_self_attn_out_proj_weight, reshape167, model_encoder_layers_20_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add144 = R.call_tir(cls.add4, (lv27, lv690), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm41 = R.call_tir(cls.layer_norm1, (add144, model_encoder_layers_20_final_layer_norm_weight, model_encoder_layers_20_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv116 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_20_fc1_weight, layer_norm41, model_encoder_layers_20_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv691 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_20_fc2_weight, lv116, model_encoder_layers_20_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv28 = R.call_tir(cls.fused_add4_maximum_minimum, (add144, lv691), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm42 = R.call_tir(cls.layer_norm1, (lv28, model_encoder_layers_21_self_attn_layer_norm_weight, model_encoder_layers_21_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv692 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_21_self_attn_q_proj_weight, layer_norm42, model_encoder_layers_21_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape168 = R.call_tir(cls.reshape, (lv692,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv152 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_21_self_attn_k_proj_weight, layer_norm42), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape169 = R.call_tir(cls.reshape, (lv152,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv693 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_21_self_attn_v_proj_weight, layer_norm42, model_encoder_layers_21_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape170 = R.call_tir(cls.reshape, (lv693,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape171 = R.call_tir(cls.reshape1, (reshape168,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape172 = R.call_tir(cls.reshape1, (reshape169,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape173 = R.call_tir(cls.reshape1, (reshape170,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv25_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape171, reshape172, reshape173), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape174 = R.call_tir(cls.reshape10, (lv25_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape175 = R.call_tir(cls.reshape11, (reshape174,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv694 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_21_self_attn_out_proj_weight, reshape175, model_encoder_layers_21_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add151 = R.call_tir(cls.add4, (lv28, lv694), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm43 = R.call_tir(cls.layer_norm1, (add151, model_encoder_layers_21_final_layer_norm_weight, model_encoder_layers_21_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv117 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_21_fc1_weight, layer_norm43, model_encoder_layers_21_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv695 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_21_fc2_weight, lv117, model_encoder_layers_21_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv29 = R.call_tir(cls.fused_add4_maximum_minimum, (add151, lv695), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm44 = R.call_tir(cls.layer_norm1, (lv29, model_encoder_layers_22_self_attn_layer_norm_weight, model_encoder_layers_22_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv696 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_22_self_attn_q_proj_weight, layer_norm44, model_encoder_layers_22_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape176 = R.call_tir(cls.reshape, (lv696,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv153 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_22_self_attn_k_proj_weight, layer_norm44), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape177 = R.call_tir(cls.reshape, (lv153,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv697 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_22_self_attn_v_proj_weight, layer_norm44, model_encoder_layers_22_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape178 = R.call_tir(cls.reshape, (lv697,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape179 = R.call_tir(cls.reshape1, (reshape176,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape180 = R.call_tir(cls.reshape1, (reshape177,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape181 = R.call_tir(cls.reshape1, (reshape178,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv26_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape179, reshape180, reshape181), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape182 = R.call_tir(cls.reshape10, (lv26_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape183 = R.call_tir(cls.reshape11, (reshape182,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv698 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_22_self_attn_out_proj_weight, reshape183, model_encoder_layers_22_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add158 = R.call_tir(cls.add4, (lv29, lv698), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm45 = R.call_tir(cls.layer_norm1, (add158, model_encoder_layers_22_final_layer_norm_weight, model_encoder_layers_22_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv118 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_22_fc1_weight, layer_norm45, model_encoder_layers_22_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv699 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_22_fc2_weight, lv118, model_encoder_layers_22_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv30 = R.call_tir(cls.fused_add4_maximum_minimum, (add158, lv699), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm46 = R.call_tir(cls.layer_norm1, (lv30, model_encoder_layers_23_self_attn_layer_norm_weight, model_encoder_layers_23_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv700 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_23_self_attn_q_proj_weight, layer_norm46, model_encoder_layers_23_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape184 = R.call_tir(cls.reshape, (lv700,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv154 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_23_self_attn_k_proj_weight, layer_norm46), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape185 = R.call_tir(cls.reshape, (lv154,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv701 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_23_self_attn_v_proj_weight, layer_norm46, model_encoder_layers_23_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape186 = R.call_tir(cls.reshape, (lv701,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape187 = R.call_tir(cls.reshape1, (reshape184,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape188 = R.call_tir(cls.reshape1, (reshape185,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape189 = R.call_tir(cls.reshape1, (reshape186,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv27_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape187, reshape188, reshape189), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape190 = R.call_tir(cls.reshape10, (lv27_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape191 = R.call_tir(cls.reshape11, (reshape190,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv702 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_23_self_attn_out_proj_weight, reshape191, model_encoder_layers_23_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add165 = R.call_tir(cls.add4, (lv30, lv702), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm47 = R.call_tir(cls.layer_norm1, (add165, model_encoder_layers_23_final_layer_norm_weight, model_encoder_layers_23_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv119 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_23_fc1_weight, layer_norm47, model_encoder_layers_23_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv703 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_23_fc2_weight, lv119, model_encoder_layers_23_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv31 = R.call_tir(cls.fused_add4_maximum_minimum, (add165, lv703), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm48 = R.call_tir(cls.layer_norm1, (lv31, model_encoder_layers_24_self_attn_layer_norm_weight, model_encoder_layers_24_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv704 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_24_self_attn_q_proj_weight, layer_norm48, model_encoder_layers_24_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape192 = R.call_tir(cls.reshape, (lv704,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv155 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_24_self_attn_k_proj_weight, layer_norm48), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape193 = R.call_tir(cls.reshape, (lv155,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv705 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_24_self_attn_v_proj_weight, layer_norm48, model_encoder_layers_24_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape194 = R.call_tir(cls.reshape, (lv705,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape195 = R.call_tir(cls.reshape1, (reshape192,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape196 = R.call_tir(cls.reshape1, (reshape193,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape197 = R.call_tir(cls.reshape1, (reshape194,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv28_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape195, reshape196, reshape197), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape198 = R.call_tir(cls.reshape10, (lv28_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape199 = R.call_tir(cls.reshape11, (reshape198,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv706 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_24_self_attn_out_proj_weight, reshape199, model_encoder_layers_24_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add172 = R.call_tir(cls.add4, (lv31, lv706), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm49 = R.call_tir(cls.layer_norm1, (add172, model_encoder_layers_24_final_layer_norm_weight, model_encoder_layers_24_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv120 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_24_fc1_weight, layer_norm49, model_encoder_layers_24_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv707 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_24_fc2_weight, lv120, model_encoder_layers_24_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv32 = R.call_tir(cls.fused_add4_maximum_minimum, (add172, lv707), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm50 = R.call_tir(cls.layer_norm1, (lv32, model_encoder_layers_25_self_attn_layer_norm_weight, model_encoder_layers_25_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv708 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_25_self_attn_q_proj_weight, layer_norm50, model_encoder_layers_25_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape200 = R.call_tir(cls.reshape, (lv708,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv156 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_25_self_attn_k_proj_weight, layer_norm50), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape201 = R.call_tir(cls.reshape, (lv156,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv709 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_25_self_attn_v_proj_weight, layer_norm50, model_encoder_layers_25_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape202 = R.call_tir(cls.reshape, (lv709,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape203 = R.call_tir(cls.reshape1, (reshape200,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape204 = R.call_tir(cls.reshape1, (reshape201,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape205 = R.call_tir(cls.reshape1, (reshape202,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv29_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape203, reshape204, reshape205), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape206 = R.call_tir(cls.reshape10, (lv29_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape207 = R.call_tir(cls.reshape11, (reshape206,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv710 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_25_self_attn_out_proj_weight, reshape207, model_encoder_layers_25_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add179 = R.call_tir(cls.add4, (lv32, lv710), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm51 = R.call_tir(cls.layer_norm1, (add179, model_encoder_layers_25_final_layer_norm_weight, model_encoder_layers_25_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv121 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_25_fc1_weight, layer_norm51, model_encoder_layers_25_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv711 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_25_fc2_weight, lv121, model_encoder_layers_25_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv33 = R.call_tir(cls.fused_add4_maximum_minimum, (add179, lv711), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm52 = R.call_tir(cls.layer_norm1, (lv33, model_encoder_layers_26_self_attn_layer_norm_weight, model_encoder_layers_26_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv712 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_26_self_attn_q_proj_weight, layer_norm52, model_encoder_layers_26_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape208 = R.call_tir(cls.reshape, (lv712,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv157 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_26_self_attn_k_proj_weight, layer_norm52), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape209 = R.call_tir(cls.reshape, (lv157,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv713 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_26_self_attn_v_proj_weight, layer_norm52, model_encoder_layers_26_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape210 = R.call_tir(cls.reshape, (lv713,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape211 = R.call_tir(cls.reshape1, (reshape208,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape212 = R.call_tir(cls.reshape1, (reshape209,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape213 = R.call_tir(cls.reshape1, (reshape210,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv30_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape211, reshape212, reshape213), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape214 = R.call_tir(cls.reshape10, (lv30_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape215 = R.call_tir(cls.reshape11, (reshape214,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv714 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_26_self_attn_out_proj_weight, reshape215, model_encoder_layers_26_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add186 = R.call_tir(cls.add4, (lv33, lv714), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm53 = R.call_tir(cls.layer_norm1, (add186, model_encoder_layers_26_final_layer_norm_weight, model_encoder_layers_26_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv122 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_26_fc1_weight, layer_norm53, model_encoder_layers_26_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv715 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_26_fc2_weight, lv122, model_encoder_layers_26_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv34 = R.call_tir(cls.fused_add4_maximum_minimum, (add186, lv715), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm54 = R.call_tir(cls.layer_norm1, (lv34, model_encoder_layers_27_self_attn_layer_norm_weight, model_encoder_layers_27_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv716 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_27_self_attn_q_proj_weight, layer_norm54, model_encoder_layers_27_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape216 = R.call_tir(cls.reshape, (lv716,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv158 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_27_self_attn_k_proj_weight, layer_norm54), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape217 = R.call_tir(cls.reshape, (lv158,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv717 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_27_self_attn_v_proj_weight, layer_norm54, model_encoder_layers_27_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape218 = R.call_tir(cls.reshape, (lv717,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape219 = R.call_tir(cls.reshape1, (reshape216,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape220 = R.call_tir(cls.reshape1, (reshape217,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape221 = R.call_tir(cls.reshape1, (reshape218,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv31_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape219, reshape220, reshape221), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape222 = R.call_tir(cls.reshape10, (lv31_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape223 = R.call_tir(cls.reshape11, (reshape222,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv718 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_27_self_attn_out_proj_weight, reshape223, model_encoder_layers_27_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add193 = R.call_tir(cls.add4, (lv34, lv718), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm55 = R.call_tir(cls.layer_norm1, (add193, model_encoder_layers_27_final_layer_norm_weight, model_encoder_layers_27_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv123 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_27_fc1_weight, layer_norm55, model_encoder_layers_27_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv719 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_27_fc2_weight, lv123, model_encoder_layers_27_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv35 = R.call_tir(cls.fused_add4_maximum_minimum, (add193, lv719), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm56 = R.call_tir(cls.layer_norm1, (lv35, model_encoder_layers_28_self_attn_layer_norm_weight, model_encoder_layers_28_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv720 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_28_self_attn_q_proj_weight, layer_norm56, model_encoder_layers_28_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape224 = R.call_tir(cls.reshape, (lv720,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv159 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_28_self_attn_k_proj_weight, layer_norm56), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape225 = R.call_tir(cls.reshape, (lv159,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv721 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_28_self_attn_v_proj_weight, layer_norm56, model_encoder_layers_28_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape226 = R.call_tir(cls.reshape, (lv721,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape227 = R.call_tir(cls.reshape1, (reshape224,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape228 = R.call_tir(cls.reshape1, (reshape225,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape229 = R.call_tir(cls.reshape1, (reshape226,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv32_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape227, reshape228, reshape229), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape230 = R.call_tir(cls.reshape10, (lv32_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape231 = R.call_tir(cls.reshape11, (reshape230,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv722 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_28_self_attn_out_proj_weight, reshape231, model_encoder_layers_28_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add200 = R.call_tir(cls.add4, (lv35, lv722), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm57 = R.call_tir(cls.layer_norm1, (add200, model_encoder_layers_28_final_layer_norm_weight, model_encoder_layers_28_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv124 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_28_fc1_weight, layer_norm57, model_encoder_layers_28_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv723 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_28_fc2_weight, lv124, model_encoder_layers_28_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv36 = R.call_tir(cls.fused_add4_maximum_minimum, (add200, lv723), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm58 = R.call_tir(cls.layer_norm1, (lv36, model_encoder_layers_29_self_attn_layer_norm_weight, model_encoder_layers_29_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv724 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_29_self_attn_q_proj_weight, layer_norm58, model_encoder_layers_29_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape232 = R.call_tir(cls.reshape, (lv724,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv160 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_29_self_attn_k_proj_weight, layer_norm58), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape233 = R.call_tir(cls.reshape, (lv160,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv725 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_29_self_attn_v_proj_weight, layer_norm58, model_encoder_layers_29_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape234 = R.call_tir(cls.reshape, (lv725,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape235 = R.call_tir(cls.reshape1, (reshape232,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape236 = R.call_tir(cls.reshape1, (reshape233,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape237 = R.call_tir(cls.reshape1, (reshape234,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv33_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape235, reshape236, reshape237), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape238 = R.call_tir(cls.reshape10, (lv33_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape239 = R.call_tir(cls.reshape11, (reshape238,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv726 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_29_self_attn_out_proj_weight, reshape239, model_encoder_layers_29_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add207 = R.call_tir(cls.add4, (lv36, lv726), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm59 = R.call_tir(cls.layer_norm1, (add207, model_encoder_layers_29_final_layer_norm_weight, model_encoder_layers_29_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv125 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_29_fc1_weight, layer_norm59, model_encoder_layers_29_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv727 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_29_fc2_weight, lv125, model_encoder_layers_29_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv37 = R.call_tir(cls.fused_add4_maximum_minimum, (add207, lv727), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm60 = R.call_tir(cls.layer_norm1, (lv37, model_encoder_layers_30_self_attn_layer_norm_weight, model_encoder_layers_30_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv728 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_30_self_attn_q_proj_weight, layer_norm60, model_encoder_layers_30_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape240 = R.call_tir(cls.reshape, (lv728,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv161 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_30_self_attn_k_proj_weight, layer_norm60), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape241 = R.call_tir(cls.reshape, (lv161,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv729 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_30_self_attn_v_proj_weight, layer_norm60, model_encoder_layers_30_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape242 = R.call_tir(cls.reshape, (lv729,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape243 = R.call_tir(cls.reshape1, (reshape240,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape244 = R.call_tir(cls.reshape1, (reshape241,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape245 = R.call_tir(cls.reshape1, (reshape242,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv34_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape243, reshape244, reshape245), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape246 = R.call_tir(cls.reshape10, (lv34_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape247 = R.call_tir(cls.reshape11, (reshape246,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv730 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_30_self_attn_out_proj_weight, reshape247, model_encoder_layers_30_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add214 = R.call_tir(cls.add4, (lv37, lv730), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm61 = R.call_tir(cls.layer_norm1, (add214, model_encoder_layers_30_final_layer_norm_weight, model_encoder_layers_30_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv126 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_30_fc1_weight, layer_norm61, model_encoder_layers_30_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv731 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_30_fc2_weight, lv126, model_encoder_layers_30_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv38 = R.call_tir(cls.fused_add4_maximum_minimum, (add214, lv731), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm62 = R.call_tir(cls.layer_norm1, (lv38, model_encoder_layers_31_self_attn_layer_norm_weight, model_encoder_layers_31_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv732 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_31_self_attn_q_proj_weight, layer_norm62, model_encoder_layers_31_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape248 = R.call_tir(cls.reshape, (lv732,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv162 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_31_self_attn_k_proj_weight, layer_norm62), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape249 = R.call_tir(cls.reshape, (lv162,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + lv733 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_31_self_attn_v_proj_weight, layer_norm62, model_encoder_layers_31_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + reshape250 = R.call_tir(cls.reshape, (lv733,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape251 = R.call_tir(cls.reshape1, (reshape248,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape252 = R.call_tir(cls.reshape1, (reshape249,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape253 = R.call_tir(cls.reshape1, (reshape250,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + lv35_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape251, reshape252, reshape253), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) + reshape254 = R.call_tir(cls.reshape10, (lv35_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) + reshape255 = R.call_tir(cls.reshape11, (reshape254,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv734 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_31_self_attn_out_proj_weight, reshape255, model_encoder_layers_31_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + add221 = R.call_tir(cls.add4, (lv38, lv734), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + layer_norm63 = R.call_tir(cls.layer_norm1, (add221, model_encoder_layers_31_final_layer_norm_weight, model_encoder_layers_31_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv127 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_31_fc1_weight, layer_norm63, model_encoder_layers_31_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) + lv735 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_31_fc2_weight, lv127, model_encoder_layers_31_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + lv39 = R.call_tir(cls.fused_add4_maximum_minimum, (add221, lv735), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + gv = R.call_tir(cls.layer_norm1, (lv39, model_encoder_layer_norm_weight, model_encoder_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) + R.output(gv) + return gv + + @R.function + def batch_prefill(input_ids: R.Tensor((1, "seq_len"), dtype="int32"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor((1, "batch_size", 51866), dtype="float32"): + batch_size = T.int64() + seq_len = T.int64() + R.func_attr({"num_input": 3, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + with R.dataflow(): + model_decoder_embed_tokens_weight2: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] + model_decoder_embed_positions_weight2: R.Tensor((448, 1280), dtype="float16") = packed_params[488] + model_decoder_layers_0_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] + model_decoder_layers_0_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] + model_decoder_layers_0_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[491] + model_decoder_layers_0_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] + model_decoder_layers_0_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[493] + model_decoder_layers_0_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] + model_decoder_layers_0_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[495] + model_decoder_layers_0_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[496] + model_decoder_layers_0_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[497] + model_decoder_layers_0_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] + model_decoder_layers_0_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[502] + model_decoder_layers_0_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] + model_decoder_layers_0_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[504] + model_decoder_layers_0_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[505] + model_decoder_layers_0_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[506] + model_decoder_layers_0_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] + model_decoder_layers_0_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[508] + model_decoder_layers_0_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] + model_decoder_layers_0_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[510] + model_decoder_layers_0_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[511] + model_decoder_layers_0_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[512] + model_decoder_layers_1_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] + model_decoder_layers_1_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] + model_decoder_layers_1_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[515] + model_decoder_layers_1_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] + model_decoder_layers_1_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[517] + model_decoder_layers_1_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] + model_decoder_layers_1_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[519] + model_decoder_layers_1_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[520] + model_decoder_layers_1_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[521] + model_decoder_layers_1_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] + model_decoder_layers_1_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[526] + model_decoder_layers_1_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] + model_decoder_layers_1_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[528] + model_decoder_layers_1_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[529] + model_decoder_layers_1_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[530] + model_decoder_layers_1_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] + model_decoder_layers_1_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[532] + model_decoder_layers_1_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] + model_decoder_layers_1_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[534] + model_decoder_layers_1_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[535] + model_decoder_layers_1_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[536] + model_decoder_layers_2_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] + model_decoder_layers_2_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] + model_decoder_layers_2_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[539] + model_decoder_layers_2_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] + model_decoder_layers_2_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[541] + model_decoder_layers_2_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] + model_decoder_layers_2_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[543] + model_decoder_layers_2_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[544] + model_decoder_layers_2_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[545] + model_decoder_layers_2_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] + model_decoder_layers_2_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[550] + model_decoder_layers_2_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] + model_decoder_layers_2_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[552] + model_decoder_layers_2_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[553] + model_decoder_layers_2_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[554] + model_decoder_layers_2_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] + model_decoder_layers_2_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[556] + model_decoder_layers_2_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] + model_decoder_layers_2_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[558] + model_decoder_layers_2_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[559] + model_decoder_layers_2_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[560] + model_decoder_layers_3_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] + model_decoder_layers_3_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] + model_decoder_layers_3_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[563] + model_decoder_layers_3_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] + model_decoder_layers_3_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[565] + model_decoder_layers_3_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] + model_decoder_layers_3_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[567] + model_decoder_layers_3_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[568] + model_decoder_layers_3_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[569] + model_decoder_layers_3_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] + model_decoder_layers_3_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[574] + model_decoder_layers_3_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] + model_decoder_layers_3_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[576] + model_decoder_layers_3_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[577] + model_decoder_layers_3_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[578] + model_decoder_layers_3_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] + model_decoder_layers_3_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[580] + model_decoder_layers_3_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] + model_decoder_layers_3_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[582] + model_decoder_layers_3_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[583] + model_decoder_layers_3_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[584] + model_decoder_layers_4_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] + model_decoder_layers_4_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] + model_decoder_layers_4_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[587] + model_decoder_layers_4_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] + model_decoder_layers_4_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[589] + model_decoder_layers_4_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] + model_decoder_layers_4_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[591] + model_decoder_layers_4_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[592] + model_decoder_layers_4_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[593] + model_decoder_layers_4_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] + model_decoder_layers_4_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[598] + model_decoder_layers_4_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] + model_decoder_layers_4_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[600] + model_decoder_layers_4_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[601] + model_decoder_layers_4_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[602] + model_decoder_layers_4_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] + model_decoder_layers_4_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[604] + model_decoder_layers_4_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] + model_decoder_layers_4_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[606] + model_decoder_layers_4_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[607] + model_decoder_layers_4_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[608] + model_decoder_layers_5_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] + model_decoder_layers_5_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] + model_decoder_layers_5_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[611] + model_decoder_layers_5_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] + model_decoder_layers_5_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[613] + model_decoder_layers_5_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] + model_decoder_layers_5_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[615] + model_decoder_layers_5_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[616] + model_decoder_layers_5_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[617] + model_decoder_layers_5_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] + model_decoder_layers_5_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[622] + model_decoder_layers_5_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] + model_decoder_layers_5_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[624] + model_decoder_layers_5_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[625] + model_decoder_layers_5_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[626] + model_decoder_layers_5_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] + model_decoder_layers_5_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[628] + model_decoder_layers_5_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] + model_decoder_layers_5_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[630] + model_decoder_layers_5_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[631] + model_decoder_layers_5_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[632] + model_decoder_layers_6_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] + model_decoder_layers_6_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] + model_decoder_layers_6_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[635] + model_decoder_layers_6_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] + model_decoder_layers_6_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[637] + model_decoder_layers_6_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] + model_decoder_layers_6_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[639] + model_decoder_layers_6_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[640] + model_decoder_layers_6_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[641] + model_decoder_layers_6_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] + model_decoder_layers_6_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[646] + model_decoder_layers_6_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] + model_decoder_layers_6_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[648] + model_decoder_layers_6_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[649] + model_decoder_layers_6_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[650] + model_decoder_layers_6_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] + model_decoder_layers_6_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[652] + model_decoder_layers_6_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] + model_decoder_layers_6_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[654] + model_decoder_layers_6_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[655] + model_decoder_layers_6_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[656] + model_decoder_layers_7_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] + model_decoder_layers_7_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] + model_decoder_layers_7_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[659] + model_decoder_layers_7_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] + model_decoder_layers_7_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[661] + model_decoder_layers_7_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] + model_decoder_layers_7_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[663] + model_decoder_layers_7_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[664] + model_decoder_layers_7_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[665] + model_decoder_layers_7_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] + model_decoder_layers_7_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[670] + model_decoder_layers_7_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] + model_decoder_layers_7_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[672] + model_decoder_layers_7_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[673] + model_decoder_layers_7_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[674] + model_decoder_layers_7_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] + model_decoder_layers_7_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[676] + model_decoder_layers_7_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] + model_decoder_layers_7_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[678] + model_decoder_layers_7_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[679] + model_decoder_layers_7_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[680] + model_decoder_layers_8_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] + model_decoder_layers_8_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] + model_decoder_layers_8_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[683] + model_decoder_layers_8_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] + model_decoder_layers_8_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[685] + model_decoder_layers_8_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] + model_decoder_layers_8_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[687] + model_decoder_layers_8_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[688] + model_decoder_layers_8_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[689] + model_decoder_layers_8_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] + model_decoder_layers_8_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[694] + model_decoder_layers_8_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] + model_decoder_layers_8_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[696] + model_decoder_layers_8_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[697] + model_decoder_layers_8_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[698] + model_decoder_layers_8_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] + model_decoder_layers_8_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[700] + model_decoder_layers_8_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] + model_decoder_layers_8_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[702] + model_decoder_layers_8_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[703] + model_decoder_layers_8_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[704] + model_decoder_layers_9_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] + model_decoder_layers_9_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] + model_decoder_layers_9_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[707] + model_decoder_layers_9_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] + model_decoder_layers_9_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[709] + model_decoder_layers_9_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] + model_decoder_layers_9_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[711] + model_decoder_layers_9_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[712] + model_decoder_layers_9_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[713] + model_decoder_layers_9_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] + model_decoder_layers_9_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[718] + model_decoder_layers_9_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] + model_decoder_layers_9_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[720] + model_decoder_layers_9_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[721] + model_decoder_layers_9_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[722] + model_decoder_layers_9_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] + model_decoder_layers_9_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[724] + model_decoder_layers_9_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] + model_decoder_layers_9_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[726] + model_decoder_layers_9_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[727] + model_decoder_layers_9_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[728] + model_decoder_layers_10_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] + model_decoder_layers_10_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] + model_decoder_layers_10_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[731] + model_decoder_layers_10_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] + model_decoder_layers_10_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[733] + model_decoder_layers_10_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] + model_decoder_layers_10_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[735] + model_decoder_layers_10_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[736] + model_decoder_layers_10_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[737] + model_decoder_layers_10_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] + model_decoder_layers_10_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[742] + model_decoder_layers_10_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] + model_decoder_layers_10_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[744] + model_decoder_layers_10_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[745] + model_decoder_layers_10_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[746] + model_decoder_layers_10_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] + model_decoder_layers_10_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[748] + model_decoder_layers_10_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] + model_decoder_layers_10_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[750] + model_decoder_layers_10_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[751] + model_decoder_layers_10_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[752] + model_decoder_layers_11_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] + model_decoder_layers_11_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] + model_decoder_layers_11_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[755] + model_decoder_layers_11_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] + model_decoder_layers_11_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[757] + model_decoder_layers_11_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] + model_decoder_layers_11_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[759] + model_decoder_layers_11_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[760] + model_decoder_layers_11_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[761] + model_decoder_layers_11_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] + model_decoder_layers_11_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[766] + model_decoder_layers_11_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] + model_decoder_layers_11_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[768] + model_decoder_layers_11_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[769] + model_decoder_layers_11_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[770] + model_decoder_layers_11_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] + model_decoder_layers_11_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[772] + model_decoder_layers_11_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] + model_decoder_layers_11_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[774] + model_decoder_layers_11_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[775] + model_decoder_layers_11_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[776] + model_decoder_layers_12_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] + model_decoder_layers_12_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] + model_decoder_layers_12_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[779] + model_decoder_layers_12_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] + model_decoder_layers_12_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[781] + model_decoder_layers_12_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] + model_decoder_layers_12_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[783] + model_decoder_layers_12_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[784] + model_decoder_layers_12_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[785] + model_decoder_layers_12_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] + model_decoder_layers_12_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[790] + model_decoder_layers_12_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] + model_decoder_layers_12_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[792] + model_decoder_layers_12_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[793] + model_decoder_layers_12_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[794] + model_decoder_layers_12_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] + model_decoder_layers_12_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[796] + model_decoder_layers_12_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] + model_decoder_layers_12_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[798] + model_decoder_layers_12_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[799] + model_decoder_layers_12_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[800] + model_decoder_layers_13_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] + model_decoder_layers_13_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] + model_decoder_layers_13_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[803] + model_decoder_layers_13_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] + model_decoder_layers_13_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[805] + model_decoder_layers_13_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] + model_decoder_layers_13_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[807] + model_decoder_layers_13_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[808] + model_decoder_layers_13_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[809] + model_decoder_layers_13_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] + model_decoder_layers_13_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[814] + model_decoder_layers_13_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] + model_decoder_layers_13_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[816] + model_decoder_layers_13_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[817] + model_decoder_layers_13_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[818] + model_decoder_layers_13_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] + model_decoder_layers_13_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[820] + model_decoder_layers_13_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] + model_decoder_layers_13_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[822] + model_decoder_layers_13_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[823] + model_decoder_layers_13_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[824] + model_decoder_layers_14_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] + model_decoder_layers_14_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] + model_decoder_layers_14_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[827] + model_decoder_layers_14_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] + model_decoder_layers_14_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[829] + model_decoder_layers_14_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] + model_decoder_layers_14_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[831] + model_decoder_layers_14_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[832] + model_decoder_layers_14_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[833] + model_decoder_layers_14_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] + model_decoder_layers_14_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[838] + model_decoder_layers_14_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] + model_decoder_layers_14_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[840] + model_decoder_layers_14_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[841] + model_decoder_layers_14_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[842] + model_decoder_layers_14_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] + model_decoder_layers_14_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[844] + model_decoder_layers_14_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] + model_decoder_layers_14_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[846] + model_decoder_layers_14_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[847] + model_decoder_layers_14_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[848] + model_decoder_layers_15_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] + model_decoder_layers_15_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] + model_decoder_layers_15_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[851] + model_decoder_layers_15_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] + model_decoder_layers_15_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[853] + model_decoder_layers_15_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] + model_decoder_layers_15_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[855] + model_decoder_layers_15_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[856] + model_decoder_layers_15_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[857] + model_decoder_layers_15_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] + model_decoder_layers_15_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[862] + model_decoder_layers_15_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] + model_decoder_layers_15_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[864] + model_decoder_layers_15_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[865] + model_decoder_layers_15_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[866] + model_decoder_layers_15_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] + model_decoder_layers_15_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[868] + model_decoder_layers_15_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] + model_decoder_layers_15_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[870] + model_decoder_layers_15_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[871] + model_decoder_layers_15_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[872] + model_decoder_layers_16_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] + model_decoder_layers_16_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] + model_decoder_layers_16_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[875] + model_decoder_layers_16_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] + model_decoder_layers_16_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[877] + model_decoder_layers_16_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] + model_decoder_layers_16_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[879] + model_decoder_layers_16_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[880] + model_decoder_layers_16_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[881] + model_decoder_layers_16_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] + model_decoder_layers_16_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[886] + model_decoder_layers_16_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] + model_decoder_layers_16_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[888] + model_decoder_layers_16_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[889] + model_decoder_layers_16_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[890] + model_decoder_layers_16_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] + model_decoder_layers_16_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[892] + model_decoder_layers_16_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] + model_decoder_layers_16_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[894] + model_decoder_layers_16_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[895] + model_decoder_layers_16_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[896] + model_decoder_layers_17_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] + model_decoder_layers_17_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] + model_decoder_layers_17_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[899] + model_decoder_layers_17_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] + model_decoder_layers_17_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[901] + model_decoder_layers_17_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] + model_decoder_layers_17_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[903] + model_decoder_layers_17_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[904] + model_decoder_layers_17_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[905] + model_decoder_layers_17_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] + model_decoder_layers_17_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[910] + model_decoder_layers_17_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] + model_decoder_layers_17_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[912] + model_decoder_layers_17_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[913] + model_decoder_layers_17_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[914] + model_decoder_layers_17_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] + model_decoder_layers_17_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[916] + model_decoder_layers_17_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] + model_decoder_layers_17_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[918] + model_decoder_layers_17_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[919] + model_decoder_layers_17_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[920] + model_decoder_layers_18_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] + model_decoder_layers_18_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] + model_decoder_layers_18_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[923] + model_decoder_layers_18_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] + model_decoder_layers_18_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[925] + model_decoder_layers_18_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] + model_decoder_layers_18_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[927] + model_decoder_layers_18_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[928] + model_decoder_layers_18_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[929] + model_decoder_layers_18_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] + model_decoder_layers_18_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[934] + model_decoder_layers_18_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] + model_decoder_layers_18_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[936] + model_decoder_layers_18_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[937] + model_decoder_layers_18_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[938] + model_decoder_layers_18_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] + model_decoder_layers_18_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[940] + model_decoder_layers_18_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] + model_decoder_layers_18_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[942] + model_decoder_layers_18_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[943] + model_decoder_layers_18_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[944] + model_decoder_layers_19_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] + model_decoder_layers_19_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] + model_decoder_layers_19_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[947] + model_decoder_layers_19_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] + model_decoder_layers_19_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[949] + model_decoder_layers_19_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] + model_decoder_layers_19_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[951] + model_decoder_layers_19_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[952] + model_decoder_layers_19_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[953] + model_decoder_layers_19_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] + model_decoder_layers_19_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[958] + model_decoder_layers_19_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] + model_decoder_layers_19_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[960] + model_decoder_layers_19_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[961] + model_decoder_layers_19_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[962] + model_decoder_layers_19_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] + model_decoder_layers_19_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[964] + model_decoder_layers_19_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] + model_decoder_layers_19_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[966] + model_decoder_layers_19_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[967] + model_decoder_layers_19_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[968] + model_decoder_layers_20_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] + model_decoder_layers_20_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] + model_decoder_layers_20_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[971] + model_decoder_layers_20_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] + model_decoder_layers_20_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[973] + model_decoder_layers_20_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] + model_decoder_layers_20_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[975] + model_decoder_layers_20_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[976] + model_decoder_layers_20_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[977] + model_decoder_layers_20_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] + model_decoder_layers_20_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[982] + model_decoder_layers_20_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] + model_decoder_layers_20_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[984] + model_decoder_layers_20_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[985] + model_decoder_layers_20_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[986] + model_decoder_layers_20_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] + model_decoder_layers_20_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[988] + model_decoder_layers_20_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] + model_decoder_layers_20_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[990] + model_decoder_layers_20_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[991] + model_decoder_layers_20_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[992] + model_decoder_layers_21_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] + model_decoder_layers_21_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] + model_decoder_layers_21_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[995] + model_decoder_layers_21_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] + model_decoder_layers_21_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[997] + model_decoder_layers_21_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] + model_decoder_layers_21_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[999] + model_decoder_layers_21_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1000] + model_decoder_layers_21_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1001] + model_decoder_layers_21_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] + model_decoder_layers_21_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1006] + model_decoder_layers_21_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] + model_decoder_layers_21_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1008] + model_decoder_layers_21_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1009] + model_decoder_layers_21_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1010] + model_decoder_layers_21_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] + model_decoder_layers_21_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1012] + model_decoder_layers_21_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] + model_decoder_layers_21_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1014] + model_decoder_layers_21_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1015] + model_decoder_layers_21_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1016] + model_decoder_layers_22_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] + model_decoder_layers_22_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] + model_decoder_layers_22_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1019] + model_decoder_layers_22_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] + model_decoder_layers_22_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1021] + model_decoder_layers_22_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] + model_decoder_layers_22_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1023] + model_decoder_layers_22_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1024] + model_decoder_layers_22_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1025] + model_decoder_layers_22_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] + model_decoder_layers_22_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1030] + model_decoder_layers_22_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] + model_decoder_layers_22_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1032] + model_decoder_layers_22_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1033] + model_decoder_layers_22_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1034] + model_decoder_layers_22_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] + model_decoder_layers_22_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1036] + model_decoder_layers_22_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] + model_decoder_layers_22_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1038] + model_decoder_layers_22_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1039] + model_decoder_layers_22_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1040] + model_decoder_layers_23_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] + model_decoder_layers_23_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] + model_decoder_layers_23_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1043] + model_decoder_layers_23_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] + model_decoder_layers_23_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1045] + model_decoder_layers_23_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] + model_decoder_layers_23_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1047] + model_decoder_layers_23_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1048] + model_decoder_layers_23_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1049] + model_decoder_layers_23_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] + model_decoder_layers_23_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1054] + model_decoder_layers_23_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] + model_decoder_layers_23_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1056] + model_decoder_layers_23_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1057] + model_decoder_layers_23_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1058] + model_decoder_layers_23_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] + model_decoder_layers_23_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1060] + model_decoder_layers_23_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] + model_decoder_layers_23_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1062] + model_decoder_layers_23_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1063] + model_decoder_layers_23_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1064] + model_decoder_layers_24_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] + model_decoder_layers_24_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] + model_decoder_layers_24_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1067] + model_decoder_layers_24_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] + model_decoder_layers_24_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1069] + model_decoder_layers_24_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] + model_decoder_layers_24_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1071] + model_decoder_layers_24_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1072] + model_decoder_layers_24_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1073] + model_decoder_layers_24_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] + model_decoder_layers_24_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1078] + model_decoder_layers_24_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] + model_decoder_layers_24_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1080] + model_decoder_layers_24_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1081] + model_decoder_layers_24_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1082] + model_decoder_layers_24_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] + model_decoder_layers_24_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1084] + model_decoder_layers_24_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] + model_decoder_layers_24_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1086] + model_decoder_layers_24_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1087] + model_decoder_layers_24_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1088] + model_decoder_layers_25_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] + model_decoder_layers_25_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] + model_decoder_layers_25_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1091] + model_decoder_layers_25_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] + model_decoder_layers_25_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1093] + model_decoder_layers_25_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] + model_decoder_layers_25_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1095] + model_decoder_layers_25_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1096] + model_decoder_layers_25_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1097] + model_decoder_layers_25_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] + model_decoder_layers_25_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1102] + model_decoder_layers_25_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] + model_decoder_layers_25_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1104] + model_decoder_layers_25_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1105] + model_decoder_layers_25_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1106] + model_decoder_layers_25_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] + model_decoder_layers_25_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1108] + model_decoder_layers_25_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] + model_decoder_layers_25_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1110] + model_decoder_layers_25_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1111] + model_decoder_layers_25_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1112] + model_decoder_layers_26_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] + model_decoder_layers_26_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] + model_decoder_layers_26_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1115] + model_decoder_layers_26_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] + model_decoder_layers_26_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1117] + model_decoder_layers_26_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] + model_decoder_layers_26_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1119] + model_decoder_layers_26_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1120] + model_decoder_layers_26_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1121] + model_decoder_layers_26_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] + model_decoder_layers_26_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1126] + model_decoder_layers_26_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] + model_decoder_layers_26_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1128] + model_decoder_layers_26_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1129] + model_decoder_layers_26_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1130] + model_decoder_layers_26_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] + model_decoder_layers_26_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1132] + model_decoder_layers_26_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] + model_decoder_layers_26_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1134] + model_decoder_layers_26_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1135] + model_decoder_layers_26_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1136] + model_decoder_layers_27_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] + model_decoder_layers_27_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] + model_decoder_layers_27_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1139] + model_decoder_layers_27_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] + model_decoder_layers_27_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1141] + model_decoder_layers_27_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] + model_decoder_layers_27_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1143] + model_decoder_layers_27_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1144] + model_decoder_layers_27_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1145] + model_decoder_layers_27_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] + model_decoder_layers_27_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1150] + model_decoder_layers_27_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] + model_decoder_layers_27_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1152] + model_decoder_layers_27_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1153] + model_decoder_layers_27_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1154] + model_decoder_layers_27_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] + model_decoder_layers_27_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1156] + model_decoder_layers_27_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] + model_decoder_layers_27_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1158] + model_decoder_layers_27_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1159] + model_decoder_layers_27_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1160] + model_decoder_layers_28_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] + model_decoder_layers_28_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] + model_decoder_layers_28_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1163] + model_decoder_layers_28_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] + model_decoder_layers_28_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1165] + model_decoder_layers_28_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] + model_decoder_layers_28_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1167] + model_decoder_layers_28_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1168] + model_decoder_layers_28_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1169] + model_decoder_layers_28_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] + model_decoder_layers_28_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1174] + model_decoder_layers_28_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] + model_decoder_layers_28_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1176] + model_decoder_layers_28_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1177] + model_decoder_layers_28_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1178] + model_decoder_layers_28_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] + model_decoder_layers_28_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1180] + model_decoder_layers_28_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] + model_decoder_layers_28_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1182] + model_decoder_layers_28_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1183] + model_decoder_layers_28_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1184] + model_decoder_layers_29_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] + model_decoder_layers_29_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] + model_decoder_layers_29_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1187] + model_decoder_layers_29_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] + model_decoder_layers_29_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1189] + model_decoder_layers_29_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] + model_decoder_layers_29_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1191] + model_decoder_layers_29_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1192] + model_decoder_layers_29_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1193] + model_decoder_layers_29_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] + model_decoder_layers_29_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1198] + model_decoder_layers_29_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] + model_decoder_layers_29_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1200] + model_decoder_layers_29_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1201] + model_decoder_layers_29_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1202] + model_decoder_layers_29_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] + model_decoder_layers_29_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1204] + model_decoder_layers_29_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] + model_decoder_layers_29_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1206] + model_decoder_layers_29_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1207] + model_decoder_layers_29_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1208] + model_decoder_layers_30_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] + model_decoder_layers_30_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] + model_decoder_layers_30_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1211] + model_decoder_layers_30_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] + model_decoder_layers_30_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1213] + model_decoder_layers_30_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] + model_decoder_layers_30_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1215] + model_decoder_layers_30_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1216] + model_decoder_layers_30_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1217] + model_decoder_layers_30_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] + model_decoder_layers_30_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1222] + model_decoder_layers_30_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] + model_decoder_layers_30_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1224] + model_decoder_layers_30_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1225] + model_decoder_layers_30_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1226] + model_decoder_layers_30_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] + model_decoder_layers_30_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1228] + model_decoder_layers_30_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] + model_decoder_layers_30_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1230] + model_decoder_layers_30_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1231] + model_decoder_layers_30_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1232] + model_decoder_layers_31_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] + model_decoder_layers_31_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] + model_decoder_layers_31_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1235] + model_decoder_layers_31_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] + model_decoder_layers_31_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1237] + model_decoder_layers_31_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] + model_decoder_layers_31_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1239] + model_decoder_layers_31_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1240] + model_decoder_layers_31_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1241] + model_decoder_layers_31_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] + model_decoder_layers_31_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1246] + model_decoder_layers_31_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] + model_decoder_layers_31_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1248] + model_decoder_layers_31_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1249] + model_decoder_layers_31_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1250] + model_decoder_layers_31_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] + model_decoder_layers_31_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1252] + model_decoder_layers_31_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] + model_decoder_layers_31_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1254] + model_decoder_layers_31_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1255] + model_decoder_layers_31_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1256] + model_decoder_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1257] + model_decoder_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1258] + reshape384 = R.call_tir(cls.reshape12, (input_ids,), out_sinfo=R.Tensor((seq_len,), dtype="int32")) + take = R.call_tir(cls.take, (model_decoder_embed_tokens_weight2, reshape384), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) + reshape385 = R.call_tir(cls.reshape13, (take,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv68: R.Tensor((seq_len,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((seq_len,), dtype="int32"),)) + take1 = R.call_tir(cls.take1, (model_decoder_embed_positions_weight2, lv68), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) + reshape386 = R.call_tir(cls.reshape13, (take1,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add257 = R.call_tir(cls.add5, (reshape385, reshape386), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm65 = R.call_tir(cls.layer_norm2, (add257, model_decoder_layers_0_self_attn_layer_norm_weight2, model_decoder_layers_0_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv416 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_q_proj_weight2, layer_norm65, model_decoder_layers_0_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape387 = R.call_tir(cls.reshape14, (lv416,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv98 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_0_self_attn_k_proj_weight2, layer_norm65), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape388 = R.call_tir(cls.reshape14, (lv98,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv417 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_v_proj_weight2, layer_norm65, model_decoder_layers_0_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape389 = R.call_tir(cls.reshape14, (lv417,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat = R.call_tir(cls.concatenate1, (reshape387, reshape388, reshape389), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape390 = R.call_tir(cls.reshape15, (concat,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv69 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape390), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape391 = R.call_tir(cls.reshape16, (lv69,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape392 = R.call_tir(cls.reshape17, (reshape391,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv418 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_out_proj_weight2, reshape392, model_decoder_layers_0_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add261 = R.call_tir(cls.add5, (add257, lv418), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm66 = R.call_tir(cls.layer_norm2, (add261, model_decoder_layers_0_encoder_attn_layer_norm_weight2, model_decoder_layers_0_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv419 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_q_proj_weight2, layer_norm66, model_decoder_layers_0_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape393 = R.call_tir(cls.reshape14, (lv419,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape394 = R.call_tir(cls.reshape18, (reshape393,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv70 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape394), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape395 = R.call_tir(cls.reshape16, (lv70,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape396 = R.call_tir(cls.reshape17, (reshape395,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv420 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_out_proj_weight2, reshape396, model_decoder_layers_0_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add264 = R.call_tir(cls.add5, (add261, lv420), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm67 = R.call_tir(cls.layer_norm2, (add264, model_decoder_layers_0_final_layer_norm_weight2, model_decoder_layers_0_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv64 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_0_fc1_weight2, layer_norm67, model_decoder_layers_0_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv421 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_0_fc2_weight2, lv64, model_decoder_layers_0_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add267 = R.call_tir(cls.add5, (add264, lv421), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm68 = R.call_tir(cls.layer_norm2, (add267, model_decoder_layers_1_self_attn_layer_norm_weight2, model_decoder_layers_1_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv422 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_q_proj_weight2, layer_norm68, model_decoder_layers_1_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape397 = R.call_tir(cls.reshape14, (lv422,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv99 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_1_self_attn_k_proj_weight2, layer_norm68), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape398 = R.call_tir(cls.reshape14, (lv99,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv423 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_v_proj_weight2, layer_norm68, model_decoder_layers_1_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape399 = R.call_tir(cls.reshape14, (lv423,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat1 = R.call_tir(cls.concatenate1, (reshape397, reshape398, reshape399), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape400 = R.call_tir(cls.reshape15, (concat1,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv71 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape400), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape401 = R.call_tir(cls.reshape16, (lv71,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape402 = R.call_tir(cls.reshape17, (reshape401,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv424 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_out_proj_weight2, reshape402, model_decoder_layers_1_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add271 = R.call_tir(cls.add5, (add267, lv424), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm69 = R.call_tir(cls.layer_norm2, (add271, model_decoder_layers_1_encoder_attn_layer_norm_weight2, model_decoder_layers_1_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv425 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_q_proj_weight2, layer_norm69, model_decoder_layers_1_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape403 = R.call_tir(cls.reshape14, (lv425,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape404 = R.call_tir(cls.reshape18, (reshape403,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv72 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape404), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape405 = R.call_tir(cls.reshape16, (lv72,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape406 = R.call_tir(cls.reshape17, (reshape405,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv426 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_out_proj_weight2, reshape406, model_decoder_layers_1_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add274 = R.call_tir(cls.add5, (add271, lv426), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm70 = R.call_tir(cls.layer_norm2, (add274, model_decoder_layers_1_final_layer_norm_weight2, model_decoder_layers_1_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv65 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_1_fc1_weight2, layer_norm70, model_decoder_layers_1_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv427 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_1_fc2_weight2, lv65, model_decoder_layers_1_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add277 = R.call_tir(cls.add5, (add274, lv427), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm71 = R.call_tir(cls.layer_norm2, (add277, model_decoder_layers_2_self_attn_layer_norm_weight2, model_decoder_layers_2_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv428 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_q_proj_weight2, layer_norm71, model_decoder_layers_2_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape407 = R.call_tir(cls.reshape14, (lv428,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv100 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_2_self_attn_k_proj_weight2, layer_norm71), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape408 = R.call_tir(cls.reshape14, (lv100,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv429 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_v_proj_weight2, layer_norm71, model_decoder_layers_2_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape409 = R.call_tir(cls.reshape14, (lv429,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat2 = R.call_tir(cls.concatenate1, (reshape407, reshape408, reshape409), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape410 = R.call_tir(cls.reshape15, (concat2,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv73 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape410), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape411 = R.call_tir(cls.reshape16, (lv73,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape412 = R.call_tir(cls.reshape17, (reshape411,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv430 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_out_proj_weight2, reshape412, model_decoder_layers_2_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add281 = R.call_tir(cls.add5, (add277, lv430), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm72 = R.call_tir(cls.layer_norm2, (add281, model_decoder_layers_2_encoder_attn_layer_norm_weight2, model_decoder_layers_2_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv431 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_q_proj_weight2, layer_norm72, model_decoder_layers_2_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape413 = R.call_tir(cls.reshape14, (lv431,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape414 = R.call_tir(cls.reshape18, (reshape413,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv74 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape414), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape415 = R.call_tir(cls.reshape16, (lv74,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape416 = R.call_tir(cls.reshape17, (reshape415,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv432 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_out_proj_weight2, reshape416, model_decoder_layers_2_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add284 = R.call_tir(cls.add5, (add281, lv432), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm73 = R.call_tir(cls.layer_norm2, (add284, model_decoder_layers_2_final_layer_norm_weight2, model_decoder_layers_2_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv66 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_2_fc1_weight2, layer_norm73, model_decoder_layers_2_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv433 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_2_fc2_weight2, lv66, model_decoder_layers_2_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add287 = R.call_tir(cls.add5, (add284, lv433), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm74 = R.call_tir(cls.layer_norm2, (add287, model_decoder_layers_3_self_attn_layer_norm_weight2, model_decoder_layers_3_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv434 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_q_proj_weight2, layer_norm74, model_decoder_layers_3_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape417 = R.call_tir(cls.reshape14, (lv434,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv101 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_3_self_attn_k_proj_weight2, layer_norm74), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape418 = R.call_tir(cls.reshape14, (lv101,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv435 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_v_proj_weight2, layer_norm74, model_decoder_layers_3_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape419 = R.call_tir(cls.reshape14, (lv435,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat3 = R.call_tir(cls.concatenate1, (reshape417, reshape418, reshape419), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape420 = R.call_tir(cls.reshape15, (concat3,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv75 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape420), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape421 = R.call_tir(cls.reshape16, (lv75,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape422 = R.call_tir(cls.reshape17, (reshape421,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv436 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_out_proj_weight2, reshape422, model_decoder_layers_3_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add291 = R.call_tir(cls.add5, (add287, lv436), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm75 = R.call_tir(cls.layer_norm2, (add291, model_decoder_layers_3_encoder_attn_layer_norm_weight2, model_decoder_layers_3_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv437 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_q_proj_weight2, layer_norm75, model_decoder_layers_3_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape423 = R.call_tir(cls.reshape14, (lv437,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape424 = R.call_tir(cls.reshape18, (reshape423,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv76 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape424), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape425 = R.call_tir(cls.reshape16, (lv76,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape426 = R.call_tir(cls.reshape17, (reshape425,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv438 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_out_proj_weight2, reshape426, model_decoder_layers_3_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add294 = R.call_tir(cls.add5, (add291, lv438), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm76 = R.call_tir(cls.layer_norm2, (add294, model_decoder_layers_3_final_layer_norm_weight2, model_decoder_layers_3_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv67 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_3_fc1_weight2, layer_norm76, model_decoder_layers_3_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv439 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_3_fc2_weight2, lv67, model_decoder_layers_3_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add297 = R.call_tir(cls.add5, (add294, lv439), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm77 = R.call_tir(cls.layer_norm2, (add297, model_decoder_layers_4_self_attn_layer_norm_weight2, model_decoder_layers_4_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv440 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_q_proj_weight2, layer_norm77, model_decoder_layers_4_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape427 = R.call_tir(cls.reshape14, (lv440,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv102 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_4_self_attn_k_proj_weight2, layer_norm77), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape428 = R.call_tir(cls.reshape14, (lv102,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv441 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_v_proj_weight2, layer_norm77, model_decoder_layers_4_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape429 = R.call_tir(cls.reshape14, (lv441,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat4 = R.call_tir(cls.concatenate1, (reshape427, reshape428, reshape429), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape430 = R.call_tir(cls.reshape15, (concat4,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv77 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape430), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape431 = R.call_tir(cls.reshape16, (lv77,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape432 = R.call_tir(cls.reshape17, (reshape431,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv442 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_out_proj_weight2, reshape432, model_decoder_layers_4_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add301 = R.call_tir(cls.add5, (add297, lv442), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm78 = R.call_tir(cls.layer_norm2, (add301, model_decoder_layers_4_encoder_attn_layer_norm_weight2, model_decoder_layers_4_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv443 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_q_proj_weight2, layer_norm78, model_decoder_layers_4_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape433 = R.call_tir(cls.reshape14, (lv443,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape434 = R.call_tir(cls.reshape18, (reshape433,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv78 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape434), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape435 = R.call_tir(cls.reshape16, (lv78,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape436 = R.call_tir(cls.reshape17, (reshape435,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv444 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_out_proj_weight2, reshape436, model_decoder_layers_4_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add304 = R.call_tir(cls.add5, (add301, lv444), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm79 = R.call_tir(cls.layer_norm2, (add304, model_decoder_layers_4_final_layer_norm_weight2, model_decoder_layers_4_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv68_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_4_fc1_weight2, layer_norm79, model_decoder_layers_4_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv445 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_4_fc2_weight2, lv68_1, model_decoder_layers_4_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add307 = R.call_tir(cls.add5, (add304, lv445), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm80 = R.call_tir(cls.layer_norm2, (add307, model_decoder_layers_5_self_attn_layer_norm_weight2, model_decoder_layers_5_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv446 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_q_proj_weight2, layer_norm80, model_decoder_layers_5_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape437 = R.call_tir(cls.reshape14, (lv446,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv103 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_5_self_attn_k_proj_weight2, layer_norm80), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape438 = R.call_tir(cls.reshape14, (lv103,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv447 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_v_proj_weight2, layer_norm80, model_decoder_layers_5_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape439 = R.call_tir(cls.reshape14, (lv447,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat5 = R.call_tir(cls.concatenate1, (reshape437, reshape438, reshape439), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape440 = R.call_tir(cls.reshape15, (concat5,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv79 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape440), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape441 = R.call_tir(cls.reshape16, (lv79,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape442 = R.call_tir(cls.reshape17, (reshape441,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv448 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_out_proj_weight2, reshape442, model_decoder_layers_5_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add311 = R.call_tir(cls.add5, (add307, lv448), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm81 = R.call_tir(cls.layer_norm2, (add311, model_decoder_layers_5_encoder_attn_layer_norm_weight2, model_decoder_layers_5_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv449 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_q_proj_weight2, layer_norm81, model_decoder_layers_5_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape443 = R.call_tir(cls.reshape14, (lv449,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape444 = R.call_tir(cls.reshape18, (reshape443,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv80 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape444), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape445 = R.call_tir(cls.reshape16, (lv80,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape446 = R.call_tir(cls.reshape17, (reshape445,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv450 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_out_proj_weight2, reshape446, model_decoder_layers_5_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add314 = R.call_tir(cls.add5, (add311, lv450), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm82 = R.call_tir(cls.layer_norm2, (add314, model_decoder_layers_5_final_layer_norm_weight2, model_decoder_layers_5_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv69_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_5_fc1_weight2, layer_norm82, model_decoder_layers_5_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv451 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_5_fc2_weight2, lv69_1, model_decoder_layers_5_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add317 = R.call_tir(cls.add5, (add314, lv451), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm83 = R.call_tir(cls.layer_norm2, (add317, model_decoder_layers_6_self_attn_layer_norm_weight2, model_decoder_layers_6_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv452 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_q_proj_weight2, layer_norm83, model_decoder_layers_6_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape447 = R.call_tir(cls.reshape14, (lv452,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv104 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_6_self_attn_k_proj_weight2, layer_norm83), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape448 = R.call_tir(cls.reshape14, (lv104,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv453 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_v_proj_weight2, layer_norm83, model_decoder_layers_6_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape449 = R.call_tir(cls.reshape14, (lv453,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat6 = R.call_tir(cls.concatenate1, (reshape447, reshape448, reshape449), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape450 = R.call_tir(cls.reshape15, (concat6,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv81 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape450), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape451 = R.call_tir(cls.reshape16, (lv81,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape452 = R.call_tir(cls.reshape17, (reshape451,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv454 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_out_proj_weight2, reshape452, model_decoder_layers_6_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add321 = R.call_tir(cls.add5, (add317, lv454), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm84 = R.call_tir(cls.layer_norm2, (add321, model_decoder_layers_6_encoder_attn_layer_norm_weight2, model_decoder_layers_6_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv455 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_q_proj_weight2, layer_norm84, model_decoder_layers_6_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape453 = R.call_tir(cls.reshape14, (lv455,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape454 = R.call_tir(cls.reshape18, (reshape453,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv82 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape454), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape455 = R.call_tir(cls.reshape16, (lv82,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape456 = R.call_tir(cls.reshape17, (reshape455,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv456 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_out_proj_weight2, reshape456, model_decoder_layers_6_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add324 = R.call_tir(cls.add5, (add321, lv456), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm85 = R.call_tir(cls.layer_norm2, (add324, model_decoder_layers_6_final_layer_norm_weight2, model_decoder_layers_6_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv70_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_6_fc1_weight2, layer_norm85, model_decoder_layers_6_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv457 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_6_fc2_weight2, lv70_1, model_decoder_layers_6_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add327 = R.call_tir(cls.add5, (add324, lv457), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm86 = R.call_tir(cls.layer_norm2, (add327, model_decoder_layers_7_self_attn_layer_norm_weight2, model_decoder_layers_7_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv458 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_q_proj_weight2, layer_norm86, model_decoder_layers_7_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape457 = R.call_tir(cls.reshape14, (lv458,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv105 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_7_self_attn_k_proj_weight2, layer_norm86), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape458 = R.call_tir(cls.reshape14, (lv105,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv459 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_v_proj_weight2, layer_norm86, model_decoder_layers_7_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape459 = R.call_tir(cls.reshape14, (lv459,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat7 = R.call_tir(cls.concatenate1, (reshape457, reshape458, reshape459), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape460 = R.call_tir(cls.reshape15, (concat7,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv83 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape460), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape461 = R.call_tir(cls.reshape16, (lv83,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape462 = R.call_tir(cls.reshape17, (reshape461,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv460 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_out_proj_weight2, reshape462, model_decoder_layers_7_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add331 = R.call_tir(cls.add5, (add327, lv460), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm87 = R.call_tir(cls.layer_norm2, (add331, model_decoder_layers_7_encoder_attn_layer_norm_weight2, model_decoder_layers_7_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv461 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_q_proj_weight2, layer_norm87, model_decoder_layers_7_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape463 = R.call_tir(cls.reshape14, (lv461,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape464 = R.call_tir(cls.reshape18, (reshape463,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv84 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape464), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape465 = R.call_tir(cls.reshape16, (lv84,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape466 = R.call_tir(cls.reshape17, (reshape465,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv462 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_out_proj_weight2, reshape466, model_decoder_layers_7_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add334 = R.call_tir(cls.add5, (add331, lv462), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm88 = R.call_tir(cls.layer_norm2, (add334, model_decoder_layers_7_final_layer_norm_weight2, model_decoder_layers_7_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv71_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_7_fc1_weight2, layer_norm88, model_decoder_layers_7_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv463 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_7_fc2_weight2, lv71_1, model_decoder_layers_7_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add337 = R.call_tir(cls.add5, (add334, lv463), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm89 = R.call_tir(cls.layer_norm2, (add337, model_decoder_layers_8_self_attn_layer_norm_weight2, model_decoder_layers_8_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv464 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_q_proj_weight2, layer_norm89, model_decoder_layers_8_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape467 = R.call_tir(cls.reshape14, (lv464,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv106 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_8_self_attn_k_proj_weight2, layer_norm89), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape468 = R.call_tir(cls.reshape14, (lv106,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv465 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_v_proj_weight2, layer_norm89, model_decoder_layers_8_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape469 = R.call_tir(cls.reshape14, (lv465,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat8 = R.call_tir(cls.concatenate1, (reshape467, reshape468, reshape469), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape470 = R.call_tir(cls.reshape15, (concat8,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv85 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape470), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape471 = R.call_tir(cls.reshape16, (lv85,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape472 = R.call_tir(cls.reshape17, (reshape471,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv466 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_out_proj_weight2, reshape472, model_decoder_layers_8_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add341 = R.call_tir(cls.add5, (add337, lv466), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm90 = R.call_tir(cls.layer_norm2, (add341, model_decoder_layers_8_encoder_attn_layer_norm_weight2, model_decoder_layers_8_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv467 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_q_proj_weight2, layer_norm90, model_decoder_layers_8_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape473 = R.call_tir(cls.reshape14, (lv467,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape474 = R.call_tir(cls.reshape18, (reshape473,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv86 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape474), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape475 = R.call_tir(cls.reshape16, (lv86,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape476 = R.call_tir(cls.reshape17, (reshape475,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv468 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_out_proj_weight2, reshape476, model_decoder_layers_8_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add344 = R.call_tir(cls.add5, (add341, lv468), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm91 = R.call_tir(cls.layer_norm2, (add344, model_decoder_layers_8_final_layer_norm_weight2, model_decoder_layers_8_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv72_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_8_fc1_weight2, layer_norm91, model_decoder_layers_8_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv469 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_8_fc2_weight2, lv72_1, model_decoder_layers_8_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add347 = R.call_tir(cls.add5, (add344, lv469), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm92 = R.call_tir(cls.layer_norm2, (add347, model_decoder_layers_9_self_attn_layer_norm_weight2, model_decoder_layers_9_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv470 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_q_proj_weight2, layer_norm92, model_decoder_layers_9_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape477 = R.call_tir(cls.reshape14, (lv470,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv107 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_9_self_attn_k_proj_weight2, layer_norm92), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape478 = R.call_tir(cls.reshape14, (lv107,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv471 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_v_proj_weight2, layer_norm92, model_decoder_layers_9_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape479 = R.call_tir(cls.reshape14, (lv471,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat9 = R.call_tir(cls.concatenate1, (reshape477, reshape478, reshape479), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape480 = R.call_tir(cls.reshape15, (concat9,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv87 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape480), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape481 = R.call_tir(cls.reshape16, (lv87,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape482 = R.call_tir(cls.reshape17, (reshape481,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv472 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_out_proj_weight2, reshape482, model_decoder_layers_9_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add351 = R.call_tir(cls.add5, (add347, lv472), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm93 = R.call_tir(cls.layer_norm2, (add351, model_decoder_layers_9_encoder_attn_layer_norm_weight2, model_decoder_layers_9_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv473 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_q_proj_weight2, layer_norm93, model_decoder_layers_9_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape483 = R.call_tir(cls.reshape14, (lv473,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape484 = R.call_tir(cls.reshape18, (reshape483,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv88 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape484), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape485 = R.call_tir(cls.reshape16, (lv88,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape486 = R.call_tir(cls.reshape17, (reshape485,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv474 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_out_proj_weight2, reshape486, model_decoder_layers_9_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add354 = R.call_tir(cls.add5, (add351, lv474), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm94 = R.call_tir(cls.layer_norm2, (add354, model_decoder_layers_9_final_layer_norm_weight2, model_decoder_layers_9_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv73_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_9_fc1_weight2, layer_norm94, model_decoder_layers_9_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv475 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_9_fc2_weight2, lv73_1, model_decoder_layers_9_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add357 = R.call_tir(cls.add5, (add354, lv475), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm95 = R.call_tir(cls.layer_norm2, (add357, model_decoder_layers_10_self_attn_layer_norm_weight2, model_decoder_layers_10_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv476 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_q_proj_weight2, layer_norm95, model_decoder_layers_10_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape487 = R.call_tir(cls.reshape14, (lv476,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv108 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_10_self_attn_k_proj_weight2, layer_norm95), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape488 = R.call_tir(cls.reshape14, (lv108,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv477 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_v_proj_weight2, layer_norm95, model_decoder_layers_10_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape489 = R.call_tir(cls.reshape14, (lv477,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat10 = R.call_tir(cls.concatenate1, (reshape487, reshape488, reshape489), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape490 = R.call_tir(cls.reshape15, (concat10,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv89 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape490), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape491 = R.call_tir(cls.reshape16, (lv89,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape492 = R.call_tir(cls.reshape17, (reshape491,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv478 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_out_proj_weight2, reshape492, model_decoder_layers_10_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add361 = R.call_tir(cls.add5, (add357, lv478), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm96 = R.call_tir(cls.layer_norm2, (add361, model_decoder_layers_10_encoder_attn_layer_norm_weight2, model_decoder_layers_10_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv479 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_q_proj_weight2, layer_norm96, model_decoder_layers_10_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape493 = R.call_tir(cls.reshape14, (lv479,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape494 = R.call_tir(cls.reshape18, (reshape493,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv90 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape494), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape495 = R.call_tir(cls.reshape16, (lv90,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape496 = R.call_tir(cls.reshape17, (reshape495,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv480 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_out_proj_weight2, reshape496, model_decoder_layers_10_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add364 = R.call_tir(cls.add5, (add361, lv480), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm97 = R.call_tir(cls.layer_norm2, (add364, model_decoder_layers_10_final_layer_norm_weight2, model_decoder_layers_10_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv74_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_10_fc1_weight2, layer_norm97, model_decoder_layers_10_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv481 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_10_fc2_weight2, lv74_1, model_decoder_layers_10_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add367 = R.call_tir(cls.add5, (add364, lv481), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm98 = R.call_tir(cls.layer_norm2, (add367, model_decoder_layers_11_self_attn_layer_norm_weight2, model_decoder_layers_11_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv482 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_q_proj_weight2, layer_norm98, model_decoder_layers_11_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape497 = R.call_tir(cls.reshape14, (lv482,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv109 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_11_self_attn_k_proj_weight2, layer_norm98), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape498 = R.call_tir(cls.reshape14, (lv109,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv483 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_v_proj_weight2, layer_norm98, model_decoder_layers_11_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape499 = R.call_tir(cls.reshape14, (lv483,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat11 = R.call_tir(cls.concatenate1, (reshape497, reshape498, reshape499), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape500 = R.call_tir(cls.reshape15, (concat11,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv91 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape500), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape501 = R.call_tir(cls.reshape16, (lv91,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape502 = R.call_tir(cls.reshape17, (reshape501,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv484 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_out_proj_weight2, reshape502, model_decoder_layers_11_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add371 = R.call_tir(cls.add5, (add367, lv484), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm99 = R.call_tir(cls.layer_norm2, (add371, model_decoder_layers_11_encoder_attn_layer_norm_weight2, model_decoder_layers_11_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv485 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_q_proj_weight2, layer_norm99, model_decoder_layers_11_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape503 = R.call_tir(cls.reshape14, (lv485,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape504 = R.call_tir(cls.reshape18, (reshape503,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv92 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape504), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape505 = R.call_tir(cls.reshape16, (lv92,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape506 = R.call_tir(cls.reshape17, (reshape505,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv486 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_out_proj_weight2, reshape506, model_decoder_layers_11_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add374 = R.call_tir(cls.add5, (add371, lv486), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm100 = R.call_tir(cls.layer_norm2, (add374, model_decoder_layers_11_final_layer_norm_weight2, model_decoder_layers_11_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv75_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_11_fc1_weight2, layer_norm100, model_decoder_layers_11_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv487 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_11_fc2_weight2, lv75_1, model_decoder_layers_11_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add377 = R.call_tir(cls.add5, (add374, lv487), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm101 = R.call_tir(cls.layer_norm2, (add377, model_decoder_layers_12_self_attn_layer_norm_weight2, model_decoder_layers_12_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv488 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_q_proj_weight2, layer_norm101, model_decoder_layers_12_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape507 = R.call_tir(cls.reshape14, (lv488,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv110 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_12_self_attn_k_proj_weight2, layer_norm101), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape508 = R.call_tir(cls.reshape14, (lv110,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv489 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_v_proj_weight2, layer_norm101, model_decoder_layers_12_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape509 = R.call_tir(cls.reshape14, (lv489,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat12 = R.call_tir(cls.concatenate1, (reshape507, reshape508, reshape509), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape510 = R.call_tir(cls.reshape15, (concat12,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv93 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape510), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape511 = R.call_tir(cls.reshape16, (lv93,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape512 = R.call_tir(cls.reshape17, (reshape511,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv490 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_out_proj_weight2, reshape512, model_decoder_layers_12_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add381 = R.call_tir(cls.add5, (add377, lv490), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm102 = R.call_tir(cls.layer_norm2, (add381, model_decoder_layers_12_encoder_attn_layer_norm_weight2, model_decoder_layers_12_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv491 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_q_proj_weight2, layer_norm102, model_decoder_layers_12_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape513 = R.call_tir(cls.reshape14, (lv491,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape514 = R.call_tir(cls.reshape18, (reshape513,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv94 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape514), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape515 = R.call_tir(cls.reshape16, (lv94,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape516 = R.call_tir(cls.reshape17, (reshape515,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv492 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_out_proj_weight2, reshape516, model_decoder_layers_12_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add384 = R.call_tir(cls.add5, (add381, lv492), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm103 = R.call_tir(cls.layer_norm2, (add384, model_decoder_layers_12_final_layer_norm_weight2, model_decoder_layers_12_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv76_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_12_fc1_weight2, layer_norm103, model_decoder_layers_12_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv493 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_12_fc2_weight2, lv76_1, model_decoder_layers_12_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add387 = R.call_tir(cls.add5, (add384, lv493), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm104 = R.call_tir(cls.layer_norm2, (add387, model_decoder_layers_13_self_attn_layer_norm_weight2, model_decoder_layers_13_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv494 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_q_proj_weight2, layer_norm104, model_decoder_layers_13_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape517 = R.call_tir(cls.reshape14, (lv494,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv111 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_13_self_attn_k_proj_weight2, layer_norm104), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape518 = R.call_tir(cls.reshape14, (lv111,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv495 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_v_proj_weight2, layer_norm104, model_decoder_layers_13_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape519 = R.call_tir(cls.reshape14, (lv495,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat13 = R.call_tir(cls.concatenate1, (reshape517, reshape518, reshape519), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape520 = R.call_tir(cls.reshape15, (concat13,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv95 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape520), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape521 = R.call_tir(cls.reshape16, (lv95,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape522 = R.call_tir(cls.reshape17, (reshape521,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv496 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_out_proj_weight2, reshape522, model_decoder_layers_13_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add391 = R.call_tir(cls.add5, (add387, lv496), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm105 = R.call_tir(cls.layer_norm2, (add391, model_decoder_layers_13_encoder_attn_layer_norm_weight2, model_decoder_layers_13_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv497 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_q_proj_weight2, layer_norm105, model_decoder_layers_13_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape523 = R.call_tir(cls.reshape14, (lv497,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape524 = R.call_tir(cls.reshape18, (reshape523,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv96 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape524), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape525 = R.call_tir(cls.reshape16, (lv96,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape526 = R.call_tir(cls.reshape17, (reshape525,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv498 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_out_proj_weight2, reshape526, model_decoder_layers_13_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add394 = R.call_tir(cls.add5, (add391, lv498), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm106 = R.call_tir(cls.layer_norm2, (add394, model_decoder_layers_13_final_layer_norm_weight2, model_decoder_layers_13_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv77_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_13_fc1_weight2, layer_norm106, model_decoder_layers_13_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv499 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_13_fc2_weight2, lv77_1, model_decoder_layers_13_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add397 = R.call_tir(cls.add5, (add394, lv499), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm107 = R.call_tir(cls.layer_norm2, (add397, model_decoder_layers_14_self_attn_layer_norm_weight2, model_decoder_layers_14_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv500 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_q_proj_weight2, layer_norm107, model_decoder_layers_14_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape527 = R.call_tir(cls.reshape14, (lv500,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv112 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_14_self_attn_k_proj_weight2, layer_norm107), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape528 = R.call_tir(cls.reshape14, (lv112,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv501 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_v_proj_weight2, layer_norm107, model_decoder_layers_14_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape529 = R.call_tir(cls.reshape14, (lv501,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat14 = R.call_tir(cls.concatenate1, (reshape527, reshape528, reshape529), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape530 = R.call_tir(cls.reshape15, (concat14,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv97 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape530), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape531 = R.call_tir(cls.reshape16, (lv97,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape532 = R.call_tir(cls.reshape17, (reshape531,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv502 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_out_proj_weight2, reshape532, model_decoder_layers_14_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add401 = R.call_tir(cls.add5, (add397, lv502), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm108 = R.call_tir(cls.layer_norm2, (add401, model_decoder_layers_14_encoder_attn_layer_norm_weight2, model_decoder_layers_14_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv503 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_q_proj_weight2, layer_norm108, model_decoder_layers_14_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape533 = R.call_tir(cls.reshape14, (lv503,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape534 = R.call_tir(cls.reshape18, (reshape533,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv98_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape534), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape535 = R.call_tir(cls.reshape16, (lv98_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape536 = R.call_tir(cls.reshape17, (reshape535,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv504 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_out_proj_weight2, reshape536, model_decoder_layers_14_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add404 = R.call_tir(cls.add5, (add401, lv504), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm109 = R.call_tir(cls.layer_norm2, (add404, model_decoder_layers_14_final_layer_norm_weight2, model_decoder_layers_14_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv78_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_14_fc1_weight2, layer_norm109, model_decoder_layers_14_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv505 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_14_fc2_weight2, lv78_1, model_decoder_layers_14_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add407 = R.call_tir(cls.add5, (add404, lv505), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm110 = R.call_tir(cls.layer_norm2, (add407, model_decoder_layers_15_self_attn_layer_norm_weight2, model_decoder_layers_15_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv506 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_q_proj_weight2, layer_norm110, model_decoder_layers_15_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape537 = R.call_tir(cls.reshape14, (lv506,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv113 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_15_self_attn_k_proj_weight2, layer_norm110), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape538 = R.call_tir(cls.reshape14, (lv113,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv507 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_v_proj_weight2, layer_norm110, model_decoder_layers_15_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape539 = R.call_tir(cls.reshape14, (lv507,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat15 = R.call_tir(cls.concatenate1, (reshape537, reshape538, reshape539), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape540 = R.call_tir(cls.reshape15, (concat15,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv99_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape540), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape541 = R.call_tir(cls.reshape16, (lv99_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape542 = R.call_tir(cls.reshape17, (reshape541,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv508 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_out_proj_weight2, reshape542, model_decoder_layers_15_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add411 = R.call_tir(cls.add5, (add407, lv508), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm111 = R.call_tir(cls.layer_norm2, (add411, model_decoder_layers_15_encoder_attn_layer_norm_weight2, model_decoder_layers_15_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv509 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_q_proj_weight2, layer_norm111, model_decoder_layers_15_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape543 = R.call_tir(cls.reshape14, (lv509,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape544 = R.call_tir(cls.reshape18, (reshape543,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv100_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape544), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape545 = R.call_tir(cls.reshape16, (lv100_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape546 = R.call_tir(cls.reshape17, (reshape545,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv510 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_out_proj_weight2, reshape546, model_decoder_layers_15_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add414 = R.call_tir(cls.add5, (add411, lv510), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm112 = R.call_tir(cls.layer_norm2, (add414, model_decoder_layers_15_final_layer_norm_weight2, model_decoder_layers_15_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv79_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_15_fc1_weight2, layer_norm112, model_decoder_layers_15_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv511 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_15_fc2_weight2, lv79_1, model_decoder_layers_15_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add417 = R.call_tir(cls.add5, (add414, lv511), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm113 = R.call_tir(cls.layer_norm2, (add417, model_decoder_layers_16_self_attn_layer_norm_weight2, model_decoder_layers_16_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv512 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_q_proj_weight2, layer_norm113, model_decoder_layers_16_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape547 = R.call_tir(cls.reshape14, (lv512,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv114 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_16_self_attn_k_proj_weight2, layer_norm113), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape548 = R.call_tir(cls.reshape14, (lv114,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv513 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_v_proj_weight2, layer_norm113, model_decoder_layers_16_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape549 = R.call_tir(cls.reshape14, (lv513,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat16 = R.call_tir(cls.concatenate1, (reshape547, reshape548, reshape549), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape550 = R.call_tir(cls.reshape15, (concat16,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv101_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape550), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape551 = R.call_tir(cls.reshape16, (lv101_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape552 = R.call_tir(cls.reshape17, (reshape551,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv514 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_out_proj_weight2, reshape552, model_decoder_layers_16_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add421 = R.call_tir(cls.add5, (add417, lv514), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm114 = R.call_tir(cls.layer_norm2, (add421, model_decoder_layers_16_encoder_attn_layer_norm_weight2, model_decoder_layers_16_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv515 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_q_proj_weight2, layer_norm114, model_decoder_layers_16_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape553 = R.call_tir(cls.reshape14, (lv515,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape554 = R.call_tir(cls.reshape18, (reshape553,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv102_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape554), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape555 = R.call_tir(cls.reshape16, (lv102_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape556 = R.call_tir(cls.reshape17, (reshape555,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv516 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_out_proj_weight2, reshape556, model_decoder_layers_16_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add424 = R.call_tir(cls.add5, (add421, lv516), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm115 = R.call_tir(cls.layer_norm2, (add424, model_decoder_layers_16_final_layer_norm_weight2, model_decoder_layers_16_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv80_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_16_fc1_weight2, layer_norm115, model_decoder_layers_16_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv517 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_16_fc2_weight2, lv80_1, model_decoder_layers_16_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add427 = R.call_tir(cls.add5, (add424, lv517), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm116 = R.call_tir(cls.layer_norm2, (add427, model_decoder_layers_17_self_attn_layer_norm_weight2, model_decoder_layers_17_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv518 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_q_proj_weight2, layer_norm116, model_decoder_layers_17_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape557 = R.call_tir(cls.reshape14, (lv518,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv115 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_17_self_attn_k_proj_weight2, layer_norm116), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape558 = R.call_tir(cls.reshape14, (lv115,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv519 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_v_proj_weight2, layer_norm116, model_decoder_layers_17_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape559 = R.call_tir(cls.reshape14, (lv519,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat17 = R.call_tir(cls.concatenate1, (reshape557, reshape558, reshape559), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape560 = R.call_tir(cls.reshape15, (concat17,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv103_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape560), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape561 = R.call_tir(cls.reshape16, (lv103_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape562 = R.call_tir(cls.reshape17, (reshape561,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv520 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_out_proj_weight2, reshape562, model_decoder_layers_17_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add431 = R.call_tir(cls.add5, (add427, lv520), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm117 = R.call_tir(cls.layer_norm2, (add431, model_decoder_layers_17_encoder_attn_layer_norm_weight2, model_decoder_layers_17_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv521 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_q_proj_weight2, layer_norm117, model_decoder_layers_17_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape563 = R.call_tir(cls.reshape14, (lv521,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape564 = R.call_tir(cls.reshape18, (reshape563,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv104_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape564), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape565 = R.call_tir(cls.reshape16, (lv104_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape566 = R.call_tir(cls.reshape17, (reshape565,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv522 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_out_proj_weight2, reshape566, model_decoder_layers_17_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add434 = R.call_tir(cls.add5, (add431, lv522), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm118 = R.call_tir(cls.layer_norm2, (add434, model_decoder_layers_17_final_layer_norm_weight2, model_decoder_layers_17_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv81_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_17_fc1_weight2, layer_norm118, model_decoder_layers_17_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv523 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_17_fc2_weight2, lv81_1, model_decoder_layers_17_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add437 = R.call_tir(cls.add5, (add434, lv523), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm119 = R.call_tir(cls.layer_norm2, (add437, model_decoder_layers_18_self_attn_layer_norm_weight2, model_decoder_layers_18_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv524 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_q_proj_weight2, layer_norm119, model_decoder_layers_18_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape567 = R.call_tir(cls.reshape14, (lv524,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv116 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_18_self_attn_k_proj_weight2, layer_norm119), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape568 = R.call_tir(cls.reshape14, (lv116,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv525 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_v_proj_weight2, layer_norm119, model_decoder_layers_18_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape569 = R.call_tir(cls.reshape14, (lv525,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat18 = R.call_tir(cls.concatenate1, (reshape567, reshape568, reshape569), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape570 = R.call_tir(cls.reshape15, (concat18,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv105_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape570), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape571 = R.call_tir(cls.reshape16, (lv105_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape572 = R.call_tir(cls.reshape17, (reshape571,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv526 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_out_proj_weight2, reshape572, model_decoder_layers_18_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add441 = R.call_tir(cls.add5, (add437, lv526), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm120 = R.call_tir(cls.layer_norm2, (add441, model_decoder_layers_18_encoder_attn_layer_norm_weight2, model_decoder_layers_18_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv527 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_q_proj_weight2, layer_norm120, model_decoder_layers_18_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape573 = R.call_tir(cls.reshape14, (lv527,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape574 = R.call_tir(cls.reshape18, (reshape573,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv106_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape574), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape575 = R.call_tir(cls.reshape16, (lv106_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape576 = R.call_tir(cls.reshape17, (reshape575,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv528 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_out_proj_weight2, reshape576, model_decoder_layers_18_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add444 = R.call_tir(cls.add5, (add441, lv528), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm121 = R.call_tir(cls.layer_norm2, (add444, model_decoder_layers_18_final_layer_norm_weight2, model_decoder_layers_18_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv82_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_18_fc1_weight2, layer_norm121, model_decoder_layers_18_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv529 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_18_fc2_weight2, lv82_1, model_decoder_layers_18_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add447 = R.call_tir(cls.add5, (add444, lv529), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm122 = R.call_tir(cls.layer_norm2, (add447, model_decoder_layers_19_self_attn_layer_norm_weight2, model_decoder_layers_19_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv530 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_q_proj_weight2, layer_norm122, model_decoder_layers_19_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape577 = R.call_tir(cls.reshape14, (lv530,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv117 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_19_self_attn_k_proj_weight2, layer_norm122), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape578 = R.call_tir(cls.reshape14, (lv117,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv531 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_v_proj_weight2, layer_norm122, model_decoder_layers_19_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape579 = R.call_tir(cls.reshape14, (lv531,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat19 = R.call_tir(cls.concatenate1, (reshape577, reshape578, reshape579), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape580 = R.call_tir(cls.reshape15, (concat19,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv107_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape580), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape581 = R.call_tir(cls.reshape16, (lv107_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape582 = R.call_tir(cls.reshape17, (reshape581,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv532 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_out_proj_weight2, reshape582, model_decoder_layers_19_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add451 = R.call_tir(cls.add5, (add447, lv532), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm123 = R.call_tir(cls.layer_norm2, (add451, model_decoder_layers_19_encoder_attn_layer_norm_weight2, model_decoder_layers_19_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv533 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_q_proj_weight2, layer_norm123, model_decoder_layers_19_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape583 = R.call_tir(cls.reshape14, (lv533,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape584 = R.call_tir(cls.reshape18, (reshape583,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv108_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape584), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape585 = R.call_tir(cls.reshape16, (lv108_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape586 = R.call_tir(cls.reshape17, (reshape585,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv534 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_out_proj_weight2, reshape586, model_decoder_layers_19_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add454 = R.call_tir(cls.add5, (add451, lv534), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm124 = R.call_tir(cls.layer_norm2, (add454, model_decoder_layers_19_final_layer_norm_weight2, model_decoder_layers_19_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv83_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_19_fc1_weight2, layer_norm124, model_decoder_layers_19_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv535 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_19_fc2_weight2, lv83_1, model_decoder_layers_19_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add457 = R.call_tir(cls.add5, (add454, lv535), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm125 = R.call_tir(cls.layer_norm2, (add457, model_decoder_layers_20_self_attn_layer_norm_weight2, model_decoder_layers_20_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv536 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_q_proj_weight2, layer_norm125, model_decoder_layers_20_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape587 = R.call_tir(cls.reshape14, (lv536,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv118 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_20_self_attn_k_proj_weight2, layer_norm125), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape588 = R.call_tir(cls.reshape14, (lv118,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv537 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_v_proj_weight2, layer_norm125, model_decoder_layers_20_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape589 = R.call_tir(cls.reshape14, (lv537,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat20 = R.call_tir(cls.concatenate1, (reshape587, reshape588, reshape589), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape590 = R.call_tir(cls.reshape15, (concat20,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv109_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape590), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape591 = R.call_tir(cls.reshape16, (lv109_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape592 = R.call_tir(cls.reshape17, (reshape591,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv538 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_out_proj_weight2, reshape592, model_decoder_layers_20_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add461 = R.call_tir(cls.add5, (add457, lv538), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm126 = R.call_tir(cls.layer_norm2, (add461, model_decoder_layers_20_encoder_attn_layer_norm_weight2, model_decoder_layers_20_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv539 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_q_proj_weight2, layer_norm126, model_decoder_layers_20_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape593 = R.call_tir(cls.reshape14, (lv539,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape594 = R.call_tir(cls.reshape18, (reshape593,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv110_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape594), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape595 = R.call_tir(cls.reshape16, (lv110_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape596 = R.call_tir(cls.reshape17, (reshape595,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv540 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_out_proj_weight2, reshape596, model_decoder_layers_20_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add464 = R.call_tir(cls.add5, (add461, lv540), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm127 = R.call_tir(cls.layer_norm2, (add464, model_decoder_layers_20_final_layer_norm_weight2, model_decoder_layers_20_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv84_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_20_fc1_weight2, layer_norm127, model_decoder_layers_20_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv541 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_20_fc2_weight2, lv84_1, model_decoder_layers_20_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add467 = R.call_tir(cls.add5, (add464, lv541), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm128 = R.call_tir(cls.layer_norm2, (add467, model_decoder_layers_21_self_attn_layer_norm_weight2, model_decoder_layers_21_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv542 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_q_proj_weight2, layer_norm128, model_decoder_layers_21_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape597 = R.call_tir(cls.reshape14, (lv542,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv119 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_21_self_attn_k_proj_weight2, layer_norm128), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape598 = R.call_tir(cls.reshape14, (lv119,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv543 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_v_proj_weight2, layer_norm128, model_decoder_layers_21_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape599 = R.call_tir(cls.reshape14, (lv543,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat21 = R.call_tir(cls.concatenate1, (reshape597, reshape598, reshape599), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape600 = R.call_tir(cls.reshape15, (concat21,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv111_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape600), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape601 = R.call_tir(cls.reshape16, (lv111_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape602 = R.call_tir(cls.reshape17, (reshape601,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv544 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_out_proj_weight2, reshape602, model_decoder_layers_21_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add471 = R.call_tir(cls.add5, (add467, lv544), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm129 = R.call_tir(cls.layer_norm2, (add471, model_decoder_layers_21_encoder_attn_layer_norm_weight2, model_decoder_layers_21_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv545 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_q_proj_weight2, layer_norm129, model_decoder_layers_21_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape603 = R.call_tir(cls.reshape14, (lv545,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape604 = R.call_tir(cls.reshape18, (reshape603,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv112_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape604), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape605 = R.call_tir(cls.reshape16, (lv112_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape606 = R.call_tir(cls.reshape17, (reshape605,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv546 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_out_proj_weight2, reshape606, model_decoder_layers_21_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add474 = R.call_tir(cls.add5, (add471, lv546), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm130 = R.call_tir(cls.layer_norm2, (add474, model_decoder_layers_21_final_layer_norm_weight2, model_decoder_layers_21_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv85_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_21_fc1_weight2, layer_norm130, model_decoder_layers_21_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv547 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_21_fc2_weight2, lv85_1, model_decoder_layers_21_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add477 = R.call_tir(cls.add5, (add474, lv547), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm131 = R.call_tir(cls.layer_norm2, (add477, model_decoder_layers_22_self_attn_layer_norm_weight2, model_decoder_layers_22_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv548 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_q_proj_weight2, layer_norm131, model_decoder_layers_22_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape607 = R.call_tir(cls.reshape14, (lv548,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv120 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_22_self_attn_k_proj_weight2, layer_norm131), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape608 = R.call_tir(cls.reshape14, (lv120,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv549 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_v_proj_weight2, layer_norm131, model_decoder_layers_22_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape609 = R.call_tir(cls.reshape14, (lv549,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat22 = R.call_tir(cls.concatenate1, (reshape607, reshape608, reshape609), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape610 = R.call_tir(cls.reshape15, (concat22,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv113_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape610), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape611 = R.call_tir(cls.reshape16, (lv113_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape612 = R.call_tir(cls.reshape17, (reshape611,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv550 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_out_proj_weight2, reshape612, model_decoder_layers_22_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add481 = R.call_tir(cls.add5, (add477, lv550), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm132 = R.call_tir(cls.layer_norm2, (add481, model_decoder_layers_22_encoder_attn_layer_norm_weight2, model_decoder_layers_22_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv551 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_q_proj_weight2, layer_norm132, model_decoder_layers_22_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape613 = R.call_tir(cls.reshape14, (lv551,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape614 = R.call_tir(cls.reshape18, (reshape613,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv114_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape614), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape615 = R.call_tir(cls.reshape16, (lv114_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape616 = R.call_tir(cls.reshape17, (reshape615,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv552 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_out_proj_weight2, reshape616, model_decoder_layers_22_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add484 = R.call_tir(cls.add5, (add481, lv552), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm133 = R.call_tir(cls.layer_norm2, (add484, model_decoder_layers_22_final_layer_norm_weight2, model_decoder_layers_22_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv86_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_22_fc1_weight2, layer_norm133, model_decoder_layers_22_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv553 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_22_fc2_weight2, lv86_1, model_decoder_layers_22_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add487 = R.call_tir(cls.add5, (add484, lv553), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm134 = R.call_tir(cls.layer_norm2, (add487, model_decoder_layers_23_self_attn_layer_norm_weight2, model_decoder_layers_23_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv554 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_q_proj_weight2, layer_norm134, model_decoder_layers_23_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape617 = R.call_tir(cls.reshape14, (lv554,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv121 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_23_self_attn_k_proj_weight2, layer_norm134), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape618 = R.call_tir(cls.reshape14, (lv121,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv555 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_v_proj_weight2, layer_norm134, model_decoder_layers_23_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape619 = R.call_tir(cls.reshape14, (lv555,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat23 = R.call_tir(cls.concatenate1, (reshape617, reshape618, reshape619), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape620 = R.call_tir(cls.reshape15, (concat23,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv115_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape620), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape621 = R.call_tir(cls.reshape16, (lv115_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape622 = R.call_tir(cls.reshape17, (reshape621,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv556 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_out_proj_weight2, reshape622, model_decoder_layers_23_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add491 = R.call_tir(cls.add5, (add487, lv556), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm135 = R.call_tir(cls.layer_norm2, (add491, model_decoder_layers_23_encoder_attn_layer_norm_weight2, model_decoder_layers_23_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv557 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_q_proj_weight2, layer_norm135, model_decoder_layers_23_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape623 = R.call_tir(cls.reshape14, (lv557,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape624 = R.call_tir(cls.reshape18, (reshape623,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv116_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape624), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape625 = R.call_tir(cls.reshape16, (lv116_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape626 = R.call_tir(cls.reshape17, (reshape625,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv558 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_out_proj_weight2, reshape626, model_decoder_layers_23_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add494 = R.call_tir(cls.add5, (add491, lv558), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm136 = R.call_tir(cls.layer_norm2, (add494, model_decoder_layers_23_final_layer_norm_weight2, model_decoder_layers_23_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv87_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_23_fc1_weight2, layer_norm136, model_decoder_layers_23_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv559 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_23_fc2_weight2, lv87_1, model_decoder_layers_23_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add497 = R.call_tir(cls.add5, (add494, lv559), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm137 = R.call_tir(cls.layer_norm2, (add497, model_decoder_layers_24_self_attn_layer_norm_weight2, model_decoder_layers_24_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv560 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_q_proj_weight2, layer_norm137, model_decoder_layers_24_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape627 = R.call_tir(cls.reshape14, (lv560,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv122 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_24_self_attn_k_proj_weight2, layer_norm137), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape628 = R.call_tir(cls.reshape14, (lv122,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv561 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_v_proj_weight2, layer_norm137, model_decoder_layers_24_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape629 = R.call_tir(cls.reshape14, (lv561,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat24 = R.call_tir(cls.concatenate1, (reshape627, reshape628, reshape629), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape630 = R.call_tir(cls.reshape15, (concat24,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv117_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape630), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape631 = R.call_tir(cls.reshape16, (lv117_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape632 = R.call_tir(cls.reshape17, (reshape631,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv562 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_out_proj_weight2, reshape632, model_decoder_layers_24_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add501 = R.call_tir(cls.add5, (add497, lv562), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm138 = R.call_tir(cls.layer_norm2, (add501, model_decoder_layers_24_encoder_attn_layer_norm_weight2, model_decoder_layers_24_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv563 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_q_proj_weight2, layer_norm138, model_decoder_layers_24_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape633 = R.call_tir(cls.reshape14, (lv563,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape634 = R.call_tir(cls.reshape18, (reshape633,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv118_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape634), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape635 = R.call_tir(cls.reshape16, (lv118_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape636 = R.call_tir(cls.reshape17, (reshape635,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv564 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_out_proj_weight2, reshape636, model_decoder_layers_24_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add504 = R.call_tir(cls.add5, (add501, lv564), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm139 = R.call_tir(cls.layer_norm2, (add504, model_decoder_layers_24_final_layer_norm_weight2, model_decoder_layers_24_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv88_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_24_fc1_weight2, layer_norm139, model_decoder_layers_24_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv565 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_24_fc2_weight2, lv88_1, model_decoder_layers_24_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add507 = R.call_tir(cls.add5, (add504, lv565), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm140 = R.call_tir(cls.layer_norm2, (add507, model_decoder_layers_25_self_attn_layer_norm_weight2, model_decoder_layers_25_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv566 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_q_proj_weight2, layer_norm140, model_decoder_layers_25_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape637 = R.call_tir(cls.reshape14, (lv566,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv123 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_25_self_attn_k_proj_weight2, layer_norm140), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape638 = R.call_tir(cls.reshape14, (lv123,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv567 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_v_proj_weight2, layer_norm140, model_decoder_layers_25_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape639 = R.call_tir(cls.reshape14, (lv567,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat25 = R.call_tir(cls.concatenate1, (reshape637, reshape638, reshape639), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape640 = R.call_tir(cls.reshape15, (concat25,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv119_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape640), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape641 = R.call_tir(cls.reshape16, (lv119_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape642 = R.call_tir(cls.reshape17, (reshape641,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv568 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_out_proj_weight2, reshape642, model_decoder_layers_25_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add511 = R.call_tir(cls.add5, (add507, lv568), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm141 = R.call_tir(cls.layer_norm2, (add511, model_decoder_layers_25_encoder_attn_layer_norm_weight2, model_decoder_layers_25_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv569 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_q_proj_weight2, layer_norm141, model_decoder_layers_25_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape643 = R.call_tir(cls.reshape14, (lv569,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape644 = R.call_tir(cls.reshape18, (reshape643,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv120_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape644), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape645 = R.call_tir(cls.reshape16, (lv120_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape646 = R.call_tir(cls.reshape17, (reshape645,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv570 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_out_proj_weight2, reshape646, model_decoder_layers_25_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add514 = R.call_tir(cls.add5, (add511, lv570), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm142 = R.call_tir(cls.layer_norm2, (add514, model_decoder_layers_25_final_layer_norm_weight2, model_decoder_layers_25_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv89_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_25_fc1_weight2, layer_norm142, model_decoder_layers_25_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv571 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_25_fc2_weight2, lv89_1, model_decoder_layers_25_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add517 = R.call_tir(cls.add5, (add514, lv571), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm143 = R.call_tir(cls.layer_norm2, (add517, model_decoder_layers_26_self_attn_layer_norm_weight2, model_decoder_layers_26_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv572 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_q_proj_weight2, layer_norm143, model_decoder_layers_26_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape647 = R.call_tir(cls.reshape14, (lv572,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv124 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_26_self_attn_k_proj_weight2, layer_norm143), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape648 = R.call_tir(cls.reshape14, (lv124,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv573 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_v_proj_weight2, layer_norm143, model_decoder_layers_26_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape649 = R.call_tir(cls.reshape14, (lv573,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat26 = R.call_tir(cls.concatenate1, (reshape647, reshape648, reshape649), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape650 = R.call_tir(cls.reshape15, (concat26,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv121_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape650), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape651 = R.call_tir(cls.reshape16, (lv121_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape652 = R.call_tir(cls.reshape17, (reshape651,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv574 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_out_proj_weight2, reshape652, model_decoder_layers_26_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add521 = R.call_tir(cls.add5, (add517, lv574), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm144 = R.call_tir(cls.layer_norm2, (add521, model_decoder_layers_26_encoder_attn_layer_norm_weight2, model_decoder_layers_26_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv575 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_q_proj_weight2, layer_norm144, model_decoder_layers_26_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape653 = R.call_tir(cls.reshape14, (lv575,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape654 = R.call_tir(cls.reshape18, (reshape653,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv122_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape654), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape655 = R.call_tir(cls.reshape16, (lv122_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape656 = R.call_tir(cls.reshape17, (reshape655,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv576 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_out_proj_weight2, reshape656, model_decoder_layers_26_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add524 = R.call_tir(cls.add5, (add521, lv576), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm145 = R.call_tir(cls.layer_norm2, (add524, model_decoder_layers_26_final_layer_norm_weight2, model_decoder_layers_26_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv90_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_26_fc1_weight2, layer_norm145, model_decoder_layers_26_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv577 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_26_fc2_weight2, lv90_1, model_decoder_layers_26_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add527 = R.call_tir(cls.add5, (add524, lv577), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm146 = R.call_tir(cls.layer_norm2, (add527, model_decoder_layers_27_self_attn_layer_norm_weight2, model_decoder_layers_27_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv578 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_q_proj_weight2, layer_norm146, model_decoder_layers_27_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape657 = R.call_tir(cls.reshape14, (lv578,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv125 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_27_self_attn_k_proj_weight2, layer_norm146), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape658 = R.call_tir(cls.reshape14, (lv125,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv579 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_v_proj_weight2, layer_norm146, model_decoder_layers_27_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape659 = R.call_tir(cls.reshape14, (lv579,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat27 = R.call_tir(cls.concatenate1, (reshape657, reshape658, reshape659), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape660 = R.call_tir(cls.reshape15, (concat27,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv123_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape660), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape661 = R.call_tir(cls.reshape16, (lv123_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape662 = R.call_tir(cls.reshape17, (reshape661,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv580 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_out_proj_weight2, reshape662, model_decoder_layers_27_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add531 = R.call_tir(cls.add5, (add527, lv580), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm147 = R.call_tir(cls.layer_norm2, (add531, model_decoder_layers_27_encoder_attn_layer_norm_weight2, model_decoder_layers_27_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv581 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_q_proj_weight2, layer_norm147, model_decoder_layers_27_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape663 = R.call_tir(cls.reshape14, (lv581,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape664 = R.call_tir(cls.reshape18, (reshape663,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv124_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape664), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape665 = R.call_tir(cls.reshape16, (lv124_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape666 = R.call_tir(cls.reshape17, (reshape665,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv582 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_out_proj_weight2, reshape666, model_decoder_layers_27_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add534 = R.call_tir(cls.add5, (add531, lv582), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm148 = R.call_tir(cls.layer_norm2, (add534, model_decoder_layers_27_final_layer_norm_weight2, model_decoder_layers_27_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv91_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_27_fc1_weight2, layer_norm148, model_decoder_layers_27_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv583 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_27_fc2_weight2, lv91_1, model_decoder_layers_27_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add537 = R.call_tir(cls.add5, (add534, lv583), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm149 = R.call_tir(cls.layer_norm2, (add537, model_decoder_layers_28_self_attn_layer_norm_weight2, model_decoder_layers_28_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv584 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_q_proj_weight2, layer_norm149, model_decoder_layers_28_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape667 = R.call_tir(cls.reshape14, (lv584,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv126 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_28_self_attn_k_proj_weight2, layer_norm149), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape668 = R.call_tir(cls.reshape14, (lv126,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv585 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_v_proj_weight2, layer_norm149, model_decoder_layers_28_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape669 = R.call_tir(cls.reshape14, (lv585,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat28 = R.call_tir(cls.concatenate1, (reshape667, reshape668, reshape669), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape670 = R.call_tir(cls.reshape15, (concat28,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv125_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape670), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape671 = R.call_tir(cls.reshape16, (lv125_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape672 = R.call_tir(cls.reshape17, (reshape671,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv586 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_out_proj_weight2, reshape672, model_decoder_layers_28_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add541 = R.call_tir(cls.add5, (add537, lv586), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm150 = R.call_tir(cls.layer_norm2, (add541, model_decoder_layers_28_encoder_attn_layer_norm_weight2, model_decoder_layers_28_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv587 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_q_proj_weight2, layer_norm150, model_decoder_layers_28_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape673 = R.call_tir(cls.reshape14, (lv587,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape674 = R.call_tir(cls.reshape18, (reshape673,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv126_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape674), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape675 = R.call_tir(cls.reshape16, (lv126_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape676 = R.call_tir(cls.reshape17, (reshape675,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv588 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_out_proj_weight2, reshape676, model_decoder_layers_28_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add544 = R.call_tir(cls.add5, (add541, lv588), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm151 = R.call_tir(cls.layer_norm2, (add544, model_decoder_layers_28_final_layer_norm_weight2, model_decoder_layers_28_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv92_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_28_fc1_weight2, layer_norm151, model_decoder_layers_28_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv589 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_28_fc2_weight2, lv92_1, model_decoder_layers_28_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add547 = R.call_tir(cls.add5, (add544, lv589), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm152 = R.call_tir(cls.layer_norm2, (add547, model_decoder_layers_29_self_attn_layer_norm_weight2, model_decoder_layers_29_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv590 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_q_proj_weight2, layer_norm152, model_decoder_layers_29_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape677 = R.call_tir(cls.reshape14, (lv590,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv127 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_29_self_attn_k_proj_weight2, layer_norm152), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape678 = R.call_tir(cls.reshape14, (lv127,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv591 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_v_proj_weight2, layer_norm152, model_decoder_layers_29_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape679 = R.call_tir(cls.reshape14, (lv591,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat29 = R.call_tir(cls.concatenate1, (reshape677, reshape678, reshape679), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape680 = R.call_tir(cls.reshape15, (concat29,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv127_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape680), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape681 = R.call_tir(cls.reshape16, (lv127_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape682 = R.call_tir(cls.reshape17, (reshape681,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv592 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_out_proj_weight2, reshape682, model_decoder_layers_29_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add551 = R.call_tir(cls.add5, (add547, lv592), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm153 = R.call_tir(cls.layer_norm2, (add551, model_decoder_layers_29_encoder_attn_layer_norm_weight2, model_decoder_layers_29_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv593 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_q_proj_weight2, layer_norm153, model_decoder_layers_29_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape683 = R.call_tir(cls.reshape14, (lv593,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape684 = R.call_tir(cls.reshape18, (reshape683,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv128 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape684), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape685 = R.call_tir(cls.reshape16, (lv128,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape686 = R.call_tir(cls.reshape17, (reshape685,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv594 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_out_proj_weight2, reshape686, model_decoder_layers_29_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add554 = R.call_tir(cls.add5, (add551, lv594), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm154 = R.call_tir(cls.layer_norm2, (add554, model_decoder_layers_29_final_layer_norm_weight2, model_decoder_layers_29_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv93_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_29_fc1_weight2, layer_norm154, model_decoder_layers_29_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv595 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_29_fc2_weight2, lv93_1, model_decoder_layers_29_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add557 = R.call_tir(cls.add5, (add554, lv595), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm155 = R.call_tir(cls.layer_norm2, (add557, model_decoder_layers_30_self_attn_layer_norm_weight2, model_decoder_layers_30_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv596 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_q_proj_weight2, layer_norm155, model_decoder_layers_30_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape687 = R.call_tir(cls.reshape14, (lv596,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv128_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_30_self_attn_k_proj_weight2, layer_norm155), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape688 = R.call_tir(cls.reshape14, (lv128_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv597 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_v_proj_weight2, layer_norm155, model_decoder_layers_30_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape689 = R.call_tir(cls.reshape14, (lv597,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat30 = R.call_tir(cls.concatenate1, (reshape687, reshape688, reshape689), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape690 = R.call_tir(cls.reshape15, (concat30,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv129 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape690), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape691 = R.call_tir(cls.reshape16, (lv129,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape692 = R.call_tir(cls.reshape17, (reshape691,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv598 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_out_proj_weight2, reshape692, model_decoder_layers_30_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add561 = R.call_tir(cls.add5, (add557, lv598), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm156 = R.call_tir(cls.layer_norm2, (add561, model_decoder_layers_30_encoder_attn_layer_norm_weight2, model_decoder_layers_30_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv599 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_q_proj_weight2, layer_norm156, model_decoder_layers_30_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape693 = R.call_tir(cls.reshape14, (lv599,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape694 = R.call_tir(cls.reshape18, (reshape693,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv130 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape694), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape695 = R.call_tir(cls.reshape16, (lv130,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape696 = R.call_tir(cls.reshape17, (reshape695,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv600 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_out_proj_weight2, reshape696, model_decoder_layers_30_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add564 = R.call_tir(cls.add5, (add561, lv600), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm157 = R.call_tir(cls.layer_norm2, (add564, model_decoder_layers_30_final_layer_norm_weight2, model_decoder_layers_30_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv94_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_30_fc1_weight2, layer_norm157, model_decoder_layers_30_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv601 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_30_fc2_weight2, lv94_1, model_decoder_layers_30_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add567 = R.call_tir(cls.add5, (add564, lv601), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm158 = R.call_tir(cls.layer_norm2, (add567, model_decoder_layers_31_self_attn_layer_norm_weight2, model_decoder_layers_31_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv602 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_q_proj_weight2, layer_norm158, model_decoder_layers_31_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape697 = R.call_tir(cls.reshape14, (lv602,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv129_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_31_self_attn_k_proj_weight2, layer_norm158), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape698 = R.call_tir(cls.reshape14, (lv129_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv603 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_v_proj_weight2, layer_norm158, model_decoder_layers_31_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape699 = R.call_tir(cls.reshape14, (lv603,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat31 = R.call_tir(cls.concatenate1, (reshape697, reshape698, reshape699), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape700 = R.call_tir(cls.reshape15, (concat31,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv131 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape700), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape701 = R.call_tir(cls.reshape16, (lv131,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape702 = R.call_tir(cls.reshape17, (reshape701,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv604 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_out_proj_weight2, reshape702, model_decoder_layers_31_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add571 = R.call_tir(cls.add5, (add567, lv604), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm159 = R.call_tir(cls.layer_norm2, (add571, model_decoder_layers_31_encoder_attn_layer_norm_weight2, model_decoder_layers_31_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv605 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_q_proj_weight2, layer_norm159, model_decoder_layers_31_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape703 = R.call_tir(cls.reshape14, (lv605,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape704 = R.call_tir(cls.reshape18, (reshape703,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv132 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape704), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape705 = R.call_tir(cls.reshape16, (lv132,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape706 = R.call_tir(cls.reshape17, (reshape705,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv606 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_out_proj_weight2, reshape706, model_decoder_layers_31_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add574 = R.call_tir(cls.add5, (add571, lv606), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm160 = R.call_tir(cls.layer_norm2, (add574, model_decoder_layers_31_final_layer_norm_weight2, model_decoder_layers_31_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv95_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_31_fc1_weight2, layer_norm160, model_decoder_layers_31_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv607 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_31_fc2_weight2, lv95_1, model_decoder_layers_31_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add577 = R.call_tir(cls.add5, (add574, lv607), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm161 = R.call_tir(cls.layer_norm2, (add577, model_decoder_layer_norm_weight2, model_decoder_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + take2 = R.call_tir(cls.take2, (layer_norm161, logit_positions), out_sinfo=R.Tensor((1, batch_size, 1280), dtype="float16")) + gv2 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul5_cublas", (model_decoder_embed_tokens_weight2, take2), out_sinfo=R.Tensor((1, batch_size, 51866), dtype="float32")) + R.output(gv2) + return gv2 + + @R.function + def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object: + max_batch_size = T.int64() + max_total_seq_len = T.int64() + prefill_chunk_size = T.int64() + page_size = T.int64() + support_sliding_window = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.prim_value(32), R.prim_value(20), R.prim_value(20), R.prim_value(64), R.prim_value(0), R.prim_value(1), R.prim_value(1), R.const(0, "float16"), cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, sinfo_args=(R.Object,)) + return paged_kv_cache + + @R.function + def decode(input_ids: R.Tensor((1, 1), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor((1, 1, 51866), dtype="float32"): + R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + with R.dataflow(): + model_decoder_embed_tokens_weight5: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] + model_decoder_embed_positions_weight5: R.Tensor((448, 1280), dtype="float16") = packed_params[488] + model_decoder_layers_0_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] + model_decoder_layers_0_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] + model_decoder_layers_0_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[491] + model_decoder_layers_0_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] + model_decoder_layers_0_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[493] + model_decoder_layers_0_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] + model_decoder_layers_0_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[495] + model_decoder_layers_0_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[496] + model_decoder_layers_0_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[497] + model_decoder_layers_0_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] + model_decoder_layers_0_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[502] + model_decoder_layers_0_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] + model_decoder_layers_0_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[504] + model_decoder_layers_0_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[505] + model_decoder_layers_0_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[506] + model_decoder_layers_0_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] + model_decoder_layers_0_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[508] + model_decoder_layers_0_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] + model_decoder_layers_0_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[510] + model_decoder_layers_0_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[511] + model_decoder_layers_0_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[512] + model_decoder_layers_1_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] + model_decoder_layers_1_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] + model_decoder_layers_1_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[515] + model_decoder_layers_1_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] + model_decoder_layers_1_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[517] + model_decoder_layers_1_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] + model_decoder_layers_1_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[519] + model_decoder_layers_1_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[520] + model_decoder_layers_1_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[521] + model_decoder_layers_1_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] + model_decoder_layers_1_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[526] + model_decoder_layers_1_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] + model_decoder_layers_1_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[528] + model_decoder_layers_1_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[529] + model_decoder_layers_1_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[530] + model_decoder_layers_1_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] + model_decoder_layers_1_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[532] + model_decoder_layers_1_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] + model_decoder_layers_1_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[534] + model_decoder_layers_1_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[535] + model_decoder_layers_1_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[536] + model_decoder_layers_2_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] + model_decoder_layers_2_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] + model_decoder_layers_2_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[539] + model_decoder_layers_2_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] + model_decoder_layers_2_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[541] + model_decoder_layers_2_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] + model_decoder_layers_2_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[543] + model_decoder_layers_2_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[544] + model_decoder_layers_2_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[545] + model_decoder_layers_2_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] + model_decoder_layers_2_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[550] + model_decoder_layers_2_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] + model_decoder_layers_2_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[552] + model_decoder_layers_2_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[553] + model_decoder_layers_2_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[554] + model_decoder_layers_2_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] + model_decoder_layers_2_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[556] + model_decoder_layers_2_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] + model_decoder_layers_2_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[558] + model_decoder_layers_2_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[559] + model_decoder_layers_2_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[560] + model_decoder_layers_3_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] + model_decoder_layers_3_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] + model_decoder_layers_3_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[563] + model_decoder_layers_3_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] + model_decoder_layers_3_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[565] + model_decoder_layers_3_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] + model_decoder_layers_3_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[567] + model_decoder_layers_3_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[568] + model_decoder_layers_3_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[569] + model_decoder_layers_3_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] + model_decoder_layers_3_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[574] + model_decoder_layers_3_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] + model_decoder_layers_3_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[576] + model_decoder_layers_3_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[577] + model_decoder_layers_3_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[578] + model_decoder_layers_3_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] + model_decoder_layers_3_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[580] + model_decoder_layers_3_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] + model_decoder_layers_3_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[582] + model_decoder_layers_3_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[583] + model_decoder_layers_3_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[584] + model_decoder_layers_4_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] + model_decoder_layers_4_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] + model_decoder_layers_4_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[587] + model_decoder_layers_4_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] + model_decoder_layers_4_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[589] + model_decoder_layers_4_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] + model_decoder_layers_4_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[591] + model_decoder_layers_4_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[592] + model_decoder_layers_4_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[593] + model_decoder_layers_4_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] + model_decoder_layers_4_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[598] + model_decoder_layers_4_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] + model_decoder_layers_4_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[600] + model_decoder_layers_4_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[601] + model_decoder_layers_4_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[602] + model_decoder_layers_4_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] + model_decoder_layers_4_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[604] + model_decoder_layers_4_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] + model_decoder_layers_4_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[606] + model_decoder_layers_4_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[607] + model_decoder_layers_4_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[608] + model_decoder_layers_5_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] + model_decoder_layers_5_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] + model_decoder_layers_5_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[611] + model_decoder_layers_5_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] + model_decoder_layers_5_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[613] + model_decoder_layers_5_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] + model_decoder_layers_5_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[615] + model_decoder_layers_5_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[616] + model_decoder_layers_5_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[617] + model_decoder_layers_5_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] + model_decoder_layers_5_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[622] + model_decoder_layers_5_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] + model_decoder_layers_5_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[624] + model_decoder_layers_5_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[625] + model_decoder_layers_5_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[626] + model_decoder_layers_5_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] + model_decoder_layers_5_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[628] + model_decoder_layers_5_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] + model_decoder_layers_5_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[630] + model_decoder_layers_5_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[631] + model_decoder_layers_5_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[632] + model_decoder_layers_6_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] + model_decoder_layers_6_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] + model_decoder_layers_6_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[635] + model_decoder_layers_6_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] + model_decoder_layers_6_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[637] + model_decoder_layers_6_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] + model_decoder_layers_6_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[639] + model_decoder_layers_6_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[640] + model_decoder_layers_6_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[641] + model_decoder_layers_6_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] + model_decoder_layers_6_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[646] + model_decoder_layers_6_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] + model_decoder_layers_6_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[648] + model_decoder_layers_6_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[649] + model_decoder_layers_6_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[650] + model_decoder_layers_6_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] + model_decoder_layers_6_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[652] + model_decoder_layers_6_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] + model_decoder_layers_6_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[654] + model_decoder_layers_6_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[655] + model_decoder_layers_6_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[656] + model_decoder_layers_7_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] + model_decoder_layers_7_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] + model_decoder_layers_7_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[659] + model_decoder_layers_7_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] + model_decoder_layers_7_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[661] + model_decoder_layers_7_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] + model_decoder_layers_7_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[663] + model_decoder_layers_7_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[664] + model_decoder_layers_7_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[665] + model_decoder_layers_7_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] + model_decoder_layers_7_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[670] + model_decoder_layers_7_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] + model_decoder_layers_7_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[672] + model_decoder_layers_7_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[673] + model_decoder_layers_7_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[674] + model_decoder_layers_7_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] + model_decoder_layers_7_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[676] + model_decoder_layers_7_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] + model_decoder_layers_7_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[678] + model_decoder_layers_7_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[679] + model_decoder_layers_7_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[680] + model_decoder_layers_8_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] + model_decoder_layers_8_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] + model_decoder_layers_8_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[683] + model_decoder_layers_8_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] + model_decoder_layers_8_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[685] + model_decoder_layers_8_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] + model_decoder_layers_8_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[687] + model_decoder_layers_8_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[688] + model_decoder_layers_8_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[689] + model_decoder_layers_8_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] + model_decoder_layers_8_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[694] + model_decoder_layers_8_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] + model_decoder_layers_8_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[696] + model_decoder_layers_8_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[697] + model_decoder_layers_8_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[698] + model_decoder_layers_8_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] + model_decoder_layers_8_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[700] + model_decoder_layers_8_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] + model_decoder_layers_8_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[702] + model_decoder_layers_8_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[703] + model_decoder_layers_8_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[704] + model_decoder_layers_9_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] + model_decoder_layers_9_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] + model_decoder_layers_9_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[707] + model_decoder_layers_9_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] + model_decoder_layers_9_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[709] + model_decoder_layers_9_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] + model_decoder_layers_9_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[711] + model_decoder_layers_9_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[712] + model_decoder_layers_9_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[713] + model_decoder_layers_9_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] + model_decoder_layers_9_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[718] + model_decoder_layers_9_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] + model_decoder_layers_9_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[720] + model_decoder_layers_9_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[721] + model_decoder_layers_9_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[722] + model_decoder_layers_9_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] + model_decoder_layers_9_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[724] + model_decoder_layers_9_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] + model_decoder_layers_9_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[726] + model_decoder_layers_9_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[727] + model_decoder_layers_9_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[728] + model_decoder_layers_10_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] + model_decoder_layers_10_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] + model_decoder_layers_10_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[731] + model_decoder_layers_10_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] + model_decoder_layers_10_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[733] + model_decoder_layers_10_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] + model_decoder_layers_10_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[735] + model_decoder_layers_10_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[736] + model_decoder_layers_10_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[737] + model_decoder_layers_10_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] + model_decoder_layers_10_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[742] + model_decoder_layers_10_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] + model_decoder_layers_10_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[744] + model_decoder_layers_10_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[745] + model_decoder_layers_10_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[746] + model_decoder_layers_10_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] + model_decoder_layers_10_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[748] + model_decoder_layers_10_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] + model_decoder_layers_10_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[750] + model_decoder_layers_10_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[751] + model_decoder_layers_10_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[752] + model_decoder_layers_11_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] + model_decoder_layers_11_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] + model_decoder_layers_11_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[755] + model_decoder_layers_11_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] + model_decoder_layers_11_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[757] + model_decoder_layers_11_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] + model_decoder_layers_11_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[759] + model_decoder_layers_11_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[760] + model_decoder_layers_11_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[761] + model_decoder_layers_11_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] + model_decoder_layers_11_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[766] + model_decoder_layers_11_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] + model_decoder_layers_11_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[768] + model_decoder_layers_11_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[769] + model_decoder_layers_11_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[770] + model_decoder_layers_11_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] + model_decoder_layers_11_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[772] + model_decoder_layers_11_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] + model_decoder_layers_11_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[774] + model_decoder_layers_11_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[775] + model_decoder_layers_11_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[776] + model_decoder_layers_12_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] + model_decoder_layers_12_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] + model_decoder_layers_12_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[779] + model_decoder_layers_12_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] + model_decoder_layers_12_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[781] + model_decoder_layers_12_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] + model_decoder_layers_12_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[783] + model_decoder_layers_12_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[784] + model_decoder_layers_12_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[785] + model_decoder_layers_12_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] + model_decoder_layers_12_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[790] + model_decoder_layers_12_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] + model_decoder_layers_12_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[792] + model_decoder_layers_12_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[793] + model_decoder_layers_12_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[794] + model_decoder_layers_12_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] + model_decoder_layers_12_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[796] + model_decoder_layers_12_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] + model_decoder_layers_12_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[798] + model_decoder_layers_12_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[799] + model_decoder_layers_12_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[800] + model_decoder_layers_13_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] + model_decoder_layers_13_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] + model_decoder_layers_13_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[803] + model_decoder_layers_13_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] + model_decoder_layers_13_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[805] + model_decoder_layers_13_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] + model_decoder_layers_13_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[807] + model_decoder_layers_13_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[808] + model_decoder_layers_13_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[809] + model_decoder_layers_13_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] + model_decoder_layers_13_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[814] + model_decoder_layers_13_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] + model_decoder_layers_13_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[816] + model_decoder_layers_13_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[817] + model_decoder_layers_13_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[818] + model_decoder_layers_13_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] + model_decoder_layers_13_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[820] + model_decoder_layers_13_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] + model_decoder_layers_13_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[822] + model_decoder_layers_13_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[823] + model_decoder_layers_13_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[824] + model_decoder_layers_14_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] + model_decoder_layers_14_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] + model_decoder_layers_14_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[827] + model_decoder_layers_14_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] + model_decoder_layers_14_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[829] + model_decoder_layers_14_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] + model_decoder_layers_14_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[831] + model_decoder_layers_14_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[832] + model_decoder_layers_14_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[833] + model_decoder_layers_14_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] + model_decoder_layers_14_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[838] + model_decoder_layers_14_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] + model_decoder_layers_14_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[840] + model_decoder_layers_14_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[841] + model_decoder_layers_14_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[842] + model_decoder_layers_14_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] + model_decoder_layers_14_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[844] + model_decoder_layers_14_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] + model_decoder_layers_14_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[846] + model_decoder_layers_14_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[847] + model_decoder_layers_14_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[848] + model_decoder_layers_15_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] + model_decoder_layers_15_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] + model_decoder_layers_15_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[851] + model_decoder_layers_15_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] + model_decoder_layers_15_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[853] + model_decoder_layers_15_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] + model_decoder_layers_15_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[855] + model_decoder_layers_15_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[856] + model_decoder_layers_15_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[857] + model_decoder_layers_15_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] + model_decoder_layers_15_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[862] + model_decoder_layers_15_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] + model_decoder_layers_15_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[864] + model_decoder_layers_15_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[865] + model_decoder_layers_15_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[866] + model_decoder_layers_15_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] + model_decoder_layers_15_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[868] + model_decoder_layers_15_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] + model_decoder_layers_15_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[870] + model_decoder_layers_15_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[871] + model_decoder_layers_15_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[872] + model_decoder_layers_16_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] + model_decoder_layers_16_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] + model_decoder_layers_16_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[875] + model_decoder_layers_16_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] + model_decoder_layers_16_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[877] + model_decoder_layers_16_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] + model_decoder_layers_16_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[879] + model_decoder_layers_16_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[880] + model_decoder_layers_16_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[881] + model_decoder_layers_16_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] + model_decoder_layers_16_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[886] + model_decoder_layers_16_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] + model_decoder_layers_16_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[888] + model_decoder_layers_16_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[889] + model_decoder_layers_16_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[890] + model_decoder_layers_16_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] + model_decoder_layers_16_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[892] + model_decoder_layers_16_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] + model_decoder_layers_16_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[894] + model_decoder_layers_16_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[895] + model_decoder_layers_16_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[896] + model_decoder_layers_17_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] + model_decoder_layers_17_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] + model_decoder_layers_17_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[899] + model_decoder_layers_17_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] + model_decoder_layers_17_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[901] + model_decoder_layers_17_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] + model_decoder_layers_17_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[903] + model_decoder_layers_17_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[904] + model_decoder_layers_17_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[905] + model_decoder_layers_17_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] + model_decoder_layers_17_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[910] + model_decoder_layers_17_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] + model_decoder_layers_17_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[912] + model_decoder_layers_17_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[913] + model_decoder_layers_17_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[914] + model_decoder_layers_17_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] + model_decoder_layers_17_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[916] + model_decoder_layers_17_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] + model_decoder_layers_17_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[918] + model_decoder_layers_17_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[919] + model_decoder_layers_17_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[920] + model_decoder_layers_18_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] + model_decoder_layers_18_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] + model_decoder_layers_18_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[923] + model_decoder_layers_18_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] + model_decoder_layers_18_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[925] + model_decoder_layers_18_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] + model_decoder_layers_18_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[927] + model_decoder_layers_18_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[928] + model_decoder_layers_18_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[929] + model_decoder_layers_18_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] + model_decoder_layers_18_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[934] + model_decoder_layers_18_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] + model_decoder_layers_18_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[936] + model_decoder_layers_18_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[937] + model_decoder_layers_18_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[938] + model_decoder_layers_18_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] + model_decoder_layers_18_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[940] + model_decoder_layers_18_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] + model_decoder_layers_18_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[942] + model_decoder_layers_18_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[943] + model_decoder_layers_18_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[944] + model_decoder_layers_19_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] + model_decoder_layers_19_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] + model_decoder_layers_19_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[947] + model_decoder_layers_19_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] + model_decoder_layers_19_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[949] + model_decoder_layers_19_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] + model_decoder_layers_19_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[951] + model_decoder_layers_19_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[952] + model_decoder_layers_19_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[953] + model_decoder_layers_19_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] + model_decoder_layers_19_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[958] + model_decoder_layers_19_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] + model_decoder_layers_19_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[960] + model_decoder_layers_19_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[961] + model_decoder_layers_19_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[962] + model_decoder_layers_19_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] + model_decoder_layers_19_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[964] + model_decoder_layers_19_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] + model_decoder_layers_19_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[966] + model_decoder_layers_19_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[967] + model_decoder_layers_19_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[968] + model_decoder_layers_20_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] + model_decoder_layers_20_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] + model_decoder_layers_20_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[971] + model_decoder_layers_20_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] + model_decoder_layers_20_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[973] + model_decoder_layers_20_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] + model_decoder_layers_20_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[975] + model_decoder_layers_20_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[976] + model_decoder_layers_20_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[977] + model_decoder_layers_20_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] + model_decoder_layers_20_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[982] + model_decoder_layers_20_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] + model_decoder_layers_20_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[984] + model_decoder_layers_20_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[985] + model_decoder_layers_20_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[986] + model_decoder_layers_20_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] + model_decoder_layers_20_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[988] + model_decoder_layers_20_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] + model_decoder_layers_20_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[990] + model_decoder_layers_20_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[991] + model_decoder_layers_20_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[992] + model_decoder_layers_21_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] + model_decoder_layers_21_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] + model_decoder_layers_21_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[995] + model_decoder_layers_21_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] + model_decoder_layers_21_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[997] + model_decoder_layers_21_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] + model_decoder_layers_21_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[999] + model_decoder_layers_21_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1000] + model_decoder_layers_21_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1001] + model_decoder_layers_21_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] + model_decoder_layers_21_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1006] + model_decoder_layers_21_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] + model_decoder_layers_21_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1008] + model_decoder_layers_21_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1009] + model_decoder_layers_21_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1010] + model_decoder_layers_21_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] + model_decoder_layers_21_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1012] + model_decoder_layers_21_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] + model_decoder_layers_21_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1014] + model_decoder_layers_21_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1015] + model_decoder_layers_21_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1016] + model_decoder_layers_22_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] + model_decoder_layers_22_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] + model_decoder_layers_22_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1019] + model_decoder_layers_22_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] + model_decoder_layers_22_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1021] + model_decoder_layers_22_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] + model_decoder_layers_22_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1023] + model_decoder_layers_22_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1024] + model_decoder_layers_22_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1025] + model_decoder_layers_22_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] + model_decoder_layers_22_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1030] + model_decoder_layers_22_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] + model_decoder_layers_22_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1032] + model_decoder_layers_22_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1033] + model_decoder_layers_22_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1034] + model_decoder_layers_22_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] + model_decoder_layers_22_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1036] + model_decoder_layers_22_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] + model_decoder_layers_22_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1038] + model_decoder_layers_22_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1039] + model_decoder_layers_22_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1040] + model_decoder_layers_23_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] + model_decoder_layers_23_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] + model_decoder_layers_23_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1043] + model_decoder_layers_23_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] + model_decoder_layers_23_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1045] + model_decoder_layers_23_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] + model_decoder_layers_23_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1047] + model_decoder_layers_23_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1048] + model_decoder_layers_23_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1049] + model_decoder_layers_23_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] + model_decoder_layers_23_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1054] + model_decoder_layers_23_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] + model_decoder_layers_23_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1056] + model_decoder_layers_23_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1057] + model_decoder_layers_23_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1058] + model_decoder_layers_23_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] + model_decoder_layers_23_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1060] + model_decoder_layers_23_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] + model_decoder_layers_23_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1062] + model_decoder_layers_23_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1063] + model_decoder_layers_23_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1064] + model_decoder_layers_24_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] + model_decoder_layers_24_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] + model_decoder_layers_24_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1067] + model_decoder_layers_24_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] + model_decoder_layers_24_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1069] + model_decoder_layers_24_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] + model_decoder_layers_24_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1071] + model_decoder_layers_24_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1072] + model_decoder_layers_24_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1073] + model_decoder_layers_24_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] + model_decoder_layers_24_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1078] + model_decoder_layers_24_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] + model_decoder_layers_24_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1080] + model_decoder_layers_24_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1081] + model_decoder_layers_24_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1082] + model_decoder_layers_24_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] + model_decoder_layers_24_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1084] + model_decoder_layers_24_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] + model_decoder_layers_24_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1086] + model_decoder_layers_24_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1087] + model_decoder_layers_24_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1088] + model_decoder_layers_25_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] + model_decoder_layers_25_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] + model_decoder_layers_25_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1091] + model_decoder_layers_25_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] + model_decoder_layers_25_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1093] + model_decoder_layers_25_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] + model_decoder_layers_25_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1095] + model_decoder_layers_25_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1096] + model_decoder_layers_25_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1097] + model_decoder_layers_25_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] + model_decoder_layers_25_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1102] + model_decoder_layers_25_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] + model_decoder_layers_25_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1104] + model_decoder_layers_25_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1105] + model_decoder_layers_25_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1106] + model_decoder_layers_25_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] + model_decoder_layers_25_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1108] + model_decoder_layers_25_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] + model_decoder_layers_25_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1110] + model_decoder_layers_25_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1111] + model_decoder_layers_25_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1112] + model_decoder_layers_26_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] + model_decoder_layers_26_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] + model_decoder_layers_26_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1115] + model_decoder_layers_26_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] + model_decoder_layers_26_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1117] + model_decoder_layers_26_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] + model_decoder_layers_26_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1119] + model_decoder_layers_26_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1120] + model_decoder_layers_26_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1121] + model_decoder_layers_26_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] + model_decoder_layers_26_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1126] + model_decoder_layers_26_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] + model_decoder_layers_26_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1128] + model_decoder_layers_26_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1129] + model_decoder_layers_26_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1130] + model_decoder_layers_26_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] + model_decoder_layers_26_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1132] + model_decoder_layers_26_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] + model_decoder_layers_26_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1134] + model_decoder_layers_26_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1135] + model_decoder_layers_26_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1136] + model_decoder_layers_27_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] + model_decoder_layers_27_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] + model_decoder_layers_27_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1139] + model_decoder_layers_27_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] + model_decoder_layers_27_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1141] + model_decoder_layers_27_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] + model_decoder_layers_27_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1143] + model_decoder_layers_27_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1144] + model_decoder_layers_27_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1145] + model_decoder_layers_27_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] + model_decoder_layers_27_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1150] + model_decoder_layers_27_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] + model_decoder_layers_27_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1152] + model_decoder_layers_27_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1153] + model_decoder_layers_27_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1154] + model_decoder_layers_27_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] + model_decoder_layers_27_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1156] + model_decoder_layers_27_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] + model_decoder_layers_27_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1158] + model_decoder_layers_27_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1159] + model_decoder_layers_27_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1160] + model_decoder_layers_28_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] + model_decoder_layers_28_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] + model_decoder_layers_28_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1163] + model_decoder_layers_28_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] + model_decoder_layers_28_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1165] + model_decoder_layers_28_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] + model_decoder_layers_28_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1167] + model_decoder_layers_28_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1168] + model_decoder_layers_28_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1169] + model_decoder_layers_28_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] + model_decoder_layers_28_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1174] + model_decoder_layers_28_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] + model_decoder_layers_28_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1176] + model_decoder_layers_28_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1177] + model_decoder_layers_28_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1178] + model_decoder_layers_28_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] + model_decoder_layers_28_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1180] + model_decoder_layers_28_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] + model_decoder_layers_28_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1182] + model_decoder_layers_28_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1183] + model_decoder_layers_28_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1184] + model_decoder_layers_29_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] + model_decoder_layers_29_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] + model_decoder_layers_29_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1187] + model_decoder_layers_29_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] + model_decoder_layers_29_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1189] + model_decoder_layers_29_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] + model_decoder_layers_29_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1191] + model_decoder_layers_29_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1192] + model_decoder_layers_29_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1193] + model_decoder_layers_29_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] + model_decoder_layers_29_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1198] + model_decoder_layers_29_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] + model_decoder_layers_29_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1200] + model_decoder_layers_29_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1201] + model_decoder_layers_29_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1202] + model_decoder_layers_29_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] + model_decoder_layers_29_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1204] + model_decoder_layers_29_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] + model_decoder_layers_29_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1206] + model_decoder_layers_29_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1207] + model_decoder_layers_29_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1208] + model_decoder_layers_30_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] + model_decoder_layers_30_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] + model_decoder_layers_30_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1211] + model_decoder_layers_30_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] + model_decoder_layers_30_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1213] + model_decoder_layers_30_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] + model_decoder_layers_30_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1215] + model_decoder_layers_30_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1216] + model_decoder_layers_30_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1217] + model_decoder_layers_30_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] + model_decoder_layers_30_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1222] + model_decoder_layers_30_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] + model_decoder_layers_30_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1224] + model_decoder_layers_30_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1225] + model_decoder_layers_30_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1226] + model_decoder_layers_30_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] + model_decoder_layers_30_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1228] + model_decoder_layers_30_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] + model_decoder_layers_30_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1230] + model_decoder_layers_30_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1231] + model_decoder_layers_30_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1232] + model_decoder_layers_31_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] + model_decoder_layers_31_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] + model_decoder_layers_31_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1235] + model_decoder_layers_31_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] + model_decoder_layers_31_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1237] + model_decoder_layers_31_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] + model_decoder_layers_31_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1239] + model_decoder_layers_31_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1240] + model_decoder_layers_31_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1241] + model_decoder_layers_31_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] + model_decoder_layers_31_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1246] + model_decoder_layers_31_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] + model_decoder_layers_31_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1248] + model_decoder_layers_31_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1249] + model_decoder_layers_31_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1250] + model_decoder_layers_31_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] + model_decoder_layers_31_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1252] + model_decoder_layers_31_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] + model_decoder_layers_31_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1254] + model_decoder_layers_31_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1255] + model_decoder_layers_31_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1256] + model_decoder_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1257] + model_decoder_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1258] + reshape1353 = R.call_tir(cls.reshape19, (input_ids,), out_sinfo=R.Tensor((1,), dtype="int32")) + take7 = R.call_tir(cls.take3, (model_decoder_embed_tokens_weight5, reshape1353), out_sinfo=R.Tensor((1, 1280), dtype="float16")) + lv264: R.Tensor((1,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((1,), dtype="int32"),)) + take8 = R.call_tir(cls.take4, (model_decoder_embed_positions_weight5, lv264), out_sinfo=R.Tensor((1, 1280), dtype="float16")) + lv40 = R.call_tir(cls.fused_reshape20_reshape20_add6, (take7, take8), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm356 = R.call_tir(cls.layer_norm3, (lv40, model_decoder_layers_0_self_attn_layer_norm_weight5, model_decoder_layers_0_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv41 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm356, model_decoder_layers_0_self_attn_q_proj_weight5, model_decoder_layers_0_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv1 = R.call_tir(cls.NT_matmul, (layer_norm356, model_decoder_layers_0_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv42 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm356, model_decoder_layers_0_self_attn_v_proj_weight5, model_decoder_layers_0_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv43 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv41, lv1, lv42), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv265 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), lv43), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv44 = R.call_tir(cls.fused_reshape23_reshape24, (lv265,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv45 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv44, model_decoder_layers_0_self_attn_out_proj_weight5, model_decoder_layers_0_self_attn_out_proj_bias5, lv40), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm357 = R.call_tir(cls.layer_norm3, (lv45, model_decoder_layers_0_encoder_attn_layer_norm_weight5, model_decoder_layers_0_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv46 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm357, model_decoder_layers_0_encoder_attn_q_proj_weight5, model_decoder_layers_0_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv47 = R.call_tir(cls.fused_reshape21_reshape25, (lv46,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv266 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), lv47), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv48 = R.call_tir(cls.fused_reshape23_reshape24, (lv266,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv49 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv48, model_decoder_layers_0_encoder_attn_out_proj_weight5, model_decoder_layers_0_encoder_attn_out_proj_bias5, lv45), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm358 = R.call_tir(cls.layer_norm3, (lv49, model_decoder_layers_0_final_layer_norm_weight5, model_decoder_layers_0_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv50 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm358, model_decoder_layers_0_fc1_weight5, model_decoder_layers_0_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv51 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv50, model_decoder_layers_0_fc2_weight5, model_decoder_layers_0_fc2_bias5, lv49), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm359 = R.call_tir(cls.layer_norm3, (lv51, model_decoder_layers_1_self_attn_layer_norm_weight5, model_decoder_layers_1_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv52 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm359, model_decoder_layers_1_self_attn_q_proj_weight5, model_decoder_layers_1_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv9 = R.call_tir(cls.NT_matmul, (layer_norm359, model_decoder_layers_1_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv53 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm359, model_decoder_layers_1_self_attn_v_proj_weight5, model_decoder_layers_1_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv54 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv52, lv9, lv53), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv267 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), lv54), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv55 = R.call_tir(cls.fused_reshape23_reshape24, (lv267,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv56 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv55, model_decoder_layers_1_self_attn_out_proj_weight5, model_decoder_layers_1_self_attn_out_proj_bias5, lv51), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm360 = R.call_tir(cls.layer_norm3, (lv56, model_decoder_layers_1_encoder_attn_layer_norm_weight5, model_decoder_layers_1_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv57 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm360, model_decoder_layers_1_encoder_attn_q_proj_weight5, model_decoder_layers_1_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv58 = R.call_tir(cls.fused_reshape21_reshape25, (lv57,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv268 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), lv58), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv59 = R.call_tir(cls.fused_reshape23_reshape24, (lv268,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv60 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv59, model_decoder_layers_1_encoder_attn_out_proj_weight5, model_decoder_layers_1_encoder_attn_out_proj_bias5, lv56), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm361 = R.call_tir(cls.layer_norm3, (lv60, model_decoder_layers_1_final_layer_norm_weight5, model_decoder_layers_1_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv61 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm361, model_decoder_layers_1_fc1_weight5, model_decoder_layers_1_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv62 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv61, model_decoder_layers_1_fc2_weight5, model_decoder_layers_1_fc2_bias5, lv60), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm362 = R.call_tir(cls.layer_norm3, (lv62, model_decoder_layers_2_self_attn_layer_norm_weight5, model_decoder_layers_2_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv63 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm362, model_decoder_layers_2_self_attn_q_proj_weight5, model_decoder_layers_2_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv17 = R.call_tir(cls.NT_matmul, (layer_norm362, model_decoder_layers_2_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv64 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm362, model_decoder_layers_2_self_attn_v_proj_weight5, model_decoder_layers_2_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv65 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv63, lv17, lv64), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv269 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), lv65), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv66 = R.call_tir(cls.fused_reshape23_reshape24, (lv269,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv67 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv66, model_decoder_layers_2_self_attn_out_proj_weight5, model_decoder_layers_2_self_attn_out_proj_bias5, lv62), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm363 = R.call_tir(cls.layer_norm3, (lv67, model_decoder_layers_2_encoder_attn_layer_norm_weight5, model_decoder_layers_2_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv68 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm363, model_decoder_layers_2_encoder_attn_q_proj_weight5, model_decoder_layers_2_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv69 = R.call_tir(cls.fused_reshape21_reshape25, (lv68,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv270 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), lv69), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv70 = R.call_tir(cls.fused_reshape23_reshape24, (lv270,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv71 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv70, model_decoder_layers_2_encoder_attn_out_proj_weight5, model_decoder_layers_2_encoder_attn_out_proj_bias5, lv67), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm364 = R.call_tir(cls.layer_norm3, (lv71, model_decoder_layers_2_final_layer_norm_weight5, model_decoder_layers_2_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv72 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm364, model_decoder_layers_2_fc1_weight5, model_decoder_layers_2_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv73 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv72, model_decoder_layers_2_fc2_weight5, model_decoder_layers_2_fc2_bias5, lv71), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm365 = R.call_tir(cls.layer_norm3, (lv73, model_decoder_layers_3_self_attn_layer_norm_weight5, model_decoder_layers_3_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv74 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm365, model_decoder_layers_3_self_attn_q_proj_weight5, model_decoder_layers_3_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv25 = R.call_tir(cls.NT_matmul, (layer_norm365, model_decoder_layers_3_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv75 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm365, model_decoder_layers_3_self_attn_v_proj_weight5, model_decoder_layers_3_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv76 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv74, lv25, lv75), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv271 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), lv76), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv77 = R.call_tir(cls.fused_reshape23_reshape24, (lv271,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv78 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv77, model_decoder_layers_3_self_attn_out_proj_weight5, model_decoder_layers_3_self_attn_out_proj_bias5, lv73), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm366 = R.call_tir(cls.layer_norm3, (lv78, model_decoder_layers_3_encoder_attn_layer_norm_weight5, model_decoder_layers_3_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv79 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm366, model_decoder_layers_3_encoder_attn_q_proj_weight5, model_decoder_layers_3_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv80 = R.call_tir(cls.fused_reshape21_reshape25, (lv79,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv272 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), lv80), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv81 = R.call_tir(cls.fused_reshape23_reshape24, (lv272,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv82 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv81, model_decoder_layers_3_encoder_attn_out_proj_weight5, model_decoder_layers_3_encoder_attn_out_proj_bias5, lv78), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm367 = R.call_tir(cls.layer_norm3, (lv82, model_decoder_layers_3_final_layer_norm_weight5, model_decoder_layers_3_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv83 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm367, model_decoder_layers_3_fc1_weight5, model_decoder_layers_3_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv84 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv83, model_decoder_layers_3_fc2_weight5, model_decoder_layers_3_fc2_bias5, lv82), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm368 = R.call_tir(cls.layer_norm3, (lv84, model_decoder_layers_4_self_attn_layer_norm_weight5, model_decoder_layers_4_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv85 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm368, model_decoder_layers_4_self_attn_q_proj_weight5, model_decoder_layers_4_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv33 = R.call_tir(cls.NT_matmul, (layer_norm368, model_decoder_layers_4_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv86 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm368, model_decoder_layers_4_self_attn_v_proj_weight5, model_decoder_layers_4_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv87 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv85, lv33, lv86), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv273 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), lv87), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv88 = R.call_tir(cls.fused_reshape23_reshape24, (lv273,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv89 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv88, model_decoder_layers_4_self_attn_out_proj_weight5, model_decoder_layers_4_self_attn_out_proj_bias5, lv84), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm369 = R.call_tir(cls.layer_norm3, (lv89, model_decoder_layers_4_encoder_attn_layer_norm_weight5, model_decoder_layers_4_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv90 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm369, model_decoder_layers_4_encoder_attn_q_proj_weight5, model_decoder_layers_4_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv91 = R.call_tir(cls.fused_reshape21_reshape25, (lv90,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv274 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), lv91), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv92 = R.call_tir(cls.fused_reshape23_reshape24, (lv274,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv93 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv92, model_decoder_layers_4_encoder_attn_out_proj_weight5, model_decoder_layers_4_encoder_attn_out_proj_bias5, lv89), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm370 = R.call_tir(cls.layer_norm3, (lv93, model_decoder_layers_4_final_layer_norm_weight5, model_decoder_layers_4_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv94 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm370, model_decoder_layers_4_fc1_weight5, model_decoder_layers_4_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv95 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv94, model_decoder_layers_4_fc2_weight5, model_decoder_layers_4_fc2_bias5, lv93), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm371 = R.call_tir(cls.layer_norm3, (lv95, model_decoder_layers_5_self_attn_layer_norm_weight5, model_decoder_layers_5_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv96 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm371, model_decoder_layers_5_self_attn_q_proj_weight5, model_decoder_layers_5_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv41_1 = R.call_tir(cls.NT_matmul, (layer_norm371, model_decoder_layers_5_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv97 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm371, model_decoder_layers_5_self_attn_v_proj_weight5, model_decoder_layers_5_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv98 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv96, lv41_1, lv97), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv275 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), lv98), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv99 = R.call_tir(cls.fused_reshape23_reshape24, (lv275,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv100 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv99, model_decoder_layers_5_self_attn_out_proj_weight5, model_decoder_layers_5_self_attn_out_proj_bias5, lv95), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm372 = R.call_tir(cls.layer_norm3, (lv100, model_decoder_layers_5_encoder_attn_layer_norm_weight5, model_decoder_layers_5_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv101 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm372, model_decoder_layers_5_encoder_attn_q_proj_weight5, model_decoder_layers_5_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv102 = R.call_tir(cls.fused_reshape21_reshape25, (lv101,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv276 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), lv102), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv103 = R.call_tir(cls.fused_reshape23_reshape24, (lv276,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv104 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv103, model_decoder_layers_5_encoder_attn_out_proj_weight5, model_decoder_layers_5_encoder_attn_out_proj_bias5, lv100), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm373 = R.call_tir(cls.layer_norm3, (lv104, model_decoder_layers_5_final_layer_norm_weight5, model_decoder_layers_5_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv105 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm373, model_decoder_layers_5_fc1_weight5, model_decoder_layers_5_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv106 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv105, model_decoder_layers_5_fc2_weight5, model_decoder_layers_5_fc2_bias5, lv104), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm374 = R.call_tir(cls.layer_norm3, (lv106, model_decoder_layers_6_self_attn_layer_norm_weight5, model_decoder_layers_6_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv107 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm374, model_decoder_layers_6_self_attn_q_proj_weight5, model_decoder_layers_6_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv49_1 = R.call_tir(cls.NT_matmul, (layer_norm374, model_decoder_layers_6_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv108 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm374, model_decoder_layers_6_self_attn_v_proj_weight5, model_decoder_layers_6_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv109 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv107, lv49_1, lv108), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv277 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), lv109), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv110 = R.call_tir(cls.fused_reshape23_reshape24, (lv277,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv111 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv110, model_decoder_layers_6_self_attn_out_proj_weight5, model_decoder_layers_6_self_attn_out_proj_bias5, lv106), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm375 = R.call_tir(cls.layer_norm3, (lv111, model_decoder_layers_6_encoder_attn_layer_norm_weight5, model_decoder_layers_6_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv112 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm375, model_decoder_layers_6_encoder_attn_q_proj_weight5, model_decoder_layers_6_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv113 = R.call_tir(cls.fused_reshape21_reshape25, (lv112,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv278 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), lv113), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv114 = R.call_tir(cls.fused_reshape23_reshape24, (lv278,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv115 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv114, model_decoder_layers_6_encoder_attn_out_proj_weight5, model_decoder_layers_6_encoder_attn_out_proj_bias5, lv111), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm376 = R.call_tir(cls.layer_norm3, (lv115, model_decoder_layers_6_final_layer_norm_weight5, model_decoder_layers_6_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv116 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm376, model_decoder_layers_6_fc1_weight5, model_decoder_layers_6_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv117 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv116, model_decoder_layers_6_fc2_weight5, model_decoder_layers_6_fc2_bias5, lv115), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm377 = R.call_tir(cls.layer_norm3, (lv117, model_decoder_layers_7_self_attn_layer_norm_weight5, model_decoder_layers_7_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv118 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm377, model_decoder_layers_7_self_attn_q_proj_weight5, model_decoder_layers_7_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv57_1 = R.call_tir(cls.NT_matmul, (layer_norm377, model_decoder_layers_7_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv119 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm377, model_decoder_layers_7_self_attn_v_proj_weight5, model_decoder_layers_7_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv120 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv118, lv57_1, lv119), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv279 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), lv120), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv121 = R.call_tir(cls.fused_reshape23_reshape24, (lv279,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv122 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv121, model_decoder_layers_7_self_attn_out_proj_weight5, model_decoder_layers_7_self_attn_out_proj_bias5, lv117), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm378 = R.call_tir(cls.layer_norm3, (lv122, model_decoder_layers_7_encoder_attn_layer_norm_weight5, model_decoder_layers_7_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv123 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm378, model_decoder_layers_7_encoder_attn_q_proj_weight5, model_decoder_layers_7_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv124 = R.call_tir(cls.fused_reshape21_reshape25, (lv123,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv280 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), lv124), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv125 = R.call_tir(cls.fused_reshape23_reshape24, (lv280,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv126 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv125, model_decoder_layers_7_encoder_attn_out_proj_weight5, model_decoder_layers_7_encoder_attn_out_proj_bias5, lv122), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm379 = R.call_tir(cls.layer_norm3, (lv126, model_decoder_layers_7_final_layer_norm_weight5, model_decoder_layers_7_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv127 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm379, model_decoder_layers_7_fc1_weight5, model_decoder_layers_7_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv128 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv127, model_decoder_layers_7_fc2_weight5, model_decoder_layers_7_fc2_bias5, lv126), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm380 = R.call_tir(cls.layer_norm3, (lv128, model_decoder_layers_8_self_attn_layer_norm_weight5, model_decoder_layers_8_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv129 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm380, model_decoder_layers_8_self_attn_q_proj_weight5, model_decoder_layers_8_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv65_1 = R.call_tir(cls.NT_matmul, (layer_norm380, model_decoder_layers_8_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv130 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm380, model_decoder_layers_8_self_attn_v_proj_weight5, model_decoder_layers_8_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv131 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv129, lv65_1, lv130), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv281 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), lv131), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv132 = R.call_tir(cls.fused_reshape23_reshape24, (lv281,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv133 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv132, model_decoder_layers_8_self_attn_out_proj_weight5, model_decoder_layers_8_self_attn_out_proj_bias5, lv128), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm381 = R.call_tir(cls.layer_norm3, (lv133, model_decoder_layers_8_encoder_attn_layer_norm_weight5, model_decoder_layers_8_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv134 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm381, model_decoder_layers_8_encoder_attn_q_proj_weight5, model_decoder_layers_8_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv135 = R.call_tir(cls.fused_reshape21_reshape25, (lv134,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv282 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), lv135), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv136 = R.call_tir(cls.fused_reshape23_reshape24, (lv282,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv137 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv136, model_decoder_layers_8_encoder_attn_out_proj_weight5, model_decoder_layers_8_encoder_attn_out_proj_bias5, lv133), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm382 = R.call_tir(cls.layer_norm3, (lv137, model_decoder_layers_8_final_layer_norm_weight5, model_decoder_layers_8_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv138 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm382, model_decoder_layers_8_fc1_weight5, model_decoder_layers_8_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv139 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv138, model_decoder_layers_8_fc2_weight5, model_decoder_layers_8_fc2_bias5, lv137), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm383 = R.call_tir(cls.layer_norm3, (lv139, model_decoder_layers_9_self_attn_layer_norm_weight5, model_decoder_layers_9_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv140 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm383, model_decoder_layers_9_self_attn_q_proj_weight5, model_decoder_layers_9_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv73_1 = R.call_tir(cls.NT_matmul, (layer_norm383, model_decoder_layers_9_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv141 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm383, model_decoder_layers_9_self_attn_v_proj_weight5, model_decoder_layers_9_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv142 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv140, lv73_1, lv141), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv283 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), lv142), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv143 = R.call_tir(cls.fused_reshape23_reshape24, (lv283,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv144 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv143, model_decoder_layers_9_self_attn_out_proj_weight5, model_decoder_layers_9_self_attn_out_proj_bias5, lv139), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm384 = R.call_tir(cls.layer_norm3, (lv144, model_decoder_layers_9_encoder_attn_layer_norm_weight5, model_decoder_layers_9_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv145 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm384, model_decoder_layers_9_encoder_attn_q_proj_weight5, model_decoder_layers_9_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv146 = R.call_tir(cls.fused_reshape21_reshape25, (lv145,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv284 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), lv146), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv147 = R.call_tir(cls.fused_reshape23_reshape24, (lv284,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv148 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv147, model_decoder_layers_9_encoder_attn_out_proj_weight5, model_decoder_layers_9_encoder_attn_out_proj_bias5, lv144), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm385 = R.call_tir(cls.layer_norm3, (lv148, model_decoder_layers_9_final_layer_norm_weight5, model_decoder_layers_9_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv149 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm385, model_decoder_layers_9_fc1_weight5, model_decoder_layers_9_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv150 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv149, model_decoder_layers_9_fc2_weight5, model_decoder_layers_9_fc2_bias5, lv148), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm386 = R.call_tir(cls.layer_norm3, (lv150, model_decoder_layers_10_self_attn_layer_norm_weight5, model_decoder_layers_10_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv151 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm386, model_decoder_layers_10_self_attn_q_proj_weight5, model_decoder_layers_10_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv81_1 = R.call_tir(cls.NT_matmul, (layer_norm386, model_decoder_layers_10_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv152 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm386, model_decoder_layers_10_self_attn_v_proj_weight5, model_decoder_layers_10_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv153 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv151, lv81_1, lv152), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv285 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), lv153), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv154 = R.call_tir(cls.fused_reshape23_reshape24, (lv285,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv155 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv154, model_decoder_layers_10_self_attn_out_proj_weight5, model_decoder_layers_10_self_attn_out_proj_bias5, lv150), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm387 = R.call_tir(cls.layer_norm3, (lv155, model_decoder_layers_10_encoder_attn_layer_norm_weight5, model_decoder_layers_10_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv156 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm387, model_decoder_layers_10_encoder_attn_q_proj_weight5, model_decoder_layers_10_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv157 = R.call_tir(cls.fused_reshape21_reshape25, (lv156,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv286 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), lv157), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv158 = R.call_tir(cls.fused_reshape23_reshape24, (lv286,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv159 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv158, model_decoder_layers_10_encoder_attn_out_proj_weight5, model_decoder_layers_10_encoder_attn_out_proj_bias5, lv155), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm388 = R.call_tir(cls.layer_norm3, (lv159, model_decoder_layers_10_final_layer_norm_weight5, model_decoder_layers_10_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv160 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm388, model_decoder_layers_10_fc1_weight5, model_decoder_layers_10_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv161 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv160, model_decoder_layers_10_fc2_weight5, model_decoder_layers_10_fc2_bias5, lv159), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm389 = R.call_tir(cls.layer_norm3, (lv161, model_decoder_layers_11_self_attn_layer_norm_weight5, model_decoder_layers_11_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv162 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm389, model_decoder_layers_11_self_attn_q_proj_weight5, model_decoder_layers_11_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv89_1 = R.call_tir(cls.NT_matmul, (layer_norm389, model_decoder_layers_11_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv163 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm389, model_decoder_layers_11_self_attn_v_proj_weight5, model_decoder_layers_11_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv164 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv162, lv89_1, lv163), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv287 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), lv164), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv165 = R.call_tir(cls.fused_reshape23_reshape24, (lv287,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv166 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv165, model_decoder_layers_11_self_attn_out_proj_weight5, model_decoder_layers_11_self_attn_out_proj_bias5, lv161), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm390 = R.call_tir(cls.layer_norm3, (lv166, model_decoder_layers_11_encoder_attn_layer_norm_weight5, model_decoder_layers_11_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv167 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm390, model_decoder_layers_11_encoder_attn_q_proj_weight5, model_decoder_layers_11_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv168 = R.call_tir(cls.fused_reshape21_reshape25, (lv167,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv288 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), lv168), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv169 = R.call_tir(cls.fused_reshape23_reshape24, (lv288,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv170 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv169, model_decoder_layers_11_encoder_attn_out_proj_weight5, model_decoder_layers_11_encoder_attn_out_proj_bias5, lv166), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm391 = R.call_tir(cls.layer_norm3, (lv170, model_decoder_layers_11_final_layer_norm_weight5, model_decoder_layers_11_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv171 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm391, model_decoder_layers_11_fc1_weight5, model_decoder_layers_11_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv172 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv171, model_decoder_layers_11_fc2_weight5, model_decoder_layers_11_fc2_bias5, lv170), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm392 = R.call_tir(cls.layer_norm3, (lv172, model_decoder_layers_12_self_attn_layer_norm_weight5, model_decoder_layers_12_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv173 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm392, model_decoder_layers_12_self_attn_q_proj_weight5, model_decoder_layers_12_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv97_1 = R.call_tir(cls.NT_matmul, (layer_norm392, model_decoder_layers_12_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv174 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm392, model_decoder_layers_12_self_attn_v_proj_weight5, model_decoder_layers_12_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv175 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv173, lv97_1, lv174), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv289 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), lv175), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv176 = R.call_tir(cls.fused_reshape23_reshape24, (lv289,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv177 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv176, model_decoder_layers_12_self_attn_out_proj_weight5, model_decoder_layers_12_self_attn_out_proj_bias5, lv172), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm393 = R.call_tir(cls.layer_norm3, (lv177, model_decoder_layers_12_encoder_attn_layer_norm_weight5, model_decoder_layers_12_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv178 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm393, model_decoder_layers_12_encoder_attn_q_proj_weight5, model_decoder_layers_12_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv179 = R.call_tir(cls.fused_reshape21_reshape25, (lv178,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv290 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), lv179), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv180 = R.call_tir(cls.fused_reshape23_reshape24, (lv290,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv181 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv180, model_decoder_layers_12_encoder_attn_out_proj_weight5, model_decoder_layers_12_encoder_attn_out_proj_bias5, lv177), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm394 = R.call_tir(cls.layer_norm3, (lv181, model_decoder_layers_12_final_layer_norm_weight5, model_decoder_layers_12_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv182 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm394, model_decoder_layers_12_fc1_weight5, model_decoder_layers_12_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv183 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv182, model_decoder_layers_12_fc2_weight5, model_decoder_layers_12_fc2_bias5, lv181), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm395 = R.call_tir(cls.layer_norm3, (lv183, model_decoder_layers_13_self_attn_layer_norm_weight5, model_decoder_layers_13_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv184 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm395, model_decoder_layers_13_self_attn_q_proj_weight5, model_decoder_layers_13_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv105_1 = R.call_tir(cls.NT_matmul, (layer_norm395, model_decoder_layers_13_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv185 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm395, model_decoder_layers_13_self_attn_v_proj_weight5, model_decoder_layers_13_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv186 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv184, lv105_1, lv185), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv291 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), lv186), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv187 = R.call_tir(cls.fused_reshape23_reshape24, (lv291,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv188 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv187, model_decoder_layers_13_self_attn_out_proj_weight5, model_decoder_layers_13_self_attn_out_proj_bias5, lv183), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm396 = R.call_tir(cls.layer_norm3, (lv188, model_decoder_layers_13_encoder_attn_layer_norm_weight5, model_decoder_layers_13_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv189 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm396, model_decoder_layers_13_encoder_attn_q_proj_weight5, model_decoder_layers_13_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv190 = R.call_tir(cls.fused_reshape21_reshape25, (lv189,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv292 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), lv190), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv191 = R.call_tir(cls.fused_reshape23_reshape24, (lv292,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv192 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv191, model_decoder_layers_13_encoder_attn_out_proj_weight5, model_decoder_layers_13_encoder_attn_out_proj_bias5, lv188), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm397 = R.call_tir(cls.layer_norm3, (lv192, model_decoder_layers_13_final_layer_norm_weight5, model_decoder_layers_13_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv193 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm397, model_decoder_layers_13_fc1_weight5, model_decoder_layers_13_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv194 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv193, model_decoder_layers_13_fc2_weight5, model_decoder_layers_13_fc2_bias5, lv192), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm398 = R.call_tir(cls.layer_norm3, (lv194, model_decoder_layers_14_self_attn_layer_norm_weight5, model_decoder_layers_14_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv195 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm398, model_decoder_layers_14_self_attn_q_proj_weight5, model_decoder_layers_14_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv113_1 = R.call_tir(cls.NT_matmul, (layer_norm398, model_decoder_layers_14_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv196 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm398, model_decoder_layers_14_self_attn_v_proj_weight5, model_decoder_layers_14_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv197 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv195, lv113_1, lv196), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv293 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), lv197), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv198 = R.call_tir(cls.fused_reshape23_reshape24, (lv293,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv199 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv198, model_decoder_layers_14_self_attn_out_proj_weight5, model_decoder_layers_14_self_attn_out_proj_bias5, lv194), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm399 = R.call_tir(cls.layer_norm3, (lv199, model_decoder_layers_14_encoder_attn_layer_norm_weight5, model_decoder_layers_14_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv200 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm399, model_decoder_layers_14_encoder_attn_q_proj_weight5, model_decoder_layers_14_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv201 = R.call_tir(cls.fused_reshape21_reshape25, (lv200,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv294 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), lv201), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv202 = R.call_tir(cls.fused_reshape23_reshape24, (lv294,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv203 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv202, model_decoder_layers_14_encoder_attn_out_proj_weight5, model_decoder_layers_14_encoder_attn_out_proj_bias5, lv199), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm400 = R.call_tir(cls.layer_norm3, (lv203, model_decoder_layers_14_final_layer_norm_weight5, model_decoder_layers_14_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv204 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm400, model_decoder_layers_14_fc1_weight5, model_decoder_layers_14_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv205 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv204, model_decoder_layers_14_fc2_weight5, model_decoder_layers_14_fc2_bias5, lv203), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm401 = R.call_tir(cls.layer_norm3, (lv205, model_decoder_layers_15_self_attn_layer_norm_weight5, model_decoder_layers_15_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv206 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm401, model_decoder_layers_15_self_attn_q_proj_weight5, model_decoder_layers_15_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv121_1 = R.call_tir(cls.NT_matmul, (layer_norm401, model_decoder_layers_15_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv207 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm401, model_decoder_layers_15_self_attn_v_proj_weight5, model_decoder_layers_15_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv208 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv206, lv121_1, lv207), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv295 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), lv208), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv209 = R.call_tir(cls.fused_reshape23_reshape24, (lv295,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv210 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv209, model_decoder_layers_15_self_attn_out_proj_weight5, model_decoder_layers_15_self_attn_out_proj_bias5, lv205), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm402 = R.call_tir(cls.layer_norm3, (lv210, model_decoder_layers_15_encoder_attn_layer_norm_weight5, model_decoder_layers_15_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv211 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm402, model_decoder_layers_15_encoder_attn_q_proj_weight5, model_decoder_layers_15_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv212 = R.call_tir(cls.fused_reshape21_reshape25, (lv211,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv296 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), lv212), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv213 = R.call_tir(cls.fused_reshape23_reshape24, (lv296,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv214 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv213, model_decoder_layers_15_encoder_attn_out_proj_weight5, model_decoder_layers_15_encoder_attn_out_proj_bias5, lv210), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm403 = R.call_tir(cls.layer_norm3, (lv214, model_decoder_layers_15_final_layer_norm_weight5, model_decoder_layers_15_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv215 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm403, model_decoder_layers_15_fc1_weight5, model_decoder_layers_15_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv216 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv215, model_decoder_layers_15_fc2_weight5, model_decoder_layers_15_fc2_bias5, lv214), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm404 = R.call_tir(cls.layer_norm3, (lv216, model_decoder_layers_16_self_attn_layer_norm_weight5, model_decoder_layers_16_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv217 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm404, model_decoder_layers_16_self_attn_q_proj_weight5, model_decoder_layers_16_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv129_1 = R.call_tir(cls.NT_matmul, (layer_norm404, model_decoder_layers_16_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv218 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm404, model_decoder_layers_16_self_attn_v_proj_weight5, model_decoder_layers_16_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv219 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv217, lv129_1, lv218), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv297 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), lv219), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv220 = R.call_tir(cls.fused_reshape23_reshape24, (lv297,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv221 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv220, model_decoder_layers_16_self_attn_out_proj_weight5, model_decoder_layers_16_self_attn_out_proj_bias5, lv216), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm405 = R.call_tir(cls.layer_norm3, (lv221, model_decoder_layers_16_encoder_attn_layer_norm_weight5, model_decoder_layers_16_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv222 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm405, model_decoder_layers_16_encoder_attn_q_proj_weight5, model_decoder_layers_16_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv223 = R.call_tir(cls.fused_reshape21_reshape25, (lv222,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv298 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), lv223), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv224 = R.call_tir(cls.fused_reshape23_reshape24, (lv298,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv225 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv224, model_decoder_layers_16_encoder_attn_out_proj_weight5, model_decoder_layers_16_encoder_attn_out_proj_bias5, lv221), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm406 = R.call_tir(cls.layer_norm3, (lv225, model_decoder_layers_16_final_layer_norm_weight5, model_decoder_layers_16_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv226 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm406, model_decoder_layers_16_fc1_weight5, model_decoder_layers_16_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv227 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv226, model_decoder_layers_16_fc2_weight5, model_decoder_layers_16_fc2_bias5, lv225), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm407 = R.call_tir(cls.layer_norm3, (lv227, model_decoder_layers_17_self_attn_layer_norm_weight5, model_decoder_layers_17_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv228 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm407, model_decoder_layers_17_self_attn_q_proj_weight5, model_decoder_layers_17_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv137_1 = R.call_tir(cls.NT_matmul, (layer_norm407, model_decoder_layers_17_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv229 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm407, model_decoder_layers_17_self_attn_v_proj_weight5, model_decoder_layers_17_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv230 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv228, lv137_1, lv229), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv299 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), lv230), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv231 = R.call_tir(cls.fused_reshape23_reshape24, (lv299,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv232 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv231, model_decoder_layers_17_self_attn_out_proj_weight5, model_decoder_layers_17_self_attn_out_proj_bias5, lv227), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm408 = R.call_tir(cls.layer_norm3, (lv232, model_decoder_layers_17_encoder_attn_layer_norm_weight5, model_decoder_layers_17_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv233 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm408, model_decoder_layers_17_encoder_attn_q_proj_weight5, model_decoder_layers_17_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv234 = R.call_tir(cls.fused_reshape21_reshape25, (lv233,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv300 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), lv234), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv235 = R.call_tir(cls.fused_reshape23_reshape24, (lv300,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv236 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv235, model_decoder_layers_17_encoder_attn_out_proj_weight5, model_decoder_layers_17_encoder_attn_out_proj_bias5, lv232), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm409 = R.call_tir(cls.layer_norm3, (lv236, model_decoder_layers_17_final_layer_norm_weight5, model_decoder_layers_17_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv237 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm409, model_decoder_layers_17_fc1_weight5, model_decoder_layers_17_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv238 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv237, model_decoder_layers_17_fc2_weight5, model_decoder_layers_17_fc2_bias5, lv236), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm410 = R.call_tir(cls.layer_norm3, (lv238, model_decoder_layers_18_self_attn_layer_norm_weight5, model_decoder_layers_18_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv239 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm410, model_decoder_layers_18_self_attn_q_proj_weight5, model_decoder_layers_18_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv145_1 = R.call_tir(cls.NT_matmul, (layer_norm410, model_decoder_layers_18_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv240 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm410, model_decoder_layers_18_self_attn_v_proj_weight5, model_decoder_layers_18_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv241 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv239, lv145_1, lv240), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv301 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), lv241), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv242 = R.call_tir(cls.fused_reshape23_reshape24, (lv301,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv243 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv242, model_decoder_layers_18_self_attn_out_proj_weight5, model_decoder_layers_18_self_attn_out_proj_bias5, lv238), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm411 = R.call_tir(cls.layer_norm3, (lv243, model_decoder_layers_18_encoder_attn_layer_norm_weight5, model_decoder_layers_18_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv244 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm411, model_decoder_layers_18_encoder_attn_q_proj_weight5, model_decoder_layers_18_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv245 = R.call_tir(cls.fused_reshape21_reshape25, (lv244,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv302 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), lv245), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv246 = R.call_tir(cls.fused_reshape23_reshape24, (lv302,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv247 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv246, model_decoder_layers_18_encoder_attn_out_proj_weight5, model_decoder_layers_18_encoder_attn_out_proj_bias5, lv243), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm412 = R.call_tir(cls.layer_norm3, (lv247, model_decoder_layers_18_final_layer_norm_weight5, model_decoder_layers_18_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv248 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm412, model_decoder_layers_18_fc1_weight5, model_decoder_layers_18_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv249 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv248, model_decoder_layers_18_fc2_weight5, model_decoder_layers_18_fc2_bias5, lv247), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm413 = R.call_tir(cls.layer_norm3, (lv249, model_decoder_layers_19_self_attn_layer_norm_weight5, model_decoder_layers_19_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv250 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm413, model_decoder_layers_19_self_attn_q_proj_weight5, model_decoder_layers_19_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv153_1 = R.call_tir(cls.NT_matmul, (layer_norm413, model_decoder_layers_19_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv251 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm413, model_decoder_layers_19_self_attn_v_proj_weight5, model_decoder_layers_19_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv252 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv250, lv153_1, lv251), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv303 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), lv252), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv253 = R.call_tir(cls.fused_reshape23_reshape24, (lv303,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv254 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv253, model_decoder_layers_19_self_attn_out_proj_weight5, model_decoder_layers_19_self_attn_out_proj_bias5, lv249), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm414 = R.call_tir(cls.layer_norm3, (lv254, model_decoder_layers_19_encoder_attn_layer_norm_weight5, model_decoder_layers_19_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv255 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm414, model_decoder_layers_19_encoder_attn_q_proj_weight5, model_decoder_layers_19_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv256 = R.call_tir(cls.fused_reshape21_reshape25, (lv255,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv304 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), lv256), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv257 = R.call_tir(cls.fused_reshape23_reshape24, (lv304,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv258 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv257, model_decoder_layers_19_encoder_attn_out_proj_weight5, model_decoder_layers_19_encoder_attn_out_proj_bias5, lv254), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm415 = R.call_tir(cls.layer_norm3, (lv258, model_decoder_layers_19_final_layer_norm_weight5, model_decoder_layers_19_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv259 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm415, model_decoder_layers_19_fc1_weight5, model_decoder_layers_19_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv260 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv259, model_decoder_layers_19_fc2_weight5, model_decoder_layers_19_fc2_bias5, lv258), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm416 = R.call_tir(cls.layer_norm3, (lv260, model_decoder_layers_20_self_attn_layer_norm_weight5, model_decoder_layers_20_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv261 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm416, model_decoder_layers_20_self_attn_q_proj_weight5, model_decoder_layers_20_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv161_1 = R.call_tir(cls.NT_matmul, (layer_norm416, model_decoder_layers_20_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv262 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm416, model_decoder_layers_20_self_attn_v_proj_weight5, model_decoder_layers_20_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv263 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv261, lv161_1, lv262), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv305 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), lv263), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv264_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv305,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv265_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv264_1, model_decoder_layers_20_self_attn_out_proj_weight5, model_decoder_layers_20_self_attn_out_proj_bias5, lv260), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm417 = R.call_tir(cls.layer_norm3, (lv265_1, model_decoder_layers_20_encoder_attn_layer_norm_weight5, model_decoder_layers_20_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv266_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm417, model_decoder_layers_20_encoder_attn_q_proj_weight5, model_decoder_layers_20_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv267_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv266_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv306 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), lv267_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv268_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv306,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv269_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv268_1, model_decoder_layers_20_encoder_attn_out_proj_weight5, model_decoder_layers_20_encoder_attn_out_proj_bias5, lv265_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm418 = R.call_tir(cls.layer_norm3, (lv269_1, model_decoder_layers_20_final_layer_norm_weight5, model_decoder_layers_20_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv270_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm418, model_decoder_layers_20_fc1_weight5, model_decoder_layers_20_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv271_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv270_1, model_decoder_layers_20_fc2_weight5, model_decoder_layers_20_fc2_bias5, lv269_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm419 = R.call_tir(cls.layer_norm3, (lv271_1, model_decoder_layers_21_self_attn_layer_norm_weight5, model_decoder_layers_21_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv272_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm419, model_decoder_layers_21_self_attn_q_proj_weight5, model_decoder_layers_21_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv169_1 = R.call_tir(cls.NT_matmul, (layer_norm419, model_decoder_layers_21_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv273_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm419, model_decoder_layers_21_self_attn_v_proj_weight5, model_decoder_layers_21_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv274_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv272_1, lv169_1, lv273_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv307 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), lv274_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv275_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv307,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv276_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv275_1, model_decoder_layers_21_self_attn_out_proj_weight5, model_decoder_layers_21_self_attn_out_proj_bias5, lv271_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm420 = R.call_tir(cls.layer_norm3, (lv276_1, model_decoder_layers_21_encoder_attn_layer_norm_weight5, model_decoder_layers_21_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv277_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm420, model_decoder_layers_21_encoder_attn_q_proj_weight5, model_decoder_layers_21_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv278_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv277_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv308 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), lv278_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv279_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv308,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv280_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv279_1, model_decoder_layers_21_encoder_attn_out_proj_weight5, model_decoder_layers_21_encoder_attn_out_proj_bias5, lv276_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm421 = R.call_tir(cls.layer_norm3, (lv280_1, model_decoder_layers_21_final_layer_norm_weight5, model_decoder_layers_21_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv281_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm421, model_decoder_layers_21_fc1_weight5, model_decoder_layers_21_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv282_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv281_1, model_decoder_layers_21_fc2_weight5, model_decoder_layers_21_fc2_bias5, lv280_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm422 = R.call_tir(cls.layer_norm3, (lv282_1, model_decoder_layers_22_self_attn_layer_norm_weight5, model_decoder_layers_22_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv283_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm422, model_decoder_layers_22_self_attn_q_proj_weight5, model_decoder_layers_22_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv177_1 = R.call_tir(cls.NT_matmul, (layer_norm422, model_decoder_layers_22_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv284_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm422, model_decoder_layers_22_self_attn_v_proj_weight5, model_decoder_layers_22_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv285_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv283_1, lv177_1, lv284_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv309 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), lv285_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv286_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv309,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv287_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv286_1, model_decoder_layers_22_self_attn_out_proj_weight5, model_decoder_layers_22_self_attn_out_proj_bias5, lv282_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm423 = R.call_tir(cls.layer_norm3, (lv287_1, model_decoder_layers_22_encoder_attn_layer_norm_weight5, model_decoder_layers_22_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv288_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm423, model_decoder_layers_22_encoder_attn_q_proj_weight5, model_decoder_layers_22_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv289_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv288_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv310 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), lv289_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv290_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv310,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv291_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv290_1, model_decoder_layers_22_encoder_attn_out_proj_weight5, model_decoder_layers_22_encoder_attn_out_proj_bias5, lv287_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm424 = R.call_tir(cls.layer_norm3, (lv291_1, model_decoder_layers_22_final_layer_norm_weight5, model_decoder_layers_22_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv292_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm424, model_decoder_layers_22_fc1_weight5, model_decoder_layers_22_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv293_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv292_1, model_decoder_layers_22_fc2_weight5, model_decoder_layers_22_fc2_bias5, lv291_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm425 = R.call_tir(cls.layer_norm3, (lv293_1, model_decoder_layers_23_self_attn_layer_norm_weight5, model_decoder_layers_23_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv294_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm425, model_decoder_layers_23_self_attn_q_proj_weight5, model_decoder_layers_23_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv185_1 = R.call_tir(cls.NT_matmul, (layer_norm425, model_decoder_layers_23_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv295_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm425, model_decoder_layers_23_self_attn_v_proj_weight5, model_decoder_layers_23_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv296_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv294_1, lv185_1, lv295_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv311 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), lv296_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv297_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv311,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv298_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv297_1, model_decoder_layers_23_self_attn_out_proj_weight5, model_decoder_layers_23_self_attn_out_proj_bias5, lv293_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm426 = R.call_tir(cls.layer_norm3, (lv298_1, model_decoder_layers_23_encoder_attn_layer_norm_weight5, model_decoder_layers_23_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv299_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm426, model_decoder_layers_23_encoder_attn_q_proj_weight5, model_decoder_layers_23_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv300_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv299_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv312 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), lv300_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv301_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv312,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv302_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv301_1, model_decoder_layers_23_encoder_attn_out_proj_weight5, model_decoder_layers_23_encoder_attn_out_proj_bias5, lv298_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm427 = R.call_tir(cls.layer_norm3, (lv302_1, model_decoder_layers_23_final_layer_norm_weight5, model_decoder_layers_23_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv303_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm427, model_decoder_layers_23_fc1_weight5, model_decoder_layers_23_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv304_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv303_1, model_decoder_layers_23_fc2_weight5, model_decoder_layers_23_fc2_bias5, lv302_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm428 = R.call_tir(cls.layer_norm3, (lv304_1, model_decoder_layers_24_self_attn_layer_norm_weight5, model_decoder_layers_24_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv305_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm428, model_decoder_layers_24_self_attn_q_proj_weight5, model_decoder_layers_24_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv193_1 = R.call_tir(cls.NT_matmul, (layer_norm428, model_decoder_layers_24_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv306_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm428, model_decoder_layers_24_self_attn_v_proj_weight5, model_decoder_layers_24_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv307_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv305_1, lv193_1, lv306_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv313 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), lv307_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv308_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv313,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv309_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv308_1, model_decoder_layers_24_self_attn_out_proj_weight5, model_decoder_layers_24_self_attn_out_proj_bias5, lv304_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm429 = R.call_tir(cls.layer_norm3, (lv309_1, model_decoder_layers_24_encoder_attn_layer_norm_weight5, model_decoder_layers_24_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv310_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm429, model_decoder_layers_24_encoder_attn_q_proj_weight5, model_decoder_layers_24_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv311_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv310_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv314 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), lv311_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv312_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv314,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv313_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv312_1, model_decoder_layers_24_encoder_attn_out_proj_weight5, model_decoder_layers_24_encoder_attn_out_proj_bias5, lv309_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm430 = R.call_tir(cls.layer_norm3, (lv313_1, model_decoder_layers_24_final_layer_norm_weight5, model_decoder_layers_24_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv314_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm430, model_decoder_layers_24_fc1_weight5, model_decoder_layers_24_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv315 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv314_1, model_decoder_layers_24_fc2_weight5, model_decoder_layers_24_fc2_bias5, lv313_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm431 = R.call_tir(cls.layer_norm3, (lv315, model_decoder_layers_25_self_attn_layer_norm_weight5, model_decoder_layers_25_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv316 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm431, model_decoder_layers_25_self_attn_q_proj_weight5, model_decoder_layers_25_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv201_1 = R.call_tir(cls.NT_matmul, (layer_norm431, model_decoder_layers_25_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv317 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm431, model_decoder_layers_25_self_attn_v_proj_weight5, model_decoder_layers_25_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv318 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv316, lv201_1, lv317), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv315_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), lv318), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv319 = R.call_tir(cls.fused_reshape23_reshape24, (lv315_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv320 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv319, model_decoder_layers_25_self_attn_out_proj_weight5, model_decoder_layers_25_self_attn_out_proj_bias5, lv315), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm432 = R.call_tir(cls.layer_norm3, (lv320, model_decoder_layers_25_encoder_attn_layer_norm_weight5, model_decoder_layers_25_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv321 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm432, model_decoder_layers_25_encoder_attn_q_proj_weight5, model_decoder_layers_25_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv322 = R.call_tir(cls.fused_reshape21_reshape25, (lv321,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv316_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), lv322), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv323 = R.call_tir(cls.fused_reshape23_reshape24, (lv316_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv324 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv323, model_decoder_layers_25_encoder_attn_out_proj_weight5, model_decoder_layers_25_encoder_attn_out_proj_bias5, lv320), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm433 = R.call_tir(cls.layer_norm3, (lv324, model_decoder_layers_25_final_layer_norm_weight5, model_decoder_layers_25_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv325 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm433, model_decoder_layers_25_fc1_weight5, model_decoder_layers_25_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv326 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv325, model_decoder_layers_25_fc2_weight5, model_decoder_layers_25_fc2_bias5, lv324), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm434 = R.call_tir(cls.layer_norm3, (lv326, model_decoder_layers_26_self_attn_layer_norm_weight5, model_decoder_layers_26_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv327 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm434, model_decoder_layers_26_self_attn_q_proj_weight5, model_decoder_layers_26_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv209_1 = R.call_tir(cls.NT_matmul, (layer_norm434, model_decoder_layers_26_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv328 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm434, model_decoder_layers_26_self_attn_v_proj_weight5, model_decoder_layers_26_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv329 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv327, lv209_1, lv328), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv317_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), lv329), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv330 = R.call_tir(cls.fused_reshape23_reshape24, (lv317_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv331 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv330, model_decoder_layers_26_self_attn_out_proj_weight5, model_decoder_layers_26_self_attn_out_proj_bias5, lv326), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm435 = R.call_tir(cls.layer_norm3, (lv331, model_decoder_layers_26_encoder_attn_layer_norm_weight5, model_decoder_layers_26_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv332 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm435, model_decoder_layers_26_encoder_attn_q_proj_weight5, model_decoder_layers_26_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv333 = R.call_tir(cls.fused_reshape21_reshape25, (lv332,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv318_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), lv333), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv334 = R.call_tir(cls.fused_reshape23_reshape24, (lv318_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv335 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv334, model_decoder_layers_26_encoder_attn_out_proj_weight5, model_decoder_layers_26_encoder_attn_out_proj_bias5, lv331), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm436 = R.call_tir(cls.layer_norm3, (lv335, model_decoder_layers_26_final_layer_norm_weight5, model_decoder_layers_26_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv336 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm436, model_decoder_layers_26_fc1_weight5, model_decoder_layers_26_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv337 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv336, model_decoder_layers_26_fc2_weight5, model_decoder_layers_26_fc2_bias5, lv335), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm437 = R.call_tir(cls.layer_norm3, (lv337, model_decoder_layers_27_self_attn_layer_norm_weight5, model_decoder_layers_27_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv338 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm437, model_decoder_layers_27_self_attn_q_proj_weight5, model_decoder_layers_27_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv217_1 = R.call_tir(cls.NT_matmul, (layer_norm437, model_decoder_layers_27_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv339 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm437, model_decoder_layers_27_self_attn_v_proj_weight5, model_decoder_layers_27_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv340 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv338, lv217_1, lv339), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv319_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), lv340), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv341 = R.call_tir(cls.fused_reshape23_reshape24, (lv319_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv342 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv341, model_decoder_layers_27_self_attn_out_proj_weight5, model_decoder_layers_27_self_attn_out_proj_bias5, lv337), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm438 = R.call_tir(cls.layer_norm3, (lv342, model_decoder_layers_27_encoder_attn_layer_norm_weight5, model_decoder_layers_27_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv343 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm438, model_decoder_layers_27_encoder_attn_q_proj_weight5, model_decoder_layers_27_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv344 = R.call_tir(cls.fused_reshape21_reshape25, (lv343,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv320_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), lv344), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv345 = R.call_tir(cls.fused_reshape23_reshape24, (lv320_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv346 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv345, model_decoder_layers_27_encoder_attn_out_proj_weight5, model_decoder_layers_27_encoder_attn_out_proj_bias5, lv342), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm439 = R.call_tir(cls.layer_norm3, (lv346, model_decoder_layers_27_final_layer_norm_weight5, model_decoder_layers_27_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv347 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm439, model_decoder_layers_27_fc1_weight5, model_decoder_layers_27_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv348 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv347, model_decoder_layers_27_fc2_weight5, model_decoder_layers_27_fc2_bias5, lv346), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm440 = R.call_tir(cls.layer_norm3, (lv348, model_decoder_layers_28_self_attn_layer_norm_weight5, model_decoder_layers_28_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv349 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm440, model_decoder_layers_28_self_attn_q_proj_weight5, model_decoder_layers_28_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv225_1 = R.call_tir(cls.NT_matmul, (layer_norm440, model_decoder_layers_28_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv350 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm440, model_decoder_layers_28_self_attn_v_proj_weight5, model_decoder_layers_28_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv351 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv349, lv225_1, lv350), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv321_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), lv351), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv352 = R.call_tir(cls.fused_reshape23_reshape24, (lv321_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv353 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv352, model_decoder_layers_28_self_attn_out_proj_weight5, model_decoder_layers_28_self_attn_out_proj_bias5, lv348), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm441 = R.call_tir(cls.layer_norm3, (lv353, model_decoder_layers_28_encoder_attn_layer_norm_weight5, model_decoder_layers_28_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv354 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm441, model_decoder_layers_28_encoder_attn_q_proj_weight5, model_decoder_layers_28_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv355 = R.call_tir(cls.fused_reshape21_reshape25, (lv354,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv322_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), lv355), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv356 = R.call_tir(cls.fused_reshape23_reshape24, (lv322_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv357 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv356, model_decoder_layers_28_encoder_attn_out_proj_weight5, model_decoder_layers_28_encoder_attn_out_proj_bias5, lv353), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm442 = R.call_tir(cls.layer_norm3, (lv357, model_decoder_layers_28_final_layer_norm_weight5, model_decoder_layers_28_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv358 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm442, model_decoder_layers_28_fc1_weight5, model_decoder_layers_28_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv359 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv358, model_decoder_layers_28_fc2_weight5, model_decoder_layers_28_fc2_bias5, lv357), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm443 = R.call_tir(cls.layer_norm3, (lv359, model_decoder_layers_29_self_attn_layer_norm_weight5, model_decoder_layers_29_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv360 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm443, model_decoder_layers_29_self_attn_q_proj_weight5, model_decoder_layers_29_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv233_1 = R.call_tir(cls.NT_matmul, (layer_norm443, model_decoder_layers_29_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv361 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm443, model_decoder_layers_29_self_attn_v_proj_weight5, model_decoder_layers_29_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv362 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv360, lv233_1, lv361), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv323_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), lv362), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv363 = R.call_tir(cls.fused_reshape23_reshape24, (lv323_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv364 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv363, model_decoder_layers_29_self_attn_out_proj_weight5, model_decoder_layers_29_self_attn_out_proj_bias5, lv359), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm444 = R.call_tir(cls.layer_norm3, (lv364, model_decoder_layers_29_encoder_attn_layer_norm_weight5, model_decoder_layers_29_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv365 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm444, model_decoder_layers_29_encoder_attn_q_proj_weight5, model_decoder_layers_29_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv366 = R.call_tir(cls.fused_reshape21_reshape25, (lv365,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv324_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), lv366), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv367 = R.call_tir(cls.fused_reshape23_reshape24, (lv324_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv368 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv367, model_decoder_layers_29_encoder_attn_out_proj_weight5, model_decoder_layers_29_encoder_attn_out_proj_bias5, lv364), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm445 = R.call_tir(cls.layer_norm3, (lv368, model_decoder_layers_29_final_layer_norm_weight5, model_decoder_layers_29_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv369 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm445, model_decoder_layers_29_fc1_weight5, model_decoder_layers_29_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv370 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv369, model_decoder_layers_29_fc2_weight5, model_decoder_layers_29_fc2_bias5, lv368), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm446 = R.call_tir(cls.layer_norm3, (lv370, model_decoder_layers_30_self_attn_layer_norm_weight5, model_decoder_layers_30_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv371 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm446, model_decoder_layers_30_self_attn_q_proj_weight5, model_decoder_layers_30_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv241_1 = R.call_tir(cls.NT_matmul, (layer_norm446, model_decoder_layers_30_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv372 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm446, model_decoder_layers_30_self_attn_v_proj_weight5, model_decoder_layers_30_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv373 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv371, lv241_1, lv372), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv325_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), lv373), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv374 = R.call_tir(cls.fused_reshape23_reshape24, (lv325_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv375 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv374, model_decoder_layers_30_self_attn_out_proj_weight5, model_decoder_layers_30_self_attn_out_proj_bias5, lv370), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm447 = R.call_tir(cls.layer_norm3, (lv375, model_decoder_layers_30_encoder_attn_layer_norm_weight5, model_decoder_layers_30_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv376 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm447, model_decoder_layers_30_encoder_attn_q_proj_weight5, model_decoder_layers_30_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv377 = R.call_tir(cls.fused_reshape21_reshape25, (lv376,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv326_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), lv377), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv378 = R.call_tir(cls.fused_reshape23_reshape24, (lv326_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv379 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv378, model_decoder_layers_30_encoder_attn_out_proj_weight5, model_decoder_layers_30_encoder_attn_out_proj_bias5, lv375), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm448 = R.call_tir(cls.layer_norm3, (lv379, model_decoder_layers_30_final_layer_norm_weight5, model_decoder_layers_30_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv380 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm448, model_decoder_layers_30_fc1_weight5, model_decoder_layers_30_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv381 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv380, model_decoder_layers_30_fc2_weight5, model_decoder_layers_30_fc2_bias5, lv379), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm449 = R.call_tir(cls.layer_norm3, (lv381, model_decoder_layers_31_self_attn_layer_norm_weight5, model_decoder_layers_31_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv382 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm449, model_decoder_layers_31_self_attn_q_proj_weight5, model_decoder_layers_31_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv249_1 = R.call_tir(cls.NT_matmul, (layer_norm449, model_decoder_layers_31_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv383 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm449, model_decoder_layers_31_self_attn_v_proj_weight5, model_decoder_layers_31_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv384 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv382, lv249_1, lv383), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) + lv327_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), lv384), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv385 = R.call_tir(cls.fused_reshape23_reshape24, (lv327_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv386 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv385, model_decoder_layers_31_self_attn_out_proj_weight5, model_decoder_layers_31_self_attn_out_proj_bias5, lv381), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm450 = R.call_tir(cls.layer_norm3, (lv386, model_decoder_layers_31_encoder_attn_layer_norm_weight5, model_decoder_layers_31_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv387 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm450, model_decoder_layers_31_encoder_attn_q_proj_weight5, model_decoder_layers_31_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv388 = R.call_tir(cls.fused_reshape21_reshape25, (lv387,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv328_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), lv388), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) + lv389 = R.call_tir(cls.fused_reshape23_reshape24, (lv328_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv390 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv389, model_decoder_layers_31_encoder_attn_out_proj_weight5, model_decoder_layers_31_encoder_attn_out_proj_bias5, lv386), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm451 = R.call_tir(cls.layer_norm3, (lv390, model_decoder_layers_31_final_layer_norm_weight5, model_decoder_layers_31_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + lv391 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm451, model_decoder_layers_31_fc1_weight5, model_decoder_layers_31_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) + lv392 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv391, model_decoder_layers_31_fc2_weight5, model_decoder_layers_31_fc2_bias5, lv390), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + layer_norm452 = R.call_tir(cls.layer_norm3, (lv392, model_decoder_layer_norm_weight5, model_decoder_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + gv5 = R.call_tir(cls.NT_matmul3, (layer_norm452, model_decoder_embed_tokens_weight5), out_sinfo=R.Tensor((1, 1, 51866), dtype="float32")) + R.output(gv5) + return gv5 + + @R.function + def multinomial_from_uniform(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), uniform_samples: R.Tensor(("num_samples",), dtype="float32"), sample_indices: R.Tensor(("num_samples",), dtype="int32")) -> R.Tensor(("num_samples",), dtype="int32"): + num_samples = T.int64() + batch_size = T.int64() + vocab_size = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) + cls = Module + with R.dataflow(): + uniform_samples_1: R.Tensor((num_samples, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", uniform_samples, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="float32"),)) + sample_indices_1: R.Tensor((num_samples, 1), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", sample_indices, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="int32"),)) + nn_multinomial_from_uniform = R.call_tir(cls.parallel_sampling_from_prob, (probs, uniform_samples_1, sample_indices_1), out_sinfo=R.Tensor((num_samples, 1), dtype="int32")) + gv: R.Tensor((num_samples,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", nn_multinomial_from_uniform, R.shape([num_samples]), sinfo_args=(R.Tensor((num_samples,), dtype="int32"),)) + R.output(gv) + return gv + + @R.function + def prefill(input_ids: R.Tensor((1, "seq_len"), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor((1, 1, 51866), dtype="float32"): + seq_len = T.int64() + R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + with R.dataflow(): + model_decoder_embed_tokens_weight4: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] + model_decoder_embed_positions_weight4: R.Tensor((448, 1280), dtype="float16") = packed_params[488] + model_decoder_layers_0_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] + model_decoder_layers_0_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] + model_decoder_layers_0_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[491] + model_decoder_layers_0_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] + model_decoder_layers_0_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[493] + model_decoder_layers_0_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] + model_decoder_layers_0_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[495] + model_decoder_layers_0_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[496] + model_decoder_layers_0_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[497] + model_decoder_layers_0_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] + model_decoder_layers_0_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[502] + model_decoder_layers_0_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] + model_decoder_layers_0_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[504] + model_decoder_layers_0_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[505] + model_decoder_layers_0_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[506] + model_decoder_layers_0_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] + model_decoder_layers_0_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[508] + model_decoder_layers_0_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] + model_decoder_layers_0_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[510] + model_decoder_layers_0_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[511] + model_decoder_layers_0_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[512] + model_decoder_layers_1_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] + model_decoder_layers_1_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] + model_decoder_layers_1_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[515] + model_decoder_layers_1_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] + model_decoder_layers_1_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[517] + model_decoder_layers_1_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] + model_decoder_layers_1_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[519] + model_decoder_layers_1_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[520] + model_decoder_layers_1_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[521] + model_decoder_layers_1_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] + model_decoder_layers_1_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[526] + model_decoder_layers_1_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] + model_decoder_layers_1_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[528] + model_decoder_layers_1_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[529] + model_decoder_layers_1_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[530] + model_decoder_layers_1_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] + model_decoder_layers_1_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[532] + model_decoder_layers_1_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] + model_decoder_layers_1_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[534] + model_decoder_layers_1_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[535] + model_decoder_layers_1_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[536] + model_decoder_layers_2_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] + model_decoder_layers_2_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] + model_decoder_layers_2_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[539] + model_decoder_layers_2_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] + model_decoder_layers_2_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[541] + model_decoder_layers_2_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] + model_decoder_layers_2_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[543] + model_decoder_layers_2_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[544] + model_decoder_layers_2_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[545] + model_decoder_layers_2_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] + model_decoder_layers_2_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[550] + model_decoder_layers_2_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] + model_decoder_layers_2_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[552] + model_decoder_layers_2_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[553] + model_decoder_layers_2_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[554] + model_decoder_layers_2_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] + model_decoder_layers_2_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[556] + model_decoder_layers_2_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] + model_decoder_layers_2_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[558] + model_decoder_layers_2_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[559] + model_decoder_layers_2_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[560] + model_decoder_layers_3_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] + model_decoder_layers_3_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] + model_decoder_layers_3_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[563] + model_decoder_layers_3_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] + model_decoder_layers_3_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[565] + model_decoder_layers_3_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] + model_decoder_layers_3_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[567] + model_decoder_layers_3_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[568] + model_decoder_layers_3_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[569] + model_decoder_layers_3_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] + model_decoder_layers_3_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[574] + model_decoder_layers_3_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] + model_decoder_layers_3_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[576] + model_decoder_layers_3_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[577] + model_decoder_layers_3_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[578] + model_decoder_layers_3_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] + model_decoder_layers_3_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[580] + model_decoder_layers_3_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] + model_decoder_layers_3_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[582] + model_decoder_layers_3_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[583] + model_decoder_layers_3_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[584] + model_decoder_layers_4_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] + model_decoder_layers_4_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] + model_decoder_layers_4_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[587] + model_decoder_layers_4_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] + model_decoder_layers_4_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[589] + model_decoder_layers_4_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] + model_decoder_layers_4_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[591] + model_decoder_layers_4_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[592] + model_decoder_layers_4_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[593] + model_decoder_layers_4_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] + model_decoder_layers_4_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[598] + model_decoder_layers_4_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] + model_decoder_layers_4_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[600] + model_decoder_layers_4_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[601] + model_decoder_layers_4_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[602] + model_decoder_layers_4_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] + model_decoder_layers_4_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[604] + model_decoder_layers_4_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] + model_decoder_layers_4_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[606] + model_decoder_layers_4_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[607] + model_decoder_layers_4_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[608] + model_decoder_layers_5_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] + model_decoder_layers_5_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] + model_decoder_layers_5_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[611] + model_decoder_layers_5_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] + model_decoder_layers_5_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[613] + model_decoder_layers_5_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] + model_decoder_layers_5_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[615] + model_decoder_layers_5_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[616] + model_decoder_layers_5_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[617] + model_decoder_layers_5_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] + model_decoder_layers_5_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[622] + model_decoder_layers_5_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] + model_decoder_layers_5_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[624] + model_decoder_layers_5_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[625] + model_decoder_layers_5_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[626] + model_decoder_layers_5_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] + model_decoder_layers_5_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[628] + model_decoder_layers_5_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] + model_decoder_layers_5_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[630] + model_decoder_layers_5_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[631] + model_decoder_layers_5_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[632] + model_decoder_layers_6_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] + model_decoder_layers_6_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] + model_decoder_layers_6_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[635] + model_decoder_layers_6_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] + model_decoder_layers_6_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[637] + model_decoder_layers_6_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] + model_decoder_layers_6_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[639] + model_decoder_layers_6_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[640] + model_decoder_layers_6_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[641] + model_decoder_layers_6_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] + model_decoder_layers_6_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[646] + model_decoder_layers_6_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] + model_decoder_layers_6_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[648] + model_decoder_layers_6_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[649] + model_decoder_layers_6_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[650] + model_decoder_layers_6_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] + model_decoder_layers_6_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[652] + model_decoder_layers_6_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] + model_decoder_layers_6_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[654] + model_decoder_layers_6_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[655] + model_decoder_layers_6_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[656] + model_decoder_layers_7_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] + model_decoder_layers_7_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] + model_decoder_layers_7_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[659] + model_decoder_layers_7_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] + model_decoder_layers_7_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[661] + model_decoder_layers_7_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] + model_decoder_layers_7_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[663] + model_decoder_layers_7_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[664] + model_decoder_layers_7_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[665] + model_decoder_layers_7_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] + model_decoder_layers_7_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[670] + model_decoder_layers_7_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] + model_decoder_layers_7_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[672] + model_decoder_layers_7_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[673] + model_decoder_layers_7_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[674] + model_decoder_layers_7_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] + model_decoder_layers_7_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[676] + model_decoder_layers_7_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] + model_decoder_layers_7_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[678] + model_decoder_layers_7_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[679] + model_decoder_layers_7_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[680] + model_decoder_layers_8_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] + model_decoder_layers_8_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] + model_decoder_layers_8_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[683] + model_decoder_layers_8_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] + model_decoder_layers_8_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[685] + model_decoder_layers_8_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] + model_decoder_layers_8_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[687] + model_decoder_layers_8_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[688] + model_decoder_layers_8_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[689] + model_decoder_layers_8_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] + model_decoder_layers_8_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[694] + model_decoder_layers_8_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] + model_decoder_layers_8_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[696] + model_decoder_layers_8_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[697] + model_decoder_layers_8_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[698] + model_decoder_layers_8_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] + model_decoder_layers_8_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[700] + model_decoder_layers_8_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] + model_decoder_layers_8_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[702] + model_decoder_layers_8_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[703] + model_decoder_layers_8_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[704] + model_decoder_layers_9_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] + model_decoder_layers_9_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] + model_decoder_layers_9_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[707] + model_decoder_layers_9_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] + model_decoder_layers_9_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[709] + model_decoder_layers_9_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] + model_decoder_layers_9_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[711] + model_decoder_layers_9_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[712] + model_decoder_layers_9_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[713] + model_decoder_layers_9_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] + model_decoder_layers_9_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[718] + model_decoder_layers_9_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] + model_decoder_layers_9_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[720] + model_decoder_layers_9_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[721] + model_decoder_layers_9_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[722] + model_decoder_layers_9_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] + model_decoder_layers_9_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[724] + model_decoder_layers_9_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] + model_decoder_layers_9_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[726] + model_decoder_layers_9_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[727] + model_decoder_layers_9_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[728] + model_decoder_layers_10_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] + model_decoder_layers_10_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] + model_decoder_layers_10_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[731] + model_decoder_layers_10_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] + model_decoder_layers_10_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[733] + model_decoder_layers_10_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] + model_decoder_layers_10_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[735] + model_decoder_layers_10_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[736] + model_decoder_layers_10_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[737] + model_decoder_layers_10_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] + model_decoder_layers_10_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[742] + model_decoder_layers_10_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] + model_decoder_layers_10_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[744] + model_decoder_layers_10_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[745] + model_decoder_layers_10_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[746] + model_decoder_layers_10_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] + model_decoder_layers_10_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[748] + model_decoder_layers_10_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] + model_decoder_layers_10_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[750] + model_decoder_layers_10_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[751] + model_decoder_layers_10_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[752] + model_decoder_layers_11_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] + model_decoder_layers_11_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] + model_decoder_layers_11_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[755] + model_decoder_layers_11_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] + model_decoder_layers_11_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[757] + model_decoder_layers_11_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] + model_decoder_layers_11_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[759] + model_decoder_layers_11_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[760] + model_decoder_layers_11_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[761] + model_decoder_layers_11_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] + model_decoder_layers_11_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[766] + model_decoder_layers_11_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] + model_decoder_layers_11_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[768] + model_decoder_layers_11_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[769] + model_decoder_layers_11_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[770] + model_decoder_layers_11_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] + model_decoder_layers_11_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[772] + model_decoder_layers_11_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] + model_decoder_layers_11_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[774] + model_decoder_layers_11_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[775] + model_decoder_layers_11_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[776] + model_decoder_layers_12_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] + model_decoder_layers_12_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] + model_decoder_layers_12_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[779] + model_decoder_layers_12_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] + model_decoder_layers_12_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[781] + model_decoder_layers_12_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] + model_decoder_layers_12_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[783] + model_decoder_layers_12_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[784] + model_decoder_layers_12_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[785] + model_decoder_layers_12_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] + model_decoder_layers_12_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[790] + model_decoder_layers_12_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] + model_decoder_layers_12_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[792] + model_decoder_layers_12_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[793] + model_decoder_layers_12_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[794] + model_decoder_layers_12_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] + model_decoder_layers_12_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[796] + model_decoder_layers_12_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] + model_decoder_layers_12_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[798] + model_decoder_layers_12_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[799] + model_decoder_layers_12_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[800] + model_decoder_layers_13_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] + model_decoder_layers_13_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] + model_decoder_layers_13_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[803] + model_decoder_layers_13_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] + model_decoder_layers_13_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[805] + model_decoder_layers_13_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] + model_decoder_layers_13_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[807] + model_decoder_layers_13_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[808] + model_decoder_layers_13_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[809] + model_decoder_layers_13_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] + model_decoder_layers_13_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[814] + model_decoder_layers_13_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] + model_decoder_layers_13_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[816] + model_decoder_layers_13_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[817] + model_decoder_layers_13_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[818] + model_decoder_layers_13_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] + model_decoder_layers_13_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[820] + model_decoder_layers_13_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] + model_decoder_layers_13_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[822] + model_decoder_layers_13_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[823] + model_decoder_layers_13_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[824] + model_decoder_layers_14_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] + model_decoder_layers_14_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] + model_decoder_layers_14_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[827] + model_decoder_layers_14_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] + model_decoder_layers_14_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[829] + model_decoder_layers_14_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] + model_decoder_layers_14_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[831] + model_decoder_layers_14_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[832] + model_decoder_layers_14_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[833] + model_decoder_layers_14_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] + model_decoder_layers_14_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[838] + model_decoder_layers_14_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] + model_decoder_layers_14_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[840] + model_decoder_layers_14_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[841] + model_decoder_layers_14_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[842] + model_decoder_layers_14_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] + model_decoder_layers_14_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[844] + model_decoder_layers_14_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] + model_decoder_layers_14_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[846] + model_decoder_layers_14_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[847] + model_decoder_layers_14_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[848] + model_decoder_layers_15_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] + model_decoder_layers_15_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] + model_decoder_layers_15_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[851] + model_decoder_layers_15_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] + model_decoder_layers_15_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[853] + model_decoder_layers_15_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] + model_decoder_layers_15_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[855] + model_decoder_layers_15_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[856] + model_decoder_layers_15_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[857] + model_decoder_layers_15_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] + model_decoder_layers_15_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[862] + model_decoder_layers_15_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] + model_decoder_layers_15_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[864] + model_decoder_layers_15_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[865] + model_decoder_layers_15_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[866] + model_decoder_layers_15_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] + model_decoder_layers_15_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[868] + model_decoder_layers_15_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] + model_decoder_layers_15_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[870] + model_decoder_layers_15_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[871] + model_decoder_layers_15_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[872] + model_decoder_layers_16_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] + model_decoder_layers_16_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] + model_decoder_layers_16_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[875] + model_decoder_layers_16_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] + model_decoder_layers_16_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[877] + model_decoder_layers_16_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] + model_decoder_layers_16_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[879] + model_decoder_layers_16_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[880] + model_decoder_layers_16_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[881] + model_decoder_layers_16_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] + model_decoder_layers_16_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[886] + model_decoder_layers_16_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] + model_decoder_layers_16_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[888] + model_decoder_layers_16_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[889] + model_decoder_layers_16_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[890] + model_decoder_layers_16_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] + model_decoder_layers_16_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[892] + model_decoder_layers_16_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] + model_decoder_layers_16_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[894] + model_decoder_layers_16_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[895] + model_decoder_layers_16_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[896] + model_decoder_layers_17_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] + model_decoder_layers_17_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] + model_decoder_layers_17_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[899] + model_decoder_layers_17_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] + model_decoder_layers_17_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[901] + model_decoder_layers_17_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] + model_decoder_layers_17_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[903] + model_decoder_layers_17_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[904] + model_decoder_layers_17_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[905] + model_decoder_layers_17_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] + model_decoder_layers_17_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[910] + model_decoder_layers_17_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] + model_decoder_layers_17_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[912] + model_decoder_layers_17_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[913] + model_decoder_layers_17_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[914] + model_decoder_layers_17_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] + model_decoder_layers_17_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[916] + model_decoder_layers_17_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] + model_decoder_layers_17_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[918] + model_decoder_layers_17_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[919] + model_decoder_layers_17_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[920] + model_decoder_layers_18_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] + model_decoder_layers_18_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] + model_decoder_layers_18_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[923] + model_decoder_layers_18_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] + model_decoder_layers_18_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[925] + model_decoder_layers_18_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] + model_decoder_layers_18_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[927] + model_decoder_layers_18_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[928] + model_decoder_layers_18_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[929] + model_decoder_layers_18_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] + model_decoder_layers_18_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[934] + model_decoder_layers_18_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] + model_decoder_layers_18_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[936] + model_decoder_layers_18_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[937] + model_decoder_layers_18_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[938] + model_decoder_layers_18_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] + model_decoder_layers_18_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[940] + model_decoder_layers_18_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] + model_decoder_layers_18_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[942] + model_decoder_layers_18_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[943] + model_decoder_layers_18_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[944] + model_decoder_layers_19_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] + model_decoder_layers_19_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] + model_decoder_layers_19_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[947] + model_decoder_layers_19_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] + model_decoder_layers_19_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[949] + model_decoder_layers_19_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] + model_decoder_layers_19_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[951] + model_decoder_layers_19_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[952] + model_decoder_layers_19_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[953] + model_decoder_layers_19_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] + model_decoder_layers_19_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[958] + model_decoder_layers_19_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] + model_decoder_layers_19_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[960] + model_decoder_layers_19_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[961] + model_decoder_layers_19_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[962] + model_decoder_layers_19_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] + model_decoder_layers_19_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[964] + model_decoder_layers_19_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] + model_decoder_layers_19_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[966] + model_decoder_layers_19_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[967] + model_decoder_layers_19_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[968] + model_decoder_layers_20_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] + model_decoder_layers_20_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] + model_decoder_layers_20_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[971] + model_decoder_layers_20_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] + model_decoder_layers_20_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[973] + model_decoder_layers_20_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] + model_decoder_layers_20_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[975] + model_decoder_layers_20_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[976] + model_decoder_layers_20_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[977] + model_decoder_layers_20_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] + model_decoder_layers_20_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[982] + model_decoder_layers_20_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] + model_decoder_layers_20_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[984] + model_decoder_layers_20_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[985] + model_decoder_layers_20_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[986] + model_decoder_layers_20_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] + model_decoder_layers_20_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[988] + model_decoder_layers_20_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] + model_decoder_layers_20_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[990] + model_decoder_layers_20_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[991] + model_decoder_layers_20_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[992] + model_decoder_layers_21_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] + model_decoder_layers_21_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] + model_decoder_layers_21_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[995] + model_decoder_layers_21_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] + model_decoder_layers_21_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[997] + model_decoder_layers_21_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] + model_decoder_layers_21_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[999] + model_decoder_layers_21_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1000] + model_decoder_layers_21_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1001] + model_decoder_layers_21_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] + model_decoder_layers_21_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1006] + model_decoder_layers_21_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] + model_decoder_layers_21_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1008] + model_decoder_layers_21_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1009] + model_decoder_layers_21_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1010] + model_decoder_layers_21_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] + model_decoder_layers_21_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1012] + model_decoder_layers_21_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] + model_decoder_layers_21_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1014] + model_decoder_layers_21_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1015] + model_decoder_layers_21_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1016] + model_decoder_layers_22_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] + model_decoder_layers_22_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] + model_decoder_layers_22_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1019] + model_decoder_layers_22_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] + model_decoder_layers_22_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1021] + model_decoder_layers_22_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] + model_decoder_layers_22_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1023] + model_decoder_layers_22_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1024] + model_decoder_layers_22_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1025] + model_decoder_layers_22_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] + model_decoder_layers_22_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1030] + model_decoder_layers_22_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] + model_decoder_layers_22_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1032] + model_decoder_layers_22_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1033] + model_decoder_layers_22_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1034] + model_decoder_layers_22_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] + model_decoder_layers_22_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1036] + model_decoder_layers_22_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] + model_decoder_layers_22_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1038] + model_decoder_layers_22_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1039] + model_decoder_layers_22_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1040] + model_decoder_layers_23_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] + model_decoder_layers_23_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] + model_decoder_layers_23_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1043] + model_decoder_layers_23_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] + model_decoder_layers_23_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1045] + model_decoder_layers_23_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] + model_decoder_layers_23_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1047] + model_decoder_layers_23_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1048] + model_decoder_layers_23_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1049] + model_decoder_layers_23_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] + model_decoder_layers_23_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1054] + model_decoder_layers_23_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] + model_decoder_layers_23_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1056] + model_decoder_layers_23_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1057] + model_decoder_layers_23_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1058] + model_decoder_layers_23_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] + model_decoder_layers_23_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1060] + model_decoder_layers_23_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] + model_decoder_layers_23_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1062] + model_decoder_layers_23_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1063] + model_decoder_layers_23_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1064] + model_decoder_layers_24_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] + model_decoder_layers_24_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] + model_decoder_layers_24_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1067] + model_decoder_layers_24_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] + model_decoder_layers_24_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1069] + model_decoder_layers_24_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] + model_decoder_layers_24_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1071] + model_decoder_layers_24_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1072] + model_decoder_layers_24_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1073] + model_decoder_layers_24_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] + model_decoder_layers_24_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1078] + model_decoder_layers_24_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] + model_decoder_layers_24_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1080] + model_decoder_layers_24_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1081] + model_decoder_layers_24_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1082] + model_decoder_layers_24_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] + model_decoder_layers_24_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1084] + model_decoder_layers_24_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] + model_decoder_layers_24_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1086] + model_decoder_layers_24_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1087] + model_decoder_layers_24_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1088] + model_decoder_layers_25_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] + model_decoder_layers_25_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] + model_decoder_layers_25_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1091] + model_decoder_layers_25_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] + model_decoder_layers_25_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1093] + model_decoder_layers_25_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] + model_decoder_layers_25_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1095] + model_decoder_layers_25_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1096] + model_decoder_layers_25_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1097] + model_decoder_layers_25_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] + model_decoder_layers_25_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1102] + model_decoder_layers_25_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] + model_decoder_layers_25_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1104] + model_decoder_layers_25_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1105] + model_decoder_layers_25_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1106] + model_decoder_layers_25_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] + model_decoder_layers_25_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1108] + model_decoder_layers_25_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] + model_decoder_layers_25_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1110] + model_decoder_layers_25_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1111] + model_decoder_layers_25_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1112] + model_decoder_layers_26_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] + model_decoder_layers_26_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] + model_decoder_layers_26_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1115] + model_decoder_layers_26_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] + model_decoder_layers_26_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1117] + model_decoder_layers_26_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] + model_decoder_layers_26_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1119] + model_decoder_layers_26_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1120] + model_decoder_layers_26_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1121] + model_decoder_layers_26_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] + model_decoder_layers_26_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1126] + model_decoder_layers_26_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] + model_decoder_layers_26_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1128] + model_decoder_layers_26_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1129] + model_decoder_layers_26_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1130] + model_decoder_layers_26_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] + model_decoder_layers_26_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1132] + model_decoder_layers_26_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] + model_decoder_layers_26_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1134] + model_decoder_layers_26_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1135] + model_decoder_layers_26_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1136] + model_decoder_layers_27_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] + model_decoder_layers_27_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] + model_decoder_layers_27_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1139] + model_decoder_layers_27_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] + model_decoder_layers_27_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1141] + model_decoder_layers_27_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] + model_decoder_layers_27_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1143] + model_decoder_layers_27_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1144] + model_decoder_layers_27_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1145] + model_decoder_layers_27_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] + model_decoder_layers_27_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1150] + model_decoder_layers_27_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] + model_decoder_layers_27_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1152] + model_decoder_layers_27_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1153] + model_decoder_layers_27_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1154] + model_decoder_layers_27_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] + model_decoder_layers_27_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1156] + model_decoder_layers_27_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] + model_decoder_layers_27_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1158] + model_decoder_layers_27_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1159] + model_decoder_layers_27_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1160] + model_decoder_layers_28_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] + model_decoder_layers_28_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] + model_decoder_layers_28_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1163] + model_decoder_layers_28_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] + model_decoder_layers_28_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1165] + model_decoder_layers_28_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] + model_decoder_layers_28_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1167] + model_decoder_layers_28_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1168] + model_decoder_layers_28_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1169] + model_decoder_layers_28_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] + model_decoder_layers_28_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1174] + model_decoder_layers_28_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] + model_decoder_layers_28_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1176] + model_decoder_layers_28_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1177] + model_decoder_layers_28_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1178] + model_decoder_layers_28_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] + model_decoder_layers_28_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1180] + model_decoder_layers_28_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] + model_decoder_layers_28_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1182] + model_decoder_layers_28_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1183] + model_decoder_layers_28_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1184] + model_decoder_layers_29_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] + model_decoder_layers_29_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] + model_decoder_layers_29_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1187] + model_decoder_layers_29_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] + model_decoder_layers_29_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1189] + model_decoder_layers_29_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] + model_decoder_layers_29_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1191] + model_decoder_layers_29_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1192] + model_decoder_layers_29_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1193] + model_decoder_layers_29_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] + model_decoder_layers_29_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1198] + model_decoder_layers_29_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] + model_decoder_layers_29_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1200] + model_decoder_layers_29_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1201] + model_decoder_layers_29_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1202] + model_decoder_layers_29_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] + model_decoder_layers_29_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1204] + model_decoder_layers_29_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] + model_decoder_layers_29_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1206] + model_decoder_layers_29_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1207] + model_decoder_layers_29_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1208] + model_decoder_layers_30_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] + model_decoder_layers_30_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] + model_decoder_layers_30_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1211] + model_decoder_layers_30_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] + model_decoder_layers_30_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1213] + model_decoder_layers_30_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] + model_decoder_layers_30_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1215] + model_decoder_layers_30_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1216] + model_decoder_layers_30_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1217] + model_decoder_layers_30_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] + model_decoder_layers_30_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1222] + model_decoder_layers_30_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] + model_decoder_layers_30_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1224] + model_decoder_layers_30_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1225] + model_decoder_layers_30_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1226] + model_decoder_layers_30_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] + model_decoder_layers_30_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1228] + model_decoder_layers_30_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] + model_decoder_layers_30_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1230] + model_decoder_layers_30_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1231] + model_decoder_layers_30_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1232] + model_decoder_layers_31_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] + model_decoder_layers_31_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] + model_decoder_layers_31_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1235] + model_decoder_layers_31_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] + model_decoder_layers_31_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1237] + model_decoder_layers_31_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] + model_decoder_layers_31_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1239] + model_decoder_layers_31_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1240] + model_decoder_layers_31_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1241] + model_decoder_layers_31_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] + model_decoder_layers_31_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1246] + model_decoder_layers_31_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] + model_decoder_layers_31_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1248] + model_decoder_layers_31_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1249] + model_decoder_layers_31_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1250] + model_decoder_layers_31_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] + model_decoder_layers_31_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1252] + model_decoder_layers_31_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] + model_decoder_layers_31_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1254] + model_decoder_layers_31_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1255] + model_decoder_layers_31_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1256] + model_decoder_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1257] + model_decoder_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1258] + reshape1030 = R.call_tir(cls.reshape12, (input_ids,), out_sinfo=R.Tensor((seq_len,), dtype="int32")) + take5 = R.call_tir(cls.take, (model_decoder_embed_tokens_weight4, reshape1030), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) + reshape1031 = R.call_tir(cls.reshape13, (take5,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv198: R.Tensor((seq_len,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((seq_len,), dtype="int32"),)) + take6 = R.call_tir(cls.take1, (model_decoder_embed_positions_weight4, lv198), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) + reshape1032 = R.call_tir(cls.reshape13, (take6,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add899 = R.call_tir(cls.add5, (reshape1031, reshape1032), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm259 = R.call_tir(cls.layer_norm2, (add899, model_decoder_layers_0_self_attn_layer_norm_weight4, model_decoder_layers_0_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv32 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_q_proj_weight4, layer_norm259, model_decoder_layers_0_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1033 = R.call_tir(cls.reshape14, (lv32,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv32_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_0_self_attn_k_proj_weight4, layer_norm259), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1034 = R.call_tir(cls.reshape14, (lv32_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv33 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_v_proj_weight4, layer_norm259, model_decoder_layers_0_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1035 = R.call_tir(cls.reshape14, (lv33,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat64 = R.call_tir(cls.concatenate1, (reshape1033, reshape1034, reshape1035), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1036 = R.call_tir(cls.reshape15, (concat64,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv199 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape1036), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1037 = R.call_tir(cls.reshape16, (lv199,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1038 = R.call_tir(cls.reshape17, (reshape1037,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv34 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_out_proj_weight4, reshape1038, model_decoder_layers_0_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add903 = R.call_tir(cls.add5, (add899, lv34), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm260 = R.call_tir(cls.layer_norm2, (add903, model_decoder_layers_0_encoder_attn_layer_norm_weight4, model_decoder_layers_0_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv35 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_q_proj_weight4, layer_norm260, model_decoder_layers_0_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1039 = R.call_tir(cls.reshape14, (lv35,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1040 = R.call_tir(cls.reshape18, (reshape1039,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv200 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape1040), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1041 = R.call_tir(cls.reshape16, (lv200,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1042 = R.call_tir(cls.reshape17, (reshape1041,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv36 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_out_proj_weight4, reshape1042, model_decoder_layers_0_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add906 = R.call_tir(cls.add5, (add903, lv36), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm261 = R.call_tir(cls.layer_norm2, (add906, model_decoder_layers_0_final_layer_norm_weight4, model_decoder_layers_0_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_0_fc1_weight4, layer_norm261, model_decoder_layers_0_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv37 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_0_fc2_weight4, lv, model_decoder_layers_0_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add909 = R.call_tir(cls.add5, (add906, lv37), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm262 = R.call_tir(cls.layer_norm2, (add909, model_decoder_layers_1_self_attn_layer_norm_weight4, model_decoder_layers_1_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv38 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_q_proj_weight4, layer_norm262, model_decoder_layers_1_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1043 = R.call_tir(cls.reshape14, (lv38,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv33_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_1_self_attn_k_proj_weight4, layer_norm262), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1044 = R.call_tir(cls.reshape14, (lv33_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv39 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_v_proj_weight4, layer_norm262, model_decoder_layers_1_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1045 = R.call_tir(cls.reshape14, (lv39,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat65 = R.call_tir(cls.concatenate1, (reshape1043, reshape1044, reshape1045), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1046 = R.call_tir(cls.reshape15, (concat65,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv201 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape1046), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1047 = R.call_tir(cls.reshape16, (lv201,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1048 = R.call_tir(cls.reshape17, (reshape1047,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv40 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_out_proj_weight4, reshape1048, model_decoder_layers_1_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add913 = R.call_tir(cls.add5, (add909, lv40), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm263 = R.call_tir(cls.layer_norm2, (add913, model_decoder_layers_1_encoder_attn_layer_norm_weight4, model_decoder_layers_1_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv41 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_q_proj_weight4, layer_norm263, model_decoder_layers_1_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1049 = R.call_tir(cls.reshape14, (lv41,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1050 = R.call_tir(cls.reshape18, (reshape1049,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv202 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape1050), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1051 = R.call_tir(cls.reshape16, (lv202,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1052 = R.call_tir(cls.reshape17, (reshape1051,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv42 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_out_proj_weight4, reshape1052, model_decoder_layers_1_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add916 = R.call_tir(cls.add5, (add913, lv42), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm264 = R.call_tir(cls.layer_norm2, (add916, model_decoder_layers_1_final_layer_norm_weight4, model_decoder_layers_1_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_1_fc1_weight4, layer_norm264, model_decoder_layers_1_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv43 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_1_fc2_weight4, lv1, model_decoder_layers_1_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add919 = R.call_tir(cls.add5, (add916, lv43), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm265 = R.call_tir(cls.layer_norm2, (add919, model_decoder_layers_2_self_attn_layer_norm_weight4, model_decoder_layers_2_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv44 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_q_proj_weight4, layer_norm265, model_decoder_layers_2_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1053 = R.call_tir(cls.reshape14, (lv44,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv34_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_2_self_attn_k_proj_weight4, layer_norm265), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1054 = R.call_tir(cls.reshape14, (lv34_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv45 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_v_proj_weight4, layer_norm265, model_decoder_layers_2_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1055 = R.call_tir(cls.reshape14, (lv45,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat66 = R.call_tir(cls.concatenate1, (reshape1053, reshape1054, reshape1055), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1056 = R.call_tir(cls.reshape15, (concat66,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv203 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape1056), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1057 = R.call_tir(cls.reshape16, (lv203,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1058 = R.call_tir(cls.reshape17, (reshape1057,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv46 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_out_proj_weight4, reshape1058, model_decoder_layers_2_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add923 = R.call_tir(cls.add5, (add919, lv46), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm266 = R.call_tir(cls.layer_norm2, (add923, model_decoder_layers_2_encoder_attn_layer_norm_weight4, model_decoder_layers_2_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv47 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_q_proj_weight4, layer_norm266, model_decoder_layers_2_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1059 = R.call_tir(cls.reshape14, (lv47,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1060 = R.call_tir(cls.reshape18, (reshape1059,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv204 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape1060), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1061 = R.call_tir(cls.reshape16, (lv204,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1062 = R.call_tir(cls.reshape17, (reshape1061,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv48 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_out_proj_weight4, reshape1062, model_decoder_layers_2_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add926 = R.call_tir(cls.add5, (add923, lv48), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm267 = R.call_tir(cls.layer_norm2, (add926, model_decoder_layers_2_final_layer_norm_weight4, model_decoder_layers_2_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv2 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_2_fc1_weight4, layer_norm267, model_decoder_layers_2_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv49 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_2_fc2_weight4, lv2, model_decoder_layers_2_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add929 = R.call_tir(cls.add5, (add926, lv49), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm268 = R.call_tir(cls.layer_norm2, (add929, model_decoder_layers_3_self_attn_layer_norm_weight4, model_decoder_layers_3_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv50 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_q_proj_weight4, layer_norm268, model_decoder_layers_3_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1063 = R.call_tir(cls.reshape14, (lv50,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv35_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_3_self_attn_k_proj_weight4, layer_norm268), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1064 = R.call_tir(cls.reshape14, (lv35_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv51 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_v_proj_weight4, layer_norm268, model_decoder_layers_3_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1065 = R.call_tir(cls.reshape14, (lv51,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat67 = R.call_tir(cls.concatenate1, (reshape1063, reshape1064, reshape1065), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1066 = R.call_tir(cls.reshape15, (concat67,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv205 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape1066), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1067 = R.call_tir(cls.reshape16, (lv205,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1068 = R.call_tir(cls.reshape17, (reshape1067,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv52 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_out_proj_weight4, reshape1068, model_decoder_layers_3_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add933 = R.call_tir(cls.add5, (add929, lv52), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm269 = R.call_tir(cls.layer_norm2, (add933, model_decoder_layers_3_encoder_attn_layer_norm_weight4, model_decoder_layers_3_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv53 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_q_proj_weight4, layer_norm269, model_decoder_layers_3_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1069 = R.call_tir(cls.reshape14, (lv53,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1070 = R.call_tir(cls.reshape18, (reshape1069,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv206 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape1070), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1071 = R.call_tir(cls.reshape16, (lv206,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1072 = R.call_tir(cls.reshape17, (reshape1071,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv54 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_out_proj_weight4, reshape1072, model_decoder_layers_3_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add936 = R.call_tir(cls.add5, (add933, lv54), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm270 = R.call_tir(cls.layer_norm2, (add936, model_decoder_layers_3_final_layer_norm_weight4, model_decoder_layers_3_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv3 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_3_fc1_weight4, layer_norm270, model_decoder_layers_3_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv55 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_3_fc2_weight4, lv3, model_decoder_layers_3_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add939 = R.call_tir(cls.add5, (add936, lv55), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm271 = R.call_tir(cls.layer_norm2, (add939, model_decoder_layers_4_self_attn_layer_norm_weight4, model_decoder_layers_4_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv56 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_q_proj_weight4, layer_norm271, model_decoder_layers_4_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1073 = R.call_tir(cls.reshape14, (lv56,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv36_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_4_self_attn_k_proj_weight4, layer_norm271), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1074 = R.call_tir(cls.reshape14, (lv36_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv57 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_v_proj_weight4, layer_norm271, model_decoder_layers_4_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1075 = R.call_tir(cls.reshape14, (lv57,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat68 = R.call_tir(cls.concatenate1, (reshape1073, reshape1074, reshape1075), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1076 = R.call_tir(cls.reshape15, (concat68,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv207 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape1076), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1077 = R.call_tir(cls.reshape16, (lv207,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1078 = R.call_tir(cls.reshape17, (reshape1077,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv58 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_out_proj_weight4, reshape1078, model_decoder_layers_4_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add943 = R.call_tir(cls.add5, (add939, lv58), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm272 = R.call_tir(cls.layer_norm2, (add943, model_decoder_layers_4_encoder_attn_layer_norm_weight4, model_decoder_layers_4_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv59 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_q_proj_weight4, layer_norm272, model_decoder_layers_4_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1079 = R.call_tir(cls.reshape14, (lv59,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1080 = R.call_tir(cls.reshape18, (reshape1079,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv208 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape1080), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1081 = R.call_tir(cls.reshape16, (lv208,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1082 = R.call_tir(cls.reshape17, (reshape1081,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv60 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_out_proj_weight4, reshape1082, model_decoder_layers_4_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add946 = R.call_tir(cls.add5, (add943, lv60), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm273 = R.call_tir(cls.layer_norm2, (add946, model_decoder_layers_4_final_layer_norm_weight4, model_decoder_layers_4_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv4 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_4_fc1_weight4, layer_norm273, model_decoder_layers_4_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv61 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_4_fc2_weight4, lv4, model_decoder_layers_4_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add949 = R.call_tir(cls.add5, (add946, lv61), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm274 = R.call_tir(cls.layer_norm2, (add949, model_decoder_layers_5_self_attn_layer_norm_weight4, model_decoder_layers_5_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv62 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_q_proj_weight4, layer_norm274, model_decoder_layers_5_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1083 = R.call_tir(cls.reshape14, (lv62,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv37_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_5_self_attn_k_proj_weight4, layer_norm274), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1084 = R.call_tir(cls.reshape14, (lv37_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv63 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_v_proj_weight4, layer_norm274, model_decoder_layers_5_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1085 = R.call_tir(cls.reshape14, (lv63,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat69 = R.call_tir(cls.concatenate1, (reshape1083, reshape1084, reshape1085), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1086 = R.call_tir(cls.reshape15, (concat69,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv209 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape1086), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1087 = R.call_tir(cls.reshape16, (lv209,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1088 = R.call_tir(cls.reshape17, (reshape1087,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv64 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_out_proj_weight4, reshape1088, model_decoder_layers_5_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add953 = R.call_tir(cls.add5, (add949, lv64), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm275 = R.call_tir(cls.layer_norm2, (add953, model_decoder_layers_5_encoder_attn_layer_norm_weight4, model_decoder_layers_5_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv65 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_q_proj_weight4, layer_norm275, model_decoder_layers_5_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1089 = R.call_tir(cls.reshape14, (lv65,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1090 = R.call_tir(cls.reshape18, (reshape1089,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv210 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape1090), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1091 = R.call_tir(cls.reshape16, (lv210,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1092 = R.call_tir(cls.reshape17, (reshape1091,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv66 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_out_proj_weight4, reshape1092, model_decoder_layers_5_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add956 = R.call_tir(cls.add5, (add953, lv66), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm276 = R.call_tir(cls.layer_norm2, (add956, model_decoder_layers_5_final_layer_norm_weight4, model_decoder_layers_5_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv5 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_5_fc1_weight4, layer_norm276, model_decoder_layers_5_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv67 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_5_fc2_weight4, lv5, model_decoder_layers_5_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add959 = R.call_tir(cls.add5, (add956, lv67), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm277 = R.call_tir(cls.layer_norm2, (add959, model_decoder_layers_6_self_attn_layer_norm_weight4, model_decoder_layers_6_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv68 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_q_proj_weight4, layer_norm277, model_decoder_layers_6_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1093 = R.call_tir(cls.reshape14, (lv68,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv38_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_6_self_attn_k_proj_weight4, layer_norm277), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1094 = R.call_tir(cls.reshape14, (lv38_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv69 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_v_proj_weight4, layer_norm277, model_decoder_layers_6_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1095 = R.call_tir(cls.reshape14, (lv69,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat70 = R.call_tir(cls.concatenate1, (reshape1093, reshape1094, reshape1095), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1096 = R.call_tir(cls.reshape15, (concat70,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv211 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape1096), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1097 = R.call_tir(cls.reshape16, (lv211,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1098 = R.call_tir(cls.reshape17, (reshape1097,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv70 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_out_proj_weight4, reshape1098, model_decoder_layers_6_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add963 = R.call_tir(cls.add5, (add959, lv70), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm278 = R.call_tir(cls.layer_norm2, (add963, model_decoder_layers_6_encoder_attn_layer_norm_weight4, model_decoder_layers_6_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv71 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_q_proj_weight4, layer_norm278, model_decoder_layers_6_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1099 = R.call_tir(cls.reshape14, (lv71,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1100 = R.call_tir(cls.reshape18, (reshape1099,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv212 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape1100), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1101 = R.call_tir(cls.reshape16, (lv212,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1102 = R.call_tir(cls.reshape17, (reshape1101,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv72 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_out_proj_weight4, reshape1102, model_decoder_layers_6_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add966 = R.call_tir(cls.add5, (add963, lv72), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm279 = R.call_tir(cls.layer_norm2, (add966, model_decoder_layers_6_final_layer_norm_weight4, model_decoder_layers_6_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv6 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_6_fc1_weight4, layer_norm279, model_decoder_layers_6_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv73 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_6_fc2_weight4, lv6, model_decoder_layers_6_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add969 = R.call_tir(cls.add5, (add966, lv73), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm280 = R.call_tir(cls.layer_norm2, (add969, model_decoder_layers_7_self_attn_layer_norm_weight4, model_decoder_layers_7_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv74 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_q_proj_weight4, layer_norm280, model_decoder_layers_7_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1103 = R.call_tir(cls.reshape14, (lv74,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv39_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_7_self_attn_k_proj_weight4, layer_norm280), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1104 = R.call_tir(cls.reshape14, (lv39_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv75 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_v_proj_weight4, layer_norm280, model_decoder_layers_7_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1105 = R.call_tir(cls.reshape14, (lv75,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat71 = R.call_tir(cls.concatenate1, (reshape1103, reshape1104, reshape1105), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1106 = R.call_tir(cls.reshape15, (concat71,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv213 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape1106), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1107 = R.call_tir(cls.reshape16, (lv213,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1108 = R.call_tir(cls.reshape17, (reshape1107,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv76 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_out_proj_weight4, reshape1108, model_decoder_layers_7_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add973 = R.call_tir(cls.add5, (add969, lv76), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm281 = R.call_tir(cls.layer_norm2, (add973, model_decoder_layers_7_encoder_attn_layer_norm_weight4, model_decoder_layers_7_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv77 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_q_proj_weight4, layer_norm281, model_decoder_layers_7_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1109 = R.call_tir(cls.reshape14, (lv77,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1110 = R.call_tir(cls.reshape18, (reshape1109,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv214 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape1110), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1111 = R.call_tir(cls.reshape16, (lv214,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1112 = R.call_tir(cls.reshape17, (reshape1111,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv78 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_out_proj_weight4, reshape1112, model_decoder_layers_7_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add976 = R.call_tir(cls.add5, (add973, lv78), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm282 = R.call_tir(cls.layer_norm2, (add976, model_decoder_layers_7_final_layer_norm_weight4, model_decoder_layers_7_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv7 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_7_fc1_weight4, layer_norm282, model_decoder_layers_7_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv79 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_7_fc2_weight4, lv7, model_decoder_layers_7_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add979 = R.call_tir(cls.add5, (add976, lv79), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm283 = R.call_tir(cls.layer_norm2, (add979, model_decoder_layers_8_self_attn_layer_norm_weight4, model_decoder_layers_8_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv80 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_q_proj_weight4, layer_norm283, model_decoder_layers_8_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1113 = R.call_tir(cls.reshape14, (lv80,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv40_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_8_self_attn_k_proj_weight4, layer_norm283), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1114 = R.call_tir(cls.reshape14, (lv40_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv81 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_v_proj_weight4, layer_norm283, model_decoder_layers_8_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1115 = R.call_tir(cls.reshape14, (lv81,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat72 = R.call_tir(cls.concatenate1, (reshape1113, reshape1114, reshape1115), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1116 = R.call_tir(cls.reshape15, (concat72,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv215 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape1116), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1117 = R.call_tir(cls.reshape16, (lv215,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1118 = R.call_tir(cls.reshape17, (reshape1117,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv82 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_out_proj_weight4, reshape1118, model_decoder_layers_8_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add983 = R.call_tir(cls.add5, (add979, lv82), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm284 = R.call_tir(cls.layer_norm2, (add983, model_decoder_layers_8_encoder_attn_layer_norm_weight4, model_decoder_layers_8_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv83 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_q_proj_weight4, layer_norm284, model_decoder_layers_8_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1119 = R.call_tir(cls.reshape14, (lv83,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1120 = R.call_tir(cls.reshape18, (reshape1119,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv216 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape1120), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1121 = R.call_tir(cls.reshape16, (lv216,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1122 = R.call_tir(cls.reshape17, (reshape1121,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv84 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_out_proj_weight4, reshape1122, model_decoder_layers_8_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add986 = R.call_tir(cls.add5, (add983, lv84), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm285 = R.call_tir(cls.layer_norm2, (add986, model_decoder_layers_8_final_layer_norm_weight4, model_decoder_layers_8_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv8 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_8_fc1_weight4, layer_norm285, model_decoder_layers_8_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv85 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_8_fc2_weight4, lv8, model_decoder_layers_8_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add989 = R.call_tir(cls.add5, (add986, lv85), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm286 = R.call_tir(cls.layer_norm2, (add989, model_decoder_layers_9_self_attn_layer_norm_weight4, model_decoder_layers_9_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv86 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_q_proj_weight4, layer_norm286, model_decoder_layers_9_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1123 = R.call_tir(cls.reshape14, (lv86,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv41_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_9_self_attn_k_proj_weight4, layer_norm286), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1124 = R.call_tir(cls.reshape14, (lv41_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv87 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_v_proj_weight4, layer_norm286, model_decoder_layers_9_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1125 = R.call_tir(cls.reshape14, (lv87,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat73 = R.call_tir(cls.concatenate1, (reshape1123, reshape1124, reshape1125), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1126 = R.call_tir(cls.reshape15, (concat73,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv217 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape1126), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1127 = R.call_tir(cls.reshape16, (lv217,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1128 = R.call_tir(cls.reshape17, (reshape1127,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv88 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_out_proj_weight4, reshape1128, model_decoder_layers_9_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add993 = R.call_tir(cls.add5, (add989, lv88), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm287 = R.call_tir(cls.layer_norm2, (add993, model_decoder_layers_9_encoder_attn_layer_norm_weight4, model_decoder_layers_9_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv89 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_q_proj_weight4, layer_norm287, model_decoder_layers_9_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1129 = R.call_tir(cls.reshape14, (lv89,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1130 = R.call_tir(cls.reshape18, (reshape1129,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv218 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape1130), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1131 = R.call_tir(cls.reshape16, (lv218,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1132 = R.call_tir(cls.reshape17, (reshape1131,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv90 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_out_proj_weight4, reshape1132, model_decoder_layers_9_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add996 = R.call_tir(cls.add5, (add993, lv90), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm288 = R.call_tir(cls.layer_norm2, (add996, model_decoder_layers_9_final_layer_norm_weight4, model_decoder_layers_9_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv9 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_9_fc1_weight4, layer_norm288, model_decoder_layers_9_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv91 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_9_fc2_weight4, lv9, model_decoder_layers_9_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add999 = R.call_tir(cls.add5, (add996, lv91), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm289 = R.call_tir(cls.layer_norm2, (add999, model_decoder_layers_10_self_attn_layer_norm_weight4, model_decoder_layers_10_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv92 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_q_proj_weight4, layer_norm289, model_decoder_layers_10_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1133 = R.call_tir(cls.reshape14, (lv92,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv42_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_10_self_attn_k_proj_weight4, layer_norm289), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1134 = R.call_tir(cls.reshape14, (lv42_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv93 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_v_proj_weight4, layer_norm289, model_decoder_layers_10_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1135 = R.call_tir(cls.reshape14, (lv93,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat74 = R.call_tir(cls.concatenate1, (reshape1133, reshape1134, reshape1135), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1136 = R.call_tir(cls.reshape15, (concat74,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv219 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape1136), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1137 = R.call_tir(cls.reshape16, (lv219,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1138 = R.call_tir(cls.reshape17, (reshape1137,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv94 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_out_proj_weight4, reshape1138, model_decoder_layers_10_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1003 = R.call_tir(cls.add5, (add999, lv94), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm290 = R.call_tir(cls.layer_norm2, (add1003, model_decoder_layers_10_encoder_attn_layer_norm_weight4, model_decoder_layers_10_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv95 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_q_proj_weight4, layer_norm290, model_decoder_layers_10_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1139 = R.call_tir(cls.reshape14, (lv95,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1140 = R.call_tir(cls.reshape18, (reshape1139,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv220 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape1140), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1141 = R.call_tir(cls.reshape16, (lv220,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1142 = R.call_tir(cls.reshape17, (reshape1141,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv96 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_out_proj_weight4, reshape1142, model_decoder_layers_10_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1006 = R.call_tir(cls.add5, (add1003, lv96), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm291 = R.call_tir(cls.layer_norm2, (add1006, model_decoder_layers_10_final_layer_norm_weight4, model_decoder_layers_10_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv10 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_10_fc1_weight4, layer_norm291, model_decoder_layers_10_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv97 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_10_fc2_weight4, lv10, model_decoder_layers_10_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1009 = R.call_tir(cls.add5, (add1006, lv97), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm292 = R.call_tir(cls.layer_norm2, (add1009, model_decoder_layers_11_self_attn_layer_norm_weight4, model_decoder_layers_11_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv98 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_q_proj_weight4, layer_norm292, model_decoder_layers_11_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1143 = R.call_tir(cls.reshape14, (lv98,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv43_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_11_self_attn_k_proj_weight4, layer_norm292), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1144 = R.call_tir(cls.reshape14, (lv43_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv99 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_v_proj_weight4, layer_norm292, model_decoder_layers_11_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1145 = R.call_tir(cls.reshape14, (lv99,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat75 = R.call_tir(cls.concatenate1, (reshape1143, reshape1144, reshape1145), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1146 = R.call_tir(cls.reshape15, (concat75,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv221 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape1146), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1147 = R.call_tir(cls.reshape16, (lv221,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1148 = R.call_tir(cls.reshape17, (reshape1147,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv100 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_out_proj_weight4, reshape1148, model_decoder_layers_11_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1013 = R.call_tir(cls.add5, (add1009, lv100), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm293 = R.call_tir(cls.layer_norm2, (add1013, model_decoder_layers_11_encoder_attn_layer_norm_weight4, model_decoder_layers_11_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv101 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_q_proj_weight4, layer_norm293, model_decoder_layers_11_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1149 = R.call_tir(cls.reshape14, (lv101,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1150 = R.call_tir(cls.reshape18, (reshape1149,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv222 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape1150), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1151 = R.call_tir(cls.reshape16, (lv222,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1152 = R.call_tir(cls.reshape17, (reshape1151,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv102 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_out_proj_weight4, reshape1152, model_decoder_layers_11_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1016 = R.call_tir(cls.add5, (add1013, lv102), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm294 = R.call_tir(cls.layer_norm2, (add1016, model_decoder_layers_11_final_layer_norm_weight4, model_decoder_layers_11_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv11 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_11_fc1_weight4, layer_norm294, model_decoder_layers_11_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv103 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_11_fc2_weight4, lv11, model_decoder_layers_11_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1019 = R.call_tir(cls.add5, (add1016, lv103), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm295 = R.call_tir(cls.layer_norm2, (add1019, model_decoder_layers_12_self_attn_layer_norm_weight4, model_decoder_layers_12_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv104 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_q_proj_weight4, layer_norm295, model_decoder_layers_12_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1153 = R.call_tir(cls.reshape14, (lv104,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv44_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_12_self_attn_k_proj_weight4, layer_norm295), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1154 = R.call_tir(cls.reshape14, (lv44_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv105 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_v_proj_weight4, layer_norm295, model_decoder_layers_12_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1155 = R.call_tir(cls.reshape14, (lv105,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat76 = R.call_tir(cls.concatenate1, (reshape1153, reshape1154, reshape1155), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1156 = R.call_tir(cls.reshape15, (concat76,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv223 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape1156), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1157 = R.call_tir(cls.reshape16, (lv223,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1158 = R.call_tir(cls.reshape17, (reshape1157,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv106 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_out_proj_weight4, reshape1158, model_decoder_layers_12_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1023 = R.call_tir(cls.add5, (add1019, lv106), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm296 = R.call_tir(cls.layer_norm2, (add1023, model_decoder_layers_12_encoder_attn_layer_norm_weight4, model_decoder_layers_12_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv107 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_q_proj_weight4, layer_norm296, model_decoder_layers_12_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1159 = R.call_tir(cls.reshape14, (lv107,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1160 = R.call_tir(cls.reshape18, (reshape1159,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv224 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape1160), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1161 = R.call_tir(cls.reshape16, (lv224,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1162 = R.call_tir(cls.reshape17, (reshape1161,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv108 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_out_proj_weight4, reshape1162, model_decoder_layers_12_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1026 = R.call_tir(cls.add5, (add1023, lv108), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm297 = R.call_tir(cls.layer_norm2, (add1026, model_decoder_layers_12_final_layer_norm_weight4, model_decoder_layers_12_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv12 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_12_fc1_weight4, layer_norm297, model_decoder_layers_12_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv109 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_12_fc2_weight4, lv12, model_decoder_layers_12_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1029 = R.call_tir(cls.add5, (add1026, lv109), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm298 = R.call_tir(cls.layer_norm2, (add1029, model_decoder_layers_13_self_attn_layer_norm_weight4, model_decoder_layers_13_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv110 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_q_proj_weight4, layer_norm298, model_decoder_layers_13_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1163 = R.call_tir(cls.reshape14, (lv110,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv45_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_13_self_attn_k_proj_weight4, layer_norm298), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1164 = R.call_tir(cls.reshape14, (lv45_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv111 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_v_proj_weight4, layer_norm298, model_decoder_layers_13_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1165 = R.call_tir(cls.reshape14, (lv111,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat77 = R.call_tir(cls.concatenate1, (reshape1163, reshape1164, reshape1165), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1166 = R.call_tir(cls.reshape15, (concat77,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv225 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape1166), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1167 = R.call_tir(cls.reshape16, (lv225,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1168 = R.call_tir(cls.reshape17, (reshape1167,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv112 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_out_proj_weight4, reshape1168, model_decoder_layers_13_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1033 = R.call_tir(cls.add5, (add1029, lv112), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm299 = R.call_tir(cls.layer_norm2, (add1033, model_decoder_layers_13_encoder_attn_layer_norm_weight4, model_decoder_layers_13_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv113 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_q_proj_weight4, layer_norm299, model_decoder_layers_13_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1169 = R.call_tir(cls.reshape14, (lv113,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1170 = R.call_tir(cls.reshape18, (reshape1169,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv226 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape1170), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1171 = R.call_tir(cls.reshape16, (lv226,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1172 = R.call_tir(cls.reshape17, (reshape1171,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv114 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_out_proj_weight4, reshape1172, model_decoder_layers_13_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1036 = R.call_tir(cls.add5, (add1033, lv114), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm300 = R.call_tir(cls.layer_norm2, (add1036, model_decoder_layers_13_final_layer_norm_weight4, model_decoder_layers_13_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv13 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_13_fc1_weight4, layer_norm300, model_decoder_layers_13_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv115 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_13_fc2_weight4, lv13, model_decoder_layers_13_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1039 = R.call_tir(cls.add5, (add1036, lv115), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm301 = R.call_tir(cls.layer_norm2, (add1039, model_decoder_layers_14_self_attn_layer_norm_weight4, model_decoder_layers_14_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv116 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_q_proj_weight4, layer_norm301, model_decoder_layers_14_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1173 = R.call_tir(cls.reshape14, (lv116,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv46_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_14_self_attn_k_proj_weight4, layer_norm301), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1174 = R.call_tir(cls.reshape14, (lv46_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv117 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_v_proj_weight4, layer_norm301, model_decoder_layers_14_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1175 = R.call_tir(cls.reshape14, (lv117,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat78 = R.call_tir(cls.concatenate1, (reshape1173, reshape1174, reshape1175), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1176 = R.call_tir(cls.reshape15, (concat78,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv227 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape1176), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1177 = R.call_tir(cls.reshape16, (lv227,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1178 = R.call_tir(cls.reshape17, (reshape1177,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv118 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_out_proj_weight4, reshape1178, model_decoder_layers_14_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1043 = R.call_tir(cls.add5, (add1039, lv118), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm302 = R.call_tir(cls.layer_norm2, (add1043, model_decoder_layers_14_encoder_attn_layer_norm_weight4, model_decoder_layers_14_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv119 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_q_proj_weight4, layer_norm302, model_decoder_layers_14_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1179 = R.call_tir(cls.reshape14, (lv119,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1180 = R.call_tir(cls.reshape18, (reshape1179,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv228 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape1180), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1181 = R.call_tir(cls.reshape16, (lv228,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1182 = R.call_tir(cls.reshape17, (reshape1181,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv120 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_out_proj_weight4, reshape1182, model_decoder_layers_14_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1046 = R.call_tir(cls.add5, (add1043, lv120), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm303 = R.call_tir(cls.layer_norm2, (add1046, model_decoder_layers_14_final_layer_norm_weight4, model_decoder_layers_14_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv14 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_14_fc1_weight4, layer_norm303, model_decoder_layers_14_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv121 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_14_fc2_weight4, lv14, model_decoder_layers_14_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1049 = R.call_tir(cls.add5, (add1046, lv121), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm304 = R.call_tir(cls.layer_norm2, (add1049, model_decoder_layers_15_self_attn_layer_norm_weight4, model_decoder_layers_15_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv122 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_q_proj_weight4, layer_norm304, model_decoder_layers_15_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1183 = R.call_tir(cls.reshape14, (lv122,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv47_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_15_self_attn_k_proj_weight4, layer_norm304), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1184 = R.call_tir(cls.reshape14, (lv47_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv123 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_v_proj_weight4, layer_norm304, model_decoder_layers_15_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1185 = R.call_tir(cls.reshape14, (lv123,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat79 = R.call_tir(cls.concatenate1, (reshape1183, reshape1184, reshape1185), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1186 = R.call_tir(cls.reshape15, (concat79,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv229 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape1186), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1187 = R.call_tir(cls.reshape16, (lv229,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1188 = R.call_tir(cls.reshape17, (reshape1187,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv124 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_out_proj_weight4, reshape1188, model_decoder_layers_15_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1053 = R.call_tir(cls.add5, (add1049, lv124), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm305 = R.call_tir(cls.layer_norm2, (add1053, model_decoder_layers_15_encoder_attn_layer_norm_weight4, model_decoder_layers_15_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv125 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_q_proj_weight4, layer_norm305, model_decoder_layers_15_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1189 = R.call_tir(cls.reshape14, (lv125,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1190 = R.call_tir(cls.reshape18, (reshape1189,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv230 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape1190), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1191 = R.call_tir(cls.reshape16, (lv230,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1192 = R.call_tir(cls.reshape17, (reshape1191,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv126 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_out_proj_weight4, reshape1192, model_decoder_layers_15_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1056 = R.call_tir(cls.add5, (add1053, lv126), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm306 = R.call_tir(cls.layer_norm2, (add1056, model_decoder_layers_15_final_layer_norm_weight4, model_decoder_layers_15_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv15 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_15_fc1_weight4, layer_norm306, model_decoder_layers_15_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv127 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_15_fc2_weight4, lv15, model_decoder_layers_15_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1059 = R.call_tir(cls.add5, (add1056, lv127), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm307 = R.call_tir(cls.layer_norm2, (add1059, model_decoder_layers_16_self_attn_layer_norm_weight4, model_decoder_layers_16_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv128 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_q_proj_weight4, layer_norm307, model_decoder_layers_16_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1193 = R.call_tir(cls.reshape14, (lv128,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv48_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_16_self_attn_k_proj_weight4, layer_norm307), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1194 = R.call_tir(cls.reshape14, (lv48_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv129 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_v_proj_weight4, layer_norm307, model_decoder_layers_16_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1195 = R.call_tir(cls.reshape14, (lv129,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat80 = R.call_tir(cls.concatenate1, (reshape1193, reshape1194, reshape1195), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1196 = R.call_tir(cls.reshape15, (concat80,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv231 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape1196), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1197 = R.call_tir(cls.reshape16, (lv231,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1198 = R.call_tir(cls.reshape17, (reshape1197,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv130 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_out_proj_weight4, reshape1198, model_decoder_layers_16_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1063 = R.call_tir(cls.add5, (add1059, lv130), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm308 = R.call_tir(cls.layer_norm2, (add1063, model_decoder_layers_16_encoder_attn_layer_norm_weight4, model_decoder_layers_16_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv131 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_q_proj_weight4, layer_norm308, model_decoder_layers_16_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1199 = R.call_tir(cls.reshape14, (lv131,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1200 = R.call_tir(cls.reshape18, (reshape1199,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv232 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape1200), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1201 = R.call_tir(cls.reshape16, (lv232,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1202 = R.call_tir(cls.reshape17, (reshape1201,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv132 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_out_proj_weight4, reshape1202, model_decoder_layers_16_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1066 = R.call_tir(cls.add5, (add1063, lv132), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm309 = R.call_tir(cls.layer_norm2, (add1066, model_decoder_layers_16_final_layer_norm_weight4, model_decoder_layers_16_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv16 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_16_fc1_weight4, layer_norm309, model_decoder_layers_16_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv133 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_16_fc2_weight4, lv16, model_decoder_layers_16_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1069 = R.call_tir(cls.add5, (add1066, lv133), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm310 = R.call_tir(cls.layer_norm2, (add1069, model_decoder_layers_17_self_attn_layer_norm_weight4, model_decoder_layers_17_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv134 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_q_proj_weight4, layer_norm310, model_decoder_layers_17_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1203 = R.call_tir(cls.reshape14, (lv134,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv49_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_17_self_attn_k_proj_weight4, layer_norm310), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1204 = R.call_tir(cls.reshape14, (lv49_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv135 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_v_proj_weight4, layer_norm310, model_decoder_layers_17_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1205 = R.call_tir(cls.reshape14, (lv135,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat81 = R.call_tir(cls.concatenate1, (reshape1203, reshape1204, reshape1205), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1206 = R.call_tir(cls.reshape15, (concat81,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv233 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape1206), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1207 = R.call_tir(cls.reshape16, (lv233,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1208 = R.call_tir(cls.reshape17, (reshape1207,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv136 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_out_proj_weight4, reshape1208, model_decoder_layers_17_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1073 = R.call_tir(cls.add5, (add1069, lv136), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm311 = R.call_tir(cls.layer_norm2, (add1073, model_decoder_layers_17_encoder_attn_layer_norm_weight4, model_decoder_layers_17_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv137 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_q_proj_weight4, layer_norm311, model_decoder_layers_17_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1209 = R.call_tir(cls.reshape14, (lv137,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1210 = R.call_tir(cls.reshape18, (reshape1209,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv234 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape1210), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1211 = R.call_tir(cls.reshape16, (lv234,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1212 = R.call_tir(cls.reshape17, (reshape1211,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv138 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_out_proj_weight4, reshape1212, model_decoder_layers_17_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1076 = R.call_tir(cls.add5, (add1073, lv138), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm312 = R.call_tir(cls.layer_norm2, (add1076, model_decoder_layers_17_final_layer_norm_weight4, model_decoder_layers_17_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv17 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_17_fc1_weight4, layer_norm312, model_decoder_layers_17_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv139 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_17_fc2_weight4, lv17, model_decoder_layers_17_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1079 = R.call_tir(cls.add5, (add1076, lv139), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm313 = R.call_tir(cls.layer_norm2, (add1079, model_decoder_layers_18_self_attn_layer_norm_weight4, model_decoder_layers_18_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv140 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_q_proj_weight4, layer_norm313, model_decoder_layers_18_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1213 = R.call_tir(cls.reshape14, (lv140,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv50_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_18_self_attn_k_proj_weight4, layer_norm313), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1214 = R.call_tir(cls.reshape14, (lv50_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv141 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_v_proj_weight4, layer_norm313, model_decoder_layers_18_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1215 = R.call_tir(cls.reshape14, (lv141,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat82 = R.call_tir(cls.concatenate1, (reshape1213, reshape1214, reshape1215), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1216 = R.call_tir(cls.reshape15, (concat82,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv235 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape1216), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1217 = R.call_tir(cls.reshape16, (lv235,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1218 = R.call_tir(cls.reshape17, (reshape1217,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv142 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_out_proj_weight4, reshape1218, model_decoder_layers_18_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1083 = R.call_tir(cls.add5, (add1079, lv142), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm314 = R.call_tir(cls.layer_norm2, (add1083, model_decoder_layers_18_encoder_attn_layer_norm_weight4, model_decoder_layers_18_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv143 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_q_proj_weight4, layer_norm314, model_decoder_layers_18_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1219 = R.call_tir(cls.reshape14, (lv143,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1220 = R.call_tir(cls.reshape18, (reshape1219,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv236 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape1220), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1221 = R.call_tir(cls.reshape16, (lv236,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1222 = R.call_tir(cls.reshape17, (reshape1221,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv144 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_out_proj_weight4, reshape1222, model_decoder_layers_18_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1086 = R.call_tir(cls.add5, (add1083, lv144), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm315 = R.call_tir(cls.layer_norm2, (add1086, model_decoder_layers_18_final_layer_norm_weight4, model_decoder_layers_18_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv18 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_18_fc1_weight4, layer_norm315, model_decoder_layers_18_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv145 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_18_fc2_weight4, lv18, model_decoder_layers_18_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1089 = R.call_tir(cls.add5, (add1086, lv145), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm316 = R.call_tir(cls.layer_norm2, (add1089, model_decoder_layers_19_self_attn_layer_norm_weight4, model_decoder_layers_19_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv146 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_q_proj_weight4, layer_norm316, model_decoder_layers_19_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1223 = R.call_tir(cls.reshape14, (lv146,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv51_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_19_self_attn_k_proj_weight4, layer_norm316), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1224 = R.call_tir(cls.reshape14, (lv51_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv147 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_v_proj_weight4, layer_norm316, model_decoder_layers_19_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1225 = R.call_tir(cls.reshape14, (lv147,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat83 = R.call_tir(cls.concatenate1, (reshape1223, reshape1224, reshape1225), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1226 = R.call_tir(cls.reshape15, (concat83,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv237 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape1226), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1227 = R.call_tir(cls.reshape16, (lv237,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1228 = R.call_tir(cls.reshape17, (reshape1227,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv148 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_out_proj_weight4, reshape1228, model_decoder_layers_19_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1093 = R.call_tir(cls.add5, (add1089, lv148), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm317 = R.call_tir(cls.layer_norm2, (add1093, model_decoder_layers_19_encoder_attn_layer_norm_weight4, model_decoder_layers_19_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv149 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_q_proj_weight4, layer_norm317, model_decoder_layers_19_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1229 = R.call_tir(cls.reshape14, (lv149,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1230 = R.call_tir(cls.reshape18, (reshape1229,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv238 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape1230), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1231 = R.call_tir(cls.reshape16, (lv238,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1232 = R.call_tir(cls.reshape17, (reshape1231,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv150 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_out_proj_weight4, reshape1232, model_decoder_layers_19_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1096 = R.call_tir(cls.add5, (add1093, lv150), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm318 = R.call_tir(cls.layer_norm2, (add1096, model_decoder_layers_19_final_layer_norm_weight4, model_decoder_layers_19_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv19 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_19_fc1_weight4, layer_norm318, model_decoder_layers_19_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv151 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_19_fc2_weight4, lv19, model_decoder_layers_19_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1099 = R.call_tir(cls.add5, (add1096, lv151), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm319 = R.call_tir(cls.layer_norm2, (add1099, model_decoder_layers_20_self_attn_layer_norm_weight4, model_decoder_layers_20_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv152 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_q_proj_weight4, layer_norm319, model_decoder_layers_20_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1233 = R.call_tir(cls.reshape14, (lv152,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv52_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_20_self_attn_k_proj_weight4, layer_norm319), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1234 = R.call_tir(cls.reshape14, (lv52_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv153 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_v_proj_weight4, layer_norm319, model_decoder_layers_20_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1235 = R.call_tir(cls.reshape14, (lv153,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat84 = R.call_tir(cls.concatenate1, (reshape1233, reshape1234, reshape1235), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1236 = R.call_tir(cls.reshape15, (concat84,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv239 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape1236), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1237 = R.call_tir(cls.reshape16, (lv239,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1238 = R.call_tir(cls.reshape17, (reshape1237,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv154 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_out_proj_weight4, reshape1238, model_decoder_layers_20_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1103 = R.call_tir(cls.add5, (add1099, lv154), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm320 = R.call_tir(cls.layer_norm2, (add1103, model_decoder_layers_20_encoder_attn_layer_norm_weight4, model_decoder_layers_20_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv155 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_q_proj_weight4, layer_norm320, model_decoder_layers_20_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1239 = R.call_tir(cls.reshape14, (lv155,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1240 = R.call_tir(cls.reshape18, (reshape1239,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv240 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape1240), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1241 = R.call_tir(cls.reshape16, (lv240,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1242 = R.call_tir(cls.reshape17, (reshape1241,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv156 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_out_proj_weight4, reshape1242, model_decoder_layers_20_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1106 = R.call_tir(cls.add5, (add1103, lv156), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm321 = R.call_tir(cls.layer_norm2, (add1106, model_decoder_layers_20_final_layer_norm_weight4, model_decoder_layers_20_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv20 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_20_fc1_weight4, layer_norm321, model_decoder_layers_20_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv157 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_20_fc2_weight4, lv20, model_decoder_layers_20_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1109 = R.call_tir(cls.add5, (add1106, lv157), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm322 = R.call_tir(cls.layer_norm2, (add1109, model_decoder_layers_21_self_attn_layer_norm_weight4, model_decoder_layers_21_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv158 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_q_proj_weight4, layer_norm322, model_decoder_layers_21_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1243 = R.call_tir(cls.reshape14, (lv158,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv53_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_21_self_attn_k_proj_weight4, layer_norm322), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1244 = R.call_tir(cls.reshape14, (lv53_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv159 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_v_proj_weight4, layer_norm322, model_decoder_layers_21_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1245 = R.call_tir(cls.reshape14, (lv159,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat85 = R.call_tir(cls.concatenate1, (reshape1243, reshape1244, reshape1245), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1246 = R.call_tir(cls.reshape15, (concat85,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv241 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape1246), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1247 = R.call_tir(cls.reshape16, (lv241,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1248 = R.call_tir(cls.reshape17, (reshape1247,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv160 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_out_proj_weight4, reshape1248, model_decoder_layers_21_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1113 = R.call_tir(cls.add5, (add1109, lv160), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm323 = R.call_tir(cls.layer_norm2, (add1113, model_decoder_layers_21_encoder_attn_layer_norm_weight4, model_decoder_layers_21_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv161 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_q_proj_weight4, layer_norm323, model_decoder_layers_21_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1249 = R.call_tir(cls.reshape14, (lv161,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1250 = R.call_tir(cls.reshape18, (reshape1249,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv242 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape1250), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1251 = R.call_tir(cls.reshape16, (lv242,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1252 = R.call_tir(cls.reshape17, (reshape1251,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv162 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_out_proj_weight4, reshape1252, model_decoder_layers_21_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1116 = R.call_tir(cls.add5, (add1113, lv162), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm324 = R.call_tir(cls.layer_norm2, (add1116, model_decoder_layers_21_final_layer_norm_weight4, model_decoder_layers_21_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv21 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_21_fc1_weight4, layer_norm324, model_decoder_layers_21_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv163 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_21_fc2_weight4, lv21, model_decoder_layers_21_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1119 = R.call_tir(cls.add5, (add1116, lv163), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm325 = R.call_tir(cls.layer_norm2, (add1119, model_decoder_layers_22_self_attn_layer_norm_weight4, model_decoder_layers_22_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv164 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_q_proj_weight4, layer_norm325, model_decoder_layers_22_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1253 = R.call_tir(cls.reshape14, (lv164,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv54_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_22_self_attn_k_proj_weight4, layer_norm325), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1254 = R.call_tir(cls.reshape14, (lv54_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv165 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_v_proj_weight4, layer_norm325, model_decoder_layers_22_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1255 = R.call_tir(cls.reshape14, (lv165,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat86 = R.call_tir(cls.concatenate1, (reshape1253, reshape1254, reshape1255), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1256 = R.call_tir(cls.reshape15, (concat86,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv243 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape1256), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1257 = R.call_tir(cls.reshape16, (lv243,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1258 = R.call_tir(cls.reshape17, (reshape1257,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv166 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_out_proj_weight4, reshape1258, model_decoder_layers_22_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1123 = R.call_tir(cls.add5, (add1119, lv166), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm326 = R.call_tir(cls.layer_norm2, (add1123, model_decoder_layers_22_encoder_attn_layer_norm_weight4, model_decoder_layers_22_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv167 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_q_proj_weight4, layer_norm326, model_decoder_layers_22_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1259 = R.call_tir(cls.reshape14, (lv167,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1260 = R.call_tir(cls.reshape18, (reshape1259,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv244 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape1260), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1261 = R.call_tir(cls.reshape16, (lv244,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1262 = R.call_tir(cls.reshape17, (reshape1261,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv168 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_out_proj_weight4, reshape1262, model_decoder_layers_22_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1126 = R.call_tir(cls.add5, (add1123, lv168), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm327 = R.call_tir(cls.layer_norm2, (add1126, model_decoder_layers_22_final_layer_norm_weight4, model_decoder_layers_22_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv22 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_22_fc1_weight4, layer_norm327, model_decoder_layers_22_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv169 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_22_fc2_weight4, lv22, model_decoder_layers_22_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1129 = R.call_tir(cls.add5, (add1126, lv169), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm328 = R.call_tir(cls.layer_norm2, (add1129, model_decoder_layers_23_self_attn_layer_norm_weight4, model_decoder_layers_23_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv170 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_q_proj_weight4, layer_norm328, model_decoder_layers_23_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1263 = R.call_tir(cls.reshape14, (lv170,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv55_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_23_self_attn_k_proj_weight4, layer_norm328), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1264 = R.call_tir(cls.reshape14, (lv55_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv171 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_v_proj_weight4, layer_norm328, model_decoder_layers_23_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1265 = R.call_tir(cls.reshape14, (lv171,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat87 = R.call_tir(cls.concatenate1, (reshape1263, reshape1264, reshape1265), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1266 = R.call_tir(cls.reshape15, (concat87,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv245 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape1266), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1267 = R.call_tir(cls.reshape16, (lv245,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1268 = R.call_tir(cls.reshape17, (reshape1267,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv172 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_out_proj_weight4, reshape1268, model_decoder_layers_23_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1133 = R.call_tir(cls.add5, (add1129, lv172), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm329 = R.call_tir(cls.layer_norm2, (add1133, model_decoder_layers_23_encoder_attn_layer_norm_weight4, model_decoder_layers_23_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv173 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_q_proj_weight4, layer_norm329, model_decoder_layers_23_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1269 = R.call_tir(cls.reshape14, (lv173,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1270 = R.call_tir(cls.reshape18, (reshape1269,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv246 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape1270), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1271 = R.call_tir(cls.reshape16, (lv246,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1272 = R.call_tir(cls.reshape17, (reshape1271,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv174 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_out_proj_weight4, reshape1272, model_decoder_layers_23_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1136 = R.call_tir(cls.add5, (add1133, lv174), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm330 = R.call_tir(cls.layer_norm2, (add1136, model_decoder_layers_23_final_layer_norm_weight4, model_decoder_layers_23_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv23 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_23_fc1_weight4, layer_norm330, model_decoder_layers_23_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv175 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_23_fc2_weight4, lv23, model_decoder_layers_23_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1139 = R.call_tir(cls.add5, (add1136, lv175), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm331 = R.call_tir(cls.layer_norm2, (add1139, model_decoder_layers_24_self_attn_layer_norm_weight4, model_decoder_layers_24_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv176 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_q_proj_weight4, layer_norm331, model_decoder_layers_24_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1273 = R.call_tir(cls.reshape14, (lv176,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv56_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_24_self_attn_k_proj_weight4, layer_norm331), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1274 = R.call_tir(cls.reshape14, (lv56_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv177 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_v_proj_weight4, layer_norm331, model_decoder_layers_24_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1275 = R.call_tir(cls.reshape14, (lv177,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat88 = R.call_tir(cls.concatenate1, (reshape1273, reshape1274, reshape1275), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1276 = R.call_tir(cls.reshape15, (concat88,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv247 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape1276), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1277 = R.call_tir(cls.reshape16, (lv247,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1278 = R.call_tir(cls.reshape17, (reshape1277,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv178 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_out_proj_weight4, reshape1278, model_decoder_layers_24_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1143 = R.call_tir(cls.add5, (add1139, lv178), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm332 = R.call_tir(cls.layer_norm2, (add1143, model_decoder_layers_24_encoder_attn_layer_norm_weight4, model_decoder_layers_24_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv179 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_q_proj_weight4, layer_norm332, model_decoder_layers_24_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1279 = R.call_tir(cls.reshape14, (lv179,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1280 = R.call_tir(cls.reshape18, (reshape1279,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv248 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape1280), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1281 = R.call_tir(cls.reshape16, (lv248,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1282 = R.call_tir(cls.reshape17, (reshape1281,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv180 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_out_proj_weight4, reshape1282, model_decoder_layers_24_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1146 = R.call_tir(cls.add5, (add1143, lv180), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm333 = R.call_tir(cls.layer_norm2, (add1146, model_decoder_layers_24_final_layer_norm_weight4, model_decoder_layers_24_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv24 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_24_fc1_weight4, layer_norm333, model_decoder_layers_24_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv181 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_24_fc2_weight4, lv24, model_decoder_layers_24_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1149 = R.call_tir(cls.add5, (add1146, lv181), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm334 = R.call_tir(cls.layer_norm2, (add1149, model_decoder_layers_25_self_attn_layer_norm_weight4, model_decoder_layers_25_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv182 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_q_proj_weight4, layer_norm334, model_decoder_layers_25_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1283 = R.call_tir(cls.reshape14, (lv182,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv57_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_25_self_attn_k_proj_weight4, layer_norm334), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1284 = R.call_tir(cls.reshape14, (lv57_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv183 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_v_proj_weight4, layer_norm334, model_decoder_layers_25_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1285 = R.call_tir(cls.reshape14, (lv183,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat89 = R.call_tir(cls.concatenate1, (reshape1283, reshape1284, reshape1285), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1286 = R.call_tir(cls.reshape15, (concat89,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv249 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape1286), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1287 = R.call_tir(cls.reshape16, (lv249,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1288 = R.call_tir(cls.reshape17, (reshape1287,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv184 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_out_proj_weight4, reshape1288, model_decoder_layers_25_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1153 = R.call_tir(cls.add5, (add1149, lv184), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm335 = R.call_tir(cls.layer_norm2, (add1153, model_decoder_layers_25_encoder_attn_layer_norm_weight4, model_decoder_layers_25_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv185 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_q_proj_weight4, layer_norm335, model_decoder_layers_25_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1289 = R.call_tir(cls.reshape14, (lv185,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1290 = R.call_tir(cls.reshape18, (reshape1289,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv250 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape1290), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1291 = R.call_tir(cls.reshape16, (lv250,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1292 = R.call_tir(cls.reshape17, (reshape1291,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv186 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_out_proj_weight4, reshape1292, model_decoder_layers_25_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1156 = R.call_tir(cls.add5, (add1153, lv186), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm336 = R.call_tir(cls.layer_norm2, (add1156, model_decoder_layers_25_final_layer_norm_weight4, model_decoder_layers_25_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv25 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_25_fc1_weight4, layer_norm336, model_decoder_layers_25_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv187 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_25_fc2_weight4, lv25, model_decoder_layers_25_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1159 = R.call_tir(cls.add5, (add1156, lv187), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm337 = R.call_tir(cls.layer_norm2, (add1159, model_decoder_layers_26_self_attn_layer_norm_weight4, model_decoder_layers_26_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv188 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_q_proj_weight4, layer_norm337, model_decoder_layers_26_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1293 = R.call_tir(cls.reshape14, (lv188,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv58_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_26_self_attn_k_proj_weight4, layer_norm337), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1294 = R.call_tir(cls.reshape14, (lv58_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv189 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_v_proj_weight4, layer_norm337, model_decoder_layers_26_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1295 = R.call_tir(cls.reshape14, (lv189,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat90 = R.call_tir(cls.concatenate1, (reshape1293, reshape1294, reshape1295), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1296 = R.call_tir(cls.reshape15, (concat90,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv251 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape1296), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1297 = R.call_tir(cls.reshape16, (lv251,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1298 = R.call_tir(cls.reshape17, (reshape1297,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv190 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_out_proj_weight4, reshape1298, model_decoder_layers_26_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1163 = R.call_tir(cls.add5, (add1159, lv190), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm338 = R.call_tir(cls.layer_norm2, (add1163, model_decoder_layers_26_encoder_attn_layer_norm_weight4, model_decoder_layers_26_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv191 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_q_proj_weight4, layer_norm338, model_decoder_layers_26_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1299 = R.call_tir(cls.reshape14, (lv191,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1300 = R.call_tir(cls.reshape18, (reshape1299,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv252 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape1300), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1301 = R.call_tir(cls.reshape16, (lv252,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1302 = R.call_tir(cls.reshape17, (reshape1301,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv192 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_out_proj_weight4, reshape1302, model_decoder_layers_26_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1166 = R.call_tir(cls.add5, (add1163, lv192), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm339 = R.call_tir(cls.layer_norm2, (add1166, model_decoder_layers_26_final_layer_norm_weight4, model_decoder_layers_26_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv26 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_26_fc1_weight4, layer_norm339, model_decoder_layers_26_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv193 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_26_fc2_weight4, lv26, model_decoder_layers_26_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1169 = R.call_tir(cls.add5, (add1166, lv193), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm340 = R.call_tir(cls.layer_norm2, (add1169, model_decoder_layers_27_self_attn_layer_norm_weight4, model_decoder_layers_27_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv194 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_q_proj_weight4, layer_norm340, model_decoder_layers_27_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1303 = R.call_tir(cls.reshape14, (lv194,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv59_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_27_self_attn_k_proj_weight4, layer_norm340), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1304 = R.call_tir(cls.reshape14, (lv59_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv195 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_v_proj_weight4, layer_norm340, model_decoder_layers_27_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1305 = R.call_tir(cls.reshape14, (lv195,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat91 = R.call_tir(cls.concatenate1, (reshape1303, reshape1304, reshape1305), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1306 = R.call_tir(cls.reshape15, (concat91,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv253 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape1306), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1307 = R.call_tir(cls.reshape16, (lv253,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1308 = R.call_tir(cls.reshape17, (reshape1307,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv196 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_out_proj_weight4, reshape1308, model_decoder_layers_27_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1173 = R.call_tir(cls.add5, (add1169, lv196), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm341 = R.call_tir(cls.layer_norm2, (add1173, model_decoder_layers_27_encoder_attn_layer_norm_weight4, model_decoder_layers_27_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv197 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_q_proj_weight4, layer_norm341, model_decoder_layers_27_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1309 = R.call_tir(cls.reshape14, (lv197,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1310 = R.call_tir(cls.reshape18, (reshape1309,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv254 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape1310), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1311 = R.call_tir(cls.reshape16, (lv254,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1312 = R.call_tir(cls.reshape17, (reshape1311,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv198_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_out_proj_weight4, reshape1312, model_decoder_layers_27_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1176 = R.call_tir(cls.add5, (add1173, lv198_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm342 = R.call_tir(cls.layer_norm2, (add1176, model_decoder_layers_27_final_layer_norm_weight4, model_decoder_layers_27_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv27 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_27_fc1_weight4, layer_norm342, model_decoder_layers_27_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv199_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_27_fc2_weight4, lv27, model_decoder_layers_27_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1179 = R.call_tir(cls.add5, (add1176, lv199_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm343 = R.call_tir(cls.layer_norm2, (add1179, model_decoder_layers_28_self_attn_layer_norm_weight4, model_decoder_layers_28_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv200_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_q_proj_weight4, layer_norm343, model_decoder_layers_28_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1313 = R.call_tir(cls.reshape14, (lv200_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv60_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_28_self_attn_k_proj_weight4, layer_norm343), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1314 = R.call_tir(cls.reshape14, (lv60_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv201_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_v_proj_weight4, layer_norm343, model_decoder_layers_28_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1315 = R.call_tir(cls.reshape14, (lv201_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat92 = R.call_tir(cls.concatenate1, (reshape1313, reshape1314, reshape1315), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1316 = R.call_tir(cls.reshape15, (concat92,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv255 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape1316), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1317 = R.call_tir(cls.reshape16, (lv255,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1318 = R.call_tir(cls.reshape17, (reshape1317,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv202_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_out_proj_weight4, reshape1318, model_decoder_layers_28_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1183 = R.call_tir(cls.add5, (add1179, lv202_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm344 = R.call_tir(cls.layer_norm2, (add1183, model_decoder_layers_28_encoder_attn_layer_norm_weight4, model_decoder_layers_28_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv203_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_q_proj_weight4, layer_norm344, model_decoder_layers_28_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1319 = R.call_tir(cls.reshape14, (lv203_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1320 = R.call_tir(cls.reshape18, (reshape1319,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv256 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape1320), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1321 = R.call_tir(cls.reshape16, (lv256,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1322 = R.call_tir(cls.reshape17, (reshape1321,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv204_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_out_proj_weight4, reshape1322, model_decoder_layers_28_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1186 = R.call_tir(cls.add5, (add1183, lv204_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm345 = R.call_tir(cls.layer_norm2, (add1186, model_decoder_layers_28_final_layer_norm_weight4, model_decoder_layers_28_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv28 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_28_fc1_weight4, layer_norm345, model_decoder_layers_28_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv205_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_28_fc2_weight4, lv28, model_decoder_layers_28_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1189 = R.call_tir(cls.add5, (add1186, lv205_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm346 = R.call_tir(cls.layer_norm2, (add1189, model_decoder_layers_29_self_attn_layer_norm_weight4, model_decoder_layers_29_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv206_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_q_proj_weight4, layer_norm346, model_decoder_layers_29_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1323 = R.call_tir(cls.reshape14, (lv206_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv61_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_29_self_attn_k_proj_weight4, layer_norm346), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1324 = R.call_tir(cls.reshape14, (lv61_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv207_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_v_proj_weight4, layer_norm346, model_decoder_layers_29_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1325 = R.call_tir(cls.reshape14, (lv207_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat93 = R.call_tir(cls.concatenate1, (reshape1323, reshape1324, reshape1325), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1326 = R.call_tir(cls.reshape15, (concat93,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv257 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1326), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1327 = R.call_tir(cls.reshape16, (lv257,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1328 = R.call_tir(cls.reshape17, (reshape1327,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv208_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_out_proj_weight4, reshape1328, model_decoder_layers_29_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1193 = R.call_tir(cls.add5, (add1189, lv208_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm347 = R.call_tir(cls.layer_norm2, (add1193, model_decoder_layers_29_encoder_attn_layer_norm_weight4, model_decoder_layers_29_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv209_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_q_proj_weight4, layer_norm347, model_decoder_layers_29_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1329 = R.call_tir(cls.reshape14, (lv209_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1330 = R.call_tir(cls.reshape18, (reshape1329,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv258 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1330), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1331 = R.call_tir(cls.reshape16, (lv258,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1332 = R.call_tir(cls.reshape17, (reshape1331,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv210_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_out_proj_weight4, reshape1332, model_decoder_layers_29_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1196 = R.call_tir(cls.add5, (add1193, lv210_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm348 = R.call_tir(cls.layer_norm2, (add1196, model_decoder_layers_29_final_layer_norm_weight4, model_decoder_layers_29_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv29 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_29_fc1_weight4, layer_norm348, model_decoder_layers_29_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv211_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_29_fc2_weight4, lv29, model_decoder_layers_29_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1199 = R.call_tir(cls.add5, (add1196, lv211_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm349 = R.call_tir(cls.layer_norm2, (add1199, model_decoder_layers_30_self_attn_layer_norm_weight4, model_decoder_layers_30_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv212_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_q_proj_weight4, layer_norm349, model_decoder_layers_30_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1333 = R.call_tir(cls.reshape14, (lv212_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv62_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_30_self_attn_k_proj_weight4, layer_norm349), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1334 = R.call_tir(cls.reshape14, (lv62_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv213_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_v_proj_weight4, layer_norm349, model_decoder_layers_30_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1335 = R.call_tir(cls.reshape14, (lv213_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat94 = R.call_tir(cls.concatenate1, (reshape1333, reshape1334, reshape1335), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1336 = R.call_tir(cls.reshape15, (concat94,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv259 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1336), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1337 = R.call_tir(cls.reshape16, (lv259,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1338 = R.call_tir(cls.reshape17, (reshape1337,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv214_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_out_proj_weight4, reshape1338, model_decoder_layers_30_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1203 = R.call_tir(cls.add5, (add1199, lv214_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm350 = R.call_tir(cls.layer_norm2, (add1203, model_decoder_layers_30_encoder_attn_layer_norm_weight4, model_decoder_layers_30_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv215_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_q_proj_weight4, layer_norm350, model_decoder_layers_30_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1339 = R.call_tir(cls.reshape14, (lv215_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1340 = R.call_tir(cls.reshape18, (reshape1339,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv260 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1340), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1341 = R.call_tir(cls.reshape16, (lv260,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1342 = R.call_tir(cls.reshape17, (reshape1341,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv216_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_out_proj_weight4, reshape1342, model_decoder_layers_30_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1206 = R.call_tir(cls.add5, (add1203, lv216_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm351 = R.call_tir(cls.layer_norm2, (add1206, model_decoder_layers_30_final_layer_norm_weight4, model_decoder_layers_30_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv30 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_30_fc1_weight4, layer_norm351, model_decoder_layers_30_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv217_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_30_fc2_weight4, lv30, model_decoder_layers_30_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1209 = R.call_tir(cls.add5, (add1206, lv217_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm352 = R.call_tir(cls.layer_norm2, (add1209, model_decoder_layers_31_self_attn_layer_norm_weight4, model_decoder_layers_31_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv218_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_q_proj_weight4, layer_norm352, model_decoder_layers_31_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1343 = R.call_tir(cls.reshape14, (lv218_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv63_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_31_self_attn_k_proj_weight4, layer_norm352), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1344 = R.call_tir(cls.reshape14, (lv63_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + lv219_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_v_proj_weight4, layer_norm352, model_decoder_layers_31_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1345 = R.call_tir(cls.reshape14, (lv219_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + concat95 = R.call_tir(cls.concatenate1, (reshape1343, reshape1344, reshape1345), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) + reshape1346 = R.call_tir(cls.reshape15, (concat95,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) + lv261 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1346), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1347 = R.call_tir(cls.reshape16, (lv261,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1348 = R.call_tir(cls.reshape17, (reshape1347,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv220_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_out_proj_weight4, reshape1348, model_decoder_layers_31_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1213 = R.call_tir(cls.add5, (add1209, lv220_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm353 = R.call_tir(cls.layer_norm2, (add1213, model_decoder_layers_31_encoder_attn_layer_norm_weight4, model_decoder_layers_31_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv221_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_q_proj_weight4, layer_norm353, model_decoder_layers_31_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + reshape1349 = R.call_tir(cls.reshape14, (lv221_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1350 = R.call_tir(cls.reshape18, (reshape1349,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + lv262 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1350), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) + reshape1351 = R.call_tir(cls.reshape16, (lv262,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) + reshape1352 = R.call_tir(cls.reshape17, (reshape1351,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv222_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_out_proj_weight4, reshape1352, model_decoder_layers_31_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1216 = R.call_tir(cls.add5, (add1213, lv222_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm354 = R.call_tir(cls.layer_norm2, (add1216, model_decoder_layers_31_final_layer_norm_weight4, model_decoder_layers_31_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv31 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_31_fc1_weight4, layer_norm354, model_decoder_layers_31_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) + lv223_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_31_fc2_weight4, lv31, model_decoder_layers_31_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + add1219 = R.call_tir(cls.add5, (add1216, lv223_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + layer_norm355 = R.call_tir(cls.layer_norm2, (add1219, model_decoder_layer_norm_weight4, model_decoder_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) + lv263 = R.call_tir(cls.index, (layer_norm355,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) + gv4 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul2_cublas", (model_decoder_embed_tokens_weight4, lv263), out_sinfo=R.Tensor((1, 1, 51866), dtype="float32")) + R.output(gv4) + return gv4 + + @R.function + def renormalize_by_top_p(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), top_p: R.Tensor(("batch_size",), dtype="float32"), init_pivots: R.Tensor(("batch_size", 3), dtype="float32")) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"): + batch_size = T.int64() + vocab_size = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) + cls = Module + with R.dataflow(): + lv6 = R.call_tir(cls.top_p_pivot_cutoff, (probs, top_p, init_pivots), out_sinfo=[R.Tensor((batch_size,), dtype="float32"), R.Tensor((batch_size,), dtype="float32")]) + lv7: R.Tensor((batch_size,), dtype="float32") = lv6[0] + lv8: R.Tensor((batch_size,), dtype="float32") = lv6[1] + gv5 = R.call_tir(cls.top_p_renorm_after_cutoff, (probs, lv7, lv8), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) + R.output(gv5) + return gv5 + + @R.function + def sample_with_top_p(sorted_probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), sorted_indices: R.Tensor(("batch_size", "vocab_size"), dtype="int32"), uniform_samples: R.Tensor(("num_samples",), dtype="float32"), sample_indices: R.Tensor(("num_samples",), dtype="int32"), top_p: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("num_samples",), dtype="int32"): + num_samples = T.int64() + batch_size = T.int64() + vocab_size = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) + cls = Module + with R.dataflow(): + uniform_samples1: R.Tensor((num_samples, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", uniform_samples, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="float32"),)) + sample_indices1: R.Tensor((num_samples, 1), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", sample_indices, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="int32"),)) + sample_indices2: R.Tensor((batch_size, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", top_p, R.shape([batch_size, 1]), sinfo_args=(R.Tensor((batch_size, 1), dtype="float32"),)) + lv3 = R.call_tir(cls.full, R.tuple(), out_sinfo=R.Tensor((batch_size, 1), dtype="int32"), tir_vars=R.shape([vocab_size])) + lv1: R.Tensor((8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12,), dtype="uint8") = R.builtin.alloc_tensor(R.shape([8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12]), R.dtype("uint8"), R.prim_value(0), R.str("global")) + cumsum = R.call_tir(cls.cumsum, (sorted_probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) + lv4 = R.call_tir(cls.get_renorm_prob, (cumsum, sample_indices2, lv3), out_sinfo=R.Tensor((batch_size, 1), dtype="float32")) + lv5 = R.call_tir(cls.get_index_from_sorted, (cumsum, sorted_indices, lv4, uniform_samples1, sample_indices1), out_sinfo=R.Tensor((num_samples, 1), dtype="int32")) + gv2: R.Tensor((num_samples,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv5, R.shape([num_samples]), sinfo_args=(R.Tensor((num_samples,), dtype="int32"),)) + R.output(gv2) + return gv2 + + @R.function + def sampler_take_probs(unsorted_probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), sorted_indices: R.Tensor(("batch_size", "vocab_size"), dtype="int32"), sample_indices: R.Tensor(("num_samples",), dtype="int32"), sampling_result: R.Tensor(("num_samples",), dtype="int32"), lobprob_offsets: R.Tensor(("num_positions",), dtype="int32")) -> R.Tuple(R.Tensor(("num_samples",), dtype="float32"), R.Tensor(("num_positions",), dtype="float32"), R.Tensor(("num_positions",), dtype="int32")): + num_samples = T.int64() + num_positions = T.int64() + batch_size = T.int64() + vocab_size = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) + cls = Module + with R.dataflow(): + gv3 = R.call_tir(cls.sampler_take_probs_tir, (unsorted_probs, sorted_indices, sample_indices, sampling_result, lobprob_offsets), out_sinfo=[R.Tensor((num_samples,), dtype="float32"), R.Tensor((num_positions,), dtype="float32"), R.Tensor((num_positions,), dtype="int32")]) + R.output(gv3) + return gv3 + + @R.function + def sampler_verify_draft_tokens(draft_probs: R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), draft_tokens: R.Tensor(("num_nodes",), dtype="int32"), model_probs: R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), token_tree_first_child: R.Tensor(("num_nodes",), dtype="int32"), token_tree_next_sibling: R.Tensor(("num_nodes",), dtype="int32"), uniform_samples: R.Tensor(("num_nodes",), dtype="float32"), token_tree_parent_ptr: R.Tensor(("nbatch",), dtype="int32")) -> R.Tuple(R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), R.Tensor(("nbatch",), dtype="int32")): + num_nodes = T.int64() + vocab_size = T.int64() + nbatch = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) + cls = Module + with R.dataflow(): + gv4: R.Tuple(R.Tensor((num_nodes, vocab_size), dtype="float32"), R.Tensor((nbatch,), dtype="int32")) = R.call_tir_inplace(cls.batch_verify_on_gpu_single_kernel, (draft_probs, draft_tokens, model_probs, token_tree_first_child, token_tree_next_sibling, uniform_samples, token_tree_parent_ptr), out_sinfo=[R.Tensor((num_nodes, vocab_size), dtype="float32"), R.Tensor((nbatch,), dtype="int32")], inplace_indices=[2, 6]) + R.output(gv4) + return gv4 + + @R.function + def softmax_with_temperature(logits: R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), temperature: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"): + batch_size = T.int64() + vocab_size = T.int64() + R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) + cls = Module + with R.dataflow(): + lv: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", logits, R.shape([batch_size, vocab_size]), sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),)) + lv1 = R.call_tir(cls.chunk_lse, (lv, temperature), out_sinfo=[R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32"), R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32")]) + lv2: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[0] + lv3: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[1] + lv4 = R.call_tir(cls.softmax_with_chunked_sum, (lv, temperature, lv2, lv3), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) + gv: R.Tensor((batch_size, 1, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", lv4, R.shape([batch_size, 1, vocab_size]), sinfo_args=(R.Tensor((batch_size, 1, vocab_size), dtype="float32"),)) + R.output(gv) + return gv + +# Metadata omitted. Use show_meta=True in script() method to show it. \ No newline at end of file