# from tvm.script import ir as I # from tvm.script import tir as T # from tvm.script import relax as R @I.ir_module class Module: I.module_attrs({"external_mods": [metadata["runtime.Module"][0], metadata["runtime.Module"][1], metadata["runtime.Module"][2], metadata["runtime.Module"][3], metadata["runtime.Module"][4], metadata["runtime.Module"][5], metadata["runtime.Module"][6], metadata["runtime.Module"][7], metadata["runtime.Module"][8], metadata["runtime.Module"][9], metadata["runtime.Module"][10], metadata["runtime.Module"][11], metadata["runtime.Module"][12], metadata["runtime.Module"][13], metadata["runtime.Module"][14]]}) @T.prim_func(private=True) def NT_matmul(layer_norm356: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_q_proj_weight5: T.Buffer((T.int64(1280), T.int64(1280)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") NT_matmul_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") model_decoder_layers_0_self_attn_q_proj_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(1280)), "float16", scope="local") layer_norm356_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_0 in T.serial(T.int64(3), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_3 in T.vectorized(T.int64(1)): with T.block("layer_norm356_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(1280)) T.reads(layer_norm356[v0, v1, v2]) T.writes(layer_norm356_shared[v0, v1, v2]) layer_norm356_shared[v0, v1, v2] = layer_norm356[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(T.int64(4)): for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): with T.block("model_decoder_layers_0_self_attn_q_proj_weight5_local"): v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) T.reads(model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1]) T.writes(model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1]) model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1] = model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1] for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) T.reads(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(T.int64(1)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads() T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] for ax1_fused_2 in range(T.int64(1)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): with T.block("NT_matmul"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul[T.int64(0), T.int64(0), v0]) with T.init(): NT_matmul[T.int64(0), T.int64(0), v0] = T.float16(0) NT_matmul[T.int64(0), T.int64(0), v0] = NT_matmul[T.int64(0), T.int64(0), v0] + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] @T.prim_func(private=True) def NT_matmul3(layer_norm452: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_embed_tokens_weight5: T.Buffer((T.int64(51866), T.int64(1280)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(51866)), "float32")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(51866)), scope="local") NT_matmul_rf_local_1 = T.alloc_buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(51866)), scope="local") model_decoder_embed_tokens_weight5_local = T.alloc_buffer((T.int64(51866), T.int64(1280)), "float16", scope="local") layer_norm452_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(12967), thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_0 in T.serial(T.int64(5), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(64), thread="threadIdx.x"): for ax2_3 in T.vectorized(T.int64(1)): with T.block("layer_norm452_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(256) + ax2_1 * T.int64(64) + ax2_2 + ax2_3) T.reads(layer_norm452[v0, v1, v2]) T.writes(layer_norm452_shared[v0, v1, v2]) layer_norm452_shared[v0, v1, v2] = layer_norm452[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init < T.int64(51866)) T.reads() T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float32(0) for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(T.int64(2)): for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): with T.block("model_decoder_embed_tokens_weight5_local"): v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 < T.int64(51866)) T.reads(model_decoder_embed_tokens_weight5[v0, v1]) T.writes(model_decoder_embed_tokens_weight5_local[v0, v1]) model_decoder_embed_tokens_weight5_local[v0, v1] = model_decoder_embed_tokens_weight5[v0, v1] for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_fused_u_fused_2, vax1_fused_u_fused_0 = T.axis.remap("RR", [ax1_fused_u_fused_2, ax1_fused_u_fused_0]) T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2 < T.int64(51866)) T.reads(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm452_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused], model_decoder_embed_tokens_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + T.Cast("float32", layer_norm452_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) * T.Cast("float32", model_decoder_embed_tokens_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(T.int64(1)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + (T.Mul(T.int64(0), T.int64(4)) + ax2_fused_0_ax2_fused_1_fused % T.int64(4) + (ax2_fused_2_0 + ax2_fused_2_1)) < T.int64(51866)) T.reads() T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float32(0) for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + (T.Mul(T.int64(0), T.int64(4)) + ax2_fused_0_ax2_fused_1_fused % T.int64(4) + (ax2_fused_2_0 + ax2_fused_2_1)) < T.int64(51866)) T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] for ax1_fused_2 in range(T.int64(1)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): with T.block("NT_matmul"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) v0 = T.axis.spatial(T.int64(51866), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.where(u_fused_ax0_fused_fused_0 * T.int64(4) + (T.Mul(T.int64(0), T.int64(4)) + ax1_fused_0_ax1_fused_1_fused % T.int64(4) + ax1_fused_2) < T.int64(51866)) T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul[T.int64(0), T.int64(0), v0]) with T.init(): NT_matmul[T.int64(0), T.int64(0), v0] = T.float32(0) NT_matmul[T.int64(0), T.int64(0), v0] = NT_matmul[T.int64(0), T.int64(0), v0] + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] @T.prim_func(private=True) def add(var_reshape708: T.handle, var_reshape709: T.handle, var_T_add: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape708 = T.match_buffer(var_reshape708, (batch_size, T.int64(1), T.int64(1280)), "float16") reshape709 = T.match_buffer(var_reshape709, (batch_size, T.int64(1), T.int64(1280)), "float16") T_add = T.match_buffer(var_T_add, (batch_size, T.int64(1), T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_add"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) T.reads(reshape708[v0, T.int64(0), v1], reshape709[v0, T.int64(0), v1]) T.writes(T_add[v0, T.int64(0), v1]) T_add[v0, T.int64(0), v1] = reshape708[v0, T.int64(0), v1] + reshape709[v0, T.int64(0), v1] @T.prim_func(private=True) def add4(var_add: T.handle, var_lv610: T.handle, var_T_add: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() add = T.match_buffer(var_add, (batch_size, T.int64(1500), T.int64(1280)), "float16") lv610 = T.match_buffer(var_lv610, (batch_size, T.int64(1500), T.int64(1280)), "float16") T_add = T.match_buffer(var_T_add, (batch_size, T.int64(1500), T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_add"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) T.reads(add[v0, v1, v2], lv610[v0, v1, v2]) T.writes(T_add[v0, v1, v2]) T_add[v0, v1, v2] = add[v0, v1, v2] + lv610[v0, v1, v2] @T.prim_func(private=True) def add5(var_reshape385: T.handle, var_reshape386: T.handle, var_T_add: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() reshape385 = T.match_buffer(var_reshape385, (T.int64(1), seq_len, T.int64(1280)), "float16") reshape386 = T.match_buffer(var_reshape386, (T.int64(1), seq_len, T.int64(1280)), "float16") T_add = T.match_buffer(var_T_add, (T.int64(1), seq_len, T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_add"): v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < seq_len * T.int64(1280)) T.reads(reshape385[T.int64(0), v0, v1], reshape386[T.int64(0), v0, v1]) T.writes(T_add[T.int64(0), v0, v1]) T_add[T.int64(0), v0, v1] = reshape385[T.int64(0), v0, v1] + reshape386[T.int64(0), v0, v1] @T.prim_func def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_seq = T.int32(is_size_var=True) seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") # with T.block("root"): for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"): for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("block"): vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size) vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size) T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size) T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv]) T.writes(logits[seq_ids[vs], vv]) logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-3.4028234663852886e+38)) @T.prim_func def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_token = T.int32(is_size_var=True) pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") logit_bias = T.match_buffer(var_logit_bias, (num_token,)) # with T.block("root"): for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"): for p1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("block"): vp = T.axis.spatial(num_token, p0 * 1024 + p1) T.where(p0 * 1024 + p1 < num_token) T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp]) T.writes(logits[pos2seq_id[vp], token_ids[vp]]) logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp] @T.prim_func def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_seq = T.int32(is_size_var=True) seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") num_token = T.int32(is_size_var=True) pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") penalties = T.match_buffer(var_penalties, (num_seq, 3)) # with T.block("root"): for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"): for p1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("block"): vp = T.axis.spatial(num_token, p0 * 1024 + p1) T.where(p0 * 1024 + p1 < num_token) T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp]) T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]) logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1]) logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > T.float32(0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2]) @T.prim_func(private=True) def argsort_thrust(var_probs: T.handle, var_lv: T.handle, var_topk_gpu_v1: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(), T.int64() data_buf = T.match_buffer(var_probs, (batch_size, vocab_size), align=8) workspace_buf = T.match_buffer(var_lv, (T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12),), "uint8", align=8) indices_buf = T.match_buffer(var_topk_gpu_v1, (batch_size, vocab_size), "int32", align=8) # with T.block("root"): value_buf = T.alloc_buffer((batch_size, vocab_size), align=8) with T.block("topk_gpu"): T.reads() T.writes() T.call_packed("tvm.contrib.thrust.sort", T.tvm_stack_make_array(data_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.tvm_stack_make_array(value_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.tvm_stack_make_array(indices_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, 0, T.int64(0)), 0, T.tvm_stack_make_array(workspace_buf.data, T.tvm_stack_make_shape(T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12)), 0, 1, T.uint8(0), T.int64(0))) @T.prim_func def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) B = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, 20, 64), "float16") max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(pages_handle, (max_num_pages, 2, 20, 16, 64), "float16") page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) output = T.match_buffer(output_handle, (B, 20, 64), "float16") lse = T.match_buffer(lse_handle, (B, 20)) # with T.block("root"): sm_scale: T.float32 = T.float32(0.18033688011112042) for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(20, thread="blockIdx.y"): for ty in T.thread_binding(1, thread="threadIdx.y"): for tx in T.thread_binding(16, thread="threadIdx.x"): for tz in T.thread_binding(32, thread="threadIdx.z"): with T.block("attn"): T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, fused_by_bz // 20 + ty + fused_by_bz % 20, tx * 4 - 32:tx * 4 - 32 + 68]) T.writes(output[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty]) Q_local = T.alloc_buffer((4,), "float16", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((64, 64), "float16", scope="shared") V_smem = T.alloc_buffer((64, 64), "float16", scope="shared") O_allreduce = T.alloc_buffer((32, 1, 64), scope="shared") md_allreduce = T.alloc_buffer((32, 1, 2), scope="shared") S_reduce_local = T.alloc_buffer((1,), scope="local") t0 = T.alloc_buffer((1,), scope="local") S_local = T.alloc_buffer((2,), scope="local") QK_local = T.alloc_buffer((4,), scope="local") V_local = T.alloc_buffer((4,), "float16", scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_prev = T.alloc_buffer((1,), scope="local") other_m = T.alloc_buffer((1,), scope="local") other_d = T.alloc_buffer((1,), scope="local") exp_mprev = T.alloc_buffer((1,), scope="local") exp_otherm = T.alloc_buffer((1,), scope="local") other_o = T.alloc_buffer((4,), scope="local") st_m = T.alloc_buffer((1,), scope="local") st_d = T.alloc_buffer((1,), scope="local") O_local = T.alloc_buffer((4,), scope="local") by: T.int32 = fused_by_bz % 20 bz: T.int32 = fused_by_bz // 20 batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0) st_m[0] = T.float32(-50000) st_d[0] = T.float32(1) for vec in T.vectorized(4): O_local[vec] = T.float32(0) for vec in T.vectorized(4): Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, Q[bx, by + bz + ty, tx * 4 + vec + 32] * T.float16(-1), Q[bx, by + bz + ty, tx * 4 + vec - 32]))), Q[bx, by + bz + ty, tx * 4 + vec]) for iterator in range((kv_chunk_len[0] + 63) // 64): tile_start_s: T.int32 = (tz + ty) * 2 tile_start_g: T.int32 = (iterator * 32 + tz + ty) * 2 for j in range(2): with T.block("KV_load"): T.reads() T.writes() row_g: T.int32 = tile_start_g + j if row_g < kv_chunk_len[0]: seq_offset: T.int32 = row_g page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, pages[page_no, 0, by, page_offset, tx * 4 + vec + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, tx * 4 + vec - 32]))), pages[page_no, 0, by, page_offset, tx * 4 + vec]) V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] else: for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) T.tvm_storage_sync("shared") m_prev[0] = st_m[0] for j in range(2): for vec in T.vectorized(4): QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale S_reduce_local[0] = T.float32(0) for vec in T.unroll(4): S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) T.writes(t0[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) S_local[j] = T.float32(-50000) if (iterator * 32 + tz) * 2 + j < kv_chunk_len[0]: S_local[j] = t0[0] st_m[0] = T.max(st_m[0], S_local[j]) o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) st_d[0] = st_d[0] * o_scale for j in range(2): S_local[j] = T.exp2(S_local[j] - st_m[0]) st_d[0] = st_d[0] + S_local[j] for j in T.vectorized(4): O_local[j] = O_local[j] * o_scale for j in range(2): for vec in T.vectorized(4): V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec] for vec in T.vectorized(4): O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] for vec in T.vectorized(4): O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] md_allreduce[tz, ty, 0] = st_m[0] md_allreduce[tz, ty, 1] = st_d[0] T.tvm_storage_sync("shared") st_m[0] = T.float32(-50000) st_d[0] = T.float32(1) for vec in T.vectorized(4): O_local[vec] = T.float32(0) for j in range(32): m_prev[0] = st_m[0] d_prev[0] = st_d[0] other_m[0] = md_allreduce[j, ty, 0] other_d[0] = md_allreduce[j, ty, 1] for vec in T.vectorized(4): other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) for vec in T.vectorized(4): O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] for vec in T.vectorized(4): O_local[vec] = O_local[vec] / st_d[0] for vec in T.vectorized(4): output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) @T.prim_func def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) B = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, 20, 64), "float16") max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(pages_handle, (max_num_pages, 2, 20, 16, 64), "float16") page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) output = T.match_buffer(output_handle, (B, 20, 64), "float16") lse = T.match_buffer(lse_handle, (B, 20)) # with T.block("root"): sm_scale: T.float32 = T.float32(0.18033688011112042) for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(20, thread="blockIdx.y"): for ty in T.thread_binding(1, thread="threadIdx.y"): for tx in T.thread_binding(16, thread="threadIdx.x"): for tz in T.thread_binding(32, thread="threadIdx.z"): with T.block("attn"): T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, fused_by_bz // 20 + ty + fused_by_bz % 20, tx * 4 - 32:tx * 4 - 32 + 68]) T.writes(output[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 20 + fused_by_bz // 20 + ty]) Q_local = T.alloc_buffer((4,), "float16", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((64, 64), "float16", scope="shared") V_smem = T.alloc_buffer((64, 64), "float16", scope="shared") O_allreduce = T.alloc_buffer((32, 1, 64), scope="shared") md_allreduce = T.alloc_buffer((32, 1, 2), scope="shared") S_reduce_local = T.alloc_buffer((1,), scope="local") t0 = T.alloc_buffer((1,), scope="local") S_local = T.alloc_buffer((2,), scope="local") QK_local = T.alloc_buffer((4,), scope="local") V_local = T.alloc_buffer((4,), "float16", scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_prev = T.alloc_buffer((1,), scope="local") other_m = T.alloc_buffer((1,), scope="local") other_d = T.alloc_buffer((1,), scope="local") exp_mprev = T.alloc_buffer((1,), scope="local") exp_otherm = T.alloc_buffer((1,), scope="local") other_o = T.alloc_buffer((4,), scope="local") st_m = T.alloc_buffer((1,), scope="local") st_d = T.alloc_buffer((1,), scope="local") O_local = T.alloc_buffer((4,), scope="local") by: T.int32 = fused_by_bz % 20 bz: T.int32 = fused_by_bz // 20 batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0) st_m[0] = T.float32(-50000) st_d[0] = T.float32(1) for vec in T.vectorized(4): O_local[vec] = T.float32(0) for vec in T.vectorized(4): Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, Q[bx, by + bz + ty, tx * 4 + vec + 32] * T.float16(-1), Q[bx, by + bz + ty, tx * 4 + vec - 32]))), Q[bx, by + bz + ty, tx * 4 + vec]) for iterator in range((kv_chunk_len[0] + 63) // 64): tile_start_s: T.int32 = (tz + ty) * 2 tile_start_g: T.int32 = (iterator * 32 + tz + ty) * 2 for j in range(2): with T.block("KV_load"): T.reads() T.writes() row_g: T.int32 = tile_start_g + j if row_g < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx]) page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 32, pages[page_no, 0, by, page_offset, tx * 4 + vec + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, tx * 4 + vec - 32]))), pages[page_no, 0, by, page_offset, tx * 4 + vec]) V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] else: for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0) T.tvm_storage_sync("shared") m_prev[0] = st_m[0] for j in range(2): for vec in T.vectorized(4): QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale S_reduce_local[0] = T.float32(0) for vec in T.unroll(4): S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) T.writes(t0[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) S_local[j] = T.float32(-50000) if (iterator * 32 + tz) * 2 + j < kv_chunk_len[0]: S_local[j] = t0[0] st_m[0] = T.max(st_m[0], S_local[j]) o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) st_d[0] = st_d[0] * o_scale for j in range(2): S_local[j] = T.exp2(S_local[j] - st_m[0]) st_d[0] = st_d[0] + S_local[j] for j in T.vectorized(4): O_local[j] = O_local[j] * o_scale for j in range(2): for vec in T.vectorized(4): V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec] for vec in T.vectorized(4): O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] for vec in T.vectorized(4): O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] md_allreduce[tz, ty, 0] = st_m[0] md_allreduce[tz, ty, 1] = st_d[0] T.tvm_storage_sync("shared") st_m[0] = T.float32(-50000) st_d[0] = T.float32(1) for vec in T.vectorized(4): O_local[vec] = T.float32(0) for j in range(32): m_prev[0] = st_m[0] d_prev[0] = st_d[0] other_m[0] = md_allreduce[j, ty, 0] other_d[0] = md_allreduce[j, ty, 1] for vec in T.vectorized(4): other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) for vec in T.vectorized(4): O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] for vec in T.vectorized(4): O_local[vec] = O_local[vec] / st_d[0] for vec in T.vectorized(4): output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) @T.prim_func def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 20, 64), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 20, 16, 64), "float16") page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 20, 64), "float16") lse = T.match_buffer(var_lse, (total_len, 20)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(20, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") S_smem = T.alloc_buffer((32, 16), scope="shared") S_local = T.alloc_buffer((32, 16), scope="local") O_local = T.alloc_buffer((32, 64), scope="local") m_smem = T.alloc_buffer((32,), scope="shared") m_prev_smem = T.alloc_buffer((32,), scope="shared") d_smem = T.alloc_buffer((32,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 32 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: m_smem[row] = T.float32(-50000) d_smem[row] = T.float32(1) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(4): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]))), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 for lz_ly_fused_0 in range(2): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, pages[page_no, 0, by, page_offset, j + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, j - 32]))), pages[page_no, 0, by, page_offset, j]) else: K_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(2): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) T.writes(S_local[0:32, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k = T.axis.reduce(64, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.18033688011112042) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 32: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) T.writes(O_local[0:32, 0:64]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 4): with T.block("O_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): with T.block("O_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) k = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 20, 64), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 20, 16, 64), "float16") page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 20, 64), "float16") lse = T.match_buffer(var_lse, (total_len, 20)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(20, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") S_smem = T.alloc_buffer((32, 16), scope="shared") S_local = T.alloc_buffer((32, 16), scope="local") O_local = T.alloc_buffer((32, 64), scope="local") m_smem = T.alloc_buffer((32,), scope="shared") m_prev_smem = T.alloc_buffer((32,), scope="shared") d_smem = T.alloc_buffer((32,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 32 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: m_smem[row] = T.float32(-50000) d_smem[row] = T.float32(1) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(4): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]))), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 for lz_ly_fused_0 in range(2): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, pages[page_no, 0, by, page_offset, j + 32] * T.float16(-1), pages[page_no, 0, by, page_offset, j - 32]))), pages[page_no, 0, by, page_offset, j]) else: K_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(2): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) T.writes(S_local[0:32, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k = T.axis.reduce(64, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.18033688011112042) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 32: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) T.writes(O_local[0:32, 0:64]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 4): with T.block("O_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): with T.block("O_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) k = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) qo_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, 20, 64), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) kv_len = T.int32(is_size_var=True) k = T.match_buffer(var_k, (kv_len, 20, 64), "float16") v = T.match_buffer(var_v, (kv_len, 20, 64), "float16") kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) output = T.match_buffer(var_output, (qo_len, 20, 64), "float16") lse = T.match_buffer(var_lse, (qo_len, 20)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(20, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") S_smem = T.alloc_buffer((32, 16), scope="shared") S_local = T.alloc_buffer((32, 16), scope="local") O_local = T.alloc_buffer((32, 64), scope="local") m_smem = T.alloc_buffer((32,), scope="shared") m_prev_smem = T.alloc_buffer((32,), scope="shared") d_smem = T.alloc_buffer((32,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] q_indptr_val: T.int32 = q_indptr[b_idx] LH_start: T.int32 = tile_id[0] * 32 kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: m_smem[row] = T.float32(-50000) d_smem[row] = T.float32(1) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(4): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]))), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 L_kv_base: T.int32 = kv_indptr[b_idx] for lz_ly_fused_0 in range(2): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(j < 32, k[L_kv_base + cur_L, by, j + 32] * T.float16(-1), k[L_kv_base + cur_L, by, j - 32]))), k[L_kv_base + cur_L, by, j]) else: K_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(2): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: V_smem[i, j] = v[L_kv_base + cur_L, by, j] else: V_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) T.writes(S_local[0:32, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k_1 = T.axis.reduce(64, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.18033688011112042) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 32: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) T.writes(O_local[0:32, 0:64]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 4): with T.block("O_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): with T.block("O_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) k_1 = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) qo_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, 20, 64), "float16") q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) kv_len = T.int32(is_size_var=True) k = T.match_buffer(var_k, (kv_len, 20, 64), "float16") v = T.match_buffer(var_v, (kv_len, 20, 64), "float16") kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1) tree_size = T.int32(is_size_var=True) mask = T.match_buffer(var_mask, (tree_size,), "int32", offset_factor=1) output = T.match_buffer(var_output, (qo_len, 20, 64), "float16") lse = T.match_buffer(var_lse, (qo_len, 20)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(20, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((32, 64), "float16", scope="shared") K_smem = T.alloc_buffer((16, 64), "float16", scope="shared") V_smem = T.alloc_buffer((16, 64), "float16", scope="shared") S_smem = T.alloc_buffer((32, 16), scope="shared") S_local = T.alloc_buffer((32, 16), scope="local") O_local = T.alloc_buffer((32, 64), scope="local") m_smem = T.alloc_buffer((32,), scope="shared") m_prev_smem = T.alloc_buffer((32,), scope="shared") d_smem = T.alloc_buffer((32,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 32 q_indptr_val: T.int32 = q_indptr[b_idx] kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: m_smem[row] = T.float32(-50000) d_smem[row] = T.float32(1) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(4): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 64) j = T.axis.spatial(64, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * q[cur_L, cur_H_qo, j] + T.Cast("float16", T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * T.if_then_else(j < 32, q[cur_L, cur_H_qo, j + 32] * T.float16(-1), q[cur_L, cur_H_qo, j - 32]), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 L_kv_base: T.int32 = kv_indptr[b_idx] for lz_ly_fused_0 in range(2): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("KV_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 64) j = T.axis.spatial(64, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 64) T.reads() T.writes() cur_L: T.int32 = L_kv_base + L_kv_start + i if L_kv_start + i < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Cast("float16", T.cos(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * k[cur_L, by, j] + T.Cast("float16", T.sin(T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 64) / T.float32(64)))) * T.if_then_else(j < 32, k[cur_L, by, j + 32] * T.float16(-1), k[cur_L, by, j - 32]), k[cur_L, by, j]) V_smem[i, j] = v[cur_L, by, j] else: K_smem[i, j] = T.float16(0) V_smem[i, j] = T.float16(0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:32, 0:64], K_smem[0:16, 0:64]) T.writes(S_local[0:32, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(8, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k_1 = T.axis.reduce(64, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.18033688011112042) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], mask[mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start:mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start + 16], mn_indptr[b_idx], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if L_kv_start + j < kv_chunk_len[0] and mask[mn_indptr[b_idx] + row_ * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + (L_kv_start + j)] == 1: m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], mask[mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start:mn_indptr[b_idx] + (LH_start + row) * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + L_kv_start + 16], mn_indptr[b_idx], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 32: row_: T.int32 = LH_start + row if L_kv_start + j < kv_chunk_len[0] and mask[mn_indptr[b_idx] + row_ * (q_indptr[b_idx + 1] - q_indptr[b_idx]) + (L_kv_start + j)] == 1: S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 32: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:16], V_smem[0:16, 0:64]) T.writes(O_local[0:32, 0:64]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 4): with T.block("O_gemm_init"): i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init) j = T.axis.spatial(64, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 4 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 4): with T.block("O_gemm_update"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) k_1 = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 4): with T.block("O_store"): i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1) j = T.axis.spatial(64, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 4 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 32) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func(private=True) def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) num_nodes, vocab_size = T.int32(is_size_var=True), T.int64() draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size)) draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32") model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size)) token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32") token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32") uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,)) nbatch = T.int32(is_size_var=True) token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32") # with T.block("root"): child_ptr = T.alloc_buffer((1,), "int32", scope="local") parent_ptr = T.alloc_buffer((1,), "int32", scope="local") child_token = T.alloc_buffer((1,), "int32", scope="local") done = T.alloc_buffer((1,), "bool", scope="local") psum = T.alloc_buffer((1,), scope="local") t0 = T.alloc_buffer((1,), scope="local") model_prob_local = T.alloc_buffer((1,), scope="local") draft_prob_local = T.alloc_buffer((1,), scope="local") p_child = T.alloc_buffer((1,), scope="local") q_child = T.alloc_buffer((1,), scope="local") uniform_sample = T.alloc_buffer((1,), scope="local") pred_shared = T.alloc_buffer((1,), "bool", scope="shared") pred_local = T.alloc_buffer((1,), "bool", scope="local") for _bx in T.thread_binding(nbatch, thread="blockIdx.x"): for _tx in T.thread_binding(1024, thread="threadIdx.x"): with T.block("CTA"): b, tx = T.axis.remap("SS", [_bx, _tx]) T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]]) T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b]) parent_ptr[0] = token_tree_parent_ptr[b] child_ptr[0] = token_tree_first_child[parent_ptr[0]] done[0] = T.bool(False) while not done[0]: T.tvm_storage_sync("shared") if child_ptr[0] == -1: done[0] = T.bool(True) T.tvm_storage_sync("shared") else: if tx == 0: child_token[0] = draft_tokens[child_ptr[0]] p_child[0] = model_probs[parent_ptr[0], child_token[0]] q_child[0] = draft_probs[child_ptr[0], child_token[0]] uniform_sample[0] = uniform_samples[child_ptr[0]] pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0] T.tvm_storage_sync("shared") pred_local[0] = pred_shared[0] if pred_local[0]: parent_ptr[0] = child_ptr[0] child_ptr[0] = token_tree_first_child[child_ptr[0]] else: psum[0] = T.float32(0) for i in range((vocab_size + T.int64(1023)) // T.int64(1024)): if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size: model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0)) psum[0] = psum[0] + model_prob_local[0] with T.block("block_cross_thread"): T.reads(psum[0]) T.writes(t0[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx) if t0[0] < T.float32(9.9999999999999995e-08): parent_ptr[0] = child_ptr[0] child_ptr[0] = token_tree_first_child[child_ptr[0]] else: for i in range((vocab_size + T.int64(1023)) // T.int64(1024)): if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size: model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0)) model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0] child_ptr[0] = token_tree_next_sibling[child_ptr[0]] if tx == 0: token_tree_parent_ptr[b] = parent_ptr[0] @T.prim_func def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) A = T.match_buffer(var_A, (batch_size, vocab_size)) temperature = T.match_buffer(var_temperature, (batch_size,)) num_chunks = T.int64(is_size_var=True) chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) # with T.block("root"): temp_max_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared") temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared") for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("max"): v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0) v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) v2 = T.axis.reduce(T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1) T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2]) T.writes(temp_max_shared[v0, v1]) with T.init(): temp_max_shared[v0, v1] = T.float32(-3.4028234663852886e+38) temp_max_shared[v0, v1] = T.max(temp_max_shared[v0, v1], T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-3.4028234663852886e+38))) for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("sum_exp"): v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0) v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) v2 = T.axis.reduce(T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1) T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_max_shared[v0, v1]) T.writes(temp_sum_shared[v0, v1]) with T.init(): temp_sum_shared[v0, v1] = T.float32(0) temp_sum_shared[v0, v1] = temp_sum_shared[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-3.4028234663852886e+38)) - temp_max_shared[v0, v1]), T.Cast("float32", T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-3.4028234663852886e+38)) == temp_max_shared[v0, v1])), T.float32(0)) for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("log"): v0 = T.axis.spatial(batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks) v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(256) + ax2_1) T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1)) T.reads(temperature[v0], temp_sum_shared[v0, v1], temp_max_shared[v0, v1]) T.writes(chunked_sum[v0, v1], chunked_max[v0, v1]) chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum_shared[v0, v1]), temp_sum_shared[v0, v1]) chunked_max[v0, v1] = temp_max_shared[v0, v1] @T.prim_func def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) num_pages = T.int32() pages = T.match_buffer(var_pages, (num_pages, 2, 20, 16, 64), "float16") copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1) total_copy_length = T.int32() copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1) with T.block("root"): T.reads() T.writes() for bhd_o in T.thread_binding((batch_size * 1280 + 1023) // 1024, thread="blockIdx.x"): for bhd_i in T.thread_binding(1024, thread="threadIdx.x"): b: T.int32 = (bhd_o * 1024 + bhd_i) // 1280 h: T.int32 = (bhd_o * 1024 + bhd_i) // 64 % 20 d: T.int32 = (bhd_o * 1024 + bhd_i) % 64 if bhd_o * 1024 + bhd_i < batch_size * 20 * 64: for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]): src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d] pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d] @T.prim_func(private=True) def concatenate(var_reshape710: T.handle, var_reshape711: T.handle, var_reshape712: T.handle, var_T_concat: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape710 = T.match_buffer(var_reshape710, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") reshape711 = T.match_buffer(var_reshape711, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") reshape712 = T.match_buffer(var_reshape712, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") T_concat = T.match_buffer(var_T_concat, (batch_size, T.int64(1), T.int64(60), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_concat"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(3840)) T.reads(reshape712[v0, T.int64(0), v1 + T.int64(-40), v2], reshape711[v0, T.int64(0), v1 + T.int64(-20), v2], reshape710[v0, T.int64(0), v1, v2]) T.writes(T_concat[v0, T.int64(0), v1, v2]) T_concat[v0, T.int64(0), v1, v2] = T.if_then_else(T.int64(40) <= v1, reshape712[v0, T.int64(0), v1 - T.int64(40), v2], T.if_then_else(T.int64(20) <= v1, reshape711[v0, T.int64(0), v1 + T.int64(-20), v2], reshape710[v0, T.int64(0), v1, v2])) @T.prim_func(private=True) def concatenate1(var_reshape387: T.handle, var_reshape388: T.handle, var_reshape389: T.handle, var_T_concat: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() reshape387 = T.match_buffer(var_reshape387, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") reshape388 = T.match_buffer(var_reshape388, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") reshape389 = T.match_buffer(var_reshape389, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") T_concat = T.match_buffer(var_T_concat, (T.int64(1), seq_len, T.int64(60), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_concat"): v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(3840)) T.reads(reshape389[T.int64(0), v0, v1 + T.int64(-40), v2], reshape388[T.int64(0), v0, v1 + T.int64(-20), v2], reshape387[T.int64(0), v0, v1, v2]) T.writes(T_concat[T.int64(0), v0, v1, v2]) T_concat[T.int64(0), v0, v1, v2] = T.if_then_else(T.int64(40) <= v1, reshape389[T.int64(0), v0, v1 - T.int64(40), v2], T.if_then_else(T.int64(20) <= v1, reshape388[T.int64(0), v0, v1 + T.int64(-20), v2], reshape387[T.int64(0), v0, v1, v2])) @T.prim_func def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) num_pages, page_size = T.int32(), T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, 20, page_size, 64), "float16") # with T.block("root"): for b in T.thread_binding((copy_length * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for t in T.thread_binding(1024, thread="threadIdx.x"): with T.block("copy"): vh = T.axis.spatial(20, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(64)))) vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(64)) // T.int64(64)) vd = T.axis.spatial(64, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(64))) T.reads(pages[src_page_id, 0:2, vh, vp, vd]) T.writes(pages[tgt_page_id, 0:2, vh, vp, vd]) pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] @T.prim_func(private=True) def cumsum(var_sorted_probs: T.handle, var_lv1: T.handle, var_exclusive_scan_thrust: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(), T.int64() data_buf = T.match_buffer(var_sorted_probs, (batch_size, vocab_size), align=8) workspace_buf = T.match_buffer(var_lv1, (T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12),), "uint8", align=8) output_buf = T.match_buffer(var_exclusive_scan_thrust, (batch_size, vocab_size), align=8) with T.block("exclusive_scan_thrust"): T.reads() T.writes() T.call_packed("tvm.contrib.thrust.sum_scan", T.tvm_stack_make_array(data_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.tvm_stack_make_array(output_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0), T.int64(0)), T.bool(False), T.tvm_stack_make_array(workspace_buf.data, T.tvm_stack_make_shape(T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12)), 0, 1, T.uint8(0), T.int64(0))) @T.prim_func def full(var_result: T.handle, value: T.int32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) batch_size = T.int32(is_size_var=True) result = T.match_buffer(var_result, (batch_size, 1), "int32") # with T.block("root"): for ax0_fused_0 in T.thread_binding((batch_size + 1023) // 1024, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("block"): v0 = T.axis.spatial(batch_size, ax0_fused_0 * 1024 + ax0_fused_1) T.where(ax0_fused_0 * 1024 + ax0_fused_1 < batch_size) T.reads() T.writes(result[v0, 0]) result[v0, 0] = value @T.prim_func(private=True) def fused_NT_matmul1_add8_gelu2(layer_norm358: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_fc1_weight5: T.Buffer((T.int64(5120), T.int64(1280)), "float16"), model_decoder_layers_0_fc1_bias5: T.Buffer((T.int64(5120),), "float16"), T_multiply_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="local") NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="local") NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="local") model_decoder_layers_0_fc1_weight5_local = T.alloc_buffer((T.int64(5120), T.int64(1280)), "float16", scope="local") layer_norm358_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(1280), thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_0 in T.serial(T.int64(5), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(64), thread="threadIdx.x"): for ax2_3 in T.vectorized(T.int64(1)): with T.block("layer_norm358_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(256) + ax2_1 * T.int64(64) + ax2_2 + ax2_3) T.reads(layer_norm358[v0, v1, v2]) T.writes(layer_norm358_shared[v0, v1, v2]) layer_norm358_shared[v0, v1, v2] = layer_norm358[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(T.int64(2)): for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): with T.block("model_decoder_layers_0_fc1_weight5_local"): v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) T.reads(model_decoder_layers_0_fc1_weight5[v0, v1]) T.writes(model_decoder_layers_0_fc1_weight5_local[v0, v1]) model_decoder_layers_0_fc1_weight5_local[v0, v1] = model_decoder_layers_0_fc1_weight5[v0, v1] for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_fused_u_fused_2, vax1_fused_u_fused_0 = T.axis.remap("RR", [ax1_fused_u_fused_2, ax1_fused_u_fused_0]) T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm358_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused], model_decoder_layers_0_fc1_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused]) T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + layer_norm358_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused] * model_decoder_layers_0_fc1_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused] for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(T.int64(1)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads() T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] for ax1_fused_2 in range(T.int64(1)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): with T.block("NT_matmul"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) with T.init(): NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"): for ax0_fused_2 in range(T.int64(1)): with T.block("T_multiply_2"): v0 = T.axis.spatial(T.int64(5120), u_fused_ax0_fused_fused_0 * T.int64(4) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) T.reads(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_fc1_bias5[v0]) T.writes(T_multiply_intermediate[T.int64(0), T.int64(0), v0]) T_multiply_intermediate[T.int64(0), T.int64(0), v0] = (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_fc1_bias5[v0]) * (T.float16(0.5) + T.Cast("float16", T.erf(T.Cast("float32", (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_fc1_bias5[v0]) * T.float16(0.70710678118654757)))) * T.float16(0.5)) @T.prim_func(private=True) def fused_NT_matmul2_add7_add6(gelu130: T.Buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16"), model_decoder_layers_0_fc2_weight5: T.Buffer((T.int64(1280), T.int64(5120)), "float16"), model_decoder_layers_0_fc2_bias5: T.Buffer((T.int64(1280),), "float16"), add1227: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_add_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") model_decoder_layers_0_fc2_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(5120)), "float16", scope="local") gelu130_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(5120)), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_0 in T.serial(T.int64(5), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_3 in T.vectorized(T.int64(2)): with T.block("gelu130_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(5120), ax2_0 * T.int64(1024) + ax2_1 * T.int64(64) + ax2_2 * T.int64(2) + ax2_3) T.reads(gelu130[v0, v1, v2]) T.writes(gelu130_shared[v0, v1, v2]) gelu130_shared[v0, v1, v2] = gelu130[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1_fused_u_fused_0 in T.serial(T.int64(20), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(T.int64(4)): for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): with T.block("model_decoder_layers_0_fc2_weight5_local"): v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(T.int64(5120), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) T.reads(model_decoder_layers_0_fc2_weight5[v0, v1]) T.writes(model_decoder_layers_0_fc2_weight5_local[v0, v1]) model_decoder_layers_0_fc2_weight5_local[v0, v1] = model_decoder_layers_0_fc2_weight5[v0, v1] for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], gelu130_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_fc2_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + gelu130_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_fc2_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(T.int64(1)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads() T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] for ax1_fused_2 in range(T.int64(1)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): with T.block("NT_matmul"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) with T.init(): NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_fused_2 in range(T.int64(1)): with T.block("T_add_1"): v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) T.reads(add1227[T.int64(0), T.int64(0), v0], NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_fc2_bias5[v0]) T.writes(T_add_intermediate_1[T.int64(0), T.int64(0), v0]) T_add_intermediate_1[T.int64(0), T.int64(0), v0] = add1227[T.int64(0), T.int64(0), v0] + (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_fc2_bias5[v0]) @T.prim_func(private=True) def fused_NT_matmul_add7(layer_norm356: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_q_proj_weight5: T.Buffer((T.int64(1280), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_q_proj_bias5: T.Buffer((T.int64(1280),), "float16"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") model_decoder_layers_0_self_attn_q_proj_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(1280)), "float16", scope="local") layer_norm356_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_0 in T.serial(T.int64(3), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_3 in T.vectorized(T.int64(1)): with T.block("layer_norm356_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(1280)) T.reads(layer_norm356[v0, v1, v2]) T.writes(layer_norm356_shared[v0, v1, v2]) layer_norm356_shared[v0, v1, v2] = layer_norm356[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(T.int64(4)): for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): with T.block("model_decoder_layers_0_self_attn_q_proj_weight5_local"): v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) T.reads(model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1]) T.writes(model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1]) model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, v1] = model_decoder_layers_0_self_attn_q_proj_weight5[v0, v1] for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + layer_norm356_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_self_attn_q_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(T.int64(1)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads() T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] for ax1_fused_2 in range(T.int64(1)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): with T.block("NT_matmul"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) with T.init(): NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_fused_2 in range(T.int64(1)): with T.block("T_add"): v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) T.reads(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_self_attn_q_proj_bias5[v0]) T.writes(T_add_intermediate[T.int64(0), T.int64(0), v0]) T_add_intermediate[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_self_attn_q_proj_bias5[v0] @T.prim_func(private=True) def fused_NT_matmul_add7_add6(reshape1361: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_out_proj_weight5: T.Buffer((T.int64(1280), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_out_proj_bias5: T.Buffer((T.int64(1280),), "float16"), add1220: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_add_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="local") model_decoder_layers_0_self_attn_out_proj_weight5_local = T.alloc_buffer((T.int64(1280), T.int64(1280)), "float16", scope="local") reshape1361_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_0 in T.serial(T.int64(3), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_3 in T.vectorized(T.int64(1)): with T.block("reshape1361_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(1280)) T.reads(reshape1361[v0, v1, v2]) T.writes(reshape1361_shared[v0, v1, v2]) reshape1361_shared[v0, v1, v2] = reshape1361[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1_fused_u_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(T.int64(4)): for ax0_ax1_fused_1 in T.vectorized(T.int64(2)): with T.block("model_decoder_layers_0_self_attn_out_proj_weight5_local"): v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(T.int64(1280), ax1_fused_u_fused_0 * T.int64(256) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1) T.reads(model_decoder_layers_0_self_attn_out_proj_weight5[v0, v1]) T.writes(model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, v1]) model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, v1] = model_decoder_layers_0_self_attn_out_proj_weight5[v0, v1] for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) T.reads(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], reshape1361_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]) T.writes(NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + reshape1361_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_decoder_layers_0_self_attn_out_proj_weight5_local[v0, vax1_fused_u_fused_0 * T.int64(256) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(T.int64(1)): with T.block("NT_matmul_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads() T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0] for ax1_fused_2 in range(T.int64(1)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): with T.block("NT_matmul"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.reads(NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]) T.writes(NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) with T.init(): NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + NT_matmul_intermediate_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_fused_2 in range(T.int64(1)): with T.block("T_add_1"): v0 = T.axis.spatial(T.int64(1280), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) T.reads(add1220[T.int64(0), T.int64(0), v0], NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0], model_decoder_layers_0_self_attn_out_proj_bias5[v0]) T.writes(T_add_intermediate_1[T.int64(0), T.int64(0), v0]) T_add_intermediate_1[T.int64(0), T.int64(0), v0] = add1220[T.int64(0), T.int64(0), v0] + (NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + model_decoder_layers_0_self_attn_out_proj_bias5[v0]) @T.prim_func(private=True) def fused_add4_maximum_minimum(p_add4: T.handle, p_lv611: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() add4 = T.match_buffer(p_add4, (batch_size, T.int64(1500), T.int64(1280)), "float16") lv611 = T.match_buffer(p_lv611, (batch_size, T.int64(1500), T.int64(1280)), "float16") T_minimum_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1500), T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_minimum"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) T.reads(add4[v0, v1, v2], lv611[v0, v1, v2]) T.writes(T_minimum_intermediate[v0, v1, v2]) T_minimum_intermediate[v0, v1, v2] = T.min(T.max(add4[v0, v1, v2] + lv611[v0, v1, v2], T.float16(-65504)), T.float16(65504)) @T.prim_func(private=True) def fused_conv1d1_add2_gelu1(p_gelu: T.handle, model_encoder_conv2_weight: T.Buffer((T.int64(1280), T.int64(1280), T.int64(3)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1280), T.int64(1)), "float16"), p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() gelu = T.match_buffer(p_gelu, (batch_size, T.int64(1280), T.int64(3000)), "float16") T_multiply_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1280), T.int64(1500)), "float16") # with T.block("root"): conv1d_ncw_intermediate_shared = T.alloc_buffer((batch_size, T.int64(1280), T.int64(1500)), "float16", scope="shared") for ax0_ax1_ax2_fused in T.thread_binding(batch_size * T.int64(1920000), thread="blockIdx.x"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): for ax3_ax4_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax3_ax4_fused_0 in T.serial(T.int64(15), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("conv1d_ncw"): v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(1920000) + ax0) v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(1920000) // T.int64(1500) + ax1) v2 = T.axis.spatial(T.int64(1500), ax0_ax1_ax2_fused % T.int64(1500) + ax2) v3 = T.axis.reduce(T.int64(1280), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) // T.int64(3)) v4 = T.axis.reduce(T.int64(3), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) % T.int64(3)) T.reads(gelu[v0, v3, v2 * T.int64(2) + v4 - T.int64(1)], model_encoder_conv2_weight[v1, v3, v4]) T.writes(conv1d_ncw_intermediate_shared[v0, v1, v2]) with T.init(): conv1d_ncw_intermediate_shared[v0, v1, v2] = T.float16(0) conv1d_ncw_intermediate_shared[v0, v1, v2] = conv1d_ncw_intermediate_shared[v0, v1, v2] + T.if_then_else(T.int64(1) <= v2 * T.int64(2) + v4 and v2 * T.int64(2) + v4 < T.int64(3001), gelu[v0, v3, v2 * T.int64(2) + v4 - T.int64(1)], T.float16(0)) * model_encoder_conv2_weight[v1, v3, v4] for ax3 in range(T.int64(1)): for ax4_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax4_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("T_multiply_2"): v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(1920000)) v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(1920000) // T.int64(1500)) v2 = T.axis.spatial(T.int64(1500), ax0_ax1_ax2_fused % T.int64(1500)) v3 = T.axis.spatial(T.int64(1), ax3) v4 = T.axis.spatial(T.int64(1), ax4_0 * T.int64(256) + ax4_1) T.where(ax4_0 * T.int64(256) + ax4_1 < T.int64(1)) T.reads(conv1d_ncw_intermediate_shared[v0, v1, v2], lv3[T.int64(0), v1, T.int64(0)]) T.writes(T_multiply_intermediate[v0, v1, v2]) T_multiply_intermediate[v0, v1, v2] = (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv3[T.int64(0), v1, T.int64(0)]) * (T.float16(0.5) + T.Cast("float16", T.erf(T.Cast("float32", (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv3[T.int64(0), v1, T.int64(0)]) * T.float16(0.70710678118654757)))) * T.float16(0.5)) @T.prim_func(private=True) def fused_conv1d_add1_gelu(p_input_features: T.handle, model_encoder_conv1_weight: T.Buffer((T.int64(1280), T.int64(128), T.int64(3)), "float16"), lv1: T.Buffer((T.int64(1), T.int64(1280), T.int64(1)), "float16"), p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() input_features = T.match_buffer(p_input_features, (batch_size, T.int64(128), T.int64(3000)), "float16") T_multiply_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1280), T.int64(3000)), "float16") # with T.block("root"): conv1d_ncw_intermediate_shared = T.alloc_buffer((batch_size, T.int64(1280), T.int64(3000)), "float16", scope="shared") for ax0_ax1_ax2_fused in T.thread_binding(batch_size * T.int64(3840000), thread="blockIdx.x"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): for ax3_ax4_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax3_ax4_fused_0 in T.serial(T.int64(2), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("conv1d_ncw"): v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(3840000) + ax0) v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(3840000) // T.int64(3000) + ax1) v2 = T.axis.spatial(T.int64(3000), ax0_ax1_ax2_fused % T.int64(3000) + ax2) v3 = T.axis.reduce(T.int64(128), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) // T.int64(3)) v4 = T.axis.reduce(T.int64(3), (ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1) % T.int64(3)) T.where(ax3_ax4_fused_0 * T.int64(256) + ax3_ax4_fused_1 < T.int64(384)) T.reads(input_features[v0, v3, v2 + v4 - T.int64(1)], model_encoder_conv1_weight[v1, v3, v4]) T.writes(conv1d_ncw_intermediate_shared[v0, v1, v2]) with T.init(): conv1d_ncw_intermediate_shared[v0, v1, v2] = T.float16(0) conv1d_ncw_intermediate_shared[v0, v1, v2] = conv1d_ncw_intermediate_shared[v0, v1, v2] + T.if_then_else(T.int64(1) <= v2 + v4 and v2 + v4 < T.int64(3001), input_features[v0, v3, v2 + v4 - T.int64(1)], T.float16(0)) * model_encoder_conv1_weight[v1, v3, v4] for ax3 in range(T.int64(1)): for ax4_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax4_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("T_multiply_2"): v0 = T.axis.spatial(batch_size, ax0_ax1_ax2_fused // T.int64(3840000)) v1 = T.axis.spatial(T.int64(1280), ax0_ax1_ax2_fused % T.int64(3840000) // T.int64(3000)) v2 = T.axis.spatial(T.int64(3000), ax0_ax1_ax2_fused % T.int64(3000)) v3 = T.axis.spatial(T.int64(1), ax3) v4 = T.axis.spatial(T.int64(1), ax4_0 * T.int64(256) + ax4_1) T.where(ax4_0 * T.int64(256) + ax4_1 < T.int64(1)) T.reads(conv1d_ncw_intermediate_shared[v0, v1, v2], lv1[T.int64(0), v1, T.int64(0)]) T.writes(T_multiply_intermediate[v0, v1, v2]) T_multiply_intermediate[v0, v1, v2] = (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv1[T.int64(0), v1, T.int64(0)]) * (T.float16(0.5) + T.Cast("float16", T.erf(T.Cast("float32", (conv1d_ncw_intermediate_shared[v0, v1, v2] + lv1[T.int64(0), v1, T.int64(0)]) * T.float16(0.70710678118654757)))) * T.float16(0.5)) @T.prim_func(private=True) def fused_reshape20_reshape20_add6(take7: T.Buffer((T.int64(1), T.int64(1280)), "float16"), take8: T.Buffer((T.int64(1), T.int64(1280)), "float16"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_add"): v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) T.reads(take7[T.int64(0), v0], take8[T.int64(0), v0]) T.writes(T_add_intermediate[T.int64(0), T.int64(0), v0]) T_add_intermediate[T.int64(0), T.int64(0), v0] = take7[T.int64(0), v0] + take8[T.int64(0), v0] @T.prim_func(private=True) def fused_reshape21_reshape21_reshape21_concatenate2_reshape22(add1221: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), lv1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), add1222: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_reshape_intermediate_1_2_3: T.Buffer((T.int64(1), T.int64(60), T.int64(64)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(T.int64(4), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape_3"): v0 = T.axis.spatial(T.int64(60), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(64)) v1 = T.axis.spatial(T.int64(64), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(64)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(3840)) T.reads(add1222[T.int64(0), T.int64(0), (v0 - T.int64(40)) * T.int64(64) + v1], lv1[T.int64(0), T.int64(0), (v0 + T.int64(-20)) * T.int64(64) + v1], add1221[T.int64(0), T.int64(0), v0 * T.int64(64) + v1]) T.writes(T_reshape_intermediate_1_2_3[T.int64(0), v0, v1]) T_reshape_intermediate_1_2_3[T.int64(0), v0, v1] = T.if_then_else(T.int64(40) <= v0, add1222[T.int64(0), T.int64(0), (v0 - T.int64(40)) * T.int64(64) + v1], T.if_then_else(T.int64(20) <= v0, lv1[T.int64(0), T.int64(0), (v0 + T.int64(-20)) * T.int64(64) + v1], add1221[T.int64(0), T.int64(0), v0 * T.int64(64) + v1])) @T.prim_func(private=True) def fused_reshape21_reshape25(add1225: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(20), T.int64(64)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape_1"): v0 = T.axis.spatial(T.int64(20), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(64)) v1 = T.axis.spatial(T.int64(64), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(64)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(1280)) T.reads(add1225[T.int64(0), T.int64(0), v0 * T.int64(64) + v1]) T.writes(T_reshape_intermediate_1[T.int64(0), v0, v1]) T_reshape_intermediate_1[T.int64(0), v0, v1] = add1225[T.int64(0), T.int64(0), v0 * T.int64(64) + v1] @T.prim_func(private=True) def fused_reshape23_reshape24(lv265: T.Buffer((T.int64(1), T.int64(20), T.int64(64)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape_1"): v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) T.reads(lv265[T.int64(0), v0 // T.int64(64), v0 % T.int64(64)]) T.writes(T_reshape_intermediate_1[T.int64(0), T.int64(0), v0]) T_reshape_intermediate_1[T.int64(0), T.int64(0), v0] = lv265[T.int64(0), v0 // T.int64(64), v0 % T.int64(64)] @T.prim_func(private=True) def fused_reshape9(packed_params_1: T.Buffer((T.int64(1280),), "float16"), T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1280), T.int64(1)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) T.reads(packed_params_1[v0]) T.writes(T_reshape_intermediate[T.int64(0), v0, T.int64(0)]) T_reshape_intermediate[T.int64(0), v0, T.int64(0)] = packed_params_1[v0] @T.prim_func def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() qkv = T.match_buffer(var_qkv, (seq_len, 60, 64), "float16") position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1) q = T.match_buffer(var_q, (seq_len, 20, 64), "float16") k = T.match_buffer(var_k, (seq_len, 20, 64), "float16") v = T.match_buffer(var_v, (seq_len, 20, 64), "float16") # with T.block("root"): for iters_0_iters_1_iters_2_fused_0 in T.thread_binding((seq_len * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for iters_0_iters_1_iters_2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("llama_fused_rope"): s = T.axis.spatial(seq_len, (iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1) // T.int64(3840)) h = T.axis.spatial(60, T.Cast("int32", (iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1) % T.int64(3840) // T.int64(64))) d = T.axis.spatial(64, T.Cast("int32", (iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1) % T.int64(64))) T.where(iters_0_iters_1_iters_2_fused_0 * T.int64(1024) + iters_0_iters_1_iters_2_fused_1 < seq_len * T.int64(3840)) T.reads(position_map[s], qkv[s, h, d - 32:d - 32 + 65]) T.writes(q[s, h, d], k[s, h - 20, d], v[s, h - 40, d]) if h < 20: q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Cast("float16", T.cos(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", qkv[s, h, d]) + T.sin(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1), qkv[s, h, d - 32]))), qkv[s, h, d]) else: if h < 40: k[s, h - 20, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Cast("float16", T.cos(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", qkv[s, h, d]) + T.sin(T.Cast("float32", position_map[s]) / T.pow(T.float32(1), T.Cast("float32", d * 2 % 64) / T.float32(64))) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1), qkv[s, h, d - 32]))), qkv[s, h, d]) else: v[s, h - 40, d] = qkv[s, h, d] @T.prim_func(private=True) def fused_transpose_add3(packed_params_4: T.Buffer((T.int64(1500), T.int64(1280)), "float16"), p_gelu1: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() gelu1 = T.match_buffer(p_gelu1, (batch_size, T.int64(1280), T.int64(1500)), "float16") T_add_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1500), T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_add"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) T.reads(gelu1[v0, v2, v1], packed_params_4[v1, v2]) T.writes(T_add_intermediate[v0, v1, v2]) T_add_intermediate[v0, v1, v2] = gelu1[v0, v2, v1] + packed_params_4[v1, v2] @T.prim_func def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) m, n = T.int32(is_size_var=True), T.int32(is_size_var=True) src = T.match_buffer(var_src, (m, n)) batch_size = T.int32(is_size_var=True) indices = T.match_buffer(var_indices, (batch_size,), "int32") dst = T.match_buffer(var_dst, (batch_size, n)) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("gather_2d"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n) v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n) T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n) T.reads(src[indices[v0], v1], indices[v0]) T.writes(dst[v0, v1]) dst[v0, v1] = src[indices[v0], v1] @T.prim_func(private=True) def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) batch, vocab_size = T.int64(), T.int64() cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) indices = T.match_buffer(B, (batch, vocab_size), "int32") renorm_prob = T.match_buffer(C, (batch, 1)) out_batch = T.int64() usample = T.match_buffer(D, (out_batch, 1)) sample_indices = T.match_buffer(E, (out_batch, 1), "int32") output_index = T.match_buffer(F, (out_batch, 1), "int32") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((out_batch * vocab_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_get_index_from_sorted"): v0 = T.axis.spatial(out_batch, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size * out_batch) // vocab_size) v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < out_batch * vocab_size) T.reads(usample[v0, T.int64(0)], cumsum_sorted[sample_indices[v0, T.int64(0)], v1 - T.int64(1):v1 - T.int64(1) + T.int64(2)], sample_indices[v0, T.int64(0)], renorm_prob[sample_indices[v0, T.int64(0)], 0], indices[sample_indices[v0, T.int64(0)], T.min(T.int64(0), v1):T.min(T.int64(0), v1) + (v1 + T.int64(1))]) T.writes(output_index[v0, 0]) if usample[v0, T.int64(0)] < cumsum_sorted[sample_indices[v0, T.int64(0)], v1] / renorm_prob[sample_indices[v0, T.int64(0)], 0] or v1 + T.int64(1) == vocab_size: if v1 == T.int64(0): output_index[v0, 0] = indices[sample_indices[v0, T.int64(0)], 0] else: if usample[v0, T.int64(0)] >= cumsum_sorted[sample_indices[v0, T.int64(0)], v1 - T.int64(1)] / renorm_prob[sample_indices[v0, T.int64(0)], 0]: output_index[v0, 0] = indices[sample_indices[v0, T.int64(0)], v1] @T.prim_func(private=True) def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) batch, vocab_size = T.int64(), T.int64() cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) top_p = T.match_buffer(B, (batch, 1)) top_k = T.match_buffer(C, (batch, 1), "int32") renorm_prob = T.match_buffer(D, (batch, 1)) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch * vocab_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_get_renorm_prob"): v0 = T.axis.spatial(batch, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size * batch) // vocab_size) v1 = T.axis.spatial(vocab_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch * vocab_size) T.reads(cumsum_sorted[v0, T.min(T.min(T.int64(0), v1), v1 + T.int64(1)):T.min(T.min(T.int64(0), v1), v1 + T.int64(1)) + (v1 + T.int64(2))], top_p[v0, 0], top_k[v0, 0]) T.writes(renorm_prob[v0, 0]) if not (cumsum_sorted[v0, 0] < top_p[v0, 0] and top_k[v0, 0] > 1): renorm_prob[v0, 0] = cumsum_sorted[v0, 0] else: if cumsum_sorted[v0, v1] < top_p[v0, 0] and v1 + T.int64(1) < T.Cast("int64", top_k[v0, 0]): if v1 + T.int64(1) == vocab_size: renorm_prob[v0, 0] = cumsum_sorted[v0, v1] else: if not (cumsum_sorted[v0, v1 + T.int64(1)] < top_p[v0, 0] and v1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v0, 0])): renorm_prob[v0, 0] = cumsum_sorted[v0, v1 + T.int64(1)] @T.prim_func(private=True) def index(var_layer_norm355: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() layer_norm355 = T.match_buffer(var_layer_norm355, (T.int64(1), seq_len, T.int64(1280)), "float16") # with T.block("root"): for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("index"): v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) T.reads(layer_norm355[T.int64(0), seq_len - T.int64(1), v0]) T.writes(index[T.int64(0), T.int64(0), v0]) index[T.int64(0), T.int64(0), v0] = layer_norm355[T.int64(0), seq_len - T.int64(1), v0] @T.prim_func(private=True) def layer_norm(var_add578: T.handle, model_decoder_layers_0_self_attn_layer_norm_weight3: T.Buffer((T.int64(1280),), "float16"), model_decoder_layers_0_self_attn_layer_norm_bias3: T.Buffer((T.int64(1280),), "float16"), var_T_layer_norm: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() add578 = T.match_buffer(var_add578, (batch_size, T.int64(1), T.int64(1280)), "float16") T_layer_norm = T.match_buffer(var_T_layer_norm, (batch_size, T.int64(1), T.int64(1280)), "float16") # with T.block("root"): add578_red_temp_v0_shared = T.alloc_buffer((batch_size, T.int64(1)), scope="shared") add578_red_temp_v1_shared = T.alloc_buffer((batch_size, T.int64(1)), scope="shared") for ax0_fused in T.thread_binding(batch_size, thread="blockIdx.x"): for ax0 in range(T.int64(1)): for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("add578_red_temp"): v0 = T.axis.spatial(batch_size, ax0_fused + ax0) v1 = T.axis.reduce(T.int64(1280), ax1_fused_0 * T.int64(256) + ax1_fused_1) T.reads(add578[v0, T.int64(0), v1]) T.writes(add578_red_temp_v0_shared[v0, T.int64(0)], add578_red_temp_v1_shared[v0, T.int64(0)]) with T.init(): add578_red_temp_v0_shared[v0, T.int64(0)] = T.float32(0) add578_red_temp_v1_shared[v0, T.int64(0)] = T.float32(0) v_add578_red_temp_v0: T.float32 = add578_red_temp_v0_shared[v0, T.int64(0)] + T.Cast("float32", add578[v0, T.int64(0), v1]) v_add578_red_temp_v1: T.float32 = add578_red_temp_v1_shared[v0, T.int64(0)] + T.Cast("float32", add578[v0, T.int64(0), v1]) * T.Cast("float32", add578[v0, T.int64(0), v1]) add578_red_temp_v0_shared[v0, T.int64(0)] = v_add578_red_temp_v0 add578_red_temp_v1_shared[v0, T.int64(0)] = v_add578_red_temp_v1 for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("T_layer_norm"): v0 = T.axis.spatial(batch_size, ax0_fused) v1 = T.axis.spatial(T.int64(1280), ax1_0 * T.int64(256) + ax1_1) T.reads(add578[v0, T.int64(0), v1], add578_red_temp_v0_shared[v0, T.int64(0)], add578_red_temp_v1_shared[v0, T.int64(0)], model_decoder_layers_0_self_attn_layer_norm_weight3[v1], model_decoder_layers_0_self_attn_layer_norm_bias3[v1]) T.writes(T_layer_norm[v0, T.int64(0), v1]) T_layer_norm[v0, T.int64(0), v1] = T.Cast("float16", (T.Cast("float32", add578[v0, T.int64(0), v1]) - add578_red_temp_v0_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004)) * T.rsqrt(add578_red_temp_v1_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004) - add578_red_temp_v0_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004) * (add578_red_temp_v0_shared[v0, T.int64(0)] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_decoder_layers_0_self_attn_layer_norm_weight3[v1] + model_decoder_layers_0_self_attn_layer_norm_bias3[v1] @T.prim_func(private=True) def layer_norm1(var_add: T.handle, model_encoder_layers_0_self_attn_layer_norm_weight: T.Buffer((T.int64(1280),), "float16"), model_encoder_layers_0_self_attn_layer_norm_bias: T.Buffer((T.int64(1280),), "float16"), var_T_layer_norm: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() add = T.match_buffer(var_add, (batch_size, T.int64(1500), T.int64(1280)), "float16") T_layer_norm = T.match_buffer(var_T_layer_norm, (batch_size, T.int64(1500), T.int64(1280)), "float16") # with T.block("root"): add_red_temp_v0_shared = T.alloc_buffer((batch_size, T.int64(1500)), scope="shared") add_red_temp_v1_shared = T.alloc_buffer((batch_size, T.int64(1500)), scope="shared") for ax0_ax1_fused in T.thread_binding(batch_size * T.int64(1500), thread="blockIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("add_red_temp"): v0 = T.axis.spatial(batch_size, ax0_ax1_fused // T.int64(1500) + ax0) v1 = T.axis.spatial(T.int64(1500), ax0_ax1_fused % T.int64(1500) + ax1) v2 = T.axis.reduce(T.int64(1280), ax2_fused_0 * T.int64(256) + ax2_fused_1) T.reads(add[v0, v1, v2]) T.writes(add_red_temp_v0_shared[v0, v1], add_red_temp_v1_shared[v0, v1]) with T.init(): add_red_temp_v0_shared[v0, v1] = T.float32(0) add_red_temp_v1_shared[v0, v1] = T.float32(0) v_add_red_temp_v0: T.float32 = add_red_temp_v0_shared[v0, v1] + T.Cast("float32", add[v0, v1, v2]) v_add_red_temp_v1: T.float32 = add_red_temp_v1_shared[v0, v1] + T.Cast("float32", add[v0, v1, v2]) * T.Cast("float32", add[v0, v1, v2]) add_red_temp_v0_shared[v0, v1] = v_add_red_temp_v0 add_red_temp_v1_shared[v0, v1] = v_add_red_temp_v1 for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("T_layer_norm"): v0 = T.axis.spatial(batch_size, ax0_ax1_fused // T.int64(1500)) v1 = T.axis.spatial(T.int64(1500), ax0_ax1_fused % T.int64(1500)) v2 = T.axis.spatial(T.int64(1280), ax2_0 * T.int64(256) + ax2_1) T.reads(add[v0, v1, v2], add_red_temp_v0_shared[v0, v1], add_red_temp_v1_shared[v0, v1], model_encoder_layers_0_self_attn_layer_norm_weight[v2], model_encoder_layers_0_self_attn_layer_norm_bias[v2]) T.writes(T_layer_norm[v0, v1, v2]) T_layer_norm[v0, v1, v2] = T.Cast("float16", (T.Cast("float32", add[v0, v1, v2]) - add_red_temp_v0_shared[v0, v1] * T.float32(0.00078125000000000004)) * T.rsqrt(add_red_temp_v1_shared[v0, v1] * T.float32(0.00078125000000000004) - add_red_temp_v0_shared[v0, v1] * T.float32(0.00078125000000000004) * (add_red_temp_v0_shared[v0, v1] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_encoder_layers_0_self_attn_layer_norm_weight[v2] + model_encoder_layers_0_self_attn_layer_norm_bias[v2] @T.prim_func(private=True) def layer_norm2(var_add257: T.handle, model_decoder_layers_0_self_attn_layer_norm_weight2: T.Buffer((T.int64(1280),), "float16"), model_decoder_layers_0_self_attn_layer_norm_bias2: T.Buffer((T.int64(1280),), "float16"), var_T_layer_norm: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() add257 = T.match_buffer(var_add257, (T.int64(1), seq_len, T.int64(1280)), "float16") T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), seq_len, T.int64(1280)), "float16") # with T.block("root"): add257_red_temp_v0_shared = T.alloc_buffer((T.int64(1), seq_len), scope="shared") add257_red_temp_v1_shared = T.alloc_buffer((T.int64(1), seq_len), scope="shared") for ax0_fused in T.thread_binding(seq_len, thread="blockIdx.x"): for ax0 in range(T.int64(1)): for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("add257_red_temp"): v0 = T.axis.spatial(seq_len, ax0_fused + ax0) v1 = T.axis.reduce(T.int64(1280), ax1_fused_0 * T.int64(256) + ax1_fused_1) T.reads(add257[T.int64(0), v0, v1]) T.writes(add257_red_temp_v0_shared[T.int64(0), v0], add257_red_temp_v1_shared[T.int64(0), v0]) with T.init(): add257_red_temp_v0_shared[T.int64(0), v0] = T.float32(0) add257_red_temp_v1_shared[T.int64(0), v0] = T.float32(0) v_add257_red_temp_v0: T.float32 = add257_red_temp_v0_shared[T.int64(0), v0] + T.Cast("float32", add257[T.int64(0), v0, v1]) v_add257_red_temp_v1: T.float32 = add257_red_temp_v1_shared[T.int64(0), v0] + T.Cast("float32", add257[T.int64(0), v0, v1]) * T.Cast("float32", add257[T.int64(0), v0, v1]) add257_red_temp_v0_shared[T.int64(0), v0] = v_add257_red_temp_v0 add257_red_temp_v1_shared[T.int64(0), v0] = v_add257_red_temp_v1 for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("T_layer_norm"): v0 = T.axis.spatial(seq_len, ax0_fused) v1 = T.axis.spatial(T.int64(1280), ax1_0 * T.int64(256) + ax1_1) T.reads(add257[T.int64(0), v0, v1], add257_red_temp_v0_shared[T.int64(0), v0], add257_red_temp_v1_shared[T.int64(0), v0], model_decoder_layers_0_self_attn_layer_norm_weight2[v1], model_decoder_layers_0_self_attn_layer_norm_bias2[v1]) T.writes(T_layer_norm[T.int64(0), v0, v1]) T_layer_norm[T.int64(0), v0, v1] = T.Cast("float16", (T.Cast("float32", add257[T.int64(0), v0, v1]) - add257_red_temp_v0_shared[T.int64(0), v0] * T.float32(0.00078125000000000004)) * T.rsqrt(add257_red_temp_v1_shared[T.int64(0), v0] * T.float32(0.00078125000000000004) - add257_red_temp_v0_shared[T.int64(0), v0] * T.float32(0.00078125000000000004) * (add257_red_temp_v0_shared[T.int64(0), v0] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_decoder_layers_0_self_attn_layer_norm_weight2[v1] + model_decoder_layers_0_self_attn_layer_norm_bias2[v1] @T.prim_func(private=True) def layer_norm3(add1220: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16"), model_decoder_layers_0_self_attn_layer_norm_weight5: T.Buffer((T.int64(1280),), "float16"), model_decoder_layers_0_self_attn_layer_norm_bias5: T.Buffer((T.int64(1280),), "float16"), T_layer_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): add1220_red_temp_v0_shared = T.alloc_buffer((T.int64(1), T.int64(1)), scope="shared") add1220_red_temp_v1_shared = T.alloc_buffer((T.int64(1), T.int64(1)), scope="shared") for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax0 in range(T.int64(1)): for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_fused_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("add1220_red_temp"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.reduce(T.int64(1280), ax1_fused_0 * T.int64(256) + ax1_fused_1) T.reads(add1220[T.int64(0), T.int64(0), v1]) T.writes(add1220_red_temp_v0_shared[T.int64(0), T.int64(0)], add1220_red_temp_v1_shared[T.int64(0), T.int64(0)]) with T.init(): add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] = T.float32(0) add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] = T.float32(0) v_add1220_red_temp_v0: T.float32 = add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] + T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) v_add1220_red_temp_v1: T.float32 = add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] + T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) * T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] = v_add1220_red_temp_v0 add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] = v_add1220_red_temp_v1 for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0 in T.serial(T.int64(5), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("T_layer_norm"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(1280), ax1_0 * T.int64(256) + ax1_1) T.reads(add1220[T.int64(0), T.int64(0), v1], add1220_red_temp_v0_shared[T.int64(0), T.int64(0)], add1220_red_temp_v1_shared[T.int64(0), T.int64(0)], model_decoder_layers_0_self_attn_layer_norm_weight5[v1], model_decoder_layers_0_self_attn_layer_norm_bias5[v1]) T.writes(T_layer_norm[T.int64(0), T.int64(0), v1]) T_layer_norm[T.int64(0), T.int64(0), v1] = T.Cast("float16", (T.Cast("float32", add1220[T.int64(0), T.int64(0), v1]) - add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004)) * T.rsqrt(add1220_red_temp_v1_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004) - add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004) * (add1220_red_temp_v0_shared[T.int64(0), T.int64(0)] * T.float32(0.00078125000000000004)) + T.float32(1.0000000000000001e-05))) * model_decoder_layers_0_self_attn_layer_norm_weight5[v1] + model_decoder_layers_0_self_attn_layer_norm_bias5[v1] @T.prim_func def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True) V = T.match_buffer(v, (N, H, D), "float16") S = T.match_buffer(s, (N, H)) V_other = T.match_buffer(v_other, (N, H, D), "float16") S_other = T.match_buffer(s_other, (N, H)) # with T.block("root"): for bx in T.thread_binding(N, thread="blockIdx.x"): for by in T.thread_binding(1, thread="blockIdx.y"): for ty in T.thread_binding(20, thread="threadIdx.y"): for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block("merge"): T.reads(S[bx, ty + by * 20], S_other[bx, ty + by * 20], V[bx, ty + by * 20, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 20, tx * 4:tx * 4 + 4]) T.writes(V[bx, ty + by * 20, tx * 4:tx * 4 + 4], S[bx, ty + by * 20]) s_val = T.alloc_buffer((1,), scope="local") s_other_val = T.alloc_buffer((1,), scope="local") s_max = T.alloc_buffer((1,), scope="local") scale = T.alloc_buffer((1,), scope="local") other_scale = T.alloc_buffer((1,), scope="local") v_vec = T.alloc_buffer((4,), "float16", scope="local") v_other_vec = T.alloc_buffer((4,), "float16", scope="local") s_val[0] = S[bx, ty + by * 20] s_other_val[0] = S_other[bx, ty + by * 20] s_max[0] = T.max(s_val[0], s_other_val[0]) s_val[0] = T.exp2(s_val[0] - s_max[0]) s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) for vec in T.vectorized(4): v_vec[vec] = V[bx, ty + by * 20, tx * 4 + vec] for vec in T.vectorized(4): v_other_vec[vec] = V_other[bx, ty + by * 20, tx * 4 + vec] for vec in range(4): v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0]) for vec in T.vectorized(4): V[bx, ty + by * 20, tx * 4 + vec] = v_vec[vec] S[bx, ty + by * 20] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] @T.prim_func def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle): T.func_attr({"tir.is_scheduled": 1}) n, vocab_size = T.int64(), T.int64() prob = T.match_buffer(var_prob, (n, vocab_size)) batch_size = T.int64() uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1)) row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int32") token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int32") # with T.block("root"): aggregate = T.alloc_buffer((), scope="local") sample_id_local = T.alloc_buffer((), "int32", scope="local") step_iter = T.alloc_buffer((), "int32", scope="local") for bx in T.thread_binding(batch_size, thread="blockIdx.x"): row_idx: T.int32 = row_indices[bx, 0] for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"): for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): u: T.float32 = uniform_samples[bx, 0] aggregate[()] = T.Cast("float32", 0) step_iter[()] = 0 while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)): with T.block(""): T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()]) T.writes(sample_id_local[()], aggregate[()]) prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local") cumsum = T.alloc_buffer((T.int64(512),), scope="shared") greater_than_u = T.alloc_buffer((T.int64(4),), "bool", scope="local") mask = T.alloc_buffer((T.int64(4),), "bool", scope="local") valid = T.alloc_buffer((T.int64(4),), "bool", scope="local") indices = T.alloc_buffer((T.int64(4),), "int32", scope="local") step_aggregate = T.alloc_buffer((), scope="local") for v in T.unroll(T.int64(4)): idx: T.int64 = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0)) prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0), prob_local, T.Cast("float32", 0)) valid[v] = prob_local > T.float32(0) and idx < vocab_size with T.block(""): T.reads(prob_gt_threshold[T.int64(0):T.int64(4)]) T.writes(step_aggregate[()]) local_sum = T.alloc_buffer((), scope="local") shared_buf = T.alloc_buffer((T.int64(128),), scope="shared") idx: T.int64 = ty * T.int64(32) + tx local_sum[()] = T.Cast("float32", 0) for i in T.unroll(T.int64(4)): local_sum[()] = local_sum[()] + prob_gt_threshold[i] shared_buf[idx] = local_sum[()] for i in T.unroll(T.int64(7)): if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0): shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(T.int64(1), i)] step_aggregate[()] = shared_buf[0] if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)): for i in T.unroll(T.int64(1), T.int64(4)): prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - T.int64(1)] for i in T.vectorized(T.int64(4)): cumsum[ty * T.int64(128) + tx * T.int64(4) + i] = prob_gt_threshold[i] for i in T.unroll(T.int64(5)): for j in T.vectorized(T.int64(4)): idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) if tx >= T.shift_left(T.int64(1), i): cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)] for i in T.unroll(T.int64(1), T.int64(4)): for j in T.vectorized(T.int64(4)): if ty == T.int64(0): idx: T.int64 = i * T.int64(128) + tx * T.int64(4) cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)] for v in T.unroll(T.int64(4)): greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07) with T.block(""): T.reads(greater_than_u[T.int64(0):T.int64(4)]) T.writes(mask[T.int64(0):T.int64(4)]) shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared") tx_idx: T.int64 = ty * T.int64(32) + tx shared_buf[tx_idx] = greater_than_u[T.int64(3)] mask[0] = T.if_then_else(tx_idx != T.int64(0), T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - T.int64(1)]), greater_than_u[0]) for i in T.unroll(T.int64(1), T.int64(4)): mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - T.int64(1)]) for v in T.unroll(T.int64(4)): mask[v] = mask[v] and valid[v] indices[v] = T.Cast("int32", T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v) with T.block(""): T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)]) T.writes(sample_id_local[()]) local_sum = T.alloc_buffer((), "int32", scope="local") shared_buf = T.alloc_buffer((T.int64(128),), "int32", scope="shared") idx: T.int64 = ty * T.int64(32) + tx local_sum[()] = T.Cast("int32", vocab_size - T.int64(1)) for i in T.unroll(T.int64(4)): if mask[i]: local_sum[()] = T.min(local_sum[()], indices[i]) shared_buf[idx] = local_sum[()] for i in T.unroll(T.int64(7)): if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0): shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(T.int64(1), i)]) sample_id_local[()] = shared_buf[0] aggregate[()] = aggregate[()] + step_aggregate[()] step_iter[()] = step_iter[()] + 1 if tx == T.int64(0) and ty == T.int64(0): token_ids[bx, 0] = sample_id_local[()] @T.prim_func(private=True) def reshape(var_lv: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() lv = T.match_buffer(var_lv, (batch_size, T.int64(1500), T.int64(1280)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_ax3_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) // T.int64(1920000)) v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1920000) // T.int64(1280)) v2 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1280) // T.int64(64)) v3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(64)) T.reads(lv[v0, v1, v2 * T.int64(64) + v3]) T.writes(T_reshape[v0, v1, v2, v3]) T_reshape[v0, v1, v2, v3] = lv[v0, v1, v2 * T.int64(64) + v3] @T.prim_func(private=True) def reshape1(var_reshape256: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape256 = T.match_buffer(var_reshape256, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size * T.int64(1500), T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size * T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.reads(reshape256[v0 // T.int64(1500), v0 % T.int64(1500), v1, v2]) T.writes(T_reshape[v0, v1, v2]) T_reshape[v0, v1, v2] = reshape256[v0 // T.int64(1500), v0 % T.int64(1500), v1, v2] @T.prim_func(private=True) def reshape10(var_lv4: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() lv4 = T.match_buffer(var_lv4, (batch_size * T.int64(1500), T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_ax3_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) // T.int64(1920000)) v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1920000) // T.int64(1280)) v2 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(1280) // T.int64(64)) v3 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_ax3_fused_0 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_1) % T.int64(64)) T.reads(lv4[v0 * T.int64(1500) + v1, v2, v3]) T.writes(T_reshape[v0, v1, v2, v3]) T_reshape[v0, v1, v2, v3] = lv4[v0 * T.int64(1500) + v1, v2, v3] @T.prim_func(private=True) def reshape11(var_reshape6: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape6 = T.match_buffer(var_reshape6, (batch_size, T.int64(1500), T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1500), T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(batch_size * T.int64(1875), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1920000)) v1 = T.axis.spatial(T.int64(1500), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1920000) // T.int64(1280)) v2 = T.axis.spatial(T.int64(1280), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280)) T.reads(reshape6[v0, v1, v2 // T.int64(64), v2 % T.int64(64)]) T.writes(T_reshape[v0, v1, v2]) T_reshape[v0, v1, v2] = reshape6[v0, v1, v2 // T.int64(64), v2 % T.int64(64)] @T.prim_func(private=True) def reshape12(var_input_ids: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() input_ids = T.match_buffer(var_input_ids, (T.int64(1), seq_len), "int32") T_reshape = T.match_buffer(var_T_reshape, (seq_len,), "int32") # with T.block("root"): for ax0_fused_0 in T.thread_binding((seq_len + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(seq_len, ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < seq_len) T.reads(input_ids[T.int64(0), v0]) T.writes(T_reshape[v0]) T_reshape[v0] = input_ids[T.int64(0), v0] @T.prim_func(private=True) def reshape13(var_take: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() take = T.match_buffer(var_take, (seq_len, T.int64(1280)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < seq_len * T.int64(1280)) T.reads(take[v0, v1]) T.writes(T_reshape[T.int64(0), v0, v1]) T_reshape[T.int64(0), v0, v1] = take[v0, v1] @T.prim_func(private=True) def reshape14(var_lv416: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() lv416 = T.match_buffer(var_lv416, (T.int64(1), seq_len, T.int64(1280)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(1280)) T.reads(lv416[T.int64(0), v0, v1 * T.int64(64) + v2]) T.writes(T_reshape[T.int64(0), v0, v1, v2]) T_reshape[T.int64(0), v0, v1, v2] = lv416[T.int64(0), v0, v1 * T.int64(64) + v2] @T.prim_func(private=True) def reshape15(var_concat: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() concat = T.match_buffer(var_concat, (T.int64(1), seq_len, T.int64(60), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(60), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(3840)) T.reads(concat[T.int64(0), v0, v1, v2]) T.writes(T_reshape[v0, v1, v2]) T_reshape[v0, v1, v2] = concat[T.int64(0), v0, v1, v2] @T.prim_func(private=True) def reshape16(var_lv69: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() lv69 = T.match_buffer(var_lv69, (seq_len, T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(1280)) T.reads(lv69[v0, v1, v2]) T.writes(T_reshape[T.int64(0), v0, v1, v2]) T_reshape[T.int64(0), v0, v1, v2] = lv69[v0, v1, v2] @T.prim_func(private=True) def reshape17(var_reshape391: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() reshape391 = T.match_buffer(var_reshape391, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(seq_len, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < seq_len * T.int64(1280)) T.reads(reshape391[T.int64(0), v0, v1 // T.int64(64), v1 % T.int64(64)]) T.writes(T_reshape[T.int64(0), v0, v1]) T_reshape[T.int64(0), v0, v1] = reshape391[T.int64(0), v0, v1 // T.int64(64), v1 % T.int64(64)] @T.prim_func(private=True) def reshape18(var_reshape393: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() reshape393 = T.match_buffer(var_reshape393, (T.int64(1), seq_len, T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((seq_len * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(seq_len, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < seq_len * T.int64(1280)) T.reads(reshape393[T.int64(0), v0, v1, v2]) T.writes(T_reshape[v0, v1, v2]) T_reshape[v0, v1, v2] = reshape393[T.int64(0), v0, v1, v2] @T.prim_func(private=True) def reshape19(input_ids: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1)) T.reads(input_ids[T.int64(0), T.int64(0)]) T.writes(T_reshape[T.int64(0)]) T_reshape[T.int64(0)] = input_ids[T.int64(0), T.int64(0)] @T.prim_func(private=True) def reshape2(var_input_ids: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() input_ids = T.match_buffer(var_input_ids, (batch_size, T.int64(1)), "int32") T_reshape = T.match_buffer(var_T_reshape, (batch_size,), "int32") # with T.block("root"): for ax0_fused_0 in T.thread_binding((batch_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < batch_size) T.reads(input_ids[v0, T.int64(0)]) T.writes(T_reshape[v0]) T_reshape[v0] = input_ids[v0, T.int64(0)] @T.prim_func(private=True) def reshape3(var_take3: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() take3 = T.match_buffer(var_take3, (batch_size, T.int64(1280)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) T.reads(take3[v0, v1]) T.writes(T_reshape[v0, T.int64(0), v1]) T_reshape[v0, T.int64(0), v1] = take3[v0, v1] @T.prim_func(private=True) def reshape4(var_lv224: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() lv224 = T.match_buffer(var_lv224, (batch_size, T.int64(1), T.int64(1280)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(1280)) T.reads(lv224[v0, T.int64(0), v1 * T.int64(64) + v2]) T.writes(T_reshape[v0, T.int64(0), v1, v2]) T_reshape[v0, T.int64(0), v1, v2] = lv224[v0, T.int64(0), v1 * T.int64(64) + v2] @T.prim_func(private=True) def reshape5(var_concat32: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() concat32 = T.match_buffer(var_concat32, (batch_size, T.int64(1), T.int64(60), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(60), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(3840) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(3840)) v1 = T.axis.spatial(T.int64(60), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(3840) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(3840)) T.reads(concat32[v0, T.int64(0), v1, v2]) T.writes(T_reshape[v0, v1, v2]) T_reshape[v0, v1, v2] = concat32[v0, T.int64(0), v1, v2] @T.prim_func(private=True) def reshape6(var_lv134: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() lv134 = T.match_buffer(var_lv134, (batch_size, T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(1280)) T.reads(lv134[v0, v1, v2]) T.writes(T_reshape[v0, T.int64(0), v1, v2]) T_reshape[v0, T.int64(0), v1, v2] = lv134[v0, v1, v2] @T.prim_func(private=True) def reshape7(var_reshape714: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape714 = T.match_buffer(var_reshape714, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) T.reads(reshape714[v0, T.int64(0), v1 // T.int64(64), v1 % T.int64(64)]) T.writes(T_reshape[v0, T.int64(0), v1]) T_reshape[v0, T.int64(0), v1] = reshape714[v0, T.int64(0), v1 // T.int64(64), v1 % T.int64(64)] @T.prim_func(private=True) def reshape8(var_reshape716: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape716 = T.match_buffer(var_reshape716, (batch_size, T.int64(1), T.int64(20), T.int64(64)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(20), T.int64(64)), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(batch_size, (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(20), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(1280) // T.int64(64)) v2 = T.axis.spatial(T.int64(64), (ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1) % T.int64(64)) T.where(ax0_ax1_ax2_fused_0 * T.int64(1024) + ax0_ax1_ax2_fused_1 < batch_size * T.int64(1280)) T.reads(reshape716[v0, T.int64(0), v1, v2]) T.writes(T_reshape[v0, v1, v2]) T_reshape[v0, v1, v2] = reshape716[v0, T.int64(0), v1, v2] @T.prim_func def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size)) sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32") num_samples = T.int32(is_size_var=True) sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32") sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32") num_positions = T.int32(is_size_var=True) top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32") sampled_values = T.match_buffer(var_sampled_values, (num_samples,)) top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,)) top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32") # with T.block("root"): for ax0_fused_0 in T.thread_binding((num_positions + num_samples + 1023) // 1024, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("block"): v0 = T.axis.spatial(num_positions + num_samples, ax0_fused_0 * 1024 + ax0_fused_1) T.where(ax0_fused_0 * 1024 + ax0_fused_1 < num_positions + num_samples) T.reads(top_prob_offsets[v0], sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], unsorted_probs[T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]):T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 + (0 - num_positions)]) + (T.max(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions]) + 1 - T.min(top_prob_offsets[v0] // vocab_size, sample_indices[v0 - num_positions])), T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]):T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 + (0 - num_positions)]) + (T.max(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[v0] // vocab_size, top_prob_offsets[v0] % vocab_size], sampling_results[v0 - num_positions]))], sample_indices[v0 + (0 - num_positions)], sampling_results[v0 + (0 - num_positions)]) T.writes(top_prob_indices[v0], top_prob_probs[v0], sampled_values[v0 + (0 - num_positions)]) if v0 < num_positions: row: T.int32 = top_prob_offsets[v0] // vocab_size col: T.int32 = top_prob_offsets[v0] % vocab_size top_prob_indices[v0] = sorted_indices[row, col] top_prob_probs[v0] = unsorted_probs[row, sorted_indices[row, col]] else: vj: T.int32 = v0 - num_positions sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]] @T.prim_func def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True) src = T.match_buffer(var_src, (batch_size, n)) indices = T.match_buffer(var_indices, (batch_size,), "int32") m = T.int32(is_size_var=True) dst = T.match_buffer(var_dst, (m, n)) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * n + 1023) // 1024, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("scatter_2d"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % (n * batch_size) // n) v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1) % n) T.where(ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 < batch_size * n) T.reads(src[v0, v1], indices[v0]) T.writes(dst[indices[v0], v1]) dst[indices[v0], v1] = src[v0, v1] @T.prim_func def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) A = T.match_buffer(var_A, (batch_size, vocab_size)) temperature = T.match_buffer(var_temperature, (batch_size,)) num_chunks = T.int64(is_size_var=True) chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) softmax = T.match_buffer(var_softmax, (batch_size, vocab_size)) # with T.block("root"): temp_max_shared = T.alloc_buffer((batch_size,), scope="shared") temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared") for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"): for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): with T.block("max"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) T.reads(chunked_max[v0, v1]) T.writes(temp_max_shared[v0]) with T.init(): temp_max_shared[v0] = T.float32(-3.4028234663852886e+38) temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1]) for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): with T.block("sum_exp"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0]) T.writes(temp_sum_shared[v0]) with T.init(): temp_sum_shared[v0] = T.float32(0) temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1]) for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"): for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): with T.block("log_pad"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks) v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2) T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0]) T.writes(softmax[v0, v1 * T.int64(4096) + v2]) if v1 * T.int64(4096) + v2 < vocab_size: softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0]) @T.prim_func(private=True) def take(model_decoder_embed_tokens_weight3: T.Buffer((T.int64(51866), T.int64(1280)), "float16"), var_reshape707: T.handle, var_T_take: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape707 = T.match_buffer(var_reshape707, (batch_size,), "int32") T_take = T.match_buffer(var_T_take, (batch_size, T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_take"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) T.reads(model_decoder_embed_tokens_weight3[reshape707[v0], v1], reshape707[v0]) T.writes(T_take[v0, v1]) T_take[v0, v1] = model_decoder_embed_tokens_weight3[reshape707[v0], v1] @T.prim_func(private=True) def take1(model_decoder_embed_positions_weight3: T.Buffer((T.int64(448), T.int64(1280)), "float16"), var_lv133: T.handle, var_T_take: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size = T.int64() lv133 = T.match_buffer(var_lv133, (batch_size,), "int32") T_take = T.match_buffer(var_T_take, (batch_size, T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_take"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) T.reads(model_decoder_embed_positions_weight3[lv133[v0], v1], lv133[v0]) T.writes(T_take[v0, v1]) T_take[v0, v1] = model_decoder_embed_positions_weight3[lv133[v0], v1] @T.prim_func(private=True) def take2(var_layer_norm161: T.handle, var_logit_positions: T.handle, var_T_take: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() layer_norm161 = T.match_buffer(var_layer_norm161, (T.int64(1), seq_len, T.int64(1280)), "float16") batch_size = T.int64() logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32") T_take = T.match_buffer(var_T_take, (T.int64(1), batch_size, T.int64(1280)), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_take"): v0 = T.axis.spatial(batch_size, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(1280)) v1 = T.axis.spatial(T.int64(1280), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(1280)) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size * T.int64(1280)) T.reads(layer_norm161[T.int64(0), logit_positions[v0], v1], logit_positions[v0]) T.writes(T_take[T.int64(0), v0, v1]) T_take[T.int64(0), v0, v1] = layer_norm161[T.int64(0), logit_positions[v0], v1] @T.prim_func(private=True) def take3(model_decoder_embed_tokens_weight5: T.Buffer((T.int64(51866), T.int64(1280)), "float16"), reshape1353: T.Buffer((T.int64(1),), "int32"), T_take: T.Buffer((T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_take"): v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) T.reads(model_decoder_embed_tokens_weight5[reshape1353[T.int64(0)], v0], reshape1353[T.int64(0)]) T.writes(T_take[T.int64(0), v0]) T_take[T.int64(0), v0] = model_decoder_embed_tokens_weight5[reshape1353[T.int64(0)], v0] @T.prim_func(private=True) def take4(model_decoder_embed_positions_weight5: T.Buffer((T.int64(448), T.int64(1280)), "float16"), lv264: T.Buffer((T.int64(1),), "int32"), T_take: T.Buffer((T.int64(1), T.int64(1280)), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("T_take"): v0 = T.axis.spatial(T.int64(1280), ax0_fused_0 * T.int64(1024) + ax0_fused_1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1280)) T.reads(model_decoder_embed_positions_weight5[lv264[T.int64(0)], v0], lv264[T.int64(0)]) T.writes(T_take[T.int64(0), v0]) T_take[T.int64(0), v0] = model_decoder_embed_positions_weight5[lv264[T.int64(0)], v0] @T.prim_func(private=True) def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(), T.int64() probs = T.match_buffer(var_probs, (batch_size, vocab_size)) lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32") batch_size_1, vocab_size_1 = T.int64(), T.int64() take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1)) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((batch_size_1 * vocab_size_1 + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("take_sorted_probs"): v0 = T.axis.spatial(batch_size_1, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % (vocab_size_1 * batch_size_1) // vocab_size_1) v1 = T.axis.spatial(vocab_size_1, (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % vocab_size_1) T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < batch_size_1 * vocab_size_1) T.reads(probs[v0, lv1[v0, v1]], lv1[v0, v1]) T.writes(take_sorted_probs[v0, v1]) take_sorted_probs[v0, v1] = probs[v0, lv1[v0, v1]] @T.prim_func def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) num_pages, page_size = T.int64(), T.int64(is_size_var=True) pages = T.match_buffer(var_pages, (num_pages, 2, 20, page_size, 64), "float16") seqlen = T.int64(is_size_var=True) position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1) k_data = T.match_buffer(var_k_data, (32, seqlen, 20, 64), "float16") v_data = T.match_buffer(var_v_data, (32, seqlen, 20, 64), "float16") # with T.block("root"): for p_h_d_fused_0 in T.thread_binding((seqlen * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for p_h_d_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.block("copy0"): vp = T.axis.spatial(seqlen, (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) // T.int64(1280)) vh = T.axis.spatial(20, T.Cast("int32", (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) % T.int64(1280) // T.int64(64))) vd = T.axis.spatial(64, T.Cast("int32", (p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1) % T.int64(64))) T.where(p_h_d_fused_0 * T.int64(1024) + p_h_d_fused_1 < seqlen * T.int64(1280)) T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd]) T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) position: T.int32 = position_map[vp] k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd] v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd] @T.prim_func def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) num_pages = T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, 20, 16, 64), "float16") ntoken = T.int64(is_size_var=True) k_data = T.match_buffer(var_k_data, (ntoken, 20, 64), "float16") v_data = T.match_buffer(var_v_data, (ntoken, 20, 64), "float16") position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1) # with T.block("root"): for global_pos_h_f_fused_0 in T.thread_binding((ntoken * T.int64(1280) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"): for global_pos_h_f_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): if position_map[(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(1280)] != -1: with T.block("k_transpose_append"): vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(1280)) vh = T.axis.spatial(20, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(1280) // T.int64(64))) vf = T.axis.spatial(64, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(64))) T.where(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1 < ntoken * T.int64(1280)) T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos = T.axis.spatial(ntoken, (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) // T.int64(1280)) vh = T.axis.spatial(20, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(1280) // T.int64(64))) vf = T.axis.spatial(64, T.Cast("int32", (global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1) % T.int64(64))) T.where(global_pos_h_f_fused_0 * T.int64(1024) + global_pos_h_f_fused_1 < ntoken * T.int64(1280)) T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf] @T.prim_func(private=True) def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) B, N = T.int32(), T.int32() prob = T.match_buffer(var_prob, (B, N)) top_p_arr = T.match_buffer(var_top_p_arr, (B,)) init_pivots = T.match_buffer(var_init_pivots, (B, 3)) final_pivot = T.match_buffer(var_final_pivot, (B,)) final_lsum = T.match_buffer(var_final_lsum, (B,)) # with T.block("root"): pivot = T.alloc_buffer((3,), scope="local") top_p = T.alloc_buffer((1,), scope="local") L = T.alloc_buffer((1,), scope="shared") R_1 = T.alloc_buffer((1,), scope="shared") L_local = T.alloc_buffer((1,), scope="local") R_local = T.alloc_buffer((1,), scope="local") q = T.alloc_buffer((1,), scope="local") lsum = T.alloc_buffer((3,), scope="local") lmin_broadcast = T.alloc_buffer((1,), scope="shared") lmin_broadcast_local = T.alloc_buffer((1,), scope="local") lmin = T.alloc_buffer((3,), scope="local") cmin = T.alloc_buffer((3,), "int32", scope="local") total_sum = T.alloc_buffer((1,), scope="local") it = T.alloc_buffer((1,), "int32", scope="local") es_local = T.alloc_buffer((1,), "bool", scope="local") es = T.alloc_buffer((1,), "bool", scope="shared") find_pivot_local = T.alloc_buffer((1,), "bool", scope="local") find_pivot = T.alloc_buffer((1,), "bool", scope="shared") total_sum_reduce = T.alloc_buffer((1,), scope="local") lsum_reduce = T.alloc_buffer((1,), scope="local") lmin_reduce = T.alloc_buffer((1,), scope="local") cmin_reduce = T.alloc_buffer((1,), "int32", scope="local") for _bx in T.thread_binding(B, thread="blockIdx.x"): for _tx in T.thread_binding(1024, thread="threadIdx.x"): with T.block("CTA"): b, tx = T.axis.remap("SS", [_bx, _tx]) T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0]) T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0]) top_p[0] = top_p_arr[b] if tx == 0: L[0] = T.float32(1) - top_p[0] R_1[0] = T.float32(9.9999999999999995e-08) find_pivot[0] = T.bool(False) T.tvm_storage_sync("shared") L_local[0] = L[0] R_local[0] = R_1[0] for i in T.unroll(3): pivot[i] = init_pivots[b, i] find_pivot_local[0] = T.bool(False) if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08): if tx == 0: final_lsum[b] = T.float32(1) final_pivot[b] = T.float32(0) find_pivot_local[0] = T.bool(True) while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]): T.tvm_storage_sync("shared") for pidx in T.unroll(3): lsum[pidx] = T.float32(0) lmin[pidx] = T.float32(3.4028234663852886e+38) cmin[pidx] = 0 total_sum[0] = T.float32(0) it[0] = 0 es_local[0] = T.bool(False) while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]: q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0)) total_sum[0] = total_sum[0] + q[0] for pidx in T.unroll(3): if q[0] >= pivot[pidx]: lsum[pidx] = lsum[pidx] + q[0] if lmin[pidx] > q[0]: lmin[pidx] = q[0] cmin[pidx] = 1 else: if lmin[pidx] == q[0]: cmin[pidx] = cmin[pidx] + 1 it[0] = it[0] + 1 if it[0] % 32 == 0: with T.block("block_cross_thread"): T.reads(total_sum[0]) T.writes(total_sum_reduce[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx) if tx == 0: es[0] = T.float32(1) - total_sum_reduce[0] < pivot[2] T.tvm_storage_sync("shared") es_local[0] = es[0] T.tvm_storage_sync("shared") for pidx in range(3): with T.block("block_cross_thread"): T.reads(lsum[pidx]) T.writes(lsum_reduce[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx) with T.block("block_cross_thread"): T.reads(lmin[pidx]) T.writes(lmin_reduce[0]) T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx) if tx == 0: lmin_broadcast[0] = lmin_reduce[0] T.tvm_storage_sync("shared") lmin_broadcast_local[0] = lmin_broadcast[0] if lmin[pidx] > lmin_broadcast_local[0]: cmin[pidx] = 0 if tx == 0: lsum[pidx] = lsum_reduce[0] lmin[pidx] = lmin_reduce[0] with T.block("block_cross_thread"): T.reads(cmin[pidx]) T.writes(cmin_reduce[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx) if tx == 0: cmin[pidx] = cmin_reduce[0] T.tvm_storage_sync("shared") if tx == 0: it[0] = 0 while it[0] < 3 and not find_pivot_local[0]: if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]: find_pivot[0] = T.bool(True) find_pivot_local[0] = T.bool(True) final_pivot[b] = pivot[it[0]] final_lsum[b] = lsum[it[0]] else: if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]: R_1[0] = pivot[it[0]] final_lsum[b] = lsum[it[0]] else: if lsum[it[0]] < top_p[0]: L[0] = pivot[it[0]] it[0] = it[0] + 1 T.tvm_storage_sync("shared") L_local[0] = L[0] R_local[0] = R_1[0] find_pivot_local[0] = find_pivot[0] for pidx in T.unroll(3): pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4) if tx == 0: if not find_pivot_local[0]: final_pivot[b] = R_local[0] if R_local[0] == T.float32(9.9999999999999995e-08): final_lsum[b] = lsum[2] @T.prim_func(private=True) def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle): T.func_attr({"target": T.target({"arch": "sm_89", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) B, N = T.int32(), T.int32() prob = T.match_buffer(var_prob, (B, N)) final_pivot = T.match_buffer(var_final_pivot, (B,)) final_lsum = T.match_buffer(var_final_lsum, (B,)) renorm_prob = T.match_buffer(var_renorm_prob, (B, N)) # with T.block("root"): pivot = T.alloc_buffer((1,), scope="local") lsum = T.alloc_buffer((1,), scope="local") for _by in T.thread_binding(B, thread="blockIdx.y"): for _bx in T.thread_binding((B + 511) // B, thread="blockIdx.x"): for _tx in T.thread_binding(1024, thread="threadIdx.x"): with T.block("CTA"): by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx]) T.reads(final_pivot[by], final_lsum[by], prob[by, T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx:T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx + (T.Select(0 <= (B + 511) // B, (N - 1) // ((B + 511) // B * 1024) * ((B + 511) // B), 0 - (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + 1)], pivot[0], lsum[0]) T.writes(pivot[0], lsum[0], renorm_prob[by, T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx:T.Select(0 <= (B + 511) // B, 0, (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + bx * 1024 + tx + (T.Select(0 <= (B + 511) // B, (N - 1) // ((B + 511) // B * 1024) * ((B + 511) // B), 0 - (((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024) - 1) * ((B + 511) // B)) * 1024 + 1)]) pivot[0] = final_pivot[by] lsum[0] = final_lsum[by] for i in range(((B + 511) // B * 1024 + N - 1) // ((B + 511) // B * 1024)): if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N: renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0)) @R.function def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")): batch_size = T.int64() vocab_size = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) cls = Module with R.dataflow(): lv: R.Tensor((8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12,), dtype="uint8") = R.builtin.alloc_tensor(R.shape([8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12]), R.dtype("uint8"), R.prim_value(0), R.str("global")) lv1 = R.call_tir(cls.argsort_thrust, (probs, lv), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="int32")) lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1 R.output(gv1) return gv1 @R.function def batch_compute_cross_attn_kv(encoder_hidden_states: R.Tensor(("batch_size", 1500, 1280), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Object: batch_size = T.int64() R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module with R.dataflow(): model_decoder_layers_0_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[498] model_decoder_layers_0_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[499] model_decoder_layers_0_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[500] model_decoder_layers_1_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[522] model_decoder_layers_1_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[523] model_decoder_layers_1_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[524] model_decoder_layers_2_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[546] model_decoder_layers_2_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[547] model_decoder_layers_2_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[548] model_decoder_layers_3_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[570] model_decoder_layers_3_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[571] model_decoder_layers_3_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[572] model_decoder_layers_4_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[594] model_decoder_layers_4_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[595] model_decoder_layers_4_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[596] model_decoder_layers_5_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[618] model_decoder_layers_5_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[619] model_decoder_layers_5_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[620] model_decoder_layers_6_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[642] model_decoder_layers_6_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[643] model_decoder_layers_6_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[644] model_decoder_layers_7_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[666] model_decoder_layers_7_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[667] model_decoder_layers_7_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[668] model_decoder_layers_8_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[690] model_decoder_layers_8_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[691] model_decoder_layers_8_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[692] model_decoder_layers_9_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[714] model_decoder_layers_9_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[715] model_decoder_layers_9_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[716] model_decoder_layers_10_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[738] model_decoder_layers_10_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[739] model_decoder_layers_10_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[740] model_decoder_layers_11_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[762] model_decoder_layers_11_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[763] model_decoder_layers_11_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[764] model_decoder_layers_12_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[786] model_decoder_layers_12_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[787] model_decoder_layers_12_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[788] model_decoder_layers_13_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[810] model_decoder_layers_13_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[811] model_decoder_layers_13_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[812] model_decoder_layers_14_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[834] model_decoder_layers_14_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[835] model_decoder_layers_14_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[836] model_decoder_layers_15_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[858] model_decoder_layers_15_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[859] model_decoder_layers_15_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[860] model_decoder_layers_16_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[882] model_decoder_layers_16_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[883] model_decoder_layers_16_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[884] model_decoder_layers_17_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[906] model_decoder_layers_17_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[907] model_decoder_layers_17_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[908] model_decoder_layers_18_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[930] model_decoder_layers_18_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[931] model_decoder_layers_18_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[932] model_decoder_layers_19_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[954] model_decoder_layers_19_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[955] model_decoder_layers_19_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[956] model_decoder_layers_20_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[978] model_decoder_layers_20_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[979] model_decoder_layers_20_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[980] model_decoder_layers_21_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1002] model_decoder_layers_21_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1003] model_decoder_layers_21_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1004] model_decoder_layers_22_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1026] model_decoder_layers_22_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1027] model_decoder_layers_22_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1028] model_decoder_layers_23_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1050] model_decoder_layers_23_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1051] model_decoder_layers_23_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1052] model_decoder_layers_24_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1074] model_decoder_layers_24_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1075] model_decoder_layers_24_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1076] model_decoder_layers_25_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1098] model_decoder_layers_25_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1099] model_decoder_layers_25_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1100] model_decoder_layers_26_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1122] model_decoder_layers_26_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1123] model_decoder_layers_26_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1124] model_decoder_layers_27_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1146] model_decoder_layers_27_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1147] model_decoder_layers_27_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1148] model_decoder_layers_28_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1170] model_decoder_layers_28_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1171] model_decoder_layers_28_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1172] model_decoder_layers_29_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1194] model_decoder_layers_29_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1195] model_decoder_layers_29_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1196] model_decoder_layers_30_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1218] model_decoder_layers_30_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1219] model_decoder_layers_30_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1220] model_decoder_layers_31_encoder_attn_k_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1242] model_decoder_layers_31_encoder_attn_v_proj_weight1: R.Tensor((1280, 1280), dtype="float16") = packed_params[1243] model_decoder_layers_31_encoder_attn_v_proj_bias1: R.Tensor((1280,), dtype="float16") = packed_params[1244] lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_0_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape256 = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_0_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_0_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape257 = R.call_tir(cls.reshape, (lv_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape258 = R.call_tir(cls.reshape1, (reshape256,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape259 = R.call_tir(cls.reshape1, (reshape257,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv36: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", paged_kv_cache, R.prim_value(0), reshape258, reshape259, sinfo_args=(R.Object,)) lv1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_1_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape260 = R.call_tir(cls.reshape, (lv1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv1_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_1_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_1_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape261 = R.call_tir(cls.reshape, (lv1_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape262 = R.call_tir(cls.reshape1, (reshape260,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape263 = R.call_tir(cls.reshape1, (reshape261,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv37: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv36, R.prim_value(1), reshape262, reshape263, sinfo_args=(R.Object,)) lv2 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_2_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape264 = R.call_tir(cls.reshape, (lv2,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv2_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_2_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_2_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape265 = R.call_tir(cls.reshape, (lv2_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape266 = R.call_tir(cls.reshape1, (reshape264,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape267 = R.call_tir(cls.reshape1, (reshape265,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv38: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv37, R.prim_value(2), reshape266, reshape267, sinfo_args=(R.Object,)) lv3 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_3_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape268 = R.call_tir(cls.reshape, (lv3,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv3_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_3_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_3_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape269 = R.call_tir(cls.reshape, (lv3_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape270 = R.call_tir(cls.reshape1, (reshape268,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape271 = R.call_tir(cls.reshape1, (reshape269,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv39: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv38, R.prim_value(3), reshape270, reshape271, sinfo_args=(R.Object,)) lv4 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_4_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape272 = R.call_tir(cls.reshape, (lv4,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv4_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_4_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_4_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape273 = R.call_tir(cls.reshape, (lv4_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape274 = R.call_tir(cls.reshape1, (reshape272,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape275 = R.call_tir(cls.reshape1, (reshape273,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv40: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv39, R.prim_value(4), reshape274, reshape275, sinfo_args=(R.Object,)) lv5 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_5_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape276 = R.call_tir(cls.reshape, (lv5,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv5_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_5_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_5_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape277 = R.call_tir(cls.reshape, (lv5_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape278 = R.call_tir(cls.reshape1, (reshape276,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape279 = R.call_tir(cls.reshape1, (reshape277,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv41: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv40, R.prim_value(5), reshape278, reshape279, sinfo_args=(R.Object,)) lv6 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_6_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape280 = R.call_tir(cls.reshape, (lv6,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv6_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_6_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_6_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape281 = R.call_tir(cls.reshape, (lv6_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape282 = R.call_tir(cls.reshape1, (reshape280,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape283 = R.call_tir(cls.reshape1, (reshape281,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv42: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv41, R.prim_value(6), reshape282, reshape283, sinfo_args=(R.Object,)) lv7 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_7_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape284 = R.call_tir(cls.reshape, (lv7,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv7_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_7_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_7_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape285 = R.call_tir(cls.reshape, (lv7_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape286 = R.call_tir(cls.reshape1, (reshape284,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape287 = R.call_tir(cls.reshape1, (reshape285,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv43: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv42, R.prim_value(7), reshape286, reshape287, sinfo_args=(R.Object,)) lv8 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_8_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape288 = R.call_tir(cls.reshape, (lv8,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv8_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_8_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_8_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape289 = R.call_tir(cls.reshape, (lv8_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape290 = R.call_tir(cls.reshape1, (reshape288,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape291 = R.call_tir(cls.reshape1, (reshape289,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv44: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv43, R.prim_value(8), reshape290, reshape291, sinfo_args=(R.Object,)) lv9 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_9_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape292 = R.call_tir(cls.reshape, (lv9,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv9_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_9_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_9_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape293 = R.call_tir(cls.reshape, (lv9_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape294 = R.call_tir(cls.reshape1, (reshape292,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape295 = R.call_tir(cls.reshape1, (reshape293,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv45: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv44, R.prim_value(9), reshape294, reshape295, sinfo_args=(R.Object,)) lv10 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_10_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape296 = R.call_tir(cls.reshape, (lv10,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv10_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_10_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_10_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape297 = R.call_tir(cls.reshape, (lv10_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape298 = R.call_tir(cls.reshape1, (reshape296,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape299 = R.call_tir(cls.reshape1, (reshape297,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv46: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv45, R.prim_value(10), reshape298, reshape299, sinfo_args=(R.Object,)) lv11 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_11_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape300 = R.call_tir(cls.reshape, (lv11,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv11_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_11_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_11_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape301 = R.call_tir(cls.reshape, (lv11_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape302 = R.call_tir(cls.reshape1, (reshape300,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape303 = R.call_tir(cls.reshape1, (reshape301,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv47: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv46, R.prim_value(11), reshape302, reshape303, sinfo_args=(R.Object,)) lv12 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_12_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape304 = R.call_tir(cls.reshape, (lv12,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv12_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_12_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_12_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape305 = R.call_tir(cls.reshape, (lv12_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape306 = R.call_tir(cls.reshape1, (reshape304,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape307 = R.call_tir(cls.reshape1, (reshape305,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv48: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv47, R.prim_value(12), reshape306, reshape307, sinfo_args=(R.Object,)) lv13 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_13_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape308 = R.call_tir(cls.reshape, (lv13,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv13_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_13_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_13_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape309 = R.call_tir(cls.reshape, (lv13_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape310 = R.call_tir(cls.reshape1, (reshape308,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape311 = R.call_tir(cls.reshape1, (reshape309,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv49: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv48, R.prim_value(13), reshape310, reshape311, sinfo_args=(R.Object,)) lv14 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_14_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape312 = R.call_tir(cls.reshape, (lv14,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv14_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_14_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_14_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape313 = R.call_tir(cls.reshape, (lv14_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape314 = R.call_tir(cls.reshape1, (reshape312,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape315 = R.call_tir(cls.reshape1, (reshape313,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv50: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv49, R.prim_value(14), reshape314, reshape315, sinfo_args=(R.Object,)) lv15 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_15_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape316 = R.call_tir(cls.reshape, (lv15,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv15_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_15_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_15_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape317 = R.call_tir(cls.reshape, (lv15_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape318 = R.call_tir(cls.reshape1, (reshape316,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape319 = R.call_tir(cls.reshape1, (reshape317,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv51: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv50, R.prim_value(15), reshape318, reshape319, sinfo_args=(R.Object,)) lv16 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_16_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape320 = R.call_tir(cls.reshape, (lv16,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv16_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_16_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_16_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape321 = R.call_tir(cls.reshape, (lv16_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape322 = R.call_tir(cls.reshape1, (reshape320,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape323 = R.call_tir(cls.reshape1, (reshape321,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv52: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv51, R.prim_value(16), reshape322, reshape323, sinfo_args=(R.Object,)) lv17 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_17_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape324 = R.call_tir(cls.reshape, (lv17,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv17_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_17_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_17_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape325 = R.call_tir(cls.reshape, (lv17_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape326 = R.call_tir(cls.reshape1, (reshape324,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape327 = R.call_tir(cls.reshape1, (reshape325,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv53: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv52, R.prim_value(17), reshape326, reshape327, sinfo_args=(R.Object,)) lv18 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_18_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape328 = R.call_tir(cls.reshape, (lv18,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv18_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_18_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_18_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape329 = R.call_tir(cls.reshape, (lv18_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape330 = R.call_tir(cls.reshape1, (reshape328,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape331 = R.call_tir(cls.reshape1, (reshape329,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv54: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv53, R.prim_value(18), reshape330, reshape331, sinfo_args=(R.Object,)) lv19 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_19_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape332 = R.call_tir(cls.reshape, (lv19,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv19_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_19_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_19_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape333 = R.call_tir(cls.reshape, (lv19_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape334 = R.call_tir(cls.reshape1, (reshape332,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape335 = R.call_tir(cls.reshape1, (reshape333,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv55: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv54, R.prim_value(19), reshape334, reshape335, sinfo_args=(R.Object,)) lv20 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_20_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape336 = R.call_tir(cls.reshape, (lv20,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv20_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_20_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_20_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape337 = R.call_tir(cls.reshape, (lv20_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape338 = R.call_tir(cls.reshape1, (reshape336,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape339 = R.call_tir(cls.reshape1, (reshape337,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv56: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv55, R.prim_value(20), reshape338, reshape339, sinfo_args=(R.Object,)) lv21 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_21_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape340 = R.call_tir(cls.reshape, (lv21,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv21_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_21_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_21_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape341 = R.call_tir(cls.reshape, (lv21_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape342 = R.call_tir(cls.reshape1, (reshape340,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape343 = R.call_tir(cls.reshape1, (reshape341,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv57: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv56, R.prim_value(21), reshape342, reshape343, sinfo_args=(R.Object,)) lv22 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_22_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape344 = R.call_tir(cls.reshape, (lv22,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv22_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_22_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_22_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape345 = R.call_tir(cls.reshape, (lv22_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape346 = R.call_tir(cls.reshape1, (reshape344,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape347 = R.call_tir(cls.reshape1, (reshape345,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv58: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv57, R.prim_value(22), reshape346, reshape347, sinfo_args=(R.Object,)) lv23 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_23_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape348 = R.call_tir(cls.reshape, (lv23,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv23_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_23_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_23_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape349 = R.call_tir(cls.reshape, (lv23_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape350 = R.call_tir(cls.reshape1, (reshape348,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape351 = R.call_tir(cls.reshape1, (reshape349,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv59: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv58, R.prim_value(23), reshape350, reshape351, sinfo_args=(R.Object,)) lv24 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_24_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape352 = R.call_tir(cls.reshape, (lv24,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv24_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_24_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_24_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape353 = R.call_tir(cls.reshape, (lv24_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape354 = R.call_tir(cls.reshape1, (reshape352,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape355 = R.call_tir(cls.reshape1, (reshape353,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv60: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv59, R.prim_value(24), reshape354, reshape355, sinfo_args=(R.Object,)) lv25 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_25_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape356 = R.call_tir(cls.reshape, (lv25,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv25_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_25_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_25_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape357 = R.call_tir(cls.reshape, (lv25_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape358 = R.call_tir(cls.reshape1, (reshape356,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape359 = R.call_tir(cls.reshape1, (reshape357,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv61: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv60, R.prim_value(25), reshape358, reshape359, sinfo_args=(R.Object,)) lv26 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_26_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape360 = R.call_tir(cls.reshape, (lv26,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv26_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_26_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_26_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape361 = R.call_tir(cls.reshape, (lv26_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape362 = R.call_tir(cls.reshape1, (reshape360,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape363 = R.call_tir(cls.reshape1, (reshape361,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv62: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv61, R.prim_value(26), reshape362, reshape363, sinfo_args=(R.Object,)) lv27 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_27_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape364 = R.call_tir(cls.reshape, (lv27,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv27_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_27_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_27_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape365 = R.call_tir(cls.reshape, (lv27_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape366 = R.call_tir(cls.reshape1, (reshape364,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape367 = R.call_tir(cls.reshape1, (reshape365,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv63: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv62, R.prim_value(27), reshape366, reshape367, sinfo_args=(R.Object,)) lv28 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_28_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape368 = R.call_tir(cls.reshape, (lv28,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv28_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_28_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_28_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape369 = R.call_tir(cls.reshape, (lv28_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape370 = R.call_tir(cls.reshape1, (reshape368,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape371 = R.call_tir(cls.reshape1, (reshape369,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv64: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv63, R.prim_value(28), reshape370, reshape371, sinfo_args=(R.Object,)) lv29 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_29_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape372 = R.call_tir(cls.reshape, (lv29,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv29_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_29_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_29_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape373 = R.call_tir(cls.reshape, (lv29_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape374 = R.call_tir(cls.reshape1, (reshape372,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape375 = R.call_tir(cls.reshape1, (reshape373,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv65: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv64, R.prim_value(29), reshape374, reshape375, sinfo_args=(R.Object,)) lv30 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_30_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape376 = R.call_tir(cls.reshape, (lv30,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv30_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_30_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_30_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape377 = R.call_tir(cls.reshape, (lv30_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape378 = R.call_tir(cls.reshape1, (reshape376,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape379 = R.call_tir(cls.reshape1, (reshape377,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv66: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv65, R.prim_value(30), reshape378, reshape379, sinfo_args=(R.Object,)) lv31 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_decoder_layers_31_encoder_attn_k_proj_weight1, encoder_hidden_states), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape380 = R.call_tir(cls.reshape, (lv31,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv31_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_decoder_layers_31_encoder_attn_v_proj_weight1, encoder_hidden_states, model_decoder_layers_31_encoder_attn_v_proj_bias1), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape381 = R.call_tir(cls.reshape, (lv31_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape382 = R.call_tir(cls.reshape1, (reshape380,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape383 = R.call_tir(cls.reshape1, (reshape381,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) gv1: R.Object = R.call_pure_packed("vm.builtin.attention_kv_cache_push_cross_attention_kv", lv66, R.prim_value(31), reshape382, reshape383, sinfo_args=(R.Object,)) R.output(gv1) return gv1 @R.function def batch_decode(input_ids: R.Tensor(("batch_size", 1), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor(("batch_size", 1, 51866), dtype="float32"): batch_size = T.int64() R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module with R.dataflow(): model_decoder_embed_tokens_weight3: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] model_decoder_embed_positions_weight3: R.Tensor((448, 1280), dtype="float16") = packed_params[488] model_decoder_layers_0_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] model_decoder_layers_0_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] model_decoder_layers_0_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[491] model_decoder_layers_0_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] model_decoder_layers_0_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[493] model_decoder_layers_0_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] model_decoder_layers_0_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[495] model_decoder_layers_0_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[496] model_decoder_layers_0_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[497] model_decoder_layers_0_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] model_decoder_layers_0_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[502] model_decoder_layers_0_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] model_decoder_layers_0_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[504] model_decoder_layers_0_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[505] model_decoder_layers_0_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[506] model_decoder_layers_0_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] model_decoder_layers_0_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[508] model_decoder_layers_0_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] model_decoder_layers_0_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[510] model_decoder_layers_0_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[511] model_decoder_layers_0_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[512] model_decoder_layers_1_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] model_decoder_layers_1_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] model_decoder_layers_1_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[515] model_decoder_layers_1_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] model_decoder_layers_1_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[517] model_decoder_layers_1_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] model_decoder_layers_1_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[519] model_decoder_layers_1_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[520] model_decoder_layers_1_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[521] model_decoder_layers_1_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] model_decoder_layers_1_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[526] model_decoder_layers_1_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] model_decoder_layers_1_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[528] model_decoder_layers_1_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[529] model_decoder_layers_1_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[530] model_decoder_layers_1_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] model_decoder_layers_1_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[532] model_decoder_layers_1_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] model_decoder_layers_1_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[534] model_decoder_layers_1_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[535] model_decoder_layers_1_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[536] model_decoder_layers_2_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] model_decoder_layers_2_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] model_decoder_layers_2_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[539] model_decoder_layers_2_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] model_decoder_layers_2_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[541] model_decoder_layers_2_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] model_decoder_layers_2_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[543] model_decoder_layers_2_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[544] model_decoder_layers_2_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[545] model_decoder_layers_2_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] model_decoder_layers_2_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[550] model_decoder_layers_2_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] model_decoder_layers_2_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[552] model_decoder_layers_2_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[553] model_decoder_layers_2_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[554] model_decoder_layers_2_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] model_decoder_layers_2_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[556] model_decoder_layers_2_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] model_decoder_layers_2_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[558] model_decoder_layers_2_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[559] model_decoder_layers_2_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[560] model_decoder_layers_3_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] model_decoder_layers_3_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] model_decoder_layers_3_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[563] model_decoder_layers_3_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] model_decoder_layers_3_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[565] model_decoder_layers_3_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] model_decoder_layers_3_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[567] model_decoder_layers_3_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[568] model_decoder_layers_3_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[569] model_decoder_layers_3_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] model_decoder_layers_3_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[574] model_decoder_layers_3_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] model_decoder_layers_3_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[576] model_decoder_layers_3_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[577] model_decoder_layers_3_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[578] model_decoder_layers_3_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] model_decoder_layers_3_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[580] model_decoder_layers_3_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] model_decoder_layers_3_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[582] model_decoder_layers_3_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[583] model_decoder_layers_3_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[584] model_decoder_layers_4_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] model_decoder_layers_4_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] model_decoder_layers_4_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[587] model_decoder_layers_4_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] model_decoder_layers_4_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[589] model_decoder_layers_4_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] model_decoder_layers_4_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[591] model_decoder_layers_4_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[592] model_decoder_layers_4_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[593] model_decoder_layers_4_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] model_decoder_layers_4_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[598] model_decoder_layers_4_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] model_decoder_layers_4_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[600] model_decoder_layers_4_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[601] model_decoder_layers_4_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[602] model_decoder_layers_4_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] model_decoder_layers_4_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[604] model_decoder_layers_4_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] model_decoder_layers_4_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[606] model_decoder_layers_4_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[607] model_decoder_layers_4_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[608] model_decoder_layers_5_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] model_decoder_layers_5_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] model_decoder_layers_5_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[611] model_decoder_layers_5_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] model_decoder_layers_5_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[613] model_decoder_layers_5_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] model_decoder_layers_5_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[615] model_decoder_layers_5_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[616] model_decoder_layers_5_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[617] model_decoder_layers_5_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] model_decoder_layers_5_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[622] model_decoder_layers_5_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] model_decoder_layers_5_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[624] model_decoder_layers_5_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[625] model_decoder_layers_5_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[626] model_decoder_layers_5_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] model_decoder_layers_5_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[628] model_decoder_layers_5_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] model_decoder_layers_5_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[630] model_decoder_layers_5_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[631] model_decoder_layers_5_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[632] model_decoder_layers_6_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] model_decoder_layers_6_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] model_decoder_layers_6_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[635] model_decoder_layers_6_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] model_decoder_layers_6_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[637] model_decoder_layers_6_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] model_decoder_layers_6_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[639] model_decoder_layers_6_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[640] model_decoder_layers_6_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[641] model_decoder_layers_6_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] model_decoder_layers_6_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[646] model_decoder_layers_6_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] model_decoder_layers_6_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[648] model_decoder_layers_6_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[649] model_decoder_layers_6_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[650] model_decoder_layers_6_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] model_decoder_layers_6_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[652] model_decoder_layers_6_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] model_decoder_layers_6_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[654] model_decoder_layers_6_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[655] model_decoder_layers_6_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[656] model_decoder_layers_7_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] model_decoder_layers_7_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] model_decoder_layers_7_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[659] model_decoder_layers_7_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] model_decoder_layers_7_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[661] model_decoder_layers_7_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] model_decoder_layers_7_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[663] model_decoder_layers_7_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[664] model_decoder_layers_7_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[665] model_decoder_layers_7_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] model_decoder_layers_7_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[670] model_decoder_layers_7_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] model_decoder_layers_7_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[672] model_decoder_layers_7_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[673] model_decoder_layers_7_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[674] model_decoder_layers_7_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] model_decoder_layers_7_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[676] model_decoder_layers_7_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] model_decoder_layers_7_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[678] model_decoder_layers_7_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[679] model_decoder_layers_7_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[680] model_decoder_layers_8_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] model_decoder_layers_8_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] model_decoder_layers_8_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[683] model_decoder_layers_8_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] model_decoder_layers_8_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[685] model_decoder_layers_8_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] model_decoder_layers_8_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[687] model_decoder_layers_8_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[688] model_decoder_layers_8_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[689] model_decoder_layers_8_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] model_decoder_layers_8_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[694] model_decoder_layers_8_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] model_decoder_layers_8_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[696] model_decoder_layers_8_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[697] model_decoder_layers_8_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[698] model_decoder_layers_8_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] model_decoder_layers_8_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[700] model_decoder_layers_8_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] model_decoder_layers_8_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[702] model_decoder_layers_8_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[703] model_decoder_layers_8_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[704] model_decoder_layers_9_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] model_decoder_layers_9_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] model_decoder_layers_9_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[707] model_decoder_layers_9_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] model_decoder_layers_9_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[709] model_decoder_layers_9_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] model_decoder_layers_9_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[711] model_decoder_layers_9_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[712] model_decoder_layers_9_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[713] model_decoder_layers_9_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] model_decoder_layers_9_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[718] model_decoder_layers_9_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] model_decoder_layers_9_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[720] model_decoder_layers_9_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[721] model_decoder_layers_9_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[722] model_decoder_layers_9_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] model_decoder_layers_9_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[724] model_decoder_layers_9_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] model_decoder_layers_9_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[726] model_decoder_layers_9_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[727] model_decoder_layers_9_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[728] model_decoder_layers_10_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] model_decoder_layers_10_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] model_decoder_layers_10_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[731] model_decoder_layers_10_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] model_decoder_layers_10_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[733] model_decoder_layers_10_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] model_decoder_layers_10_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[735] model_decoder_layers_10_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[736] model_decoder_layers_10_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[737] model_decoder_layers_10_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] model_decoder_layers_10_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[742] model_decoder_layers_10_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] model_decoder_layers_10_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[744] model_decoder_layers_10_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[745] model_decoder_layers_10_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[746] model_decoder_layers_10_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] model_decoder_layers_10_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[748] model_decoder_layers_10_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] model_decoder_layers_10_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[750] model_decoder_layers_10_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[751] model_decoder_layers_10_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[752] model_decoder_layers_11_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] model_decoder_layers_11_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] model_decoder_layers_11_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[755] model_decoder_layers_11_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] model_decoder_layers_11_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[757] model_decoder_layers_11_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] model_decoder_layers_11_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[759] model_decoder_layers_11_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[760] model_decoder_layers_11_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[761] model_decoder_layers_11_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] model_decoder_layers_11_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[766] model_decoder_layers_11_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] model_decoder_layers_11_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[768] model_decoder_layers_11_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[769] model_decoder_layers_11_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[770] model_decoder_layers_11_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] model_decoder_layers_11_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[772] model_decoder_layers_11_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] model_decoder_layers_11_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[774] model_decoder_layers_11_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[775] model_decoder_layers_11_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[776] model_decoder_layers_12_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] model_decoder_layers_12_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] model_decoder_layers_12_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[779] model_decoder_layers_12_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] model_decoder_layers_12_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[781] model_decoder_layers_12_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] model_decoder_layers_12_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[783] model_decoder_layers_12_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[784] model_decoder_layers_12_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[785] model_decoder_layers_12_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] model_decoder_layers_12_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[790] model_decoder_layers_12_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] model_decoder_layers_12_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[792] model_decoder_layers_12_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[793] model_decoder_layers_12_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[794] model_decoder_layers_12_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] model_decoder_layers_12_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[796] model_decoder_layers_12_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] model_decoder_layers_12_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[798] model_decoder_layers_12_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[799] model_decoder_layers_12_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[800] model_decoder_layers_13_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] model_decoder_layers_13_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] model_decoder_layers_13_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[803] model_decoder_layers_13_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] model_decoder_layers_13_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[805] model_decoder_layers_13_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] model_decoder_layers_13_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[807] model_decoder_layers_13_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[808] model_decoder_layers_13_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[809] model_decoder_layers_13_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] model_decoder_layers_13_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[814] model_decoder_layers_13_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] model_decoder_layers_13_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[816] model_decoder_layers_13_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[817] model_decoder_layers_13_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[818] model_decoder_layers_13_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] model_decoder_layers_13_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[820] model_decoder_layers_13_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] model_decoder_layers_13_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[822] model_decoder_layers_13_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[823] model_decoder_layers_13_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[824] model_decoder_layers_14_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] model_decoder_layers_14_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] model_decoder_layers_14_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[827] model_decoder_layers_14_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] model_decoder_layers_14_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[829] model_decoder_layers_14_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] model_decoder_layers_14_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[831] model_decoder_layers_14_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[832] model_decoder_layers_14_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[833] model_decoder_layers_14_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] model_decoder_layers_14_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[838] model_decoder_layers_14_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] model_decoder_layers_14_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[840] model_decoder_layers_14_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[841] model_decoder_layers_14_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[842] model_decoder_layers_14_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] model_decoder_layers_14_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[844] model_decoder_layers_14_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] model_decoder_layers_14_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[846] model_decoder_layers_14_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[847] model_decoder_layers_14_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[848] model_decoder_layers_15_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] model_decoder_layers_15_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] model_decoder_layers_15_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[851] model_decoder_layers_15_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] model_decoder_layers_15_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[853] model_decoder_layers_15_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] model_decoder_layers_15_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[855] model_decoder_layers_15_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[856] model_decoder_layers_15_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[857] model_decoder_layers_15_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] model_decoder_layers_15_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[862] model_decoder_layers_15_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] model_decoder_layers_15_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[864] model_decoder_layers_15_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[865] model_decoder_layers_15_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[866] model_decoder_layers_15_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] model_decoder_layers_15_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[868] model_decoder_layers_15_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] model_decoder_layers_15_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[870] model_decoder_layers_15_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[871] model_decoder_layers_15_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[872] model_decoder_layers_16_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] model_decoder_layers_16_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] model_decoder_layers_16_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[875] model_decoder_layers_16_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] model_decoder_layers_16_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[877] model_decoder_layers_16_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] model_decoder_layers_16_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[879] model_decoder_layers_16_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[880] model_decoder_layers_16_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[881] model_decoder_layers_16_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] model_decoder_layers_16_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[886] model_decoder_layers_16_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] model_decoder_layers_16_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[888] model_decoder_layers_16_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[889] model_decoder_layers_16_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[890] model_decoder_layers_16_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] model_decoder_layers_16_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[892] model_decoder_layers_16_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] model_decoder_layers_16_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[894] model_decoder_layers_16_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[895] model_decoder_layers_16_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[896] model_decoder_layers_17_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] model_decoder_layers_17_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] model_decoder_layers_17_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[899] model_decoder_layers_17_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] model_decoder_layers_17_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[901] model_decoder_layers_17_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] model_decoder_layers_17_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[903] model_decoder_layers_17_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[904] model_decoder_layers_17_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[905] model_decoder_layers_17_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] model_decoder_layers_17_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[910] model_decoder_layers_17_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] model_decoder_layers_17_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[912] model_decoder_layers_17_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[913] model_decoder_layers_17_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[914] model_decoder_layers_17_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] model_decoder_layers_17_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[916] model_decoder_layers_17_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] model_decoder_layers_17_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[918] model_decoder_layers_17_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[919] model_decoder_layers_17_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[920] model_decoder_layers_18_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] model_decoder_layers_18_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] model_decoder_layers_18_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[923] model_decoder_layers_18_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] model_decoder_layers_18_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[925] model_decoder_layers_18_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] model_decoder_layers_18_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[927] model_decoder_layers_18_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[928] model_decoder_layers_18_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[929] model_decoder_layers_18_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] model_decoder_layers_18_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[934] model_decoder_layers_18_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] model_decoder_layers_18_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[936] model_decoder_layers_18_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[937] model_decoder_layers_18_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[938] model_decoder_layers_18_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] model_decoder_layers_18_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[940] model_decoder_layers_18_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] model_decoder_layers_18_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[942] model_decoder_layers_18_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[943] model_decoder_layers_18_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[944] model_decoder_layers_19_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] model_decoder_layers_19_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] model_decoder_layers_19_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[947] model_decoder_layers_19_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] model_decoder_layers_19_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[949] model_decoder_layers_19_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] model_decoder_layers_19_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[951] model_decoder_layers_19_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[952] model_decoder_layers_19_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[953] model_decoder_layers_19_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] model_decoder_layers_19_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[958] model_decoder_layers_19_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] model_decoder_layers_19_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[960] model_decoder_layers_19_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[961] model_decoder_layers_19_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[962] model_decoder_layers_19_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] model_decoder_layers_19_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[964] model_decoder_layers_19_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] model_decoder_layers_19_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[966] model_decoder_layers_19_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[967] model_decoder_layers_19_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[968] model_decoder_layers_20_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] model_decoder_layers_20_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] model_decoder_layers_20_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[971] model_decoder_layers_20_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] model_decoder_layers_20_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[973] model_decoder_layers_20_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] model_decoder_layers_20_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[975] model_decoder_layers_20_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[976] model_decoder_layers_20_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[977] model_decoder_layers_20_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] model_decoder_layers_20_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[982] model_decoder_layers_20_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] model_decoder_layers_20_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[984] model_decoder_layers_20_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[985] model_decoder_layers_20_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[986] model_decoder_layers_20_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] model_decoder_layers_20_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[988] model_decoder_layers_20_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] model_decoder_layers_20_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[990] model_decoder_layers_20_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[991] model_decoder_layers_20_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[992] model_decoder_layers_21_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] model_decoder_layers_21_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] model_decoder_layers_21_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[995] model_decoder_layers_21_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] model_decoder_layers_21_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[997] model_decoder_layers_21_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] model_decoder_layers_21_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[999] model_decoder_layers_21_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1000] model_decoder_layers_21_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1001] model_decoder_layers_21_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] model_decoder_layers_21_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1006] model_decoder_layers_21_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] model_decoder_layers_21_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1008] model_decoder_layers_21_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1009] model_decoder_layers_21_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1010] model_decoder_layers_21_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] model_decoder_layers_21_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1012] model_decoder_layers_21_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] model_decoder_layers_21_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1014] model_decoder_layers_21_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1015] model_decoder_layers_21_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1016] model_decoder_layers_22_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] model_decoder_layers_22_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] model_decoder_layers_22_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1019] model_decoder_layers_22_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] model_decoder_layers_22_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1021] model_decoder_layers_22_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] model_decoder_layers_22_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1023] model_decoder_layers_22_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1024] model_decoder_layers_22_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1025] model_decoder_layers_22_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] model_decoder_layers_22_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1030] model_decoder_layers_22_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] model_decoder_layers_22_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1032] model_decoder_layers_22_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1033] model_decoder_layers_22_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1034] model_decoder_layers_22_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] model_decoder_layers_22_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1036] model_decoder_layers_22_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] model_decoder_layers_22_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1038] model_decoder_layers_22_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1039] model_decoder_layers_22_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1040] model_decoder_layers_23_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] model_decoder_layers_23_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] model_decoder_layers_23_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1043] model_decoder_layers_23_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] model_decoder_layers_23_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1045] model_decoder_layers_23_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] model_decoder_layers_23_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1047] model_decoder_layers_23_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1048] model_decoder_layers_23_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1049] model_decoder_layers_23_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] model_decoder_layers_23_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1054] model_decoder_layers_23_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] model_decoder_layers_23_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1056] model_decoder_layers_23_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1057] model_decoder_layers_23_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1058] model_decoder_layers_23_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] model_decoder_layers_23_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1060] model_decoder_layers_23_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] model_decoder_layers_23_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1062] model_decoder_layers_23_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1063] model_decoder_layers_23_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1064] model_decoder_layers_24_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] model_decoder_layers_24_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] model_decoder_layers_24_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1067] model_decoder_layers_24_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] model_decoder_layers_24_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1069] model_decoder_layers_24_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] model_decoder_layers_24_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1071] model_decoder_layers_24_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1072] model_decoder_layers_24_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1073] model_decoder_layers_24_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] model_decoder_layers_24_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1078] model_decoder_layers_24_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] model_decoder_layers_24_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1080] model_decoder_layers_24_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1081] model_decoder_layers_24_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1082] model_decoder_layers_24_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] model_decoder_layers_24_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1084] model_decoder_layers_24_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] model_decoder_layers_24_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1086] model_decoder_layers_24_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1087] model_decoder_layers_24_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1088] model_decoder_layers_25_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] model_decoder_layers_25_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] model_decoder_layers_25_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1091] model_decoder_layers_25_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] model_decoder_layers_25_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1093] model_decoder_layers_25_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] model_decoder_layers_25_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1095] model_decoder_layers_25_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1096] model_decoder_layers_25_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1097] model_decoder_layers_25_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] model_decoder_layers_25_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1102] model_decoder_layers_25_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] model_decoder_layers_25_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1104] model_decoder_layers_25_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1105] model_decoder_layers_25_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1106] model_decoder_layers_25_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] model_decoder_layers_25_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1108] model_decoder_layers_25_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] model_decoder_layers_25_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1110] model_decoder_layers_25_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1111] model_decoder_layers_25_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1112] model_decoder_layers_26_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] model_decoder_layers_26_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] model_decoder_layers_26_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1115] model_decoder_layers_26_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] model_decoder_layers_26_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1117] model_decoder_layers_26_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] model_decoder_layers_26_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1119] model_decoder_layers_26_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1120] model_decoder_layers_26_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1121] model_decoder_layers_26_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] model_decoder_layers_26_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1126] model_decoder_layers_26_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] model_decoder_layers_26_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1128] model_decoder_layers_26_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1129] model_decoder_layers_26_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1130] model_decoder_layers_26_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] model_decoder_layers_26_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1132] model_decoder_layers_26_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] model_decoder_layers_26_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1134] model_decoder_layers_26_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1135] model_decoder_layers_26_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1136] model_decoder_layers_27_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] model_decoder_layers_27_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] model_decoder_layers_27_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1139] model_decoder_layers_27_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] model_decoder_layers_27_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1141] model_decoder_layers_27_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] model_decoder_layers_27_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1143] model_decoder_layers_27_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1144] model_decoder_layers_27_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1145] model_decoder_layers_27_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] model_decoder_layers_27_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1150] model_decoder_layers_27_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] model_decoder_layers_27_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1152] model_decoder_layers_27_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1153] model_decoder_layers_27_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1154] model_decoder_layers_27_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] model_decoder_layers_27_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1156] model_decoder_layers_27_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] model_decoder_layers_27_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1158] model_decoder_layers_27_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1159] model_decoder_layers_27_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1160] model_decoder_layers_28_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] model_decoder_layers_28_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] model_decoder_layers_28_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1163] model_decoder_layers_28_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] model_decoder_layers_28_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1165] model_decoder_layers_28_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] model_decoder_layers_28_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1167] model_decoder_layers_28_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1168] model_decoder_layers_28_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1169] model_decoder_layers_28_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] model_decoder_layers_28_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1174] model_decoder_layers_28_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] model_decoder_layers_28_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1176] model_decoder_layers_28_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1177] model_decoder_layers_28_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1178] model_decoder_layers_28_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] model_decoder_layers_28_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1180] model_decoder_layers_28_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] model_decoder_layers_28_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1182] model_decoder_layers_28_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1183] model_decoder_layers_28_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1184] model_decoder_layers_29_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] model_decoder_layers_29_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] model_decoder_layers_29_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1187] model_decoder_layers_29_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] model_decoder_layers_29_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1189] model_decoder_layers_29_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] model_decoder_layers_29_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1191] model_decoder_layers_29_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1192] model_decoder_layers_29_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1193] model_decoder_layers_29_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] model_decoder_layers_29_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1198] model_decoder_layers_29_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] model_decoder_layers_29_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1200] model_decoder_layers_29_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1201] model_decoder_layers_29_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1202] model_decoder_layers_29_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] model_decoder_layers_29_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1204] model_decoder_layers_29_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] model_decoder_layers_29_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1206] model_decoder_layers_29_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1207] model_decoder_layers_29_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1208] model_decoder_layers_30_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] model_decoder_layers_30_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] model_decoder_layers_30_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1211] model_decoder_layers_30_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] model_decoder_layers_30_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1213] model_decoder_layers_30_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] model_decoder_layers_30_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1215] model_decoder_layers_30_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1216] model_decoder_layers_30_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1217] model_decoder_layers_30_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] model_decoder_layers_30_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1222] model_decoder_layers_30_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] model_decoder_layers_30_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1224] model_decoder_layers_30_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1225] model_decoder_layers_30_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1226] model_decoder_layers_30_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] model_decoder_layers_30_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1228] model_decoder_layers_30_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] model_decoder_layers_30_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1230] model_decoder_layers_30_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1231] model_decoder_layers_30_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1232] model_decoder_layers_31_self_attn_k_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] model_decoder_layers_31_self_attn_v_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] model_decoder_layers_31_self_attn_v_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1235] model_decoder_layers_31_self_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] model_decoder_layers_31_self_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1237] model_decoder_layers_31_self_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] model_decoder_layers_31_self_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1239] model_decoder_layers_31_self_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1240] model_decoder_layers_31_self_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1241] model_decoder_layers_31_encoder_attn_q_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] model_decoder_layers_31_encoder_attn_q_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1246] model_decoder_layers_31_encoder_attn_out_proj_weight3: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] model_decoder_layers_31_encoder_attn_out_proj_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1248] model_decoder_layers_31_encoder_attn_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1249] model_decoder_layers_31_encoder_attn_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1250] model_decoder_layers_31_fc1_weight3: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] model_decoder_layers_31_fc1_bias3: R.Tensor((5120,), dtype="float16") = packed_params[1252] model_decoder_layers_31_fc2_weight3: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] model_decoder_layers_31_fc2_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1254] model_decoder_layers_31_final_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1255] model_decoder_layers_31_final_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1256] model_decoder_layer_norm_weight3: R.Tensor((1280,), dtype="float16") = packed_params[1257] model_decoder_layer_norm_bias3: R.Tensor((1280,), dtype="float16") = packed_params[1258] reshape707 = R.call_tir(cls.reshape2, (input_ids,), out_sinfo=R.Tensor((batch_size,), dtype="int32")) take3 = R.call_tir(cls.take, (model_decoder_embed_tokens_weight3, reshape707), out_sinfo=R.Tensor((batch_size, 1280), dtype="float16")) reshape708 = R.call_tir(cls.reshape3, (take3,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv133: R.Tensor((batch_size,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((batch_size,), dtype="int32"),)) take4 = R.call_tir(cls.take1, (model_decoder_embed_positions_weight3, lv133), out_sinfo=R.Tensor((batch_size, 1280), dtype="float16")) reshape709 = R.call_tir(cls.reshape3, (take4,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add578 = R.call_tir(cls.add, (reshape708, reshape709), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm162 = R.call_tir(cls.layer_norm, (add578, model_decoder_layers_0_self_attn_layer_norm_weight3, model_decoder_layers_0_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv224 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_self_attn_q_proj_weight3, layer_norm162, model_decoder_layers_0_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape710 = R.call_tir(cls.reshape4, (lv224,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv65 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_0_self_attn_k_proj_weight3, layer_norm162), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape711 = R.call_tir(cls.reshape4, (lv65,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv225 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_self_attn_v_proj_weight3, layer_norm162, model_decoder_layers_0_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape712 = R.call_tir(cls.reshape4, (lv225,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat32 = R.call_tir(cls.concatenate, (reshape710, reshape711, reshape712), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape713 = R.call_tir(cls.reshape5, (concat32,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv134 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape713), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape714 = R.call_tir(cls.reshape6, (lv134,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape715 = R.call_tir(cls.reshape7, (reshape714,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv226 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_self_attn_out_proj_weight3, reshape715, model_decoder_layers_0_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add582 = R.call_tir(cls.add, (add578, lv226), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm163 = R.call_tir(cls.layer_norm, (add582, model_decoder_layers_0_encoder_attn_layer_norm_weight3, model_decoder_layers_0_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv227 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_encoder_attn_q_proj_weight3, layer_norm163, model_decoder_layers_0_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape716 = R.call_tir(cls.reshape4, (lv227,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape717 = R.call_tir(cls.reshape8, (reshape716,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv135 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape717), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape718 = R.call_tir(cls.reshape6, (lv135,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape719 = R.call_tir(cls.reshape7, (reshape718,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv228 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_0_encoder_attn_out_proj_weight3, reshape719, model_decoder_layers_0_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add585 = R.call_tir(cls.add, (add582, lv228), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm164 = R.call_tir(cls.layer_norm, (add585, model_decoder_layers_0_final_layer_norm_weight3, model_decoder_layers_0_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv32 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_0_fc1_weight3, layer_norm164, model_decoder_layers_0_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv229 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_0_fc2_weight3, lv32, model_decoder_layers_0_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add588 = R.call_tir(cls.add, (add585, lv229), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm165 = R.call_tir(cls.layer_norm, (add588, model_decoder_layers_1_self_attn_layer_norm_weight3, model_decoder_layers_1_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv230 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_self_attn_q_proj_weight3, layer_norm165, model_decoder_layers_1_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape720 = R.call_tir(cls.reshape4, (lv230,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv66 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_1_self_attn_k_proj_weight3, layer_norm165), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape721 = R.call_tir(cls.reshape4, (lv66,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv231 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_self_attn_v_proj_weight3, layer_norm165, model_decoder_layers_1_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape722 = R.call_tir(cls.reshape4, (lv231,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat33 = R.call_tir(cls.concatenate, (reshape720, reshape721, reshape722), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape723 = R.call_tir(cls.reshape5, (concat33,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv136 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape723), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape724 = R.call_tir(cls.reshape6, (lv136,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape725 = R.call_tir(cls.reshape7, (reshape724,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv232 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_self_attn_out_proj_weight3, reshape725, model_decoder_layers_1_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add592 = R.call_tir(cls.add, (add588, lv232), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm166 = R.call_tir(cls.layer_norm, (add592, model_decoder_layers_1_encoder_attn_layer_norm_weight3, model_decoder_layers_1_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv233 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_encoder_attn_q_proj_weight3, layer_norm166, model_decoder_layers_1_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape726 = R.call_tir(cls.reshape4, (lv233,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape727 = R.call_tir(cls.reshape8, (reshape726,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv137 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape727), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape728 = R.call_tir(cls.reshape6, (lv137,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape729 = R.call_tir(cls.reshape7, (reshape728,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv234 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_1_encoder_attn_out_proj_weight3, reshape729, model_decoder_layers_1_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add595 = R.call_tir(cls.add, (add592, lv234), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm167 = R.call_tir(cls.layer_norm, (add595, model_decoder_layers_1_final_layer_norm_weight3, model_decoder_layers_1_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv33 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_1_fc1_weight3, layer_norm167, model_decoder_layers_1_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv235 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_1_fc2_weight3, lv33, model_decoder_layers_1_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add598 = R.call_tir(cls.add, (add595, lv235), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm168 = R.call_tir(cls.layer_norm, (add598, model_decoder_layers_2_self_attn_layer_norm_weight3, model_decoder_layers_2_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv236 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_self_attn_q_proj_weight3, layer_norm168, model_decoder_layers_2_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape730 = R.call_tir(cls.reshape4, (lv236,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv67 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_2_self_attn_k_proj_weight3, layer_norm168), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape731 = R.call_tir(cls.reshape4, (lv67,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv237 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_self_attn_v_proj_weight3, layer_norm168, model_decoder_layers_2_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape732 = R.call_tir(cls.reshape4, (lv237,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat34 = R.call_tir(cls.concatenate, (reshape730, reshape731, reshape732), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape733 = R.call_tir(cls.reshape5, (concat34,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv138 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape733), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape734 = R.call_tir(cls.reshape6, (lv138,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape735 = R.call_tir(cls.reshape7, (reshape734,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv238 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_self_attn_out_proj_weight3, reshape735, model_decoder_layers_2_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add602 = R.call_tir(cls.add, (add598, lv238), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm169 = R.call_tir(cls.layer_norm, (add602, model_decoder_layers_2_encoder_attn_layer_norm_weight3, model_decoder_layers_2_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv239 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_encoder_attn_q_proj_weight3, layer_norm169, model_decoder_layers_2_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape736 = R.call_tir(cls.reshape4, (lv239,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape737 = R.call_tir(cls.reshape8, (reshape736,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv139 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape737), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape738 = R.call_tir(cls.reshape6, (lv139,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape739 = R.call_tir(cls.reshape7, (reshape738,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv240 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_2_encoder_attn_out_proj_weight3, reshape739, model_decoder_layers_2_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add605 = R.call_tir(cls.add, (add602, lv240), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm170 = R.call_tir(cls.layer_norm, (add605, model_decoder_layers_2_final_layer_norm_weight3, model_decoder_layers_2_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv34 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_2_fc1_weight3, layer_norm170, model_decoder_layers_2_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv241 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_2_fc2_weight3, lv34, model_decoder_layers_2_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add608 = R.call_tir(cls.add, (add605, lv241), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm171 = R.call_tir(cls.layer_norm, (add608, model_decoder_layers_3_self_attn_layer_norm_weight3, model_decoder_layers_3_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv242 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_self_attn_q_proj_weight3, layer_norm171, model_decoder_layers_3_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape740 = R.call_tir(cls.reshape4, (lv242,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv68 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_3_self_attn_k_proj_weight3, layer_norm171), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape741 = R.call_tir(cls.reshape4, (lv68,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv243 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_self_attn_v_proj_weight3, layer_norm171, model_decoder_layers_3_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape742 = R.call_tir(cls.reshape4, (lv243,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat35 = R.call_tir(cls.concatenate, (reshape740, reshape741, reshape742), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape743 = R.call_tir(cls.reshape5, (concat35,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv140 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape743), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape744 = R.call_tir(cls.reshape6, (lv140,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape745 = R.call_tir(cls.reshape7, (reshape744,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv244 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_self_attn_out_proj_weight3, reshape745, model_decoder_layers_3_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add612 = R.call_tir(cls.add, (add608, lv244), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm172 = R.call_tir(cls.layer_norm, (add612, model_decoder_layers_3_encoder_attn_layer_norm_weight3, model_decoder_layers_3_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv245 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_encoder_attn_q_proj_weight3, layer_norm172, model_decoder_layers_3_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape746 = R.call_tir(cls.reshape4, (lv245,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape747 = R.call_tir(cls.reshape8, (reshape746,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv141 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape747), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape748 = R.call_tir(cls.reshape6, (lv141,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape749 = R.call_tir(cls.reshape7, (reshape748,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv246 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_3_encoder_attn_out_proj_weight3, reshape749, model_decoder_layers_3_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add615 = R.call_tir(cls.add, (add612, lv246), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm173 = R.call_tir(cls.layer_norm, (add615, model_decoder_layers_3_final_layer_norm_weight3, model_decoder_layers_3_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv35 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_3_fc1_weight3, layer_norm173, model_decoder_layers_3_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv247 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_3_fc2_weight3, lv35, model_decoder_layers_3_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add618 = R.call_tir(cls.add, (add615, lv247), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm174 = R.call_tir(cls.layer_norm, (add618, model_decoder_layers_4_self_attn_layer_norm_weight3, model_decoder_layers_4_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv248 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_self_attn_q_proj_weight3, layer_norm174, model_decoder_layers_4_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape750 = R.call_tir(cls.reshape4, (lv248,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv69 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_4_self_attn_k_proj_weight3, layer_norm174), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape751 = R.call_tir(cls.reshape4, (lv69,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv249 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_self_attn_v_proj_weight3, layer_norm174, model_decoder_layers_4_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape752 = R.call_tir(cls.reshape4, (lv249,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat36 = R.call_tir(cls.concatenate, (reshape750, reshape751, reshape752), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape753 = R.call_tir(cls.reshape5, (concat36,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv142 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape753), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape754 = R.call_tir(cls.reshape6, (lv142,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape755 = R.call_tir(cls.reshape7, (reshape754,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv250 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_self_attn_out_proj_weight3, reshape755, model_decoder_layers_4_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add622 = R.call_tir(cls.add, (add618, lv250), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm175 = R.call_tir(cls.layer_norm, (add622, model_decoder_layers_4_encoder_attn_layer_norm_weight3, model_decoder_layers_4_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv251 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_encoder_attn_q_proj_weight3, layer_norm175, model_decoder_layers_4_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape756 = R.call_tir(cls.reshape4, (lv251,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape757 = R.call_tir(cls.reshape8, (reshape756,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv143 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape757), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape758 = R.call_tir(cls.reshape6, (lv143,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape759 = R.call_tir(cls.reshape7, (reshape758,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv252 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_4_encoder_attn_out_proj_weight3, reshape759, model_decoder_layers_4_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add625 = R.call_tir(cls.add, (add622, lv252), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm176 = R.call_tir(cls.layer_norm, (add625, model_decoder_layers_4_final_layer_norm_weight3, model_decoder_layers_4_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv36 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_4_fc1_weight3, layer_norm176, model_decoder_layers_4_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv253 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_4_fc2_weight3, lv36, model_decoder_layers_4_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add628 = R.call_tir(cls.add, (add625, lv253), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm177 = R.call_tir(cls.layer_norm, (add628, model_decoder_layers_5_self_attn_layer_norm_weight3, model_decoder_layers_5_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv254 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_self_attn_q_proj_weight3, layer_norm177, model_decoder_layers_5_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape760 = R.call_tir(cls.reshape4, (lv254,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv70 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_5_self_attn_k_proj_weight3, layer_norm177), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape761 = R.call_tir(cls.reshape4, (lv70,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv255 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_self_attn_v_proj_weight3, layer_norm177, model_decoder_layers_5_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape762 = R.call_tir(cls.reshape4, (lv255,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat37 = R.call_tir(cls.concatenate, (reshape760, reshape761, reshape762), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape763 = R.call_tir(cls.reshape5, (concat37,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv144 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape763), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape764 = R.call_tir(cls.reshape6, (lv144,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape765 = R.call_tir(cls.reshape7, (reshape764,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv256 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_self_attn_out_proj_weight3, reshape765, model_decoder_layers_5_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add632 = R.call_tir(cls.add, (add628, lv256), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm178 = R.call_tir(cls.layer_norm, (add632, model_decoder_layers_5_encoder_attn_layer_norm_weight3, model_decoder_layers_5_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv257 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_encoder_attn_q_proj_weight3, layer_norm178, model_decoder_layers_5_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape766 = R.call_tir(cls.reshape4, (lv257,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape767 = R.call_tir(cls.reshape8, (reshape766,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv145 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape767), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape768 = R.call_tir(cls.reshape6, (lv145,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape769 = R.call_tir(cls.reshape7, (reshape768,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv258 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_5_encoder_attn_out_proj_weight3, reshape769, model_decoder_layers_5_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add635 = R.call_tir(cls.add, (add632, lv258), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm179 = R.call_tir(cls.layer_norm, (add635, model_decoder_layers_5_final_layer_norm_weight3, model_decoder_layers_5_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv37 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_5_fc1_weight3, layer_norm179, model_decoder_layers_5_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv259 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_5_fc2_weight3, lv37, model_decoder_layers_5_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add638 = R.call_tir(cls.add, (add635, lv259), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm180 = R.call_tir(cls.layer_norm, (add638, model_decoder_layers_6_self_attn_layer_norm_weight3, model_decoder_layers_6_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv260 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_self_attn_q_proj_weight3, layer_norm180, model_decoder_layers_6_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape770 = R.call_tir(cls.reshape4, (lv260,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv71 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_6_self_attn_k_proj_weight3, layer_norm180), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape771 = R.call_tir(cls.reshape4, (lv71,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv261 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_self_attn_v_proj_weight3, layer_norm180, model_decoder_layers_6_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape772 = R.call_tir(cls.reshape4, (lv261,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat38 = R.call_tir(cls.concatenate, (reshape770, reshape771, reshape772), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape773 = R.call_tir(cls.reshape5, (concat38,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv146 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape773), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape774 = R.call_tir(cls.reshape6, (lv146,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape775 = R.call_tir(cls.reshape7, (reshape774,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv262 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_self_attn_out_proj_weight3, reshape775, model_decoder_layers_6_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add642 = R.call_tir(cls.add, (add638, lv262), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm181 = R.call_tir(cls.layer_norm, (add642, model_decoder_layers_6_encoder_attn_layer_norm_weight3, model_decoder_layers_6_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv263 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_encoder_attn_q_proj_weight3, layer_norm181, model_decoder_layers_6_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape776 = R.call_tir(cls.reshape4, (lv263,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape777 = R.call_tir(cls.reshape8, (reshape776,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv147 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape777), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape778 = R.call_tir(cls.reshape6, (lv147,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape779 = R.call_tir(cls.reshape7, (reshape778,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv264 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_6_encoder_attn_out_proj_weight3, reshape779, model_decoder_layers_6_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add645 = R.call_tir(cls.add, (add642, lv264), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm182 = R.call_tir(cls.layer_norm, (add645, model_decoder_layers_6_final_layer_norm_weight3, model_decoder_layers_6_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv38 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_6_fc1_weight3, layer_norm182, model_decoder_layers_6_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv265 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_6_fc2_weight3, lv38, model_decoder_layers_6_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add648 = R.call_tir(cls.add, (add645, lv265), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm183 = R.call_tir(cls.layer_norm, (add648, model_decoder_layers_7_self_attn_layer_norm_weight3, model_decoder_layers_7_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv266 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_self_attn_q_proj_weight3, layer_norm183, model_decoder_layers_7_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape780 = R.call_tir(cls.reshape4, (lv266,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv72 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_7_self_attn_k_proj_weight3, layer_norm183), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape781 = R.call_tir(cls.reshape4, (lv72,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv267 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_self_attn_v_proj_weight3, layer_norm183, model_decoder_layers_7_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape782 = R.call_tir(cls.reshape4, (lv267,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat39 = R.call_tir(cls.concatenate, (reshape780, reshape781, reshape782), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape783 = R.call_tir(cls.reshape5, (concat39,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv148 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape783), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape784 = R.call_tir(cls.reshape6, (lv148,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape785 = R.call_tir(cls.reshape7, (reshape784,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv268 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_self_attn_out_proj_weight3, reshape785, model_decoder_layers_7_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add652 = R.call_tir(cls.add, (add648, lv268), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm184 = R.call_tir(cls.layer_norm, (add652, model_decoder_layers_7_encoder_attn_layer_norm_weight3, model_decoder_layers_7_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv269 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_encoder_attn_q_proj_weight3, layer_norm184, model_decoder_layers_7_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape786 = R.call_tir(cls.reshape4, (lv269,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape787 = R.call_tir(cls.reshape8, (reshape786,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv149 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape787), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape788 = R.call_tir(cls.reshape6, (lv149,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape789 = R.call_tir(cls.reshape7, (reshape788,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv270 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_7_encoder_attn_out_proj_weight3, reshape789, model_decoder_layers_7_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add655 = R.call_tir(cls.add, (add652, lv270), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm185 = R.call_tir(cls.layer_norm, (add655, model_decoder_layers_7_final_layer_norm_weight3, model_decoder_layers_7_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv39 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_7_fc1_weight3, layer_norm185, model_decoder_layers_7_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv271 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_7_fc2_weight3, lv39, model_decoder_layers_7_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add658 = R.call_tir(cls.add, (add655, lv271), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm186 = R.call_tir(cls.layer_norm, (add658, model_decoder_layers_8_self_attn_layer_norm_weight3, model_decoder_layers_8_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv272 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_self_attn_q_proj_weight3, layer_norm186, model_decoder_layers_8_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape790 = R.call_tir(cls.reshape4, (lv272,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv73 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_8_self_attn_k_proj_weight3, layer_norm186), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape791 = R.call_tir(cls.reshape4, (lv73,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv273 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_self_attn_v_proj_weight3, layer_norm186, model_decoder_layers_8_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape792 = R.call_tir(cls.reshape4, (lv273,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat40 = R.call_tir(cls.concatenate, (reshape790, reshape791, reshape792), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape793 = R.call_tir(cls.reshape5, (concat40,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv150 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape793), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape794 = R.call_tir(cls.reshape6, (lv150,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape795 = R.call_tir(cls.reshape7, (reshape794,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv274 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_self_attn_out_proj_weight3, reshape795, model_decoder_layers_8_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add662 = R.call_tir(cls.add, (add658, lv274), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm187 = R.call_tir(cls.layer_norm, (add662, model_decoder_layers_8_encoder_attn_layer_norm_weight3, model_decoder_layers_8_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv275 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_encoder_attn_q_proj_weight3, layer_norm187, model_decoder_layers_8_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape796 = R.call_tir(cls.reshape4, (lv275,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape797 = R.call_tir(cls.reshape8, (reshape796,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv151 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape797), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape798 = R.call_tir(cls.reshape6, (lv151,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape799 = R.call_tir(cls.reshape7, (reshape798,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv276 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_8_encoder_attn_out_proj_weight3, reshape799, model_decoder_layers_8_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add665 = R.call_tir(cls.add, (add662, lv276), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm188 = R.call_tir(cls.layer_norm, (add665, model_decoder_layers_8_final_layer_norm_weight3, model_decoder_layers_8_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv40 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_8_fc1_weight3, layer_norm188, model_decoder_layers_8_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv277 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_8_fc2_weight3, lv40, model_decoder_layers_8_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add668 = R.call_tir(cls.add, (add665, lv277), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm189 = R.call_tir(cls.layer_norm, (add668, model_decoder_layers_9_self_attn_layer_norm_weight3, model_decoder_layers_9_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv278 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_self_attn_q_proj_weight3, layer_norm189, model_decoder_layers_9_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape800 = R.call_tir(cls.reshape4, (lv278,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv74 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_9_self_attn_k_proj_weight3, layer_norm189), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape801 = R.call_tir(cls.reshape4, (lv74,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv279 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_self_attn_v_proj_weight3, layer_norm189, model_decoder_layers_9_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape802 = R.call_tir(cls.reshape4, (lv279,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat41 = R.call_tir(cls.concatenate, (reshape800, reshape801, reshape802), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape803 = R.call_tir(cls.reshape5, (concat41,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv152 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape803), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape804 = R.call_tir(cls.reshape6, (lv152,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape805 = R.call_tir(cls.reshape7, (reshape804,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv280 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_self_attn_out_proj_weight3, reshape805, model_decoder_layers_9_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add672 = R.call_tir(cls.add, (add668, lv280), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm190 = R.call_tir(cls.layer_norm, (add672, model_decoder_layers_9_encoder_attn_layer_norm_weight3, model_decoder_layers_9_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv281 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_encoder_attn_q_proj_weight3, layer_norm190, model_decoder_layers_9_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape806 = R.call_tir(cls.reshape4, (lv281,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape807 = R.call_tir(cls.reshape8, (reshape806,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv153 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape807), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape808 = R.call_tir(cls.reshape6, (lv153,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape809 = R.call_tir(cls.reshape7, (reshape808,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv282 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_9_encoder_attn_out_proj_weight3, reshape809, model_decoder_layers_9_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add675 = R.call_tir(cls.add, (add672, lv282), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm191 = R.call_tir(cls.layer_norm, (add675, model_decoder_layers_9_final_layer_norm_weight3, model_decoder_layers_9_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv41 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_9_fc1_weight3, layer_norm191, model_decoder_layers_9_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv283 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_9_fc2_weight3, lv41, model_decoder_layers_9_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add678 = R.call_tir(cls.add, (add675, lv283), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm192 = R.call_tir(cls.layer_norm, (add678, model_decoder_layers_10_self_attn_layer_norm_weight3, model_decoder_layers_10_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv284 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_self_attn_q_proj_weight3, layer_norm192, model_decoder_layers_10_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape810 = R.call_tir(cls.reshape4, (lv284,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv75 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_10_self_attn_k_proj_weight3, layer_norm192), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape811 = R.call_tir(cls.reshape4, (lv75,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv285 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_self_attn_v_proj_weight3, layer_norm192, model_decoder_layers_10_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape812 = R.call_tir(cls.reshape4, (lv285,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat42 = R.call_tir(cls.concatenate, (reshape810, reshape811, reshape812), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape813 = R.call_tir(cls.reshape5, (concat42,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv154 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape813), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape814 = R.call_tir(cls.reshape6, (lv154,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape815 = R.call_tir(cls.reshape7, (reshape814,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv286 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_self_attn_out_proj_weight3, reshape815, model_decoder_layers_10_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add682 = R.call_tir(cls.add, (add678, lv286), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm193 = R.call_tir(cls.layer_norm, (add682, model_decoder_layers_10_encoder_attn_layer_norm_weight3, model_decoder_layers_10_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv287 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_encoder_attn_q_proj_weight3, layer_norm193, model_decoder_layers_10_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape816 = R.call_tir(cls.reshape4, (lv287,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape817 = R.call_tir(cls.reshape8, (reshape816,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv155 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape817), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape818 = R.call_tir(cls.reshape6, (lv155,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape819 = R.call_tir(cls.reshape7, (reshape818,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv288 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_10_encoder_attn_out_proj_weight3, reshape819, model_decoder_layers_10_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add685 = R.call_tir(cls.add, (add682, lv288), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm194 = R.call_tir(cls.layer_norm, (add685, model_decoder_layers_10_final_layer_norm_weight3, model_decoder_layers_10_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv42 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_10_fc1_weight3, layer_norm194, model_decoder_layers_10_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv289 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_10_fc2_weight3, lv42, model_decoder_layers_10_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add688 = R.call_tir(cls.add, (add685, lv289), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm195 = R.call_tir(cls.layer_norm, (add688, model_decoder_layers_11_self_attn_layer_norm_weight3, model_decoder_layers_11_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv290 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_self_attn_q_proj_weight3, layer_norm195, model_decoder_layers_11_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape820 = R.call_tir(cls.reshape4, (lv290,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv76 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_11_self_attn_k_proj_weight3, layer_norm195), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape821 = R.call_tir(cls.reshape4, (lv76,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv291 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_self_attn_v_proj_weight3, layer_norm195, model_decoder_layers_11_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape822 = R.call_tir(cls.reshape4, (lv291,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat43 = R.call_tir(cls.concatenate, (reshape820, reshape821, reshape822), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape823 = R.call_tir(cls.reshape5, (concat43,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv156 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape823), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape824 = R.call_tir(cls.reshape6, (lv156,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape825 = R.call_tir(cls.reshape7, (reshape824,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv292 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_self_attn_out_proj_weight3, reshape825, model_decoder_layers_11_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add692 = R.call_tir(cls.add, (add688, lv292), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm196 = R.call_tir(cls.layer_norm, (add692, model_decoder_layers_11_encoder_attn_layer_norm_weight3, model_decoder_layers_11_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv293 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_encoder_attn_q_proj_weight3, layer_norm196, model_decoder_layers_11_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape826 = R.call_tir(cls.reshape4, (lv293,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape827 = R.call_tir(cls.reshape8, (reshape826,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv157 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape827), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape828 = R.call_tir(cls.reshape6, (lv157,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape829 = R.call_tir(cls.reshape7, (reshape828,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv294 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_11_encoder_attn_out_proj_weight3, reshape829, model_decoder_layers_11_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add695 = R.call_tir(cls.add, (add692, lv294), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm197 = R.call_tir(cls.layer_norm, (add695, model_decoder_layers_11_final_layer_norm_weight3, model_decoder_layers_11_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv43 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_11_fc1_weight3, layer_norm197, model_decoder_layers_11_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv295 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_11_fc2_weight3, lv43, model_decoder_layers_11_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add698 = R.call_tir(cls.add, (add695, lv295), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm198 = R.call_tir(cls.layer_norm, (add698, model_decoder_layers_12_self_attn_layer_norm_weight3, model_decoder_layers_12_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv296 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_self_attn_q_proj_weight3, layer_norm198, model_decoder_layers_12_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape830 = R.call_tir(cls.reshape4, (lv296,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv77 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_12_self_attn_k_proj_weight3, layer_norm198), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape831 = R.call_tir(cls.reshape4, (lv77,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv297 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_self_attn_v_proj_weight3, layer_norm198, model_decoder_layers_12_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape832 = R.call_tir(cls.reshape4, (lv297,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat44 = R.call_tir(cls.concatenate, (reshape830, reshape831, reshape832), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape833 = R.call_tir(cls.reshape5, (concat44,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv158 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape833), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape834 = R.call_tir(cls.reshape6, (lv158,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape835 = R.call_tir(cls.reshape7, (reshape834,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv298 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_self_attn_out_proj_weight3, reshape835, model_decoder_layers_12_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add702 = R.call_tir(cls.add, (add698, lv298), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm199 = R.call_tir(cls.layer_norm, (add702, model_decoder_layers_12_encoder_attn_layer_norm_weight3, model_decoder_layers_12_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv299 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_encoder_attn_q_proj_weight3, layer_norm199, model_decoder_layers_12_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape836 = R.call_tir(cls.reshape4, (lv299,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape837 = R.call_tir(cls.reshape8, (reshape836,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv159 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape837), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape838 = R.call_tir(cls.reshape6, (lv159,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape839 = R.call_tir(cls.reshape7, (reshape838,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv300 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_12_encoder_attn_out_proj_weight3, reshape839, model_decoder_layers_12_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add705 = R.call_tir(cls.add, (add702, lv300), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm200 = R.call_tir(cls.layer_norm, (add705, model_decoder_layers_12_final_layer_norm_weight3, model_decoder_layers_12_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv44 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_12_fc1_weight3, layer_norm200, model_decoder_layers_12_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv301 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_12_fc2_weight3, lv44, model_decoder_layers_12_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add708 = R.call_tir(cls.add, (add705, lv301), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm201 = R.call_tir(cls.layer_norm, (add708, model_decoder_layers_13_self_attn_layer_norm_weight3, model_decoder_layers_13_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv302 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_self_attn_q_proj_weight3, layer_norm201, model_decoder_layers_13_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape840 = R.call_tir(cls.reshape4, (lv302,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv78 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_13_self_attn_k_proj_weight3, layer_norm201), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape841 = R.call_tir(cls.reshape4, (lv78,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv303 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_self_attn_v_proj_weight3, layer_norm201, model_decoder_layers_13_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape842 = R.call_tir(cls.reshape4, (lv303,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat45 = R.call_tir(cls.concatenate, (reshape840, reshape841, reshape842), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape843 = R.call_tir(cls.reshape5, (concat45,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv160 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape843), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape844 = R.call_tir(cls.reshape6, (lv160,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape845 = R.call_tir(cls.reshape7, (reshape844,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv304 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_self_attn_out_proj_weight3, reshape845, model_decoder_layers_13_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add712 = R.call_tir(cls.add, (add708, lv304), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm202 = R.call_tir(cls.layer_norm, (add712, model_decoder_layers_13_encoder_attn_layer_norm_weight3, model_decoder_layers_13_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv305 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_encoder_attn_q_proj_weight3, layer_norm202, model_decoder_layers_13_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape846 = R.call_tir(cls.reshape4, (lv305,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape847 = R.call_tir(cls.reshape8, (reshape846,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv161 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape847), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape848 = R.call_tir(cls.reshape6, (lv161,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape849 = R.call_tir(cls.reshape7, (reshape848,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv306 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_13_encoder_attn_out_proj_weight3, reshape849, model_decoder_layers_13_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add715 = R.call_tir(cls.add, (add712, lv306), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm203 = R.call_tir(cls.layer_norm, (add715, model_decoder_layers_13_final_layer_norm_weight3, model_decoder_layers_13_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv45 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_13_fc1_weight3, layer_norm203, model_decoder_layers_13_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv307 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_13_fc2_weight3, lv45, model_decoder_layers_13_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add718 = R.call_tir(cls.add, (add715, lv307), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm204 = R.call_tir(cls.layer_norm, (add718, model_decoder_layers_14_self_attn_layer_norm_weight3, model_decoder_layers_14_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv308 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_self_attn_q_proj_weight3, layer_norm204, model_decoder_layers_14_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape850 = R.call_tir(cls.reshape4, (lv308,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv79 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_14_self_attn_k_proj_weight3, layer_norm204), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape851 = R.call_tir(cls.reshape4, (lv79,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv309 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_self_attn_v_proj_weight3, layer_norm204, model_decoder_layers_14_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape852 = R.call_tir(cls.reshape4, (lv309,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat46 = R.call_tir(cls.concatenate, (reshape850, reshape851, reshape852), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape853 = R.call_tir(cls.reshape5, (concat46,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv162 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape853), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape854 = R.call_tir(cls.reshape6, (lv162,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape855 = R.call_tir(cls.reshape7, (reshape854,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv310 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_self_attn_out_proj_weight3, reshape855, model_decoder_layers_14_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add722 = R.call_tir(cls.add, (add718, lv310), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm205 = R.call_tir(cls.layer_norm, (add722, model_decoder_layers_14_encoder_attn_layer_norm_weight3, model_decoder_layers_14_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv311 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_encoder_attn_q_proj_weight3, layer_norm205, model_decoder_layers_14_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape856 = R.call_tir(cls.reshape4, (lv311,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape857 = R.call_tir(cls.reshape8, (reshape856,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv163 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape857), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape858 = R.call_tir(cls.reshape6, (lv163,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape859 = R.call_tir(cls.reshape7, (reshape858,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv312 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_14_encoder_attn_out_proj_weight3, reshape859, model_decoder_layers_14_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add725 = R.call_tir(cls.add, (add722, lv312), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm206 = R.call_tir(cls.layer_norm, (add725, model_decoder_layers_14_final_layer_norm_weight3, model_decoder_layers_14_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv46 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_14_fc1_weight3, layer_norm206, model_decoder_layers_14_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv313 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_14_fc2_weight3, lv46, model_decoder_layers_14_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add728 = R.call_tir(cls.add, (add725, lv313), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm207 = R.call_tir(cls.layer_norm, (add728, model_decoder_layers_15_self_attn_layer_norm_weight3, model_decoder_layers_15_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv314 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_self_attn_q_proj_weight3, layer_norm207, model_decoder_layers_15_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape860 = R.call_tir(cls.reshape4, (lv314,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv80 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_15_self_attn_k_proj_weight3, layer_norm207), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape861 = R.call_tir(cls.reshape4, (lv80,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv315 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_self_attn_v_proj_weight3, layer_norm207, model_decoder_layers_15_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape862 = R.call_tir(cls.reshape4, (lv315,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat47 = R.call_tir(cls.concatenate, (reshape860, reshape861, reshape862), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape863 = R.call_tir(cls.reshape5, (concat47,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv164 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape863), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape864 = R.call_tir(cls.reshape6, (lv164,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape865 = R.call_tir(cls.reshape7, (reshape864,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv316 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_self_attn_out_proj_weight3, reshape865, model_decoder_layers_15_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add732 = R.call_tir(cls.add, (add728, lv316), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm208 = R.call_tir(cls.layer_norm, (add732, model_decoder_layers_15_encoder_attn_layer_norm_weight3, model_decoder_layers_15_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv317 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_encoder_attn_q_proj_weight3, layer_norm208, model_decoder_layers_15_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape866 = R.call_tir(cls.reshape4, (lv317,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape867 = R.call_tir(cls.reshape8, (reshape866,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv165 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape867), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape868 = R.call_tir(cls.reshape6, (lv165,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape869 = R.call_tir(cls.reshape7, (reshape868,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv318 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_15_encoder_attn_out_proj_weight3, reshape869, model_decoder_layers_15_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add735 = R.call_tir(cls.add, (add732, lv318), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm209 = R.call_tir(cls.layer_norm, (add735, model_decoder_layers_15_final_layer_norm_weight3, model_decoder_layers_15_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv47 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_15_fc1_weight3, layer_norm209, model_decoder_layers_15_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv319 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_15_fc2_weight3, lv47, model_decoder_layers_15_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add738 = R.call_tir(cls.add, (add735, lv319), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm210 = R.call_tir(cls.layer_norm, (add738, model_decoder_layers_16_self_attn_layer_norm_weight3, model_decoder_layers_16_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv320 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_self_attn_q_proj_weight3, layer_norm210, model_decoder_layers_16_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape870 = R.call_tir(cls.reshape4, (lv320,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv81 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_16_self_attn_k_proj_weight3, layer_norm210), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape871 = R.call_tir(cls.reshape4, (lv81,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv321 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_self_attn_v_proj_weight3, layer_norm210, model_decoder_layers_16_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape872 = R.call_tir(cls.reshape4, (lv321,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat48 = R.call_tir(cls.concatenate, (reshape870, reshape871, reshape872), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape873 = R.call_tir(cls.reshape5, (concat48,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv166 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape873), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape874 = R.call_tir(cls.reshape6, (lv166,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape875 = R.call_tir(cls.reshape7, (reshape874,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv322 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_self_attn_out_proj_weight3, reshape875, model_decoder_layers_16_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add742 = R.call_tir(cls.add, (add738, lv322), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm211 = R.call_tir(cls.layer_norm, (add742, model_decoder_layers_16_encoder_attn_layer_norm_weight3, model_decoder_layers_16_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv323 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_encoder_attn_q_proj_weight3, layer_norm211, model_decoder_layers_16_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape876 = R.call_tir(cls.reshape4, (lv323,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape877 = R.call_tir(cls.reshape8, (reshape876,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv167 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape877), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape878 = R.call_tir(cls.reshape6, (lv167,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape879 = R.call_tir(cls.reshape7, (reshape878,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv324 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_16_encoder_attn_out_proj_weight3, reshape879, model_decoder_layers_16_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add745 = R.call_tir(cls.add, (add742, lv324), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm212 = R.call_tir(cls.layer_norm, (add745, model_decoder_layers_16_final_layer_norm_weight3, model_decoder_layers_16_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv48 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_16_fc1_weight3, layer_norm212, model_decoder_layers_16_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv325 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_16_fc2_weight3, lv48, model_decoder_layers_16_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add748 = R.call_tir(cls.add, (add745, lv325), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm213 = R.call_tir(cls.layer_norm, (add748, model_decoder_layers_17_self_attn_layer_norm_weight3, model_decoder_layers_17_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv326 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_self_attn_q_proj_weight3, layer_norm213, model_decoder_layers_17_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape880 = R.call_tir(cls.reshape4, (lv326,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv82 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_17_self_attn_k_proj_weight3, layer_norm213), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape881 = R.call_tir(cls.reshape4, (lv82,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv327 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_self_attn_v_proj_weight3, layer_norm213, model_decoder_layers_17_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape882 = R.call_tir(cls.reshape4, (lv327,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat49 = R.call_tir(cls.concatenate, (reshape880, reshape881, reshape882), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape883 = R.call_tir(cls.reshape5, (concat49,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv168 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape883), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape884 = R.call_tir(cls.reshape6, (lv168,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape885 = R.call_tir(cls.reshape7, (reshape884,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv328 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_self_attn_out_proj_weight3, reshape885, model_decoder_layers_17_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add752 = R.call_tir(cls.add, (add748, lv328), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm214 = R.call_tir(cls.layer_norm, (add752, model_decoder_layers_17_encoder_attn_layer_norm_weight3, model_decoder_layers_17_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv329 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_encoder_attn_q_proj_weight3, layer_norm214, model_decoder_layers_17_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape886 = R.call_tir(cls.reshape4, (lv329,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape887 = R.call_tir(cls.reshape8, (reshape886,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv169 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape887), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape888 = R.call_tir(cls.reshape6, (lv169,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape889 = R.call_tir(cls.reshape7, (reshape888,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv330 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_17_encoder_attn_out_proj_weight3, reshape889, model_decoder_layers_17_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add755 = R.call_tir(cls.add, (add752, lv330), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm215 = R.call_tir(cls.layer_norm, (add755, model_decoder_layers_17_final_layer_norm_weight3, model_decoder_layers_17_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv49 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_17_fc1_weight3, layer_norm215, model_decoder_layers_17_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv331 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_17_fc2_weight3, lv49, model_decoder_layers_17_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add758 = R.call_tir(cls.add, (add755, lv331), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm216 = R.call_tir(cls.layer_norm, (add758, model_decoder_layers_18_self_attn_layer_norm_weight3, model_decoder_layers_18_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv332 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_self_attn_q_proj_weight3, layer_norm216, model_decoder_layers_18_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape890 = R.call_tir(cls.reshape4, (lv332,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv83 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_18_self_attn_k_proj_weight3, layer_norm216), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape891 = R.call_tir(cls.reshape4, (lv83,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv333 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_self_attn_v_proj_weight3, layer_norm216, model_decoder_layers_18_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape892 = R.call_tir(cls.reshape4, (lv333,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat50 = R.call_tir(cls.concatenate, (reshape890, reshape891, reshape892), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape893 = R.call_tir(cls.reshape5, (concat50,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv170 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape893), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape894 = R.call_tir(cls.reshape6, (lv170,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape895 = R.call_tir(cls.reshape7, (reshape894,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv334 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_self_attn_out_proj_weight3, reshape895, model_decoder_layers_18_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add762 = R.call_tir(cls.add, (add758, lv334), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm217 = R.call_tir(cls.layer_norm, (add762, model_decoder_layers_18_encoder_attn_layer_norm_weight3, model_decoder_layers_18_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv335 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_encoder_attn_q_proj_weight3, layer_norm217, model_decoder_layers_18_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape896 = R.call_tir(cls.reshape4, (lv335,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape897 = R.call_tir(cls.reshape8, (reshape896,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv171 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape897), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape898 = R.call_tir(cls.reshape6, (lv171,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape899 = R.call_tir(cls.reshape7, (reshape898,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv336 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_18_encoder_attn_out_proj_weight3, reshape899, model_decoder_layers_18_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add765 = R.call_tir(cls.add, (add762, lv336), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm218 = R.call_tir(cls.layer_norm, (add765, model_decoder_layers_18_final_layer_norm_weight3, model_decoder_layers_18_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv50 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_18_fc1_weight3, layer_norm218, model_decoder_layers_18_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv337 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_18_fc2_weight3, lv50, model_decoder_layers_18_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add768 = R.call_tir(cls.add, (add765, lv337), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm219 = R.call_tir(cls.layer_norm, (add768, model_decoder_layers_19_self_attn_layer_norm_weight3, model_decoder_layers_19_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv338 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_self_attn_q_proj_weight3, layer_norm219, model_decoder_layers_19_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape900 = R.call_tir(cls.reshape4, (lv338,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv84 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_19_self_attn_k_proj_weight3, layer_norm219), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape901 = R.call_tir(cls.reshape4, (lv84,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv339 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_self_attn_v_proj_weight3, layer_norm219, model_decoder_layers_19_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape902 = R.call_tir(cls.reshape4, (lv339,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat51 = R.call_tir(cls.concatenate, (reshape900, reshape901, reshape902), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape903 = R.call_tir(cls.reshape5, (concat51,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv172 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape903), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape904 = R.call_tir(cls.reshape6, (lv172,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape905 = R.call_tir(cls.reshape7, (reshape904,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv340 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_self_attn_out_proj_weight3, reshape905, model_decoder_layers_19_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add772 = R.call_tir(cls.add, (add768, lv340), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm220 = R.call_tir(cls.layer_norm, (add772, model_decoder_layers_19_encoder_attn_layer_norm_weight3, model_decoder_layers_19_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv341 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_encoder_attn_q_proj_weight3, layer_norm220, model_decoder_layers_19_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape906 = R.call_tir(cls.reshape4, (lv341,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape907 = R.call_tir(cls.reshape8, (reshape906,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv173 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape907), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape908 = R.call_tir(cls.reshape6, (lv173,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape909 = R.call_tir(cls.reshape7, (reshape908,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv342 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_19_encoder_attn_out_proj_weight3, reshape909, model_decoder_layers_19_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add775 = R.call_tir(cls.add, (add772, lv342), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm221 = R.call_tir(cls.layer_norm, (add775, model_decoder_layers_19_final_layer_norm_weight3, model_decoder_layers_19_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv51 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_19_fc1_weight3, layer_norm221, model_decoder_layers_19_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv343 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_19_fc2_weight3, lv51, model_decoder_layers_19_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add778 = R.call_tir(cls.add, (add775, lv343), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm222 = R.call_tir(cls.layer_norm, (add778, model_decoder_layers_20_self_attn_layer_norm_weight3, model_decoder_layers_20_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv344 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_self_attn_q_proj_weight3, layer_norm222, model_decoder_layers_20_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape910 = R.call_tir(cls.reshape4, (lv344,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv85 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_20_self_attn_k_proj_weight3, layer_norm222), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape911 = R.call_tir(cls.reshape4, (lv85,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv345 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_self_attn_v_proj_weight3, layer_norm222, model_decoder_layers_20_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape912 = R.call_tir(cls.reshape4, (lv345,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat52 = R.call_tir(cls.concatenate, (reshape910, reshape911, reshape912), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape913 = R.call_tir(cls.reshape5, (concat52,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv174 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape913), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape914 = R.call_tir(cls.reshape6, (lv174,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape915 = R.call_tir(cls.reshape7, (reshape914,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv346 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_self_attn_out_proj_weight3, reshape915, model_decoder_layers_20_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add782 = R.call_tir(cls.add, (add778, lv346), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm223 = R.call_tir(cls.layer_norm, (add782, model_decoder_layers_20_encoder_attn_layer_norm_weight3, model_decoder_layers_20_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv347 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_encoder_attn_q_proj_weight3, layer_norm223, model_decoder_layers_20_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape916 = R.call_tir(cls.reshape4, (lv347,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape917 = R.call_tir(cls.reshape8, (reshape916,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv175 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape917), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape918 = R.call_tir(cls.reshape6, (lv175,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape919 = R.call_tir(cls.reshape7, (reshape918,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv348 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_20_encoder_attn_out_proj_weight3, reshape919, model_decoder_layers_20_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add785 = R.call_tir(cls.add, (add782, lv348), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm224 = R.call_tir(cls.layer_norm, (add785, model_decoder_layers_20_final_layer_norm_weight3, model_decoder_layers_20_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv52 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_20_fc1_weight3, layer_norm224, model_decoder_layers_20_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv349 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_20_fc2_weight3, lv52, model_decoder_layers_20_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add788 = R.call_tir(cls.add, (add785, lv349), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm225 = R.call_tir(cls.layer_norm, (add788, model_decoder_layers_21_self_attn_layer_norm_weight3, model_decoder_layers_21_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv350 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_self_attn_q_proj_weight3, layer_norm225, model_decoder_layers_21_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape920 = R.call_tir(cls.reshape4, (lv350,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv86 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_21_self_attn_k_proj_weight3, layer_norm225), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape921 = R.call_tir(cls.reshape4, (lv86,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv351 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_self_attn_v_proj_weight3, layer_norm225, model_decoder_layers_21_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape922 = R.call_tir(cls.reshape4, (lv351,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat53 = R.call_tir(cls.concatenate, (reshape920, reshape921, reshape922), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape923 = R.call_tir(cls.reshape5, (concat53,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv176 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape923), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape924 = R.call_tir(cls.reshape6, (lv176,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape925 = R.call_tir(cls.reshape7, (reshape924,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv352 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_self_attn_out_proj_weight3, reshape925, model_decoder_layers_21_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add792 = R.call_tir(cls.add, (add788, lv352), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm226 = R.call_tir(cls.layer_norm, (add792, model_decoder_layers_21_encoder_attn_layer_norm_weight3, model_decoder_layers_21_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv353 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_encoder_attn_q_proj_weight3, layer_norm226, model_decoder_layers_21_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape926 = R.call_tir(cls.reshape4, (lv353,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape927 = R.call_tir(cls.reshape8, (reshape926,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv177 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape927), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape928 = R.call_tir(cls.reshape6, (lv177,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape929 = R.call_tir(cls.reshape7, (reshape928,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv354 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_21_encoder_attn_out_proj_weight3, reshape929, model_decoder_layers_21_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add795 = R.call_tir(cls.add, (add792, lv354), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm227 = R.call_tir(cls.layer_norm, (add795, model_decoder_layers_21_final_layer_norm_weight3, model_decoder_layers_21_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv53 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_21_fc1_weight3, layer_norm227, model_decoder_layers_21_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv355 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_21_fc2_weight3, lv53, model_decoder_layers_21_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add798 = R.call_tir(cls.add, (add795, lv355), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm228 = R.call_tir(cls.layer_norm, (add798, model_decoder_layers_22_self_attn_layer_norm_weight3, model_decoder_layers_22_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv356 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_self_attn_q_proj_weight3, layer_norm228, model_decoder_layers_22_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape930 = R.call_tir(cls.reshape4, (lv356,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv87 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_22_self_attn_k_proj_weight3, layer_norm228), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape931 = R.call_tir(cls.reshape4, (lv87,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv357 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_self_attn_v_proj_weight3, layer_norm228, model_decoder_layers_22_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape932 = R.call_tir(cls.reshape4, (lv357,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat54 = R.call_tir(cls.concatenate, (reshape930, reshape931, reshape932), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape933 = R.call_tir(cls.reshape5, (concat54,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv178 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape933), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape934 = R.call_tir(cls.reshape6, (lv178,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape935 = R.call_tir(cls.reshape7, (reshape934,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv358 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_self_attn_out_proj_weight3, reshape935, model_decoder_layers_22_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add802 = R.call_tir(cls.add, (add798, lv358), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm229 = R.call_tir(cls.layer_norm, (add802, model_decoder_layers_22_encoder_attn_layer_norm_weight3, model_decoder_layers_22_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv359 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_encoder_attn_q_proj_weight3, layer_norm229, model_decoder_layers_22_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape936 = R.call_tir(cls.reshape4, (lv359,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape937 = R.call_tir(cls.reshape8, (reshape936,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv179 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape937), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape938 = R.call_tir(cls.reshape6, (lv179,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape939 = R.call_tir(cls.reshape7, (reshape938,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv360 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_22_encoder_attn_out_proj_weight3, reshape939, model_decoder_layers_22_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add805 = R.call_tir(cls.add, (add802, lv360), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm230 = R.call_tir(cls.layer_norm, (add805, model_decoder_layers_22_final_layer_norm_weight3, model_decoder_layers_22_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv54 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_22_fc1_weight3, layer_norm230, model_decoder_layers_22_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv361 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_22_fc2_weight3, lv54, model_decoder_layers_22_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add808 = R.call_tir(cls.add, (add805, lv361), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm231 = R.call_tir(cls.layer_norm, (add808, model_decoder_layers_23_self_attn_layer_norm_weight3, model_decoder_layers_23_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv362 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_self_attn_q_proj_weight3, layer_norm231, model_decoder_layers_23_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape940 = R.call_tir(cls.reshape4, (lv362,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv88 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_23_self_attn_k_proj_weight3, layer_norm231), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape941 = R.call_tir(cls.reshape4, (lv88,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv363 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_self_attn_v_proj_weight3, layer_norm231, model_decoder_layers_23_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape942 = R.call_tir(cls.reshape4, (lv363,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat55 = R.call_tir(cls.concatenate, (reshape940, reshape941, reshape942), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape943 = R.call_tir(cls.reshape5, (concat55,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv180 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape943), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape944 = R.call_tir(cls.reshape6, (lv180,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape945 = R.call_tir(cls.reshape7, (reshape944,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv364 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_self_attn_out_proj_weight3, reshape945, model_decoder_layers_23_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add812 = R.call_tir(cls.add, (add808, lv364), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm232 = R.call_tir(cls.layer_norm, (add812, model_decoder_layers_23_encoder_attn_layer_norm_weight3, model_decoder_layers_23_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv365 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_encoder_attn_q_proj_weight3, layer_norm232, model_decoder_layers_23_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape946 = R.call_tir(cls.reshape4, (lv365,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape947 = R.call_tir(cls.reshape8, (reshape946,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv181 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape947), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape948 = R.call_tir(cls.reshape6, (lv181,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape949 = R.call_tir(cls.reshape7, (reshape948,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv366 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_23_encoder_attn_out_proj_weight3, reshape949, model_decoder_layers_23_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add815 = R.call_tir(cls.add, (add812, lv366), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm233 = R.call_tir(cls.layer_norm, (add815, model_decoder_layers_23_final_layer_norm_weight3, model_decoder_layers_23_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv55 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_23_fc1_weight3, layer_norm233, model_decoder_layers_23_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv367 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_23_fc2_weight3, lv55, model_decoder_layers_23_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add818 = R.call_tir(cls.add, (add815, lv367), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm234 = R.call_tir(cls.layer_norm, (add818, model_decoder_layers_24_self_attn_layer_norm_weight3, model_decoder_layers_24_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv368 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_self_attn_q_proj_weight3, layer_norm234, model_decoder_layers_24_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape950 = R.call_tir(cls.reshape4, (lv368,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv89 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_24_self_attn_k_proj_weight3, layer_norm234), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape951 = R.call_tir(cls.reshape4, (lv89,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv369 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_self_attn_v_proj_weight3, layer_norm234, model_decoder_layers_24_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape952 = R.call_tir(cls.reshape4, (lv369,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat56 = R.call_tir(cls.concatenate, (reshape950, reshape951, reshape952), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape953 = R.call_tir(cls.reshape5, (concat56,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv182 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape953), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape954 = R.call_tir(cls.reshape6, (lv182,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape955 = R.call_tir(cls.reshape7, (reshape954,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv370 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_self_attn_out_proj_weight3, reshape955, model_decoder_layers_24_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add822 = R.call_tir(cls.add, (add818, lv370), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm235 = R.call_tir(cls.layer_norm, (add822, model_decoder_layers_24_encoder_attn_layer_norm_weight3, model_decoder_layers_24_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv371 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_encoder_attn_q_proj_weight3, layer_norm235, model_decoder_layers_24_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape956 = R.call_tir(cls.reshape4, (lv371,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape957 = R.call_tir(cls.reshape8, (reshape956,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv183 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape957), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape958 = R.call_tir(cls.reshape6, (lv183,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape959 = R.call_tir(cls.reshape7, (reshape958,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv372 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_24_encoder_attn_out_proj_weight3, reshape959, model_decoder_layers_24_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add825 = R.call_tir(cls.add, (add822, lv372), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm236 = R.call_tir(cls.layer_norm, (add825, model_decoder_layers_24_final_layer_norm_weight3, model_decoder_layers_24_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv56 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_24_fc1_weight3, layer_norm236, model_decoder_layers_24_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv373 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_24_fc2_weight3, lv56, model_decoder_layers_24_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add828 = R.call_tir(cls.add, (add825, lv373), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm237 = R.call_tir(cls.layer_norm, (add828, model_decoder_layers_25_self_attn_layer_norm_weight3, model_decoder_layers_25_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv374 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_self_attn_q_proj_weight3, layer_norm237, model_decoder_layers_25_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape960 = R.call_tir(cls.reshape4, (lv374,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv90 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_25_self_attn_k_proj_weight3, layer_norm237), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape961 = R.call_tir(cls.reshape4, (lv90,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv375 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_self_attn_v_proj_weight3, layer_norm237, model_decoder_layers_25_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape962 = R.call_tir(cls.reshape4, (lv375,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat57 = R.call_tir(cls.concatenate, (reshape960, reshape961, reshape962), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape963 = R.call_tir(cls.reshape5, (concat57,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv184 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape963), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape964 = R.call_tir(cls.reshape6, (lv184,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape965 = R.call_tir(cls.reshape7, (reshape964,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv376 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_self_attn_out_proj_weight3, reshape965, model_decoder_layers_25_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add832 = R.call_tir(cls.add, (add828, lv376), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm238 = R.call_tir(cls.layer_norm, (add832, model_decoder_layers_25_encoder_attn_layer_norm_weight3, model_decoder_layers_25_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv377 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_encoder_attn_q_proj_weight3, layer_norm238, model_decoder_layers_25_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape966 = R.call_tir(cls.reshape4, (lv377,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape967 = R.call_tir(cls.reshape8, (reshape966,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv185 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape967), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape968 = R.call_tir(cls.reshape6, (lv185,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape969 = R.call_tir(cls.reshape7, (reshape968,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv378 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_25_encoder_attn_out_proj_weight3, reshape969, model_decoder_layers_25_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add835 = R.call_tir(cls.add, (add832, lv378), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm239 = R.call_tir(cls.layer_norm, (add835, model_decoder_layers_25_final_layer_norm_weight3, model_decoder_layers_25_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv57 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_25_fc1_weight3, layer_norm239, model_decoder_layers_25_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv379 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_25_fc2_weight3, lv57, model_decoder_layers_25_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add838 = R.call_tir(cls.add, (add835, lv379), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm240 = R.call_tir(cls.layer_norm, (add838, model_decoder_layers_26_self_attn_layer_norm_weight3, model_decoder_layers_26_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv380 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_self_attn_q_proj_weight3, layer_norm240, model_decoder_layers_26_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape970 = R.call_tir(cls.reshape4, (lv380,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv91 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_26_self_attn_k_proj_weight3, layer_norm240), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape971 = R.call_tir(cls.reshape4, (lv91,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv381 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_self_attn_v_proj_weight3, layer_norm240, model_decoder_layers_26_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape972 = R.call_tir(cls.reshape4, (lv381,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat58 = R.call_tir(cls.concatenate, (reshape970, reshape971, reshape972), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape973 = R.call_tir(cls.reshape5, (concat58,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv186 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape973), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape974 = R.call_tir(cls.reshape6, (lv186,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape975 = R.call_tir(cls.reshape7, (reshape974,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv382 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_self_attn_out_proj_weight3, reshape975, model_decoder_layers_26_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add842 = R.call_tir(cls.add, (add838, lv382), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm241 = R.call_tir(cls.layer_norm, (add842, model_decoder_layers_26_encoder_attn_layer_norm_weight3, model_decoder_layers_26_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv383 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_encoder_attn_q_proj_weight3, layer_norm241, model_decoder_layers_26_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape976 = R.call_tir(cls.reshape4, (lv383,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape977 = R.call_tir(cls.reshape8, (reshape976,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv187 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape977), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape978 = R.call_tir(cls.reshape6, (lv187,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape979 = R.call_tir(cls.reshape7, (reshape978,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv384 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_26_encoder_attn_out_proj_weight3, reshape979, model_decoder_layers_26_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add845 = R.call_tir(cls.add, (add842, lv384), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm242 = R.call_tir(cls.layer_norm, (add845, model_decoder_layers_26_final_layer_norm_weight3, model_decoder_layers_26_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv58 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_26_fc1_weight3, layer_norm242, model_decoder_layers_26_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv385 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_26_fc2_weight3, lv58, model_decoder_layers_26_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add848 = R.call_tir(cls.add, (add845, lv385), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm243 = R.call_tir(cls.layer_norm, (add848, model_decoder_layers_27_self_attn_layer_norm_weight3, model_decoder_layers_27_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv386 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_self_attn_q_proj_weight3, layer_norm243, model_decoder_layers_27_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape980 = R.call_tir(cls.reshape4, (lv386,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv92 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_27_self_attn_k_proj_weight3, layer_norm243), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape981 = R.call_tir(cls.reshape4, (lv92,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv387 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_self_attn_v_proj_weight3, layer_norm243, model_decoder_layers_27_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape982 = R.call_tir(cls.reshape4, (lv387,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat59 = R.call_tir(cls.concatenate, (reshape980, reshape981, reshape982), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape983 = R.call_tir(cls.reshape5, (concat59,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv188 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape983), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape984 = R.call_tir(cls.reshape6, (lv188,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape985 = R.call_tir(cls.reshape7, (reshape984,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv388 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_self_attn_out_proj_weight3, reshape985, model_decoder_layers_27_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add852 = R.call_tir(cls.add, (add848, lv388), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm244 = R.call_tir(cls.layer_norm, (add852, model_decoder_layers_27_encoder_attn_layer_norm_weight3, model_decoder_layers_27_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv389 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_encoder_attn_q_proj_weight3, layer_norm244, model_decoder_layers_27_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape986 = R.call_tir(cls.reshape4, (lv389,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape987 = R.call_tir(cls.reshape8, (reshape986,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv189 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape987), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape988 = R.call_tir(cls.reshape6, (lv189,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape989 = R.call_tir(cls.reshape7, (reshape988,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv390 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_27_encoder_attn_out_proj_weight3, reshape989, model_decoder_layers_27_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add855 = R.call_tir(cls.add, (add852, lv390), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm245 = R.call_tir(cls.layer_norm, (add855, model_decoder_layers_27_final_layer_norm_weight3, model_decoder_layers_27_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv59 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_27_fc1_weight3, layer_norm245, model_decoder_layers_27_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv391 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_27_fc2_weight3, lv59, model_decoder_layers_27_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add858 = R.call_tir(cls.add, (add855, lv391), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm246 = R.call_tir(cls.layer_norm, (add858, model_decoder_layers_28_self_attn_layer_norm_weight3, model_decoder_layers_28_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv392 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_self_attn_q_proj_weight3, layer_norm246, model_decoder_layers_28_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape990 = R.call_tir(cls.reshape4, (lv392,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv93 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_28_self_attn_k_proj_weight3, layer_norm246), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape991 = R.call_tir(cls.reshape4, (lv93,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv393 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_self_attn_v_proj_weight3, layer_norm246, model_decoder_layers_28_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape992 = R.call_tir(cls.reshape4, (lv393,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat60 = R.call_tir(cls.concatenate, (reshape990, reshape991, reshape992), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape993 = R.call_tir(cls.reshape5, (concat60,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv190 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape993), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape994 = R.call_tir(cls.reshape6, (lv190,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape995 = R.call_tir(cls.reshape7, (reshape994,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv394 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_self_attn_out_proj_weight3, reshape995, model_decoder_layers_28_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add862 = R.call_tir(cls.add, (add858, lv394), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm247 = R.call_tir(cls.layer_norm, (add862, model_decoder_layers_28_encoder_attn_layer_norm_weight3, model_decoder_layers_28_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv395 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_encoder_attn_q_proj_weight3, layer_norm247, model_decoder_layers_28_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape996 = R.call_tir(cls.reshape4, (lv395,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape997 = R.call_tir(cls.reshape8, (reshape996,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv191 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape997), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape998 = R.call_tir(cls.reshape6, (lv191,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape999 = R.call_tir(cls.reshape7, (reshape998,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv396 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_28_encoder_attn_out_proj_weight3, reshape999, model_decoder_layers_28_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add865 = R.call_tir(cls.add, (add862, lv396), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm248 = R.call_tir(cls.layer_norm, (add865, model_decoder_layers_28_final_layer_norm_weight3, model_decoder_layers_28_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv60 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_28_fc1_weight3, layer_norm248, model_decoder_layers_28_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv397 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_28_fc2_weight3, lv60, model_decoder_layers_28_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add868 = R.call_tir(cls.add, (add865, lv397), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm249 = R.call_tir(cls.layer_norm, (add868, model_decoder_layers_29_self_attn_layer_norm_weight3, model_decoder_layers_29_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv398 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_self_attn_q_proj_weight3, layer_norm249, model_decoder_layers_29_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1000 = R.call_tir(cls.reshape4, (lv398,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv94 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_29_self_attn_k_proj_weight3, layer_norm249), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1001 = R.call_tir(cls.reshape4, (lv94,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv399 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_self_attn_v_proj_weight3, layer_norm249, model_decoder_layers_29_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1002 = R.call_tir(cls.reshape4, (lv399,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat61 = R.call_tir(cls.concatenate, (reshape1000, reshape1001, reshape1002), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape1003 = R.call_tir(cls.reshape5, (concat61,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv192 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1003), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape1004 = R.call_tir(cls.reshape6, (lv192,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1005 = R.call_tir(cls.reshape7, (reshape1004,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv400 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_self_attn_out_proj_weight3, reshape1005, model_decoder_layers_29_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add872 = R.call_tir(cls.add, (add868, lv400), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm250 = R.call_tir(cls.layer_norm, (add872, model_decoder_layers_29_encoder_attn_layer_norm_weight3, model_decoder_layers_29_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv401 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_encoder_attn_q_proj_weight3, layer_norm250, model_decoder_layers_29_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1006 = R.call_tir(cls.reshape4, (lv401,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1007 = R.call_tir(cls.reshape8, (reshape1006,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv193 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1007), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape1008 = R.call_tir(cls.reshape6, (lv193,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1009 = R.call_tir(cls.reshape7, (reshape1008,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv402 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_29_encoder_attn_out_proj_weight3, reshape1009, model_decoder_layers_29_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add875 = R.call_tir(cls.add, (add872, lv402), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm251 = R.call_tir(cls.layer_norm, (add875, model_decoder_layers_29_final_layer_norm_weight3, model_decoder_layers_29_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv61 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_29_fc1_weight3, layer_norm251, model_decoder_layers_29_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv403 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_29_fc2_weight3, lv61, model_decoder_layers_29_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add878 = R.call_tir(cls.add, (add875, lv403), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm252 = R.call_tir(cls.layer_norm, (add878, model_decoder_layers_30_self_attn_layer_norm_weight3, model_decoder_layers_30_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv404 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_self_attn_q_proj_weight3, layer_norm252, model_decoder_layers_30_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1010 = R.call_tir(cls.reshape4, (lv404,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv95 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_30_self_attn_k_proj_weight3, layer_norm252), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1011 = R.call_tir(cls.reshape4, (lv95,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv405 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_self_attn_v_proj_weight3, layer_norm252, model_decoder_layers_30_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1012 = R.call_tir(cls.reshape4, (lv405,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat62 = R.call_tir(cls.concatenate, (reshape1010, reshape1011, reshape1012), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape1013 = R.call_tir(cls.reshape5, (concat62,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv194 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1013), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape1014 = R.call_tir(cls.reshape6, (lv194,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1015 = R.call_tir(cls.reshape7, (reshape1014,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv406 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_self_attn_out_proj_weight3, reshape1015, model_decoder_layers_30_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add882 = R.call_tir(cls.add, (add878, lv406), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm253 = R.call_tir(cls.layer_norm, (add882, model_decoder_layers_30_encoder_attn_layer_norm_weight3, model_decoder_layers_30_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv407 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_encoder_attn_q_proj_weight3, layer_norm253, model_decoder_layers_30_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1016 = R.call_tir(cls.reshape4, (lv407,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1017 = R.call_tir(cls.reshape8, (reshape1016,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv195 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1017), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape1018 = R.call_tir(cls.reshape6, (lv195,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1019 = R.call_tir(cls.reshape7, (reshape1018,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv408 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_30_encoder_attn_out_proj_weight3, reshape1019, model_decoder_layers_30_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add885 = R.call_tir(cls.add, (add882, lv408), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm254 = R.call_tir(cls.layer_norm, (add885, model_decoder_layers_30_final_layer_norm_weight3, model_decoder_layers_30_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv62 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_30_fc1_weight3, layer_norm254, model_decoder_layers_30_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv409 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_30_fc2_weight3, lv62, model_decoder_layers_30_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add888 = R.call_tir(cls.add, (add885, lv409), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm255 = R.call_tir(cls.layer_norm, (add888, model_decoder_layers_31_self_attn_layer_norm_weight3, model_decoder_layers_31_self_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv410 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_self_attn_q_proj_weight3, layer_norm255, model_decoder_layers_31_self_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1020 = R.call_tir(cls.reshape4, (lv410,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv96 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul3_cublas", (model_decoder_layers_31_self_attn_k_proj_weight3, layer_norm255), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1021 = R.call_tir(cls.reshape4, (lv96,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) lv411 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_self_attn_v_proj_weight3, layer_norm255, model_decoder_layers_31_self_attn_v_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1022 = R.call_tir(cls.reshape4, (lv411,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) concat63 = R.call_tir(cls.concatenate, (reshape1020, reshape1021, reshape1022), out_sinfo=R.Tensor((batch_size, 1, 60, 64), dtype="float16")) reshape1023 = R.call_tir(cls.reshape5, (concat63,), out_sinfo=R.Tensor((batch_size, 60, 64), dtype="float16")) lv196 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1023), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape1024 = R.call_tir(cls.reshape6, (lv196,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1025 = R.call_tir(cls.reshape7, (reshape1024,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv412 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_self_attn_out_proj_weight3, reshape1025, model_decoder_layers_31_self_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add892 = R.call_tir(cls.add, (add888, lv412), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm256 = R.call_tir(cls.layer_norm, (add892, model_decoder_layers_31_encoder_attn_layer_norm_weight3, model_decoder_layers_31_encoder_attn_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv413 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_encoder_attn_q_proj_weight3, layer_norm256, model_decoder_layers_31_encoder_attn_q_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) reshape1026 = R.call_tir(cls.reshape4, (lv413,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1027 = R.call_tir(cls.reshape8, (reshape1026,), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) lv197 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1027), out_sinfo=R.Tensor((batch_size, 20, 64), dtype="float16")) reshape1028 = R.call_tir(cls.reshape6, (lv197,), out_sinfo=R.Tensor((batch_size, 1, 20, 64), dtype="float16")) reshape1029 = R.call_tir(cls.reshape7, (reshape1028,), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv414 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add3_cublas", (model_decoder_layers_31_encoder_attn_out_proj_weight3, reshape1029, model_decoder_layers_31_encoder_attn_out_proj_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add895 = R.call_tir(cls.add, (add892, lv414), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm257 = R.call_tir(cls.layer_norm, (add895, model_decoder_layers_31_final_layer_norm_weight3, model_decoder_layers_31_final_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) lv63 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu1_cublas", (model_decoder_layers_31_fc1_weight3, layer_norm257, model_decoder_layers_31_fc1_bias3), out_sinfo=R.Tensor((batch_size, 1, 5120), dtype="float16")) lv415 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add4_cublas", (model_decoder_layers_31_fc2_weight3, lv63, model_decoder_layers_31_fc2_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) add898 = R.call_tir(cls.add, (add895, lv415), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) layer_norm258 = R.call_tir(cls.layer_norm, (add898, model_decoder_layer_norm_weight3, model_decoder_layer_norm_bias3), out_sinfo=R.Tensor((batch_size, 1, 1280), dtype="float16")) gv3 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul4_cublas", (model_decoder_embed_tokens_weight3, layer_norm258), out_sinfo=R.Tensor((batch_size, 1, 51866), dtype="float32")) R.output(gv3) return gv3 @R.function def batch_encode(input_features: R.Tensor(("batch_size", 128, 3000), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor(("batch_size", 1500, 1280), dtype="float16"): batch_size = T.int64() R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module with R.dataflow(): model_encoder_conv1_weight: R.Tensor((1280, 128, 3), dtype="float16") = packed_params[0] lv: R.Tensor((1280,), dtype="float16") = packed_params[1] lv1 = R.call_tir(cls.fused_reshape9, (lv,), out_sinfo=R.Tensor((1, 1280, 1), dtype="float16")) model_encoder_conv2_weight: R.Tensor((1280, 1280, 3), dtype="float16") = packed_params[2] lv2: R.Tensor((1280,), dtype="float16") = packed_params[3] lv3 = R.call_tir(cls.fused_reshape9, (lv2,), out_sinfo=R.Tensor((1, 1280, 1), dtype="float16")) lv4 = R.call_tir(cls.fused_conv1d_add1_gelu, (input_features, model_encoder_conv1_weight, lv1), out_sinfo=R.Tensor((batch_size, 1280, 3000), dtype="float16")) lv5 = R.call_tir(cls.fused_conv1d1_add2_gelu1, (lv4, model_encoder_conv2_weight, lv3), out_sinfo=R.Tensor((batch_size, 1280, 1500), dtype="float16")) lv6: R.Tensor((1500, 1280), dtype="float16") = packed_params[4] lv7 = R.call_tir(cls.fused_transpose_add3, (lv6, lv5), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) model_encoder_layers_0_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[5] model_encoder_layers_0_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[6] model_encoder_layers_0_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[7] model_encoder_layers_0_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[8] model_encoder_layers_0_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[9] model_encoder_layers_0_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[10] model_encoder_layers_0_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[11] model_encoder_layers_0_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[12] model_encoder_layers_0_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[13] model_encoder_layers_0_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[14] model_encoder_layers_0_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[15] model_encoder_layers_0_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[16] model_encoder_layers_0_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[17] model_encoder_layers_0_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[18] model_encoder_layers_0_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[19] model_encoder_layers_1_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[20] model_encoder_layers_1_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[21] model_encoder_layers_1_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[22] model_encoder_layers_1_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[23] model_encoder_layers_1_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[24] model_encoder_layers_1_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[25] model_encoder_layers_1_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[26] model_encoder_layers_1_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[27] model_encoder_layers_1_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[28] model_encoder_layers_1_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[29] model_encoder_layers_1_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[30] model_encoder_layers_1_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[31] model_encoder_layers_1_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[32] model_encoder_layers_1_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[33] model_encoder_layers_1_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[34] model_encoder_layers_2_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[35] model_encoder_layers_2_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[36] model_encoder_layers_2_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[37] model_encoder_layers_2_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[38] model_encoder_layers_2_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[39] model_encoder_layers_2_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[40] model_encoder_layers_2_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[41] model_encoder_layers_2_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[42] model_encoder_layers_2_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[43] model_encoder_layers_2_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[44] model_encoder_layers_2_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[45] model_encoder_layers_2_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[46] model_encoder_layers_2_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[47] model_encoder_layers_2_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[48] model_encoder_layers_2_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[49] model_encoder_layers_3_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[50] model_encoder_layers_3_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[51] model_encoder_layers_3_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[52] model_encoder_layers_3_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[53] model_encoder_layers_3_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[54] model_encoder_layers_3_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[55] model_encoder_layers_3_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[56] model_encoder_layers_3_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[57] model_encoder_layers_3_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[58] model_encoder_layers_3_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[59] model_encoder_layers_3_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[60] model_encoder_layers_3_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[61] model_encoder_layers_3_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[62] model_encoder_layers_3_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[63] model_encoder_layers_3_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[64] model_encoder_layers_4_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[65] model_encoder_layers_4_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[66] model_encoder_layers_4_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[67] model_encoder_layers_4_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[68] model_encoder_layers_4_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[69] model_encoder_layers_4_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[70] model_encoder_layers_4_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[71] model_encoder_layers_4_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[72] model_encoder_layers_4_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[73] model_encoder_layers_4_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[74] model_encoder_layers_4_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[75] model_encoder_layers_4_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[76] model_encoder_layers_4_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[77] model_encoder_layers_4_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[78] model_encoder_layers_4_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[79] model_encoder_layers_5_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[80] model_encoder_layers_5_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[81] model_encoder_layers_5_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[82] model_encoder_layers_5_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[83] model_encoder_layers_5_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[84] model_encoder_layers_5_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[85] model_encoder_layers_5_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[86] model_encoder_layers_5_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[87] model_encoder_layers_5_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[88] model_encoder_layers_5_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[89] model_encoder_layers_5_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[90] model_encoder_layers_5_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[91] model_encoder_layers_5_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[92] model_encoder_layers_5_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[93] model_encoder_layers_5_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[94] model_encoder_layers_6_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[95] model_encoder_layers_6_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[96] model_encoder_layers_6_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[97] model_encoder_layers_6_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[98] model_encoder_layers_6_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[99] model_encoder_layers_6_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[100] model_encoder_layers_6_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[101] model_encoder_layers_6_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[102] model_encoder_layers_6_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[103] model_encoder_layers_6_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[104] model_encoder_layers_6_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[105] model_encoder_layers_6_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[106] model_encoder_layers_6_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[107] model_encoder_layers_6_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[108] model_encoder_layers_6_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[109] model_encoder_layers_7_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[110] model_encoder_layers_7_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[111] model_encoder_layers_7_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[112] model_encoder_layers_7_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[113] model_encoder_layers_7_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[114] model_encoder_layers_7_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[115] model_encoder_layers_7_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[116] model_encoder_layers_7_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[117] model_encoder_layers_7_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[118] model_encoder_layers_7_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[119] model_encoder_layers_7_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[120] model_encoder_layers_7_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[121] model_encoder_layers_7_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[122] model_encoder_layers_7_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[123] model_encoder_layers_7_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[124] model_encoder_layers_8_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[125] model_encoder_layers_8_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[126] model_encoder_layers_8_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[127] model_encoder_layers_8_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[128] model_encoder_layers_8_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[129] model_encoder_layers_8_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[130] model_encoder_layers_8_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[131] model_encoder_layers_8_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[132] model_encoder_layers_8_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[133] model_encoder_layers_8_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[134] model_encoder_layers_8_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[135] model_encoder_layers_8_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[136] model_encoder_layers_8_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[137] model_encoder_layers_8_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[138] model_encoder_layers_8_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[139] model_encoder_layers_9_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[140] model_encoder_layers_9_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[141] model_encoder_layers_9_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[142] model_encoder_layers_9_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[143] model_encoder_layers_9_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[144] model_encoder_layers_9_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[145] model_encoder_layers_9_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[146] model_encoder_layers_9_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[147] model_encoder_layers_9_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[148] model_encoder_layers_9_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[149] model_encoder_layers_9_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[150] model_encoder_layers_9_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[151] model_encoder_layers_9_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[152] model_encoder_layers_9_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[153] model_encoder_layers_9_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[154] model_encoder_layers_10_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[155] model_encoder_layers_10_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[156] model_encoder_layers_10_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[157] model_encoder_layers_10_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[158] model_encoder_layers_10_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[159] model_encoder_layers_10_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[160] model_encoder_layers_10_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[161] model_encoder_layers_10_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[162] model_encoder_layers_10_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[163] model_encoder_layers_10_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[164] model_encoder_layers_10_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[165] model_encoder_layers_10_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[166] model_encoder_layers_10_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[167] model_encoder_layers_10_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[168] model_encoder_layers_10_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[169] model_encoder_layers_11_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[170] model_encoder_layers_11_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[171] model_encoder_layers_11_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[172] model_encoder_layers_11_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[173] model_encoder_layers_11_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[174] model_encoder_layers_11_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[175] model_encoder_layers_11_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[176] model_encoder_layers_11_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[177] model_encoder_layers_11_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[178] model_encoder_layers_11_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[179] model_encoder_layers_11_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[180] model_encoder_layers_11_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[181] model_encoder_layers_11_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[182] model_encoder_layers_11_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[183] model_encoder_layers_11_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[184] model_encoder_layers_12_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[185] model_encoder_layers_12_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[186] model_encoder_layers_12_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[187] model_encoder_layers_12_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[188] model_encoder_layers_12_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[189] model_encoder_layers_12_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[190] model_encoder_layers_12_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[191] model_encoder_layers_12_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[192] model_encoder_layers_12_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[193] model_encoder_layers_12_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[194] model_encoder_layers_12_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[195] model_encoder_layers_12_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[196] model_encoder_layers_12_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[197] model_encoder_layers_12_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[198] model_encoder_layers_12_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[199] model_encoder_layers_13_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[200] model_encoder_layers_13_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[201] model_encoder_layers_13_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[202] model_encoder_layers_13_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[203] model_encoder_layers_13_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[204] model_encoder_layers_13_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[205] model_encoder_layers_13_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[206] model_encoder_layers_13_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[207] model_encoder_layers_13_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[208] model_encoder_layers_13_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[209] model_encoder_layers_13_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[210] model_encoder_layers_13_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[211] model_encoder_layers_13_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[212] model_encoder_layers_13_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[213] model_encoder_layers_13_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[214] model_encoder_layers_14_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[215] model_encoder_layers_14_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[216] model_encoder_layers_14_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[217] model_encoder_layers_14_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[218] model_encoder_layers_14_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[219] model_encoder_layers_14_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[220] model_encoder_layers_14_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[221] model_encoder_layers_14_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[222] model_encoder_layers_14_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[223] model_encoder_layers_14_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[224] model_encoder_layers_14_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[225] model_encoder_layers_14_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[226] model_encoder_layers_14_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[227] model_encoder_layers_14_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[228] model_encoder_layers_14_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[229] model_encoder_layers_15_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[230] model_encoder_layers_15_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[231] model_encoder_layers_15_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[232] model_encoder_layers_15_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[233] model_encoder_layers_15_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[234] model_encoder_layers_15_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[235] model_encoder_layers_15_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[236] model_encoder_layers_15_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[237] model_encoder_layers_15_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[238] model_encoder_layers_15_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[239] model_encoder_layers_15_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[240] model_encoder_layers_15_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[241] model_encoder_layers_15_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[242] model_encoder_layers_15_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[243] model_encoder_layers_15_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[244] model_encoder_layers_16_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[245] model_encoder_layers_16_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[246] model_encoder_layers_16_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[247] model_encoder_layers_16_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[248] model_encoder_layers_16_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[249] model_encoder_layers_16_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[250] model_encoder_layers_16_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[251] model_encoder_layers_16_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[252] model_encoder_layers_16_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[253] model_encoder_layers_16_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[254] model_encoder_layers_16_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[255] model_encoder_layers_16_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[256] model_encoder_layers_16_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[257] model_encoder_layers_16_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[258] model_encoder_layers_16_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[259] model_encoder_layers_17_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[260] model_encoder_layers_17_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[261] model_encoder_layers_17_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[262] model_encoder_layers_17_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[263] model_encoder_layers_17_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[264] model_encoder_layers_17_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[265] model_encoder_layers_17_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[266] model_encoder_layers_17_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[267] model_encoder_layers_17_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[268] model_encoder_layers_17_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[269] model_encoder_layers_17_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[270] model_encoder_layers_17_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[271] model_encoder_layers_17_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[272] model_encoder_layers_17_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[273] model_encoder_layers_17_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[274] model_encoder_layers_18_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[275] model_encoder_layers_18_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[276] model_encoder_layers_18_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[277] model_encoder_layers_18_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[278] model_encoder_layers_18_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[279] model_encoder_layers_18_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[280] model_encoder_layers_18_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[281] model_encoder_layers_18_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[282] model_encoder_layers_18_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[283] model_encoder_layers_18_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[284] model_encoder_layers_18_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[285] model_encoder_layers_18_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[286] model_encoder_layers_18_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[287] model_encoder_layers_18_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[288] model_encoder_layers_18_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[289] model_encoder_layers_19_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[290] model_encoder_layers_19_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[291] model_encoder_layers_19_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[292] model_encoder_layers_19_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[293] model_encoder_layers_19_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[294] model_encoder_layers_19_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[295] model_encoder_layers_19_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[296] model_encoder_layers_19_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[297] model_encoder_layers_19_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[298] model_encoder_layers_19_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[299] model_encoder_layers_19_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[300] model_encoder_layers_19_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[301] model_encoder_layers_19_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[302] model_encoder_layers_19_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[303] model_encoder_layers_19_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[304] model_encoder_layers_20_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[305] model_encoder_layers_20_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[306] model_encoder_layers_20_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[307] model_encoder_layers_20_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[308] model_encoder_layers_20_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[309] model_encoder_layers_20_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[310] model_encoder_layers_20_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[311] model_encoder_layers_20_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[312] model_encoder_layers_20_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[313] model_encoder_layers_20_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[314] model_encoder_layers_20_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[315] model_encoder_layers_20_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[316] model_encoder_layers_20_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[317] model_encoder_layers_20_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[318] model_encoder_layers_20_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[319] model_encoder_layers_21_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[320] model_encoder_layers_21_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[321] model_encoder_layers_21_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[322] model_encoder_layers_21_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[323] model_encoder_layers_21_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[324] model_encoder_layers_21_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[325] model_encoder_layers_21_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[326] model_encoder_layers_21_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[327] model_encoder_layers_21_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[328] model_encoder_layers_21_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[329] model_encoder_layers_21_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[330] model_encoder_layers_21_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[331] model_encoder_layers_21_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[332] model_encoder_layers_21_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[333] model_encoder_layers_21_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[334] model_encoder_layers_22_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[335] model_encoder_layers_22_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[336] model_encoder_layers_22_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[337] model_encoder_layers_22_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[338] model_encoder_layers_22_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[339] model_encoder_layers_22_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[340] model_encoder_layers_22_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[341] model_encoder_layers_22_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[342] model_encoder_layers_22_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[343] model_encoder_layers_22_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[344] model_encoder_layers_22_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[345] model_encoder_layers_22_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[346] model_encoder_layers_22_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[347] model_encoder_layers_22_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[348] model_encoder_layers_22_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[349] model_encoder_layers_23_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[350] model_encoder_layers_23_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[351] model_encoder_layers_23_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[352] model_encoder_layers_23_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[353] model_encoder_layers_23_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[354] model_encoder_layers_23_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[355] model_encoder_layers_23_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[356] model_encoder_layers_23_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[357] model_encoder_layers_23_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[358] model_encoder_layers_23_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[359] model_encoder_layers_23_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[360] model_encoder_layers_23_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[361] model_encoder_layers_23_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[362] model_encoder_layers_23_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[363] model_encoder_layers_23_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[364] model_encoder_layers_24_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[365] model_encoder_layers_24_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[366] model_encoder_layers_24_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[367] model_encoder_layers_24_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[368] model_encoder_layers_24_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[369] model_encoder_layers_24_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[370] model_encoder_layers_24_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[371] model_encoder_layers_24_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[372] model_encoder_layers_24_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[373] model_encoder_layers_24_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[374] model_encoder_layers_24_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[375] model_encoder_layers_24_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[376] model_encoder_layers_24_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[377] model_encoder_layers_24_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[378] model_encoder_layers_24_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[379] model_encoder_layers_25_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[380] model_encoder_layers_25_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[381] model_encoder_layers_25_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[382] model_encoder_layers_25_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[383] model_encoder_layers_25_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[384] model_encoder_layers_25_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[385] model_encoder_layers_25_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[386] model_encoder_layers_25_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[387] model_encoder_layers_25_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[388] model_encoder_layers_25_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[389] model_encoder_layers_25_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[390] model_encoder_layers_25_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[391] model_encoder_layers_25_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[392] model_encoder_layers_25_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[393] model_encoder_layers_25_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[394] model_encoder_layers_26_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[395] model_encoder_layers_26_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[396] model_encoder_layers_26_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[397] model_encoder_layers_26_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[398] model_encoder_layers_26_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[399] model_encoder_layers_26_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[400] model_encoder_layers_26_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[401] model_encoder_layers_26_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[402] model_encoder_layers_26_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[403] model_encoder_layers_26_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[404] model_encoder_layers_26_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[405] model_encoder_layers_26_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[406] model_encoder_layers_26_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[407] model_encoder_layers_26_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[408] model_encoder_layers_26_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[409] model_encoder_layers_27_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[410] model_encoder_layers_27_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[411] model_encoder_layers_27_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[412] model_encoder_layers_27_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[413] model_encoder_layers_27_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[414] model_encoder_layers_27_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[415] model_encoder_layers_27_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[416] model_encoder_layers_27_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[417] model_encoder_layers_27_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[418] model_encoder_layers_27_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[419] model_encoder_layers_27_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[420] model_encoder_layers_27_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[421] model_encoder_layers_27_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[422] model_encoder_layers_27_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[423] model_encoder_layers_27_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[424] model_encoder_layers_28_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[425] model_encoder_layers_28_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[426] model_encoder_layers_28_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[427] model_encoder_layers_28_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[428] model_encoder_layers_28_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[429] model_encoder_layers_28_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[430] model_encoder_layers_28_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[431] model_encoder_layers_28_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[432] model_encoder_layers_28_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[433] model_encoder_layers_28_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[434] model_encoder_layers_28_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[435] model_encoder_layers_28_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[436] model_encoder_layers_28_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[437] model_encoder_layers_28_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[438] model_encoder_layers_28_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[439] model_encoder_layers_29_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[440] model_encoder_layers_29_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[441] model_encoder_layers_29_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[442] model_encoder_layers_29_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[443] model_encoder_layers_29_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[444] model_encoder_layers_29_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[445] model_encoder_layers_29_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[446] model_encoder_layers_29_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[447] model_encoder_layers_29_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[448] model_encoder_layers_29_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[449] model_encoder_layers_29_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[450] model_encoder_layers_29_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[451] model_encoder_layers_29_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[452] model_encoder_layers_29_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[453] model_encoder_layers_29_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[454] model_encoder_layers_30_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[455] model_encoder_layers_30_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[456] model_encoder_layers_30_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[457] model_encoder_layers_30_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[458] model_encoder_layers_30_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[459] model_encoder_layers_30_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[460] model_encoder_layers_30_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[461] model_encoder_layers_30_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[462] model_encoder_layers_30_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[463] model_encoder_layers_30_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[464] model_encoder_layers_30_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[465] model_encoder_layers_30_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[466] model_encoder_layers_30_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[467] model_encoder_layers_30_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[468] model_encoder_layers_30_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[469] model_encoder_layers_31_self_attn_k_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[470] model_encoder_layers_31_self_attn_v_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[471] model_encoder_layers_31_self_attn_v_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[472] model_encoder_layers_31_self_attn_q_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[473] model_encoder_layers_31_self_attn_q_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[474] model_encoder_layers_31_self_attn_out_proj_weight: R.Tensor((1280, 1280), dtype="float16") = packed_params[475] model_encoder_layers_31_self_attn_out_proj_bias: R.Tensor((1280,), dtype="float16") = packed_params[476] model_encoder_layers_31_self_attn_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[477] model_encoder_layers_31_self_attn_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[478] model_encoder_layers_31_fc1_weight: R.Tensor((5120, 1280), dtype="float16") = packed_params[479] model_encoder_layers_31_fc1_bias: R.Tensor((5120,), dtype="float16") = packed_params[480] model_encoder_layers_31_fc2_weight: R.Tensor((1280, 5120), dtype="float16") = packed_params[481] model_encoder_layers_31_fc2_bias: R.Tensor((1280,), dtype="float16") = packed_params[482] model_encoder_layers_31_final_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[483] model_encoder_layers_31_final_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[484] model_encoder_layer_norm_weight: R.Tensor((1280,), dtype="float16") = packed_params[485] model_encoder_layer_norm_bias: R.Tensor((1280,), dtype="float16") = packed_params[486] layer_norm = R.call_tir(cls.layer_norm1, (lv7, model_encoder_layers_0_self_attn_layer_norm_weight, model_encoder_layers_0_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv608 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_0_self_attn_q_proj_weight, layer_norm, model_encoder_layers_0_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape = R.call_tir(cls.reshape, (lv608,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv131 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_0_self_attn_k_proj_weight, layer_norm), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape1 = R.call_tir(cls.reshape, (lv131,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv609 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_0_self_attn_v_proj_weight, layer_norm, model_encoder_layers_0_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape2 = R.call_tir(cls.reshape, (lv609,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape3 = R.call_tir(cls.reshape1, (reshape,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape4 = R.call_tir(cls.reshape1, (reshape1,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape5 = R.call_tir(cls.reshape1, (reshape2,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv4_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape3, reshape4, reshape5), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape6 = R.call_tir(cls.reshape10, (lv4_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape7 = R.call_tir(cls.reshape11, (reshape6,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv610 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_0_self_attn_out_proj_weight, reshape7, model_encoder_layers_0_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add4 = R.call_tir(cls.add4, (lv7, lv610), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm1 = R.call_tir(cls.layer_norm1, (add4, model_encoder_layers_0_final_layer_norm_weight, model_encoder_layers_0_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv96 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_0_fc1_weight, layer_norm1, model_encoder_layers_0_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv611 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_0_fc2_weight, lv96, model_encoder_layers_0_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv8 = R.call_tir(cls.fused_add4_maximum_minimum, (add4, lv611), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm2 = R.call_tir(cls.layer_norm1, (lv8, model_encoder_layers_1_self_attn_layer_norm_weight, model_encoder_layers_1_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv612 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_1_self_attn_q_proj_weight, layer_norm2, model_encoder_layers_1_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape8 = R.call_tir(cls.reshape, (lv612,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv132 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_1_self_attn_k_proj_weight, layer_norm2), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape9 = R.call_tir(cls.reshape, (lv132,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv613 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_1_self_attn_v_proj_weight, layer_norm2, model_encoder_layers_1_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape10 = R.call_tir(cls.reshape, (lv613,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape11 = R.call_tir(cls.reshape1, (reshape8,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape12 = R.call_tir(cls.reshape1, (reshape9,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape13 = R.call_tir(cls.reshape1, (reshape10,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv5_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape11, reshape12, reshape13), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape14 = R.call_tir(cls.reshape10, (lv5_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape15 = R.call_tir(cls.reshape11, (reshape14,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv614 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_1_self_attn_out_proj_weight, reshape15, model_encoder_layers_1_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add11 = R.call_tir(cls.add4, (lv8, lv614), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm3 = R.call_tir(cls.layer_norm1, (add11, model_encoder_layers_1_final_layer_norm_weight, model_encoder_layers_1_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv97 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_1_fc1_weight, layer_norm3, model_encoder_layers_1_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv615 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_1_fc2_weight, lv97, model_encoder_layers_1_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv9 = R.call_tir(cls.fused_add4_maximum_minimum, (add11, lv615), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm4 = R.call_tir(cls.layer_norm1, (lv9, model_encoder_layers_2_self_attn_layer_norm_weight, model_encoder_layers_2_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv616 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_2_self_attn_q_proj_weight, layer_norm4, model_encoder_layers_2_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape16 = R.call_tir(cls.reshape, (lv616,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv133 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_2_self_attn_k_proj_weight, layer_norm4), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape17 = R.call_tir(cls.reshape, (lv133,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv617 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_2_self_attn_v_proj_weight, layer_norm4, model_encoder_layers_2_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape18 = R.call_tir(cls.reshape, (lv617,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape19 = R.call_tir(cls.reshape1, (reshape16,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape20 = R.call_tir(cls.reshape1, (reshape17,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape21 = R.call_tir(cls.reshape1, (reshape18,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv6_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape19, reshape20, reshape21), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape22 = R.call_tir(cls.reshape10, (lv6_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape23 = R.call_tir(cls.reshape11, (reshape22,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv618 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_2_self_attn_out_proj_weight, reshape23, model_encoder_layers_2_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add18 = R.call_tir(cls.add4, (lv9, lv618), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm5 = R.call_tir(cls.layer_norm1, (add18, model_encoder_layers_2_final_layer_norm_weight, model_encoder_layers_2_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv98 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_2_fc1_weight, layer_norm5, model_encoder_layers_2_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv619 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_2_fc2_weight, lv98, model_encoder_layers_2_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv10 = R.call_tir(cls.fused_add4_maximum_minimum, (add18, lv619), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm6 = R.call_tir(cls.layer_norm1, (lv10, model_encoder_layers_3_self_attn_layer_norm_weight, model_encoder_layers_3_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv620 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_3_self_attn_q_proj_weight, layer_norm6, model_encoder_layers_3_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape24 = R.call_tir(cls.reshape, (lv620,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv134 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_3_self_attn_k_proj_weight, layer_norm6), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape25 = R.call_tir(cls.reshape, (lv134,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv621 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_3_self_attn_v_proj_weight, layer_norm6, model_encoder_layers_3_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape26 = R.call_tir(cls.reshape, (lv621,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape27 = R.call_tir(cls.reshape1, (reshape24,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape28 = R.call_tir(cls.reshape1, (reshape25,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape29 = R.call_tir(cls.reshape1, (reshape26,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv7_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape27, reshape28, reshape29), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape30 = R.call_tir(cls.reshape10, (lv7_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape31 = R.call_tir(cls.reshape11, (reshape30,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv622 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_3_self_attn_out_proj_weight, reshape31, model_encoder_layers_3_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add25 = R.call_tir(cls.add4, (lv10, lv622), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm7 = R.call_tir(cls.layer_norm1, (add25, model_encoder_layers_3_final_layer_norm_weight, model_encoder_layers_3_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv99 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_3_fc1_weight, layer_norm7, model_encoder_layers_3_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv623 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_3_fc2_weight, lv99, model_encoder_layers_3_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv11 = R.call_tir(cls.fused_add4_maximum_minimum, (add25, lv623), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm8 = R.call_tir(cls.layer_norm1, (lv11, model_encoder_layers_4_self_attn_layer_norm_weight, model_encoder_layers_4_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv624 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_4_self_attn_q_proj_weight, layer_norm8, model_encoder_layers_4_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape32 = R.call_tir(cls.reshape, (lv624,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv135 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_4_self_attn_k_proj_weight, layer_norm8), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape33 = R.call_tir(cls.reshape, (lv135,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv625 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_4_self_attn_v_proj_weight, layer_norm8, model_encoder_layers_4_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape34 = R.call_tir(cls.reshape, (lv625,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape35 = R.call_tir(cls.reshape1, (reshape32,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape36 = R.call_tir(cls.reshape1, (reshape33,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape37 = R.call_tir(cls.reshape1, (reshape34,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv8_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape35, reshape36, reshape37), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape38 = R.call_tir(cls.reshape10, (lv8_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape39 = R.call_tir(cls.reshape11, (reshape38,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv626 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_4_self_attn_out_proj_weight, reshape39, model_encoder_layers_4_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add32 = R.call_tir(cls.add4, (lv11, lv626), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm9 = R.call_tir(cls.layer_norm1, (add32, model_encoder_layers_4_final_layer_norm_weight, model_encoder_layers_4_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv100 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_4_fc1_weight, layer_norm9, model_encoder_layers_4_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv627 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_4_fc2_weight, lv100, model_encoder_layers_4_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv12 = R.call_tir(cls.fused_add4_maximum_minimum, (add32, lv627), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm10 = R.call_tir(cls.layer_norm1, (lv12, model_encoder_layers_5_self_attn_layer_norm_weight, model_encoder_layers_5_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv628 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_5_self_attn_q_proj_weight, layer_norm10, model_encoder_layers_5_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape40 = R.call_tir(cls.reshape, (lv628,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv136 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_5_self_attn_k_proj_weight, layer_norm10), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape41 = R.call_tir(cls.reshape, (lv136,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv629 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_5_self_attn_v_proj_weight, layer_norm10, model_encoder_layers_5_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape42 = R.call_tir(cls.reshape, (lv629,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape43 = R.call_tir(cls.reshape1, (reshape40,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape44 = R.call_tir(cls.reshape1, (reshape41,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape45 = R.call_tir(cls.reshape1, (reshape42,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv9_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape43, reshape44, reshape45), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape46 = R.call_tir(cls.reshape10, (lv9_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape47 = R.call_tir(cls.reshape11, (reshape46,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv630 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_5_self_attn_out_proj_weight, reshape47, model_encoder_layers_5_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add39 = R.call_tir(cls.add4, (lv12, lv630), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm11 = R.call_tir(cls.layer_norm1, (add39, model_encoder_layers_5_final_layer_norm_weight, model_encoder_layers_5_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv101 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_5_fc1_weight, layer_norm11, model_encoder_layers_5_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv631 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_5_fc2_weight, lv101, model_encoder_layers_5_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv13 = R.call_tir(cls.fused_add4_maximum_minimum, (add39, lv631), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm12 = R.call_tir(cls.layer_norm1, (lv13, model_encoder_layers_6_self_attn_layer_norm_weight, model_encoder_layers_6_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv632 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_6_self_attn_q_proj_weight, layer_norm12, model_encoder_layers_6_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape48 = R.call_tir(cls.reshape, (lv632,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv137 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_6_self_attn_k_proj_weight, layer_norm12), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape49 = R.call_tir(cls.reshape, (lv137,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv633 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_6_self_attn_v_proj_weight, layer_norm12, model_encoder_layers_6_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape50 = R.call_tir(cls.reshape, (lv633,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape51 = R.call_tir(cls.reshape1, (reshape48,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape52 = R.call_tir(cls.reshape1, (reshape49,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape53 = R.call_tir(cls.reshape1, (reshape50,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv10_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape51, reshape52, reshape53), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape54 = R.call_tir(cls.reshape10, (lv10_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape55 = R.call_tir(cls.reshape11, (reshape54,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv634 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_6_self_attn_out_proj_weight, reshape55, model_encoder_layers_6_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add46 = R.call_tir(cls.add4, (lv13, lv634), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm13 = R.call_tir(cls.layer_norm1, (add46, model_encoder_layers_6_final_layer_norm_weight, model_encoder_layers_6_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv102 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_6_fc1_weight, layer_norm13, model_encoder_layers_6_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv635 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_6_fc2_weight, lv102, model_encoder_layers_6_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv14 = R.call_tir(cls.fused_add4_maximum_minimum, (add46, lv635), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm14 = R.call_tir(cls.layer_norm1, (lv14, model_encoder_layers_7_self_attn_layer_norm_weight, model_encoder_layers_7_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv636 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_7_self_attn_q_proj_weight, layer_norm14, model_encoder_layers_7_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape56 = R.call_tir(cls.reshape, (lv636,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv138 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_7_self_attn_k_proj_weight, layer_norm14), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape57 = R.call_tir(cls.reshape, (lv138,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv637 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_7_self_attn_v_proj_weight, layer_norm14, model_encoder_layers_7_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape58 = R.call_tir(cls.reshape, (lv637,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape59 = R.call_tir(cls.reshape1, (reshape56,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape60 = R.call_tir(cls.reshape1, (reshape57,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape61 = R.call_tir(cls.reshape1, (reshape58,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv11_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape59, reshape60, reshape61), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape62 = R.call_tir(cls.reshape10, (lv11_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape63 = R.call_tir(cls.reshape11, (reshape62,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv638 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_7_self_attn_out_proj_weight, reshape63, model_encoder_layers_7_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add53 = R.call_tir(cls.add4, (lv14, lv638), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm15 = R.call_tir(cls.layer_norm1, (add53, model_encoder_layers_7_final_layer_norm_weight, model_encoder_layers_7_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv103 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_7_fc1_weight, layer_norm15, model_encoder_layers_7_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv639 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_7_fc2_weight, lv103, model_encoder_layers_7_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv15 = R.call_tir(cls.fused_add4_maximum_minimum, (add53, lv639), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm16 = R.call_tir(cls.layer_norm1, (lv15, model_encoder_layers_8_self_attn_layer_norm_weight, model_encoder_layers_8_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv640 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_8_self_attn_q_proj_weight, layer_norm16, model_encoder_layers_8_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape64 = R.call_tir(cls.reshape, (lv640,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv139 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_8_self_attn_k_proj_weight, layer_norm16), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape65 = R.call_tir(cls.reshape, (lv139,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv641 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_8_self_attn_v_proj_weight, layer_norm16, model_encoder_layers_8_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape66 = R.call_tir(cls.reshape, (lv641,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape67 = R.call_tir(cls.reshape1, (reshape64,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape68 = R.call_tir(cls.reshape1, (reshape65,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape69 = R.call_tir(cls.reshape1, (reshape66,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv12_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape67, reshape68, reshape69), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape70 = R.call_tir(cls.reshape10, (lv12_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape71 = R.call_tir(cls.reshape11, (reshape70,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv642 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_8_self_attn_out_proj_weight, reshape71, model_encoder_layers_8_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add60 = R.call_tir(cls.add4, (lv15, lv642), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm17 = R.call_tir(cls.layer_norm1, (add60, model_encoder_layers_8_final_layer_norm_weight, model_encoder_layers_8_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv104 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_8_fc1_weight, layer_norm17, model_encoder_layers_8_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv643 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_8_fc2_weight, lv104, model_encoder_layers_8_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv16 = R.call_tir(cls.fused_add4_maximum_minimum, (add60, lv643), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm18 = R.call_tir(cls.layer_norm1, (lv16, model_encoder_layers_9_self_attn_layer_norm_weight, model_encoder_layers_9_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv644 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_9_self_attn_q_proj_weight, layer_norm18, model_encoder_layers_9_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape72 = R.call_tir(cls.reshape, (lv644,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv140 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_9_self_attn_k_proj_weight, layer_norm18), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape73 = R.call_tir(cls.reshape, (lv140,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv645 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_9_self_attn_v_proj_weight, layer_norm18, model_encoder_layers_9_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape74 = R.call_tir(cls.reshape, (lv645,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape75 = R.call_tir(cls.reshape1, (reshape72,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape76 = R.call_tir(cls.reshape1, (reshape73,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape77 = R.call_tir(cls.reshape1, (reshape74,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv13_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape75, reshape76, reshape77), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape78 = R.call_tir(cls.reshape10, (lv13_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape79 = R.call_tir(cls.reshape11, (reshape78,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv646 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_9_self_attn_out_proj_weight, reshape79, model_encoder_layers_9_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add67 = R.call_tir(cls.add4, (lv16, lv646), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm19 = R.call_tir(cls.layer_norm1, (add67, model_encoder_layers_9_final_layer_norm_weight, model_encoder_layers_9_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv105 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_9_fc1_weight, layer_norm19, model_encoder_layers_9_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv647 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_9_fc2_weight, lv105, model_encoder_layers_9_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv17 = R.call_tir(cls.fused_add4_maximum_minimum, (add67, lv647), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm20 = R.call_tir(cls.layer_norm1, (lv17, model_encoder_layers_10_self_attn_layer_norm_weight, model_encoder_layers_10_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv648 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_10_self_attn_q_proj_weight, layer_norm20, model_encoder_layers_10_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape80 = R.call_tir(cls.reshape, (lv648,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv141 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_10_self_attn_k_proj_weight, layer_norm20), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape81 = R.call_tir(cls.reshape, (lv141,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv649 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_10_self_attn_v_proj_weight, layer_norm20, model_encoder_layers_10_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape82 = R.call_tir(cls.reshape, (lv649,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape83 = R.call_tir(cls.reshape1, (reshape80,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape84 = R.call_tir(cls.reshape1, (reshape81,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape85 = R.call_tir(cls.reshape1, (reshape82,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv14_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape83, reshape84, reshape85), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape86 = R.call_tir(cls.reshape10, (lv14_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape87 = R.call_tir(cls.reshape11, (reshape86,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv650 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_10_self_attn_out_proj_weight, reshape87, model_encoder_layers_10_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add74 = R.call_tir(cls.add4, (lv17, lv650), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm21 = R.call_tir(cls.layer_norm1, (add74, model_encoder_layers_10_final_layer_norm_weight, model_encoder_layers_10_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv106 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_10_fc1_weight, layer_norm21, model_encoder_layers_10_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv651 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_10_fc2_weight, lv106, model_encoder_layers_10_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv18 = R.call_tir(cls.fused_add4_maximum_minimum, (add74, lv651), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm22 = R.call_tir(cls.layer_norm1, (lv18, model_encoder_layers_11_self_attn_layer_norm_weight, model_encoder_layers_11_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv652 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_11_self_attn_q_proj_weight, layer_norm22, model_encoder_layers_11_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape88 = R.call_tir(cls.reshape, (lv652,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv142 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_11_self_attn_k_proj_weight, layer_norm22), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape89 = R.call_tir(cls.reshape, (lv142,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv653 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_11_self_attn_v_proj_weight, layer_norm22, model_encoder_layers_11_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape90 = R.call_tir(cls.reshape, (lv653,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape91 = R.call_tir(cls.reshape1, (reshape88,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape92 = R.call_tir(cls.reshape1, (reshape89,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape93 = R.call_tir(cls.reshape1, (reshape90,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv15_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape91, reshape92, reshape93), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape94 = R.call_tir(cls.reshape10, (lv15_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape95 = R.call_tir(cls.reshape11, (reshape94,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv654 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_11_self_attn_out_proj_weight, reshape95, model_encoder_layers_11_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add81 = R.call_tir(cls.add4, (lv18, lv654), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm23 = R.call_tir(cls.layer_norm1, (add81, model_encoder_layers_11_final_layer_norm_weight, model_encoder_layers_11_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv107 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_11_fc1_weight, layer_norm23, model_encoder_layers_11_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv655 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_11_fc2_weight, lv107, model_encoder_layers_11_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv19 = R.call_tir(cls.fused_add4_maximum_minimum, (add81, lv655), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm24 = R.call_tir(cls.layer_norm1, (lv19, model_encoder_layers_12_self_attn_layer_norm_weight, model_encoder_layers_12_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv656 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_12_self_attn_q_proj_weight, layer_norm24, model_encoder_layers_12_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape96 = R.call_tir(cls.reshape, (lv656,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv143 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_12_self_attn_k_proj_weight, layer_norm24), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape97 = R.call_tir(cls.reshape, (lv143,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv657 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_12_self_attn_v_proj_weight, layer_norm24, model_encoder_layers_12_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape98 = R.call_tir(cls.reshape, (lv657,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape99 = R.call_tir(cls.reshape1, (reshape96,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape100 = R.call_tir(cls.reshape1, (reshape97,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape101 = R.call_tir(cls.reshape1, (reshape98,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv16_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape99, reshape100, reshape101), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape102 = R.call_tir(cls.reshape10, (lv16_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape103 = R.call_tir(cls.reshape11, (reshape102,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv658 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_12_self_attn_out_proj_weight, reshape103, model_encoder_layers_12_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add88 = R.call_tir(cls.add4, (lv19, lv658), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm25 = R.call_tir(cls.layer_norm1, (add88, model_encoder_layers_12_final_layer_norm_weight, model_encoder_layers_12_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv108 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_12_fc1_weight, layer_norm25, model_encoder_layers_12_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv659 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_12_fc2_weight, lv108, model_encoder_layers_12_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv20 = R.call_tir(cls.fused_add4_maximum_minimum, (add88, lv659), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm26 = R.call_tir(cls.layer_norm1, (lv20, model_encoder_layers_13_self_attn_layer_norm_weight, model_encoder_layers_13_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv660 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_13_self_attn_q_proj_weight, layer_norm26, model_encoder_layers_13_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape104 = R.call_tir(cls.reshape, (lv660,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv144 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_13_self_attn_k_proj_weight, layer_norm26), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape105 = R.call_tir(cls.reshape, (lv144,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv661 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_13_self_attn_v_proj_weight, layer_norm26, model_encoder_layers_13_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape106 = R.call_tir(cls.reshape, (lv661,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape107 = R.call_tir(cls.reshape1, (reshape104,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape108 = R.call_tir(cls.reshape1, (reshape105,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape109 = R.call_tir(cls.reshape1, (reshape106,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv17_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape107, reshape108, reshape109), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape110 = R.call_tir(cls.reshape10, (lv17_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape111 = R.call_tir(cls.reshape11, (reshape110,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv662 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_13_self_attn_out_proj_weight, reshape111, model_encoder_layers_13_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add95 = R.call_tir(cls.add4, (lv20, lv662), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm27 = R.call_tir(cls.layer_norm1, (add95, model_encoder_layers_13_final_layer_norm_weight, model_encoder_layers_13_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv109 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_13_fc1_weight, layer_norm27, model_encoder_layers_13_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv663 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_13_fc2_weight, lv109, model_encoder_layers_13_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv21 = R.call_tir(cls.fused_add4_maximum_minimum, (add95, lv663), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm28 = R.call_tir(cls.layer_norm1, (lv21, model_encoder_layers_14_self_attn_layer_norm_weight, model_encoder_layers_14_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv664 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_14_self_attn_q_proj_weight, layer_norm28, model_encoder_layers_14_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape112 = R.call_tir(cls.reshape, (lv664,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv145 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_14_self_attn_k_proj_weight, layer_norm28), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape113 = R.call_tir(cls.reshape, (lv145,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv665 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_14_self_attn_v_proj_weight, layer_norm28, model_encoder_layers_14_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape114 = R.call_tir(cls.reshape, (lv665,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape115 = R.call_tir(cls.reshape1, (reshape112,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape116 = R.call_tir(cls.reshape1, (reshape113,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape117 = R.call_tir(cls.reshape1, (reshape114,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv18_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape115, reshape116, reshape117), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape118 = R.call_tir(cls.reshape10, (lv18_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape119 = R.call_tir(cls.reshape11, (reshape118,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv666 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_14_self_attn_out_proj_weight, reshape119, model_encoder_layers_14_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add102 = R.call_tir(cls.add4, (lv21, lv666), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm29 = R.call_tir(cls.layer_norm1, (add102, model_encoder_layers_14_final_layer_norm_weight, model_encoder_layers_14_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv110 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_14_fc1_weight, layer_norm29, model_encoder_layers_14_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv667 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_14_fc2_weight, lv110, model_encoder_layers_14_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv22 = R.call_tir(cls.fused_add4_maximum_minimum, (add102, lv667), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm30 = R.call_tir(cls.layer_norm1, (lv22, model_encoder_layers_15_self_attn_layer_norm_weight, model_encoder_layers_15_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv668 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_15_self_attn_q_proj_weight, layer_norm30, model_encoder_layers_15_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape120 = R.call_tir(cls.reshape, (lv668,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv146 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_15_self_attn_k_proj_weight, layer_norm30), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape121 = R.call_tir(cls.reshape, (lv146,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv669 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_15_self_attn_v_proj_weight, layer_norm30, model_encoder_layers_15_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape122 = R.call_tir(cls.reshape, (lv669,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape123 = R.call_tir(cls.reshape1, (reshape120,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape124 = R.call_tir(cls.reshape1, (reshape121,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape125 = R.call_tir(cls.reshape1, (reshape122,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv19_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape123, reshape124, reshape125), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape126 = R.call_tir(cls.reshape10, (lv19_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape127 = R.call_tir(cls.reshape11, (reshape126,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv670 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_15_self_attn_out_proj_weight, reshape127, model_encoder_layers_15_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add109 = R.call_tir(cls.add4, (lv22, lv670), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm31 = R.call_tir(cls.layer_norm1, (add109, model_encoder_layers_15_final_layer_norm_weight, model_encoder_layers_15_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv111 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_15_fc1_weight, layer_norm31, model_encoder_layers_15_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv671 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_15_fc2_weight, lv111, model_encoder_layers_15_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv23 = R.call_tir(cls.fused_add4_maximum_minimum, (add109, lv671), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm32 = R.call_tir(cls.layer_norm1, (lv23, model_encoder_layers_16_self_attn_layer_norm_weight, model_encoder_layers_16_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv672 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_16_self_attn_q_proj_weight, layer_norm32, model_encoder_layers_16_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape128 = R.call_tir(cls.reshape, (lv672,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv147 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_16_self_attn_k_proj_weight, layer_norm32), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape129 = R.call_tir(cls.reshape, (lv147,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv673 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_16_self_attn_v_proj_weight, layer_norm32, model_encoder_layers_16_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape130 = R.call_tir(cls.reshape, (lv673,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape131 = R.call_tir(cls.reshape1, (reshape128,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape132 = R.call_tir(cls.reshape1, (reshape129,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape133 = R.call_tir(cls.reshape1, (reshape130,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv20_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape131, reshape132, reshape133), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape134 = R.call_tir(cls.reshape10, (lv20_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape135 = R.call_tir(cls.reshape11, (reshape134,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv674 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_16_self_attn_out_proj_weight, reshape135, model_encoder_layers_16_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add116 = R.call_tir(cls.add4, (lv23, lv674), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm33 = R.call_tir(cls.layer_norm1, (add116, model_encoder_layers_16_final_layer_norm_weight, model_encoder_layers_16_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv112 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_16_fc1_weight, layer_norm33, model_encoder_layers_16_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv675 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_16_fc2_weight, lv112, model_encoder_layers_16_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv24 = R.call_tir(cls.fused_add4_maximum_minimum, (add116, lv675), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm34 = R.call_tir(cls.layer_norm1, (lv24, model_encoder_layers_17_self_attn_layer_norm_weight, model_encoder_layers_17_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv676 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_17_self_attn_q_proj_weight, layer_norm34, model_encoder_layers_17_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape136 = R.call_tir(cls.reshape, (lv676,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv148 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_17_self_attn_k_proj_weight, layer_norm34), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape137 = R.call_tir(cls.reshape, (lv148,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv677 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_17_self_attn_v_proj_weight, layer_norm34, model_encoder_layers_17_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape138 = R.call_tir(cls.reshape, (lv677,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape139 = R.call_tir(cls.reshape1, (reshape136,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape140 = R.call_tir(cls.reshape1, (reshape137,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape141 = R.call_tir(cls.reshape1, (reshape138,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv21_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape139, reshape140, reshape141), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape142 = R.call_tir(cls.reshape10, (lv21_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape143 = R.call_tir(cls.reshape11, (reshape142,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv678 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_17_self_attn_out_proj_weight, reshape143, model_encoder_layers_17_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add123 = R.call_tir(cls.add4, (lv24, lv678), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm35 = R.call_tir(cls.layer_norm1, (add123, model_encoder_layers_17_final_layer_norm_weight, model_encoder_layers_17_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv113 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_17_fc1_weight, layer_norm35, model_encoder_layers_17_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv679 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_17_fc2_weight, lv113, model_encoder_layers_17_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv25 = R.call_tir(cls.fused_add4_maximum_minimum, (add123, lv679), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm36 = R.call_tir(cls.layer_norm1, (lv25, model_encoder_layers_18_self_attn_layer_norm_weight, model_encoder_layers_18_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv680 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_18_self_attn_q_proj_weight, layer_norm36, model_encoder_layers_18_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape144 = R.call_tir(cls.reshape, (lv680,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv149 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_18_self_attn_k_proj_weight, layer_norm36), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape145 = R.call_tir(cls.reshape, (lv149,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv681 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_18_self_attn_v_proj_weight, layer_norm36, model_encoder_layers_18_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape146 = R.call_tir(cls.reshape, (lv681,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape147 = R.call_tir(cls.reshape1, (reshape144,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape148 = R.call_tir(cls.reshape1, (reshape145,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape149 = R.call_tir(cls.reshape1, (reshape146,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv22_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape147, reshape148, reshape149), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape150 = R.call_tir(cls.reshape10, (lv22_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape151 = R.call_tir(cls.reshape11, (reshape150,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv682 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_18_self_attn_out_proj_weight, reshape151, model_encoder_layers_18_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add130 = R.call_tir(cls.add4, (lv25, lv682), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm37 = R.call_tir(cls.layer_norm1, (add130, model_encoder_layers_18_final_layer_norm_weight, model_encoder_layers_18_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv114 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_18_fc1_weight, layer_norm37, model_encoder_layers_18_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv683 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_18_fc2_weight, lv114, model_encoder_layers_18_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv26 = R.call_tir(cls.fused_add4_maximum_minimum, (add130, lv683), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm38 = R.call_tir(cls.layer_norm1, (lv26, model_encoder_layers_19_self_attn_layer_norm_weight, model_encoder_layers_19_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv684 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_19_self_attn_q_proj_weight, layer_norm38, model_encoder_layers_19_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape152 = R.call_tir(cls.reshape, (lv684,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv150 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_19_self_attn_k_proj_weight, layer_norm38), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape153 = R.call_tir(cls.reshape, (lv150,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv685 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_19_self_attn_v_proj_weight, layer_norm38, model_encoder_layers_19_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape154 = R.call_tir(cls.reshape, (lv685,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape155 = R.call_tir(cls.reshape1, (reshape152,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape156 = R.call_tir(cls.reshape1, (reshape153,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape157 = R.call_tir(cls.reshape1, (reshape154,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv23_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape155, reshape156, reshape157), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape158 = R.call_tir(cls.reshape10, (lv23_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape159 = R.call_tir(cls.reshape11, (reshape158,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv686 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_19_self_attn_out_proj_weight, reshape159, model_encoder_layers_19_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add137 = R.call_tir(cls.add4, (lv26, lv686), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm39 = R.call_tir(cls.layer_norm1, (add137, model_encoder_layers_19_final_layer_norm_weight, model_encoder_layers_19_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv115 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_19_fc1_weight, layer_norm39, model_encoder_layers_19_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv687 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_19_fc2_weight, lv115, model_encoder_layers_19_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv27 = R.call_tir(cls.fused_add4_maximum_minimum, (add137, lv687), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm40 = R.call_tir(cls.layer_norm1, (lv27, model_encoder_layers_20_self_attn_layer_norm_weight, model_encoder_layers_20_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv688 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_20_self_attn_q_proj_weight, layer_norm40, model_encoder_layers_20_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape160 = R.call_tir(cls.reshape, (lv688,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv151 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_20_self_attn_k_proj_weight, layer_norm40), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape161 = R.call_tir(cls.reshape, (lv151,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv689 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_20_self_attn_v_proj_weight, layer_norm40, model_encoder_layers_20_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape162 = R.call_tir(cls.reshape, (lv689,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape163 = R.call_tir(cls.reshape1, (reshape160,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape164 = R.call_tir(cls.reshape1, (reshape161,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape165 = R.call_tir(cls.reshape1, (reshape162,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv24_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape163, reshape164, reshape165), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape166 = R.call_tir(cls.reshape10, (lv24_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape167 = R.call_tir(cls.reshape11, (reshape166,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv690 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_20_self_attn_out_proj_weight, reshape167, model_encoder_layers_20_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add144 = R.call_tir(cls.add4, (lv27, lv690), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm41 = R.call_tir(cls.layer_norm1, (add144, model_encoder_layers_20_final_layer_norm_weight, model_encoder_layers_20_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv116 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_20_fc1_weight, layer_norm41, model_encoder_layers_20_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv691 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_20_fc2_weight, lv116, model_encoder_layers_20_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv28 = R.call_tir(cls.fused_add4_maximum_minimum, (add144, lv691), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm42 = R.call_tir(cls.layer_norm1, (lv28, model_encoder_layers_21_self_attn_layer_norm_weight, model_encoder_layers_21_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv692 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_21_self_attn_q_proj_weight, layer_norm42, model_encoder_layers_21_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape168 = R.call_tir(cls.reshape, (lv692,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv152 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_21_self_attn_k_proj_weight, layer_norm42), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape169 = R.call_tir(cls.reshape, (lv152,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv693 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_21_self_attn_v_proj_weight, layer_norm42, model_encoder_layers_21_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape170 = R.call_tir(cls.reshape, (lv693,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape171 = R.call_tir(cls.reshape1, (reshape168,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape172 = R.call_tir(cls.reshape1, (reshape169,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape173 = R.call_tir(cls.reshape1, (reshape170,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv25_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape171, reshape172, reshape173), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape174 = R.call_tir(cls.reshape10, (lv25_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape175 = R.call_tir(cls.reshape11, (reshape174,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv694 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_21_self_attn_out_proj_weight, reshape175, model_encoder_layers_21_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add151 = R.call_tir(cls.add4, (lv28, lv694), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm43 = R.call_tir(cls.layer_norm1, (add151, model_encoder_layers_21_final_layer_norm_weight, model_encoder_layers_21_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv117 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_21_fc1_weight, layer_norm43, model_encoder_layers_21_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv695 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_21_fc2_weight, lv117, model_encoder_layers_21_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv29 = R.call_tir(cls.fused_add4_maximum_minimum, (add151, lv695), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm44 = R.call_tir(cls.layer_norm1, (lv29, model_encoder_layers_22_self_attn_layer_norm_weight, model_encoder_layers_22_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv696 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_22_self_attn_q_proj_weight, layer_norm44, model_encoder_layers_22_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape176 = R.call_tir(cls.reshape, (lv696,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv153 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_22_self_attn_k_proj_weight, layer_norm44), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape177 = R.call_tir(cls.reshape, (lv153,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv697 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_22_self_attn_v_proj_weight, layer_norm44, model_encoder_layers_22_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape178 = R.call_tir(cls.reshape, (lv697,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape179 = R.call_tir(cls.reshape1, (reshape176,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape180 = R.call_tir(cls.reshape1, (reshape177,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape181 = R.call_tir(cls.reshape1, (reshape178,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv26_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape179, reshape180, reshape181), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape182 = R.call_tir(cls.reshape10, (lv26_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape183 = R.call_tir(cls.reshape11, (reshape182,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv698 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_22_self_attn_out_proj_weight, reshape183, model_encoder_layers_22_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add158 = R.call_tir(cls.add4, (lv29, lv698), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm45 = R.call_tir(cls.layer_norm1, (add158, model_encoder_layers_22_final_layer_norm_weight, model_encoder_layers_22_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv118 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_22_fc1_weight, layer_norm45, model_encoder_layers_22_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv699 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_22_fc2_weight, lv118, model_encoder_layers_22_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv30 = R.call_tir(cls.fused_add4_maximum_minimum, (add158, lv699), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm46 = R.call_tir(cls.layer_norm1, (lv30, model_encoder_layers_23_self_attn_layer_norm_weight, model_encoder_layers_23_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv700 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_23_self_attn_q_proj_weight, layer_norm46, model_encoder_layers_23_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape184 = R.call_tir(cls.reshape, (lv700,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv154 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_23_self_attn_k_proj_weight, layer_norm46), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape185 = R.call_tir(cls.reshape, (lv154,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv701 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_23_self_attn_v_proj_weight, layer_norm46, model_encoder_layers_23_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape186 = R.call_tir(cls.reshape, (lv701,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape187 = R.call_tir(cls.reshape1, (reshape184,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape188 = R.call_tir(cls.reshape1, (reshape185,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape189 = R.call_tir(cls.reshape1, (reshape186,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv27_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape187, reshape188, reshape189), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape190 = R.call_tir(cls.reshape10, (lv27_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape191 = R.call_tir(cls.reshape11, (reshape190,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv702 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_23_self_attn_out_proj_weight, reshape191, model_encoder_layers_23_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add165 = R.call_tir(cls.add4, (lv30, lv702), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm47 = R.call_tir(cls.layer_norm1, (add165, model_encoder_layers_23_final_layer_norm_weight, model_encoder_layers_23_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv119 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_23_fc1_weight, layer_norm47, model_encoder_layers_23_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv703 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_23_fc2_weight, lv119, model_encoder_layers_23_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv31 = R.call_tir(cls.fused_add4_maximum_minimum, (add165, lv703), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm48 = R.call_tir(cls.layer_norm1, (lv31, model_encoder_layers_24_self_attn_layer_norm_weight, model_encoder_layers_24_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv704 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_24_self_attn_q_proj_weight, layer_norm48, model_encoder_layers_24_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape192 = R.call_tir(cls.reshape, (lv704,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv155 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_24_self_attn_k_proj_weight, layer_norm48), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape193 = R.call_tir(cls.reshape, (lv155,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv705 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_24_self_attn_v_proj_weight, layer_norm48, model_encoder_layers_24_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape194 = R.call_tir(cls.reshape, (lv705,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape195 = R.call_tir(cls.reshape1, (reshape192,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape196 = R.call_tir(cls.reshape1, (reshape193,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape197 = R.call_tir(cls.reshape1, (reshape194,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv28_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape195, reshape196, reshape197), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape198 = R.call_tir(cls.reshape10, (lv28_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape199 = R.call_tir(cls.reshape11, (reshape198,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv706 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_24_self_attn_out_proj_weight, reshape199, model_encoder_layers_24_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add172 = R.call_tir(cls.add4, (lv31, lv706), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm49 = R.call_tir(cls.layer_norm1, (add172, model_encoder_layers_24_final_layer_norm_weight, model_encoder_layers_24_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv120 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_24_fc1_weight, layer_norm49, model_encoder_layers_24_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv707 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_24_fc2_weight, lv120, model_encoder_layers_24_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv32 = R.call_tir(cls.fused_add4_maximum_minimum, (add172, lv707), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm50 = R.call_tir(cls.layer_norm1, (lv32, model_encoder_layers_25_self_attn_layer_norm_weight, model_encoder_layers_25_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv708 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_25_self_attn_q_proj_weight, layer_norm50, model_encoder_layers_25_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape200 = R.call_tir(cls.reshape, (lv708,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv156 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_25_self_attn_k_proj_weight, layer_norm50), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape201 = R.call_tir(cls.reshape, (lv156,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv709 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_25_self_attn_v_proj_weight, layer_norm50, model_encoder_layers_25_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape202 = R.call_tir(cls.reshape, (lv709,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape203 = R.call_tir(cls.reshape1, (reshape200,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape204 = R.call_tir(cls.reshape1, (reshape201,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape205 = R.call_tir(cls.reshape1, (reshape202,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv29_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape203, reshape204, reshape205), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape206 = R.call_tir(cls.reshape10, (lv29_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape207 = R.call_tir(cls.reshape11, (reshape206,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv710 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_25_self_attn_out_proj_weight, reshape207, model_encoder_layers_25_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add179 = R.call_tir(cls.add4, (lv32, lv710), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm51 = R.call_tir(cls.layer_norm1, (add179, model_encoder_layers_25_final_layer_norm_weight, model_encoder_layers_25_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv121 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_25_fc1_weight, layer_norm51, model_encoder_layers_25_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv711 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_25_fc2_weight, lv121, model_encoder_layers_25_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv33 = R.call_tir(cls.fused_add4_maximum_minimum, (add179, lv711), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm52 = R.call_tir(cls.layer_norm1, (lv33, model_encoder_layers_26_self_attn_layer_norm_weight, model_encoder_layers_26_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv712 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_26_self_attn_q_proj_weight, layer_norm52, model_encoder_layers_26_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape208 = R.call_tir(cls.reshape, (lv712,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv157 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_26_self_attn_k_proj_weight, layer_norm52), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape209 = R.call_tir(cls.reshape, (lv157,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv713 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_26_self_attn_v_proj_weight, layer_norm52, model_encoder_layers_26_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape210 = R.call_tir(cls.reshape, (lv713,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape211 = R.call_tir(cls.reshape1, (reshape208,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape212 = R.call_tir(cls.reshape1, (reshape209,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape213 = R.call_tir(cls.reshape1, (reshape210,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv30_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape211, reshape212, reshape213), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape214 = R.call_tir(cls.reshape10, (lv30_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape215 = R.call_tir(cls.reshape11, (reshape214,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv714 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_26_self_attn_out_proj_weight, reshape215, model_encoder_layers_26_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add186 = R.call_tir(cls.add4, (lv33, lv714), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm53 = R.call_tir(cls.layer_norm1, (add186, model_encoder_layers_26_final_layer_norm_weight, model_encoder_layers_26_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv122 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_26_fc1_weight, layer_norm53, model_encoder_layers_26_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv715 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_26_fc2_weight, lv122, model_encoder_layers_26_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv34 = R.call_tir(cls.fused_add4_maximum_minimum, (add186, lv715), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm54 = R.call_tir(cls.layer_norm1, (lv34, model_encoder_layers_27_self_attn_layer_norm_weight, model_encoder_layers_27_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv716 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_27_self_attn_q_proj_weight, layer_norm54, model_encoder_layers_27_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape216 = R.call_tir(cls.reshape, (lv716,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv158 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_27_self_attn_k_proj_weight, layer_norm54), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape217 = R.call_tir(cls.reshape, (lv158,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv717 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_27_self_attn_v_proj_weight, layer_norm54, model_encoder_layers_27_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape218 = R.call_tir(cls.reshape, (lv717,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape219 = R.call_tir(cls.reshape1, (reshape216,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape220 = R.call_tir(cls.reshape1, (reshape217,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape221 = R.call_tir(cls.reshape1, (reshape218,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv31_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape219, reshape220, reshape221), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape222 = R.call_tir(cls.reshape10, (lv31_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape223 = R.call_tir(cls.reshape11, (reshape222,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv718 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_27_self_attn_out_proj_weight, reshape223, model_encoder_layers_27_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add193 = R.call_tir(cls.add4, (lv34, lv718), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm55 = R.call_tir(cls.layer_norm1, (add193, model_encoder_layers_27_final_layer_norm_weight, model_encoder_layers_27_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv123 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_27_fc1_weight, layer_norm55, model_encoder_layers_27_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv719 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_27_fc2_weight, lv123, model_encoder_layers_27_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv35 = R.call_tir(cls.fused_add4_maximum_minimum, (add193, lv719), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm56 = R.call_tir(cls.layer_norm1, (lv35, model_encoder_layers_28_self_attn_layer_norm_weight, model_encoder_layers_28_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv720 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_28_self_attn_q_proj_weight, layer_norm56, model_encoder_layers_28_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape224 = R.call_tir(cls.reshape, (lv720,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv159 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_28_self_attn_k_proj_weight, layer_norm56), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape225 = R.call_tir(cls.reshape, (lv159,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv721 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_28_self_attn_v_proj_weight, layer_norm56, model_encoder_layers_28_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape226 = R.call_tir(cls.reshape, (lv721,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape227 = R.call_tir(cls.reshape1, (reshape224,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape228 = R.call_tir(cls.reshape1, (reshape225,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape229 = R.call_tir(cls.reshape1, (reshape226,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv32_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape227, reshape228, reshape229), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape230 = R.call_tir(cls.reshape10, (lv32_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape231 = R.call_tir(cls.reshape11, (reshape230,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv722 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_28_self_attn_out_proj_weight, reshape231, model_encoder_layers_28_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add200 = R.call_tir(cls.add4, (lv35, lv722), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm57 = R.call_tir(cls.layer_norm1, (add200, model_encoder_layers_28_final_layer_norm_weight, model_encoder_layers_28_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv124 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_28_fc1_weight, layer_norm57, model_encoder_layers_28_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv723 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_28_fc2_weight, lv124, model_encoder_layers_28_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv36 = R.call_tir(cls.fused_add4_maximum_minimum, (add200, lv723), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm58 = R.call_tir(cls.layer_norm1, (lv36, model_encoder_layers_29_self_attn_layer_norm_weight, model_encoder_layers_29_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv724 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_29_self_attn_q_proj_weight, layer_norm58, model_encoder_layers_29_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape232 = R.call_tir(cls.reshape, (lv724,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv160 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_29_self_attn_k_proj_weight, layer_norm58), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape233 = R.call_tir(cls.reshape, (lv160,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv725 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_29_self_attn_v_proj_weight, layer_norm58, model_encoder_layers_29_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape234 = R.call_tir(cls.reshape, (lv725,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape235 = R.call_tir(cls.reshape1, (reshape232,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape236 = R.call_tir(cls.reshape1, (reshape233,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape237 = R.call_tir(cls.reshape1, (reshape234,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv33_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape235, reshape236, reshape237), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape238 = R.call_tir(cls.reshape10, (lv33_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape239 = R.call_tir(cls.reshape11, (reshape238,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv726 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_29_self_attn_out_proj_weight, reshape239, model_encoder_layers_29_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add207 = R.call_tir(cls.add4, (lv36, lv726), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm59 = R.call_tir(cls.layer_norm1, (add207, model_encoder_layers_29_final_layer_norm_weight, model_encoder_layers_29_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv125 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_29_fc1_weight, layer_norm59, model_encoder_layers_29_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv727 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_29_fc2_weight, lv125, model_encoder_layers_29_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv37 = R.call_tir(cls.fused_add4_maximum_minimum, (add207, lv727), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm60 = R.call_tir(cls.layer_norm1, (lv37, model_encoder_layers_30_self_attn_layer_norm_weight, model_encoder_layers_30_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv728 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_30_self_attn_q_proj_weight, layer_norm60, model_encoder_layers_30_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape240 = R.call_tir(cls.reshape, (lv728,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv161 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_30_self_attn_k_proj_weight, layer_norm60), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape241 = R.call_tir(cls.reshape, (lv161,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv729 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_30_self_attn_v_proj_weight, layer_norm60, model_encoder_layers_30_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape242 = R.call_tir(cls.reshape, (lv729,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape243 = R.call_tir(cls.reshape1, (reshape240,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape244 = R.call_tir(cls.reshape1, (reshape241,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape245 = R.call_tir(cls.reshape1, (reshape242,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv34_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape243, reshape244, reshape245), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape246 = R.call_tir(cls.reshape10, (lv34_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape247 = R.call_tir(cls.reshape11, (reshape246,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv730 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_30_self_attn_out_proj_weight, reshape247, model_encoder_layers_30_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add214 = R.call_tir(cls.add4, (lv37, lv730), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm61 = R.call_tir(cls.layer_norm1, (add214, model_encoder_layers_30_final_layer_norm_weight, model_encoder_layers_30_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv126 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_30_fc1_weight, layer_norm61, model_encoder_layers_30_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv731 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_30_fc2_weight, lv126, model_encoder_layers_30_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv38 = R.call_tir(cls.fused_add4_maximum_minimum, (add214, lv731), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm62 = R.call_tir(cls.layer_norm1, (lv38, model_encoder_layers_31_self_attn_layer_norm_weight, model_encoder_layers_31_self_attn_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv732 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_31_self_attn_q_proj_weight, layer_norm62, model_encoder_layers_31_self_attn_q_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape248 = R.call_tir(cls.reshape, (lv732,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv162 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_cublas", (model_encoder_layers_31_self_attn_k_proj_weight, layer_norm62), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape249 = R.call_tir(cls.reshape, (lv162,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) lv733 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_31_self_attn_v_proj_weight, layer_norm62, model_encoder_layers_31_self_attn_v_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) reshape250 = R.call_tir(cls.reshape, (lv733,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape251 = R.call_tir(cls.reshape1, (reshape248,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape252 = R.call_tir(cls.reshape1, (reshape249,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape253 = R.call_tir(cls.reshape1, (reshape250,), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) lv35_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_no_append", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape251, reshape252, reshape253), out_sinfo=R.Tensor((batch_size * 1500, 20, 64), dtype="float16")) reshape254 = R.call_tir(cls.reshape10, (lv35_1,), out_sinfo=R.Tensor((batch_size, 1500, 20, 64), dtype="float16")) reshape255 = R.call_tir(cls.reshape11, (reshape254,), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv734 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_cublas", (model_encoder_layers_31_self_attn_out_proj_weight, reshape255, model_encoder_layers_31_self_attn_out_proj_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) add221 = R.call_tir(cls.add4, (lv38, lv734), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) layer_norm63 = R.call_tir(cls.layer_norm1, (add221, model_encoder_layers_31_final_layer_norm_weight, model_encoder_layers_31_final_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv127 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu2_cublas", (model_encoder_layers_31_fc1_weight, layer_norm63, model_encoder_layers_31_fc1_bias), out_sinfo=R.Tensor((batch_size, 1500, 5120), dtype="float16")) lv735 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add5_cublas", (model_encoder_layers_31_fc2_weight, lv127, model_encoder_layers_31_fc2_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) lv39 = R.call_tir(cls.fused_add4_maximum_minimum, (add221, lv735), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) gv = R.call_tir(cls.layer_norm1, (lv39, model_encoder_layer_norm_weight, model_encoder_layer_norm_bias), out_sinfo=R.Tensor((batch_size, 1500, 1280), dtype="float16")) R.output(gv) return gv @R.function def batch_prefill(input_ids: R.Tensor((1, "seq_len"), dtype="int32"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor((1, "batch_size", 51866), dtype="float32"): batch_size = T.int64() seq_len = T.int64() R.func_attr({"num_input": 3, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module with R.dataflow(): model_decoder_embed_tokens_weight2: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] model_decoder_embed_positions_weight2: R.Tensor((448, 1280), dtype="float16") = packed_params[488] model_decoder_layers_0_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] model_decoder_layers_0_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] model_decoder_layers_0_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[491] model_decoder_layers_0_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] model_decoder_layers_0_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[493] model_decoder_layers_0_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] model_decoder_layers_0_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[495] model_decoder_layers_0_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[496] model_decoder_layers_0_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[497] model_decoder_layers_0_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] model_decoder_layers_0_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[502] model_decoder_layers_0_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] model_decoder_layers_0_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[504] model_decoder_layers_0_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[505] model_decoder_layers_0_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[506] model_decoder_layers_0_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] model_decoder_layers_0_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[508] model_decoder_layers_0_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] model_decoder_layers_0_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[510] model_decoder_layers_0_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[511] model_decoder_layers_0_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[512] model_decoder_layers_1_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] model_decoder_layers_1_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] model_decoder_layers_1_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[515] model_decoder_layers_1_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] model_decoder_layers_1_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[517] model_decoder_layers_1_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] model_decoder_layers_1_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[519] model_decoder_layers_1_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[520] model_decoder_layers_1_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[521] model_decoder_layers_1_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] model_decoder_layers_1_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[526] model_decoder_layers_1_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] model_decoder_layers_1_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[528] model_decoder_layers_1_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[529] model_decoder_layers_1_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[530] model_decoder_layers_1_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] model_decoder_layers_1_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[532] model_decoder_layers_1_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] model_decoder_layers_1_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[534] model_decoder_layers_1_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[535] model_decoder_layers_1_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[536] model_decoder_layers_2_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] model_decoder_layers_2_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] model_decoder_layers_2_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[539] model_decoder_layers_2_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] model_decoder_layers_2_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[541] model_decoder_layers_2_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] model_decoder_layers_2_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[543] model_decoder_layers_2_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[544] model_decoder_layers_2_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[545] model_decoder_layers_2_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] model_decoder_layers_2_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[550] model_decoder_layers_2_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] model_decoder_layers_2_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[552] model_decoder_layers_2_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[553] model_decoder_layers_2_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[554] model_decoder_layers_2_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] model_decoder_layers_2_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[556] model_decoder_layers_2_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] model_decoder_layers_2_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[558] model_decoder_layers_2_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[559] model_decoder_layers_2_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[560] model_decoder_layers_3_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] model_decoder_layers_3_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] model_decoder_layers_3_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[563] model_decoder_layers_3_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] model_decoder_layers_3_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[565] model_decoder_layers_3_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] model_decoder_layers_3_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[567] model_decoder_layers_3_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[568] model_decoder_layers_3_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[569] model_decoder_layers_3_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] model_decoder_layers_3_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[574] model_decoder_layers_3_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] model_decoder_layers_3_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[576] model_decoder_layers_3_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[577] model_decoder_layers_3_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[578] model_decoder_layers_3_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] model_decoder_layers_3_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[580] model_decoder_layers_3_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] model_decoder_layers_3_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[582] model_decoder_layers_3_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[583] model_decoder_layers_3_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[584] model_decoder_layers_4_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] model_decoder_layers_4_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] model_decoder_layers_4_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[587] model_decoder_layers_4_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] model_decoder_layers_4_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[589] model_decoder_layers_4_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] model_decoder_layers_4_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[591] model_decoder_layers_4_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[592] model_decoder_layers_4_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[593] model_decoder_layers_4_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] model_decoder_layers_4_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[598] model_decoder_layers_4_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] model_decoder_layers_4_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[600] model_decoder_layers_4_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[601] model_decoder_layers_4_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[602] model_decoder_layers_4_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] model_decoder_layers_4_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[604] model_decoder_layers_4_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] model_decoder_layers_4_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[606] model_decoder_layers_4_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[607] model_decoder_layers_4_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[608] model_decoder_layers_5_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] model_decoder_layers_5_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] model_decoder_layers_5_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[611] model_decoder_layers_5_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] model_decoder_layers_5_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[613] model_decoder_layers_5_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] model_decoder_layers_5_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[615] model_decoder_layers_5_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[616] model_decoder_layers_5_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[617] model_decoder_layers_5_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] model_decoder_layers_5_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[622] model_decoder_layers_5_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] model_decoder_layers_5_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[624] model_decoder_layers_5_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[625] model_decoder_layers_5_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[626] model_decoder_layers_5_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] model_decoder_layers_5_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[628] model_decoder_layers_5_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] model_decoder_layers_5_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[630] model_decoder_layers_5_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[631] model_decoder_layers_5_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[632] model_decoder_layers_6_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] model_decoder_layers_6_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] model_decoder_layers_6_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[635] model_decoder_layers_6_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] model_decoder_layers_6_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[637] model_decoder_layers_6_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] model_decoder_layers_6_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[639] model_decoder_layers_6_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[640] model_decoder_layers_6_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[641] model_decoder_layers_6_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] model_decoder_layers_6_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[646] model_decoder_layers_6_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] model_decoder_layers_6_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[648] model_decoder_layers_6_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[649] model_decoder_layers_6_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[650] model_decoder_layers_6_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] model_decoder_layers_6_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[652] model_decoder_layers_6_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] model_decoder_layers_6_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[654] model_decoder_layers_6_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[655] model_decoder_layers_6_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[656] model_decoder_layers_7_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] model_decoder_layers_7_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] model_decoder_layers_7_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[659] model_decoder_layers_7_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] model_decoder_layers_7_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[661] model_decoder_layers_7_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] model_decoder_layers_7_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[663] model_decoder_layers_7_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[664] model_decoder_layers_7_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[665] model_decoder_layers_7_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] model_decoder_layers_7_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[670] model_decoder_layers_7_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] model_decoder_layers_7_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[672] model_decoder_layers_7_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[673] model_decoder_layers_7_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[674] model_decoder_layers_7_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] model_decoder_layers_7_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[676] model_decoder_layers_7_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] model_decoder_layers_7_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[678] model_decoder_layers_7_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[679] model_decoder_layers_7_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[680] model_decoder_layers_8_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] model_decoder_layers_8_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] model_decoder_layers_8_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[683] model_decoder_layers_8_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] model_decoder_layers_8_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[685] model_decoder_layers_8_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] model_decoder_layers_8_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[687] model_decoder_layers_8_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[688] model_decoder_layers_8_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[689] model_decoder_layers_8_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] model_decoder_layers_8_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[694] model_decoder_layers_8_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] model_decoder_layers_8_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[696] model_decoder_layers_8_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[697] model_decoder_layers_8_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[698] model_decoder_layers_8_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] model_decoder_layers_8_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[700] model_decoder_layers_8_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] model_decoder_layers_8_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[702] model_decoder_layers_8_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[703] model_decoder_layers_8_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[704] model_decoder_layers_9_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] model_decoder_layers_9_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] model_decoder_layers_9_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[707] model_decoder_layers_9_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] model_decoder_layers_9_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[709] model_decoder_layers_9_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] model_decoder_layers_9_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[711] model_decoder_layers_9_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[712] model_decoder_layers_9_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[713] model_decoder_layers_9_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] model_decoder_layers_9_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[718] model_decoder_layers_9_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] model_decoder_layers_9_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[720] model_decoder_layers_9_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[721] model_decoder_layers_9_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[722] model_decoder_layers_9_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] model_decoder_layers_9_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[724] model_decoder_layers_9_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] model_decoder_layers_9_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[726] model_decoder_layers_9_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[727] model_decoder_layers_9_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[728] model_decoder_layers_10_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] model_decoder_layers_10_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] model_decoder_layers_10_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[731] model_decoder_layers_10_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] model_decoder_layers_10_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[733] model_decoder_layers_10_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] model_decoder_layers_10_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[735] model_decoder_layers_10_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[736] model_decoder_layers_10_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[737] model_decoder_layers_10_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] model_decoder_layers_10_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[742] model_decoder_layers_10_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] model_decoder_layers_10_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[744] model_decoder_layers_10_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[745] model_decoder_layers_10_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[746] model_decoder_layers_10_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] model_decoder_layers_10_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[748] model_decoder_layers_10_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] model_decoder_layers_10_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[750] model_decoder_layers_10_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[751] model_decoder_layers_10_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[752] model_decoder_layers_11_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] model_decoder_layers_11_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] model_decoder_layers_11_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[755] model_decoder_layers_11_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] model_decoder_layers_11_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[757] model_decoder_layers_11_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] model_decoder_layers_11_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[759] model_decoder_layers_11_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[760] model_decoder_layers_11_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[761] model_decoder_layers_11_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] model_decoder_layers_11_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[766] model_decoder_layers_11_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] model_decoder_layers_11_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[768] model_decoder_layers_11_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[769] model_decoder_layers_11_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[770] model_decoder_layers_11_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] model_decoder_layers_11_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[772] model_decoder_layers_11_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] model_decoder_layers_11_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[774] model_decoder_layers_11_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[775] model_decoder_layers_11_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[776] model_decoder_layers_12_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] model_decoder_layers_12_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] model_decoder_layers_12_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[779] model_decoder_layers_12_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] model_decoder_layers_12_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[781] model_decoder_layers_12_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] model_decoder_layers_12_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[783] model_decoder_layers_12_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[784] model_decoder_layers_12_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[785] model_decoder_layers_12_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] model_decoder_layers_12_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[790] model_decoder_layers_12_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] model_decoder_layers_12_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[792] model_decoder_layers_12_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[793] model_decoder_layers_12_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[794] model_decoder_layers_12_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] model_decoder_layers_12_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[796] model_decoder_layers_12_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] model_decoder_layers_12_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[798] model_decoder_layers_12_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[799] model_decoder_layers_12_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[800] model_decoder_layers_13_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] model_decoder_layers_13_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] model_decoder_layers_13_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[803] model_decoder_layers_13_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] model_decoder_layers_13_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[805] model_decoder_layers_13_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] model_decoder_layers_13_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[807] model_decoder_layers_13_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[808] model_decoder_layers_13_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[809] model_decoder_layers_13_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] model_decoder_layers_13_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[814] model_decoder_layers_13_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] model_decoder_layers_13_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[816] model_decoder_layers_13_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[817] model_decoder_layers_13_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[818] model_decoder_layers_13_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] model_decoder_layers_13_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[820] model_decoder_layers_13_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] model_decoder_layers_13_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[822] model_decoder_layers_13_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[823] model_decoder_layers_13_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[824] model_decoder_layers_14_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] model_decoder_layers_14_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] model_decoder_layers_14_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[827] model_decoder_layers_14_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] model_decoder_layers_14_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[829] model_decoder_layers_14_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] model_decoder_layers_14_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[831] model_decoder_layers_14_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[832] model_decoder_layers_14_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[833] model_decoder_layers_14_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] model_decoder_layers_14_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[838] model_decoder_layers_14_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] model_decoder_layers_14_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[840] model_decoder_layers_14_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[841] model_decoder_layers_14_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[842] model_decoder_layers_14_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] model_decoder_layers_14_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[844] model_decoder_layers_14_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] model_decoder_layers_14_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[846] model_decoder_layers_14_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[847] model_decoder_layers_14_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[848] model_decoder_layers_15_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] model_decoder_layers_15_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] model_decoder_layers_15_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[851] model_decoder_layers_15_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] model_decoder_layers_15_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[853] model_decoder_layers_15_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] model_decoder_layers_15_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[855] model_decoder_layers_15_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[856] model_decoder_layers_15_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[857] model_decoder_layers_15_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] model_decoder_layers_15_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[862] model_decoder_layers_15_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] model_decoder_layers_15_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[864] model_decoder_layers_15_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[865] model_decoder_layers_15_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[866] model_decoder_layers_15_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] model_decoder_layers_15_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[868] model_decoder_layers_15_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] model_decoder_layers_15_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[870] model_decoder_layers_15_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[871] model_decoder_layers_15_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[872] model_decoder_layers_16_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] model_decoder_layers_16_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] model_decoder_layers_16_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[875] model_decoder_layers_16_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] model_decoder_layers_16_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[877] model_decoder_layers_16_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] model_decoder_layers_16_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[879] model_decoder_layers_16_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[880] model_decoder_layers_16_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[881] model_decoder_layers_16_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] model_decoder_layers_16_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[886] model_decoder_layers_16_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] model_decoder_layers_16_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[888] model_decoder_layers_16_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[889] model_decoder_layers_16_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[890] model_decoder_layers_16_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] model_decoder_layers_16_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[892] model_decoder_layers_16_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] model_decoder_layers_16_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[894] model_decoder_layers_16_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[895] model_decoder_layers_16_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[896] model_decoder_layers_17_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] model_decoder_layers_17_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] model_decoder_layers_17_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[899] model_decoder_layers_17_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] model_decoder_layers_17_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[901] model_decoder_layers_17_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] model_decoder_layers_17_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[903] model_decoder_layers_17_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[904] model_decoder_layers_17_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[905] model_decoder_layers_17_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] model_decoder_layers_17_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[910] model_decoder_layers_17_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] model_decoder_layers_17_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[912] model_decoder_layers_17_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[913] model_decoder_layers_17_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[914] model_decoder_layers_17_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] model_decoder_layers_17_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[916] model_decoder_layers_17_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] model_decoder_layers_17_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[918] model_decoder_layers_17_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[919] model_decoder_layers_17_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[920] model_decoder_layers_18_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] model_decoder_layers_18_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] model_decoder_layers_18_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[923] model_decoder_layers_18_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] model_decoder_layers_18_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[925] model_decoder_layers_18_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] model_decoder_layers_18_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[927] model_decoder_layers_18_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[928] model_decoder_layers_18_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[929] model_decoder_layers_18_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] model_decoder_layers_18_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[934] model_decoder_layers_18_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] model_decoder_layers_18_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[936] model_decoder_layers_18_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[937] model_decoder_layers_18_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[938] model_decoder_layers_18_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] model_decoder_layers_18_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[940] model_decoder_layers_18_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] model_decoder_layers_18_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[942] model_decoder_layers_18_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[943] model_decoder_layers_18_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[944] model_decoder_layers_19_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] model_decoder_layers_19_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] model_decoder_layers_19_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[947] model_decoder_layers_19_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] model_decoder_layers_19_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[949] model_decoder_layers_19_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] model_decoder_layers_19_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[951] model_decoder_layers_19_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[952] model_decoder_layers_19_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[953] model_decoder_layers_19_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] model_decoder_layers_19_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[958] model_decoder_layers_19_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] model_decoder_layers_19_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[960] model_decoder_layers_19_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[961] model_decoder_layers_19_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[962] model_decoder_layers_19_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] model_decoder_layers_19_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[964] model_decoder_layers_19_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] model_decoder_layers_19_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[966] model_decoder_layers_19_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[967] model_decoder_layers_19_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[968] model_decoder_layers_20_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] model_decoder_layers_20_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] model_decoder_layers_20_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[971] model_decoder_layers_20_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] model_decoder_layers_20_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[973] model_decoder_layers_20_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] model_decoder_layers_20_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[975] model_decoder_layers_20_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[976] model_decoder_layers_20_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[977] model_decoder_layers_20_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] model_decoder_layers_20_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[982] model_decoder_layers_20_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] model_decoder_layers_20_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[984] model_decoder_layers_20_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[985] model_decoder_layers_20_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[986] model_decoder_layers_20_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] model_decoder_layers_20_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[988] model_decoder_layers_20_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] model_decoder_layers_20_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[990] model_decoder_layers_20_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[991] model_decoder_layers_20_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[992] model_decoder_layers_21_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] model_decoder_layers_21_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] model_decoder_layers_21_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[995] model_decoder_layers_21_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] model_decoder_layers_21_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[997] model_decoder_layers_21_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] model_decoder_layers_21_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[999] model_decoder_layers_21_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1000] model_decoder_layers_21_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1001] model_decoder_layers_21_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] model_decoder_layers_21_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1006] model_decoder_layers_21_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] model_decoder_layers_21_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1008] model_decoder_layers_21_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1009] model_decoder_layers_21_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1010] model_decoder_layers_21_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] model_decoder_layers_21_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1012] model_decoder_layers_21_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] model_decoder_layers_21_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1014] model_decoder_layers_21_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1015] model_decoder_layers_21_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1016] model_decoder_layers_22_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] model_decoder_layers_22_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] model_decoder_layers_22_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1019] model_decoder_layers_22_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] model_decoder_layers_22_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1021] model_decoder_layers_22_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] model_decoder_layers_22_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1023] model_decoder_layers_22_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1024] model_decoder_layers_22_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1025] model_decoder_layers_22_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] model_decoder_layers_22_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1030] model_decoder_layers_22_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] model_decoder_layers_22_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1032] model_decoder_layers_22_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1033] model_decoder_layers_22_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1034] model_decoder_layers_22_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] model_decoder_layers_22_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1036] model_decoder_layers_22_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] model_decoder_layers_22_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1038] model_decoder_layers_22_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1039] model_decoder_layers_22_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1040] model_decoder_layers_23_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] model_decoder_layers_23_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] model_decoder_layers_23_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1043] model_decoder_layers_23_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] model_decoder_layers_23_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1045] model_decoder_layers_23_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] model_decoder_layers_23_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1047] model_decoder_layers_23_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1048] model_decoder_layers_23_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1049] model_decoder_layers_23_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] model_decoder_layers_23_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1054] model_decoder_layers_23_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] model_decoder_layers_23_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1056] model_decoder_layers_23_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1057] model_decoder_layers_23_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1058] model_decoder_layers_23_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] model_decoder_layers_23_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1060] model_decoder_layers_23_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] model_decoder_layers_23_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1062] model_decoder_layers_23_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1063] model_decoder_layers_23_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1064] model_decoder_layers_24_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] model_decoder_layers_24_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] model_decoder_layers_24_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1067] model_decoder_layers_24_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] model_decoder_layers_24_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1069] model_decoder_layers_24_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] model_decoder_layers_24_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1071] model_decoder_layers_24_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1072] model_decoder_layers_24_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1073] model_decoder_layers_24_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] model_decoder_layers_24_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1078] model_decoder_layers_24_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] model_decoder_layers_24_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1080] model_decoder_layers_24_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1081] model_decoder_layers_24_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1082] model_decoder_layers_24_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] model_decoder_layers_24_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1084] model_decoder_layers_24_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] model_decoder_layers_24_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1086] model_decoder_layers_24_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1087] model_decoder_layers_24_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1088] model_decoder_layers_25_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] model_decoder_layers_25_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] model_decoder_layers_25_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1091] model_decoder_layers_25_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] model_decoder_layers_25_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1093] model_decoder_layers_25_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] model_decoder_layers_25_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1095] model_decoder_layers_25_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1096] model_decoder_layers_25_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1097] model_decoder_layers_25_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] model_decoder_layers_25_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1102] model_decoder_layers_25_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] model_decoder_layers_25_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1104] model_decoder_layers_25_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1105] model_decoder_layers_25_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1106] model_decoder_layers_25_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] model_decoder_layers_25_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1108] model_decoder_layers_25_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] model_decoder_layers_25_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1110] model_decoder_layers_25_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1111] model_decoder_layers_25_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1112] model_decoder_layers_26_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] model_decoder_layers_26_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] model_decoder_layers_26_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1115] model_decoder_layers_26_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] model_decoder_layers_26_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1117] model_decoder_layers_26_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] model_decoder_layers_26_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1119] model_decoder_layers_26_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1120] model_decoder_layers_26_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1121] model_decoder_layers_26_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] model_decoder_layers_26_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1126] model_decoder_layers_26_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] model_decoder_layers_26_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1128] model_decoder_layers_26_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1129] model_decoder_layers_26_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1130] model_decoder_layers_26_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] model_decoder_layers_26_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1132] model_decoder_layers_26_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] model_decoder_layers_26_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1134] model_decoder_layers_26_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1135] model_decoder_layers_26_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1136] model_decoder_layers_27_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] model_decoder_layers_27_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] model_decoder_layers_27_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1139] model_decoder_layers_27_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] model_decoder_layers_27_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1141] model_decoder_layers_27_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] model_decoder_layers_27_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1143] model_decoder_layers_27_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1144] model_decoder_layers_27_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1145] model_decoder_layers_27_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] model_decoder_layers_27_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1150] model_decoder_layers_27_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] model_decoder_layers_27_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1152] model_decoder_layers_27_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1153] model_decoder_layers_27_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1154] model_decoder_layers_27_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] model_decoder_layers_27_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1156] model_decoder_layers_27_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] model_decoder_layers_27_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1158] model_decoder_layers_27_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1159] model_decoder_layers_27_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1160] model_decoder_layers_28_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] model_decoder_layers_28_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] model_decoder_layers_28_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1163] model_decoder_layers_28_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] model_decoder_layers_28_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1165] model_decoder_layers_28_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] model_decoder_layers_28_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1167] model_decoder_layers_28_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1168] model_decoder_layers_28_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1169] model_decoder_layers_28_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] model_decoder_layers_28_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1174] model_decoder_layers_28_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] model_decoder_layers_28_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1176] model_decoder_layers_28_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1177] model_decoder_layers_28_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1178] model_decoder_layers_28_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] model_decoder_layers_28_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1180] model_decoder_layers_28_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] model_decoder_layers_28_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1182] model_decoder_layers_28_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1183] model_decoder_layers_28_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1184] model_decoder_layers_29_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] model_decoder_layers_29_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] model_decoder_layers_29_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1187] model_decoder_layers_29_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] model_decoder_layers_29_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1189] model_decoder_layers_29_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] model_decoder_layers_29_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1191] model_decoder_layers_29_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1192] model_decoder_layers_29_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1193] model_decoder_layers_29_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] model_decoder_layers_29_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1198] model_decoder_layers_29_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] model_decoder_layers_29_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1200] model_decoder_layers_29_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1201] model_decoder_layers_29_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1202] model_decoder_layers_29_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] model_decoder_layers_29_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1204] model_decoder_layers_29_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] model_decoder_layers_29_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1206] model_decoder_layers_29_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1207] model_decoder_layers_29_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1208] model_decoder_layers_30_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] model_decoder_layers_30_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] model_decoder_layers_30_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1211] model_decoder_layers_30_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] model_decoder_layers_30_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1213] model_decoder_layers_30_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] model_decoder_layers_30_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1215] model_decoder_layers_30_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1216] model_decoder_layers_30_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1217] model_decoder_layers_30_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] model_decoder_layers_30_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1222] model_decoder_layers_30_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] model_decoder_layers_30_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1224] model_decoder_layers_30_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1225] model_decoder_layers_30_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1226] model_decoder_layers_30_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] model_decoder_layers_30_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1228] model_decoder_layers_30_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] model_decoder_layers_30_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1230] model_decoder_layers_30_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1231] model_decoder_layers_30_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1232] model_decoder_layers_31_self_attn_k_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] model_decoder_layers_31_self_attn_v_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] model_decoder_layers_31_self_attn_v_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1235] model_decoder_layers_31_self_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] model_decoder_layers_31_self_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1237] model_decoder_layers_31_self_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] model_decoder_layers_31_self_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1239] model_decoder_layers_31_self_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1240] model_decoder_layers_31_self_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1241] model_decoder_layers_31_encoder_attn_q_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] model_decoder_layers_31_encoder_attn_q_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1246] model_decoder_layers_31_encoder_attn_out_proj_weight2: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] model_decoder_layers_31_encoder_attn_out_proj_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1248] model_decoder_layers_31_encoder_attn_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1249] model_decoder_layers_31_encoder_attn_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1250] model_decoder_layers_31_fc1_weight2: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] model_decoder_layers_31_fc1_bias2: R.Tensor((5120,), dtype="float16") = packed_params[1252] model_decoder_layers_31_fc2_weight2: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] model_decoder_layers_31_fc2_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1254] model_decoder_layers_31_final_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1255] model_decoder_layers_31_final_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1256] model_decoder_layer_norm_weight2: R.Tensor((1280,), dtype="float16") = packed_params[1257] model_decoder_layer_norm_bias2: R.Tensor((1280,), dtype="float16") = packed_params[1258] reshape384 = R.call_tir(cls.reshape12, (input_ids,), out_sinfo=R.Tensor((seq_len,), dtype="int32")) take = R.call_tir(cls.take, (model_decoder_embed_tokens_weight2, reshape384), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) reshape385 = R.call_tir(cls.reshape13, (take,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv68: R.Tensor((seq_len,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((seq_len,), dtype="int32"),)) take1 = R.call_tir(cls.take1, (model_decoder_embed_positions_weight2, lv68), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) reshape386 = R.call_tir(cls.reshape13, (take1,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add257 = R.call_tir(cls.add5, (reshape385, reshape386), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm65 = R.call_tir(cls.layer_norm2, (add257, model_decoder_layers_0_self_attn_layer_norm_weight2, model_decoder_layers_0_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv416 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_q_proj_weight2, layer_norm65, model_decoder_layers_0_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape387 = R.call_tir(cls.reshape14, (lv416,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv98 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_0_self_attn_k_proj_weight2, layer_norm65), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape388 = R.call_tir(cls.reshape14, (lv98,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv417 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_v_proj_weight2, layer_norm65, model_decoder_layers_0_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape389 = R.call_tir(cls.reshape14, (lv417,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat = R.call_tir(cls.concatenate1, (reshape387, reshape388, reshape389), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape390 = R.call_tir(cls.reshape15, (concat,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv69 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape390), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape391 = R.call_tir(cls.reshape16, (lv69,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape392 = R.call_tir(cls.reshape17, (reshape391,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv418 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_out_proj_weight2, reshape392, model_decoder_layers_0_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add261 = R.call_tir(cls.add5, (add257, lv418), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm66 = R.call_tir(cls.layer_norm2, (add261, model_decoder_layers_0_encoder_attn_layer_norm_weight2, model_decoder_layers_0_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv419 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_q_proj_weight2, layer_norm66, model_decoder_layers_0_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape393 = R.call_tir(cls.reshape14, (lv419,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape394 = R.call_tir(cls.reshape18, (reshape393,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv70 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape394), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape395 = R.call_tir(cls.reshape16, (lv70,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape396 = R.call_tir(cls.reshape17, (reshape395,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv420 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_out_proj_weight2, reshape396, model_decoder_layers_0_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add264 = R.call_tir(cls.add5, (add261, lv420), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm67 = R.call_tir(cls.layer_norm2, (add264, model_decoder_layers_0_final_layer_norm_weight2, model_decoder_layers_0_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv64 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_0_fc1_weight2, layer_norm67, model_decoder_layers_0_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv421 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_0_fc2_weight2, lv64, model_decoder_layers_0_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add267 = R.call_tir(cls.add5, (add264, lv421), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm68 = R.call_tir(cls.layer_norm2, (add267, model_decoder_layers_1_self_attn_layer_norm_weight2, model_decoder_layers_1_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv422 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_q_proj_weight2, layer_norm68, model_decoder_layers_1_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape397 = R.call_tir(cls.reshape14, (lv422,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv99 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_1_self_attn_k_proj_weight2, layer_norm68), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape398 = R.call_tir(cls.reshape14, (lv99,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv423 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_v_proj_weight2, layer_norm68, model_decoder_layers_1_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape399 = R.call_tir(cls.reshape14, (lv423,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat1 = R.call_tir(cls.concatenate1, (reshape397, reshape398, reshape399), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape400 = R.call_tir(cls.reshape15, (concat1,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv71 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape400), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape401 = R.call_tir(cls.reshape16, (lv71,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape402 = R.call_tir(cls.reshape17, (reshape401,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv424 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_out_proj_weight2, reshape402, model_decoder_layers_1_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add271 = R.call_tir(cls.add5, (add267, lv424), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm69 = R.call_tir(cls.layer_norm2, (add271, model_decoder_layers_1_encoder_attn_layer_norm_weight2, model_decoder_layers_1_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv425 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_q_proj_weight2, layer_norm69, model_decoder_layers_1_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape403 = R.call_tir(cls.reshape14, (lv425,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape404 = R.call_tir(cls.reshape18, (reshape403,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv72 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape404), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape405 = R.call_tir(cls.reshape16, (lv72,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape406 = R.call_tir(cls.reshape17, (reshape405,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv426 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_out_proj_weight2, reshape406, model_decoder_layers_1_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add274 = R.call_tir(cls.add5, (add271, lv426), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm70 = R.call_tir(cls.layer_norm2, (add274, model_decoder_layers_1_final_layer_norm_weight2, model_decoder_layers_1_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv65 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_1_fc1_weight2, layer_norm70, model_decoder_layers_1_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv427 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_1_fc2_weight2, lv65, model_decoder_layers_1_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add277 = R.call_tir(cls.add5, (add274, lv427), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm71 = R.call_tir(cls.layer_norm2, (add277, model_decoder_layers_2_self_attn_layer_norm_weight2, model_decoder_layers_2_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv428 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_q_proj_weight2, layer_norm71, model_decoder_layers_2_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape407 = R.call_tir(cls.reshape14, (lv428,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv100 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_2_self_attn_k_proj_weight2, layer_norm71), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape408 = R.call_tir(cls.reshape14, (lv100,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv429 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_v_proj_weight2, layer_norm71, model_decoder_layers_2_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape409 = R.call_tir(cls.reshape14, (lv429,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat2 = R.call_tir(cls.concatenate1, (reshape407, reshape408, reshape409), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape410 = R.call_tir(cls.reshape15, (concat2,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv73 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape410), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape411 = R.call_tir(cls.reshape16, (lv73,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape412 = R.call_tir(cls.reshape17, (reshape411,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv430 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_out_proj_weight2, reshape412, model_decoder_layers_2_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add281 = R.call_tir(cls.add5, (add277, lv430), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm72 = R.call_tir(cls.layer_norm2, (add281, model_decoder_layers_2_encoder_attn_layer_norm_weight2, model_decoder_layers_2_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv431 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_q_proj_weight2, layer_norm72, model_decoder_layers_2_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape413 = R.call_tir(cls.reshape14, (lv431,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape414 = R.call_tir(cls.reshape18, (reshape413,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv74 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape414), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape415 = R.call_tir(cls.reshape16, (lv74,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape416 = R.call_tir(cls.reshape17, (reshape415,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv432 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_out_proj_weight2, reshape416, model_decoder_layers_2_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add284 = R.call_tir(cls.add5, (add281, lv432), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm73 = R.call_tir(cls.layer_norm2, (add284, model_decoder_layers_2_final_layer_norm_weight2, model_decoder_layers_2_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv66 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_2_fc1_weight2, layer_norm73, model_decoder_layers_2_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv433 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_2_fc2_weight2, lv66, model_decoder_layers_2_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add287 = R.call_tir(cls.add5, (add284, lv433), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm74 = R.call_tir(cls.layer_norm2, (add287, model_decoder_layers_3_self_attn_layer_norm_weight2, model_decoder_layers_3_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv434 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_q_proj_weight2, layer_norm74, model_decoder_layers_3_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape417 = R.call_tir(cls.reshape14, (lv434,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv101 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_3_self_attn_k_proj_weight2, layer_norm74), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape418 = R.call_tir(cls.reshape14, (lv101,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv435 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_v_proj_weight2, layer_norm74, model_decoder_layers_3_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape419 = R.call_tir(cls.reshape14, (lv435,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat3 = R.call_tir(cls.concatenate1, (reshape417, reshape418, reshape419), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape420 = R.call_tir(cls.reshape15, (concat3,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv75 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape420), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape421 = R.call_tir(cls.reshape16, (lv75,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape422 = R.call_tir(cls.reshape17, (reshape421,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv436 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_out_proj_weight2, reshape422, model_decoder_layers_3_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add291 = R.call_tir(cls.add5, (add287, lv436), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm75 = R.call_tir(cls.layer_norm2, (add291, model_decoder_layers_3_encoder_attn_layer_norm_weight2, model_decoder_layers_3_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv437 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_q_proj_weight2, layer_norm75, model_decoder_layers_3_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape423 = R.call_tir(cls.reshape14, (lv437,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape424 = R.call_tir(cls.reshape18, (reshape423,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv76 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape424), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape425 = R.call_tir(cls.reshape16, (lv76,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape426 = R.call_tir(cls.reshape17, (reshape425,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv438 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_out_proj_weight2, reshape426, model_decoder_layers_3_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add294 = R.call_tir(cls.add5, (add291, lv438), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm76 = R.call_tir(cls.layer_norm2, (add294, model_decoder_layers_3_final_layer_norm_weight2, model_decoder_layers_3_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv67 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_3_fc1_weight2, layer_norm76, model_decoder_layers_3_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv439 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_3_fc2_weight2, lv67, model_decoder_layers_3_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add297 = R.call_tir(cls.add5, (add294, lv439), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm77 = R.call_tir(cls.layer_norm2, (add297, model_decoder_layers_4_self_attn_layer_norm_weight2, model_decoder_layers_4_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv440 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_q_proj_weight2, layer_norm77, model_decoder_layers_4_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape427 = R.call_tir(cls.reshape14, (lv440,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv102 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_4_self_attn_k_proj_weight2, layer_norm77), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape428 = R.call_tir(cls.reshape14, (lv102,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv441 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_v_proj_weight2, layer_norm77, model_decoder_layers_4_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape429 = R.call_tir(cls.reshape14, (lv441,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat4 = R.call_tir(cls.concatenate1, (reshape427, reshape428, reshape429), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape430 = R.call_tir(cls.reshape15, (concat4,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv77 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape430), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape431 = R.call_tir(cls.reshape16, (lv77,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape432 = R.call_tir(cls.reshape17, (reshape431,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv442 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_out_proj_weight2, reshape432, model_decoder_layers_4_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add301 = R.call_tir(cls.add5, (add297, lv442), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm78 = R.call_tir(cls.layer_norm2, (add301, model_decoder_layers_4_encoder_attn_layer_norm_weight2, model_decoder_layers_4_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv443 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_q_proj_weight2, layer_norm78, model_decoder_layers_4_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape433 = R.call_tir(cls.reshape14, (lv443,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape434 = R.call_tir(cls.reshape18, (reshape433,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv78 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape434), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape435 = R.call_tir(cls.reshape16, (lv78,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape436 = R.call_tir(cls.reshape17, (reshape435,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv444 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_out_proj_weight2, reshape436, model_decoder_layers_4_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add304 = R.call_tir(cls.add5, (add301, lv444), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm79 = R.call_tir(cls.layer_norm2, (add304, model_decoder_layers_4_final_layer_norm_weight2, model_decoder_layers_4_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv68_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_4_fc1_weight2, layer_norm79, model_decoder_layers_4_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv445 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_4_fc2_weight2, lv68_1, model_decoder_layers_4_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add307 = R.call_tir(cls.add5, (add304, lv445), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm80 = R.call_tir(cls.layer_norm2, (add307, model_decoder_layers_5_self_attn_layer_norm_weight2, model_decoder_layers_5_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv446 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_q_proj_weight2, layer_norm80, model_decoder_layers_5_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape437 = R.call_tir(cls.reshape14, (lv446,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv103 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_5_self_attn_k_proj_weight2, layer_norm80), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape438 = R.call_tir(cls.reshape14, (lv103,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv447 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_v_proj_weight2, layer_norm80, model_decoder_layers_5_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape439 = R.call_tir(cls.reshape14, (lv447,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat5 = R.call_tir(cls.concatenate1, (reshape437, reshape438, reshape439), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape440 = R.call_tir(cls.reshape15, (concat5,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv79 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape440), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape441 = R.call_tir(cls.reshape16, (lv79,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape442 = R.call_tir(cls.reshape17, (reshape441,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv448 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_out_proj_weight2, reshape442, model_decoder_layers_5_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add311 = R.call_tir(cls.add5, (add307, lv448), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm81 = R.call_tir(cls.layer_norm2, (add311, model_decoder_layers_5_encoder_attn_layer_norm_weight2, model_decoder_layers_5_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv449 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_q_proj_weight2, layer_norm81, model_decoder_layers_5_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape443 = R.call_tir(cls.reshape14, (lv449,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape444 = R.call_tir(cls.reshape18, (reshape443,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv80 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape444), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape445 = R.call_tir(cls.reshape16, (lv80,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape446 = R.call_tir(cls.reshape17, (reshape445,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv450 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_out_proj_weight2, reshape446, model_decoder_layers_5_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add314 = R.call_tir(cls.add5, (add311, lv450), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm82 = R.call_tir(cls.layer_norm2, (add314, model_decoder_layers_5_final_layer_norm_weight2, model_decoder_layers_5_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv69_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_5_fc1_weight2, layer_norm82, model_decoder_layers_5_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv451 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_5_fc2_weight2, lv69_1, model_decoder_layers_5_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add317 = R.call_tir(cls.add5, (add314, lv451), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm83 = R.call_tir(cls.layer_norm2, (add317, model_decoder_layers_6_self_attn_layer_norm_weight2, model_decoder_layers_6_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv452 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_q_proj_weight2, layer_norm83, model_decoder_layers_6_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape447 = R.call_tir(cls.reshape14, (lv452,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv104 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_6_self_attn_k_proj_weight2, layer_norm83), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape448 = R.call_tir(cls.reshape14, (lv104,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv453 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_v_proj_weight2, layer_norm83, model_decoder_layers_6_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape449 = R.call_tir(cls.reshape14, (lv453,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat6 = R.call_tir(cls.concatenate1, (reshape447, reshape448, reshape449), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape450 = R.call_tir(cls.reshape15, (concat6,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv81 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape450), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape451 = R.call_tir(cls.reshape16, (lv81,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape452 = R.call_tir(cls.reshape17, (reshape451,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv454 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_out_proj_weight2, reshape452, model_decoder_layers_6_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add321 = R.call_tir(cls.add5, (add317, lv454), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm84 = R.call_tir(cls.layer_norm2, (add321, model_decoder_layers_6_encoder_attn_layer_norm_weight2, model_decoder_layers_6_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv455 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_q_proj_weight2, layer_norm84, model_decoder_layers_6_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape453 = R.call_tir(cls.reshape14, (lv455,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape454 = R.call_tir(cls.reshape18, (reshape453,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv82 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape454), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape455 = R.call_tir(cls.reshape16, (lv82,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape456 = R.call_tir(cls.reshape17, (reshape455,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv456 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_out_proj_weight2, reshape456, model_decoder_layers_6_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add324 = R.call_tir(cls.add5, (add321, lv456), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm85 = R.call_tir(cls.layer_norm2, (add324, model_decoder_layers_6_final_layer_norm_weight2, model_decoder_layers_6_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv70_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_6_fc1_weight2, layer_norm85, model_decoder_layers_6_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv457 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_6_fc2_weight2, lv70_1, model_decoder_layers_6_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add327 = R.call_tir(cls.add5, (add324, lv457), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm86 = R.call_tir(cls.layer_norm2, (add327, model_decoder_layers_7_self_attn_layer_norm_weight2, model_decoder_layers_7_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv458 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_q_proj_weight2, layer_norm86, model_decoder_layers_7_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape457 = R.call_tir(cls.reshape14, (lv458,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv105 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_7_self_attn_k_proj_weight2, layer_norm86), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape458 = R.call_tir(cls.reshape14, (lv105,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv459 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_v_proj_weight2, layer_norm86, model_decoder_layers_7_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape459 = R.call_tir(cls.reshape14, (lv459,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat7 = R.call_tir(cls.concatenate1, (reshape457, reshape458, reshape459), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape460 = R.call_tir(cls.reshape15, (concat7,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv83 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape460), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape461 = R.call_tir(cls.reshape16, (lv83,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape462 = R.call_tir(cls.reshape17, (reshape461,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv460 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_out_proj_weight2, reshape462, model_decoder_layers_7_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add331 = R.call_tir(cls.add5, (add327, lv460), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm87 = R.call_tir(cls.layer_norm2, (add331, model_decoder_layers_7_encoder_attn_layer_norm_weight2, model_decoder_layers_7_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv461 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_q_proj_weight2, layer_norm87, model_decoder_layers_7_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape463 = R.call_tir(cls.reshape14, (lv461,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape464 = R.call_tir(cls.reshape18, (reshape463,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv84 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape464), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape465 = R.call_tir(cls.reshape16, (lv84,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape466 = R.call_tir(cls.reshape17, (reshape465,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv462 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_out_proj_weight2, reshape466, model_decoder_layers_7_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add334 = R.call_tir(cls.add5, (add331, lv462), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm88 = R.call_tir(cls.layer_norm2, (add334, model_decoder_layers_7_final_layer_norm_weight2, model_decoder_layers_7_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv71_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_7_fc1_weight2, layer_norm88, model_decoder_layers_7_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv463 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_7_fc2_weight2, lv71_1, model_decoder_layers_7_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add337 = R.call_tir(cls.add5, (add334, lv463), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm89 = R.call_tir(cls.layer_norm2, (add337, model_decoder_layers_8_self_attn_layer_norm_weight2, model_decoder_layers_8_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv464 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_q_proj_weight2, layer_norm89, model_decoder_layers_8_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape467 = R.call_tir(cls.reshape14, (lv464,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv106 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_8_self_attn_k_proj_weight2, layer_norm89), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape468 = R.call_tir(cls.reshape14, (lv106,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv465 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_v_proj_weight2, layer_norm89, model_decoder_layers_8_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape469 = R.call_tir(cls.reshape14, (lv465,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat8 = R.call_tir(cls.concatenate1, (reshape467, reshape468, reshape469), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape470 = R.call_tir(cls.reshape15, (concat8,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv85 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape470), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape471 = R.call_tir(cls.reshape16, (lv85,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape472 = R.call_tir(cls.reshape17, (reshape471,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv466 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_out_proj_weight2, reshape472, model_decoder_layers_8_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add341 = R.call_tir(cls.add5, (add337, lv466), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm90 = R.call_tir(cls.layer_norm2, (add341, model_decoder_layers_8_encoder_attn_layer_norm_weight2, model_decoder_layers_8_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv467 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_q_proj_weight2, layer_norm90, model_decoder_layers_8_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape473 = R.call_tir(cls.reshape14, (lv467,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape474 = R.call_tir(cls.reshape18, (reshape473,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv86 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape474), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape475 = R.call_tir(cls.reshape16, (lv86,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape476 = R.call_tir(cls.reshape17, (reshape475,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv468 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_out_proj_weight2, reshape476, model_decoder_layers_8_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add344 = R.call_tir(cls.add5, (add341, lv468), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm91 = R.call_tir(cls.layer_norm2, (add344, model_decoder_layers_8_final_layer_norm_weight2, model_decoder_layers_8_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv72_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_8_fc1_weight2, layer_norm91, model_decoder_layers_8_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv469 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_8_fc2_weight2, lv72_1, model_decoder_layers_8_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add347 = R.call_tir(cls.add5, (add344, lv469), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm92 = R.call_tir(cls.layer_norm2, (add347, model_decoder_layers_9_self_attn_layer_norm_weight2, model_decoder_layers_9_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv470 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_q_proj_weight2, layer_norm92, model_decoder_layers_9_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape477 = R.call_tir(cls.reshape14, (lv470,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv107 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_9_self_attn_k_proj_weight2, layer_norm92), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape478 = R.call_tir(cls.reshape14, (lv107,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv471 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_v_proj_weight2, layer_norm92, model_decoder_layers_9_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape479 = R.call_tir(cls.reshape14, (lv471,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat9 = R.call_tir(cls.concatenate1, (reshape477, reshape478, reshape479), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape480 = R.call_tir(cls.reshape15, (concat9,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv87 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape480), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape481 = R.call_tir(cls.reshape16, (lv87,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape482 = R.call_tir(cls.reshape17, (reshape481,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv472 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_out_proj_weight2, reshape482, model_decoder_layers_9_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add351 = R.call_tir(cls.add5, (add347, lv472), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm93 = R.call_tir(cls.layer_norm2, (add351, model_decoder_layers_9_encoder_attn_layer_norm_weight2, model_decoder_layers_9_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv473 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_q_proj_weight2, layer_norm93, model_decoder_layers_9_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape483 = R.call_tir(cls.reshape14, (lv473,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape484 = R.call_tir(cls.reshape18, (reshape483,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv88 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape484), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape485 = R.call_tir(cls.reshape16, (lv88,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape486 = R.call_tir(cls.reshape17, (reshape485,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv474 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_out_proj_weight2, reshape486, model_decoder_layers_9_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add354 = R.call_tir(cls.add5, (add351, lv474), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm94 = R.call_tir(cls.layer_norm2, (add354, model_decoder_layers_9_final_layer_norm_weight2, model_decoder_layers_9_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv73_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_9_fc1_weight2, layer_norm94, model_decoder_layers_9_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv475 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_9_fc2_weight2, lv73_1, model_decoder_layers_9_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add357 = R.call_tir(cls.add5, (add354, lv475), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm95 = R.call_tir(cls.layer_norm2, (add357, model_decoder_layers_10_self_attn_layer_norm_weight2, model_decoder_layers_10_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv476 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_q_proj_weight2, layer_norm95, model_decoder_layers_10_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape487 = R.call_tir(cls.reshape14, (lv476,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv108 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_10_self_attn_k_proj_weight2, layer_norm95), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape488 = R.call_tir(cls.reshape14, (lv108,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv477 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_v_proj_weight2, layer_norm95, model_decoder_layers_10_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape489 = R.call_tir(cls.reshape14, (lv477,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat10 = R.call_tir(cls.concatenate1, (reshape487, reshape488, reshape489), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape490 = R.call_tir(cls.reshape15, (concat10,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv89 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape490), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape491 = R.call_tir(cls.reshape16, (lv89,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape492 = R.call_tir(cls.reshape17, (reshape491,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv478 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_out_proj_weight2, reshape492, model_decoder_layers_10_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add361 = R.call_tir(cls.add5, (add357, lv478), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm96 = R.call_tir(cls.layer_norm2, (add361, model_decoder_layers_10_encoder_attn_layer_norm_weight2, model_decoder_layers_10_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv479 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_q_proj_weight2, layer_norm96, model_decoder_layers_10_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape493 = R.call_tir(cls.reshape14, (lv479,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape494 = R.call_tir(cls.reshape18, (reshape493,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv90 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape494), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape495 = R.call_tir(cls.reshape16, (lv90,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape496 = R.call_tir(cls.reshape17, (reshape495,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv480 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_out_proj_weight2, reshape496, model_decoder_layers_10_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add364 = R.call_tir(cls.add5, (add361, lv480), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm97 = R.call_tir(cls.layer_norm2, (add364, model_decoder_layers_10_final_layer_norm_weight2, model_decoder_layers_10_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv74_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_10_fc1_weight2, layer_norm97, model_decoder_layers_10_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv481 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_10_fc2_weight2, lv74_1, model_decoder_layers_10_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add367 = R.call_tir(cls.add5, (add364, lv481), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm98 = R.call_tir(cls.layer_norm2, (add367, model_decoder_layers_11_self_attn_layer_norm_weight2, model_decoder_layers_11_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv482 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_q_proj_weight2, layer_norm98, model_decoder_layers_11_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape497 = R.call_tir(cls.reshape14, (lv482,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv109 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_11_self_attn_k_proj_weight2, layer_norm98), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape498 = R.call_tir(cls.reshape14, (lv109,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv483 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_v_proj_weight2, layer_norm98, model_decoder_layers_11_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape499 = R.call_tir(cls.reshape14, (lv483,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat11 = R.call_tir(cls.concatenate1, (reshape497, reshape498, reshape499), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape500 = R.call_tir(cls.reshape15, (concat11,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv91 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape500), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape501 = R.call_tir(cls.reshape16, (lv91,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape502 = R.call_tir(cls.reshape17, (reshape501,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv484 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_out_proj_weight2, reshape502, model_decoder_layers_11_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add371 = R.call_tir(cls.add5, (add367, lv484), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm99 = R.call_tir(cls.layer_norm2, (add371, model_decoder_layers_11_encoder_attn_layer_norm_weight2, model_decoder_layers_11_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv485 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_q_proj_weight2, layer_norm99, model_decoder_layers_11_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape503 = R.call_tir(cls.reshape14, (lv485,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape504 = R.call_tir(cls.reshape18, (reshape503,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv92 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape504), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape505 = R.call_tir(cls.reshape16, (lv92,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape506 = R.call_tir(cls.reshape17, (reshape505,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv486 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_out_proj_weight2, reshape506, model_decoder_layers_11_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add374 = R.call_tir(cls.add5, (add371, lv486), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm100 = R.call_tir(cls.layer_norm2, (add374, model_decoder_layers_11_final_layer_norm_weight2, model_decoder_layers_11_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv75_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_11_fc1_weight2, layer_norm100, model_decoder_layers_11_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv487 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_11_fc2_weight2, lv75_1, model_decoder_layers_11_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add377 = R.call_tir(cls.add5, (add374, lv487), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm101 = R.call_tir(cls.layer_norm2, (add377, model_decoder_layers_12_self_attn_layer_norm_weight2, model_decoder_layers_12_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv488 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_q_proj_weight2, layer_norm101, model_decoder_layers_12_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape507 = R.call_tir(cls.reshape14, (lv488,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv110 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_12_self_attn_k_proj_weight2, layer_norm101), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape508 = R.call_tir(cls.reshape14, (lv110,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv489 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_v_proj_weight2, layer_norm101, model_decoder_layers_12_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape509 = R.call_tir(cls.reshape14, (lv489,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat12 = R.call_tir(cls.concatenate1, (reshape507, reshape508, reshape509), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape510 = R.call_tir(cls.reshape15, (concat12,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv93 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape510), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape511 = R.call_tir(cls.reshape16, (lv93,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape512 = R.call_tir(cls.reshape17, (reshape511,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv490 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_out_proj_weight2, reshape512, model_decoder_layers_12_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add381 = R.call_tir(cls.add5, (add377, lv490), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm102 = R.call_tir(cls.layer_norm2, (add381, model_decoder_layers_12_encoder_attn_layer_norm_weight2, model_decoder_layers_12_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv491 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_q_proj_weight2, layer_norm102, model_decoder_layers_12_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape513 = R.call_tir(cls.reshape14, (lv491,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape514 = R.call_tir(cls.reshape18, (reshape513,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv94 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape514), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape515 = R.call_tir(cls.reshape16, (lv94,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape516 = R.call_tir(cls.reshape17, (reshape515,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv492 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_out_proj_weight2, reshape516, model_decoder_layers_12_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add384 = R.call_tir(cls.add5, (add381, lv492), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm103 = R.call_tir(cls.layer_norm2, (add384, model_decoder_layers_12_final_layer_norm_weight2, model_decoder_layers_12_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv76_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_12_fc1_weight2, layer_norm103, model_decoder_layers_12_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv493 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_12_fc2_weight2, lv76_1, model_decoder_layers_12_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add387 = R.call_tir(cls.add5, (add384, lv493), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm104 = R.call_tir(cls.layer_norm2, (add387, model_decoder_layers_13_self_attn_layer_norm_weight2, model_decoder_layers_13_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv494 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_q_proj_weight2, layer_norm104, model_decoder_layers_13_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape517 = R.call_tir(cls.reshape14, (lv494,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv111 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_13_self_attn_k_proj_weight2, layer_norm104), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape518 = R.call_tir(cls.reshape14, (lv111,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv495 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_v_proj_weight2, layer_norm104, model_decoder_layers_13_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape519 = R.call_tir(cls.reshape14, (lv495,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat13 = R.call_tir(cls.concatenate1, (reshape517, reshape518, reshape519), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape520 = R.call_tir(cls.reshape15, (concat13,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv95 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape520), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape521 = R.call_tir(cls.reshape16, (lv95,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape522 = R.call_tir(cls.reshape17, (reshape521,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv496 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_out_proj_weight2, reshape522, model_decoder_layers_13_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add391 = R.call_tir(cls.add5, (add387, lv496), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm105 = R.call_tir(cls.layer_norm2, (add391, model_decoder_layers_13_encoder_attn_layer_norm_weight2, model_decoder_layers_13_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv497 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_q_proj_weight2, layer_norm105, model_decoder_layers_13_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape523 = R.call_tir(cls.reshape14, (lv497,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape524 = R.call_tir(cls.reshape18, (reshape523,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv96 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape524), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape525 = R.call_tir(cls.reshape16, (lv96,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape526 = R.call_tir(cls.reshape17, (reshape525,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv498 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_out_proj_weight2, reshape526, model_decoder_layers_13_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add394 = R.call_tir(cls.add5, (add391, lv498), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm106 = R.call_tir(cls.layer_norm2, (add394, model_decoder_layers_13_final_layer_norm_weight2, model_decoder_layers_13_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv77_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_13_fc1_weight2, layer_norm106, model_decoder_layers_13_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv499 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_13_fc2_weight2, lv77_1, model_decoder_layers_13_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add397 = R.call_tir(cls.add5, (add394, lv499), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm107 = R.call_tir(cls.layer_norm2, (add397, model_decoder_layers_14_self_attn_layer_norm_weight2, model_decoder_layers_14_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv500 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_q_proj_weight2, layer_norm107, model_decoder_layers_14_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape527 = R.call_tir(cls.reshape14, (lv500,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv112 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_14_self_attn_k_proj_weight2, layer_norm107), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape528 = R.call_tir(cls.reshape14, (lv112,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv501 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_v_proj_weight2, layer_norm107, model_decoder_layers_14_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape529 = R.call_tir(cls.reshape14, (lv501,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat14 = R.call_tir(cls.concatenate1, (reshape527, reshape528, reshape529), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape530 = R.call_tir(cls.reshape15, (concat14,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv97 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape530), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape531 = R.call_tir(cls.reshape16, (lv97,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape532 = R.call_tir(cls.reshape17, (reshape531,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv502 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_out_proj_weight2, reshape532, model_decoder_layers_14_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add401 = R.call_tir(cls.add5, (add397, lv502), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm108 = R.call_tir(cls.layer_norm2, (add401, model_decoder_layers_14_encoder_attn_layer_norm_weight2, model_decoder_layers_14_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv503 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_q_proj_weight2, layer_norm108, model_decoder_layers_14_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape533 = R.call_tir(cls.reshape14, (lv503,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape534 = R.call_tir(cls.reshape18, (reshape533,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv98_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape534), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape535 = R.call_tir(cls.reshape16, (lv98_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape536 = R.call_tir(cls.reshape17, (reshape535,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv504 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_out_proj_weight2, reshape536, model_decoder_layers_14_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add404 = R.call_tir(cls.add5, (add401, lv504), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm109 = R.call_tir(cls.layer_norm2, (add404, model_decoder_layers_14_final_layer_norm_weight2, model_decoder_layers_14_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv78_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_14_fc1_weight2, layer_norm109, model_decoder_layers_14_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv505 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_14_fc2_weight2, lv78_1, model_decoder_layers_14_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add407 = R.call_tir(cls.add5, (add404, lv505), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm110 = R.call_tir(cls.layer_norm2, (add407, model_decoder_layers_15_self_attn_layer_norm_weight2, model_decoder_layers_15_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv506 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_q_proj_weight2, layer_norm110, model_decoder_layers_15_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape537 = R.call_tir(cls.reshape14, (lv506,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv113 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_15_self_attn_k_proj_weight2, layer_norm110), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape538 = R.call_tir(cls.reshape14, (lv113,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv507 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_v_proj_weight2, layer_norm110, model_decoder_layers_15_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape539 = R.call_tir(cls.reshape14, (lv507,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat15 = R.call_tir(cls.concatenate1, (reshape537, reshape538, reshape539), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape540 = R.call_tir(cls.reshape15, (concat15,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv99_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape540), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape541 = R.call_tir(cls.reshape16, (lv99_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape542 = R.call_tir(cls.reshape17, (reshape541,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv508 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_out_proj_weight2, reshape542, model_decoder_layers_15_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add411 = R.call_tir(cls.add5, (add407, lv508), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm111 = R.call_tir(cls.layer_norm2, (add411, model_decoder_layers_15_encoder_attn_layer_norm_weight2, model_decoder_layers_15_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv509 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_q_proj_weight2, layer_norm111, model_decoder_layers_15_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape543 = R.call_tir(cls.reshape14, (lv509,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape544 = R.call_tir(cls.reshape18, (reshape543,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv100_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape544), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape545 = R.call_tir(cls.reshape16, (lv100_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape546 = R.call_tir(cls.reshape17, (reshape545,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv510 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_out_proj_weight2, reshape546, model_decoder_layers_15_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add414 = R.call_tir(cls.add5, (add411, lv510), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm112 = R.call_tir(cls.layer_norm2, (add414, model_decoder_layers_15_final_layer_norm_weight2, model_decoder_layers_15_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv79_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_15_fc1_weight2, layer_norm112, model_decoder_layers_15_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv511 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_15_fc2_weight2, lv79_1, model_decoder_layers_15_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add417 = R.call_tir(cls.add5, (add414, lv511), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm113 = R.call_tir(cls.layer_norm2, (add417, model_decoder_layers_16_self_attn_layer_norm_weight2, model_decoder_layers_16_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv512 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_q_proj_weight2, layer_norm113, model_decoder_layers_16_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape547 = R.call_tir(cls.reshape14, (lv512,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv114 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_16_self_attn_k_proj_weight2, layer_norm113), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape548 = R.call_tir(cls.reshape14, (lv114,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv513 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_v_proj_weight2, layer_norm113, model_decoder_layers_16_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape549 = R.call_tir(cls.reshape14, (lv513,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat16 = R.call_tir(cls.concatenate1, (reshape547, reshape548, reshape549), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape550 = R.call_tir(cls.reshape15, (concat16,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv101_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape550), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape551 = R.call_tir(cls.reshape16, (lv101_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape552 = R.call_tir(cls.reshape17, (reshape551,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv514 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_out_proj_weight2, reshape552, model_decoder_layers_16_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add421 = R.call_tir(cls.add5, (add417, lv514), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm114 = R.call_tir(cls.layer_norm2, (add421, model_decoder_layers_16_encoder_attn_layer_norm_weight2, model_decoder_layers_16_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv515 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_q_proj_weight2, layer_norm114, model_decoder_layers_16_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape553 = R.call_tir(cls.reshape14, (lv515,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape554 = R.call_tir(cls.reshape18, (reshape553,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv102_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape554), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape555 = R.call_tir(cls.reshape16, (lv102_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape556 = R.call_tir(cls.reshape17, (reshape555,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv516 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_out_proj_weight2, reshape556, model_decoder_layers_16_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add424 = R.call_tir(cls.add5, (add421, lv516), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm115 = R.call_tir(cls.layer_norm2, (add424, model_decoder_layers_16_final_layer_norm_weight2, model_decoder_layers_16_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv80_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_16_fc1_weight2, layer_norm115, model_decoder_layers_16_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv517 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_16_fc2_weight2, lv80_1, model_decoder_layers_16_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add427 = R.call_tir(cls.add5, (add424, lv517), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm116 = R.call_tir(cls.layer_norm2, (add427, model_decoder_layers_17_self_attn_layer_norm_weight2, model_decoder_layers_17_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv518 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_q_proj_weight2, layer_norm116, model_decoder_layers_17_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape557 = R.call_tir(cls.reshape14, (lv518,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv115 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_17_self_attn_k_proj_weight2, layer_norm116), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape558 = R.call_tir(cls.reshape14, (lv115,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv519 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_v_proj_weight2, layer_norm116, model_decoder_layers_17_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape559 = R.call_tir(cls.reshape14, (lv519,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat17 = R.call_tir(cls.concatenate1, (reshape557, reshape558, reshape559), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape560 = R.call_tir(cls.reshape15, (concat17,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv103_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape560), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape561 = R.call_tir(cls.reshape16, (lv103_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape562 = R.call_tir(cls.reshape17, (reshape561,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv520 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_out_proj_weight2, reshape562, model_decoder_layers_17_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add431 = R.call_tir(cls.add5, (add427, lv520), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm117 = R.call_tir(cls.layer_norm2, (add431, model_decoder_layers_17_encoder_attn_layer_norm_weight2, model_decoder_layers_17_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv521 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_q_proj_weight2, layer_norm117, model_decoder_layers_17_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape563 = R.call_tir(cls.reshape14, (lv521,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape564 = R.call_tir(cls.reshape18, (reshape563,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv104_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape564), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape565 = R.call_tir(cls.reshape16, (lv104_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape566 = R.call_tir(cls.reshape17, (reshape565,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv522 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_out_proj_weight2, reshape566, model_decoder_layers_17_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add434 = R.call_tir(cls.add5, (add431, lv522), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm118 = R.call_tir(cls.layer_norm2, (add434, model_decoder_layers_17_final_layer_norm_weight2, model_decoder_layers_17_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv81_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_17_fc1_weight2, layer_norm118, model_decoder_layers_17_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv523 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_17_fc2_weight2, lv81_1, model_decoder_layers_17_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add437 = R.call_tir(cls.add5, (add434, lv523), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm119 = R.call_tir(cls.layer_norm2, (add437, model_decoder_layers_18_self_attn_layer_norm_weight2, model_decoder_layers_18_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv524 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_q_proj_weight2, layer_norm119, model_decoder_layers_18_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape567 = R.call_tir(cls.reshape14, (lv524,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv116 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_18_self_attn_k_proj_weight2, layer_norm119), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape568 = R.call_tir(cls.reshape14, (lv116,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv525 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_v_proj_weight2, layer_norm119, model_decoder_layers_18_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape569 = R.call_tir(cls.reshape14, (lv525,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat18 = R.call_tir(cls.concatenate1, (reshape567, reshape568, reshape569), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape570 = R.call_tir(cls.reshape15, (concat18,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv105_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape570), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape571 = R.call_tir(cls.reshape16, (lv105_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape572 = R.call_tir(cls.reshape17, (reshape571,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv526 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_out_proj_weight2, reshape572, model_decoder_layers_18_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add441 = R.call_tir(cls.add5, (add437, lv526), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm120 = R.call_tir(cls.layer_norm2, (add441, model_decoder_layers_18_encoder_attn_layer_norm_weight2, model_decoder_layers_18_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv527 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_q_proj_weight2, layer_norm120, model_decoder_layers_18_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape573 = R.call_tir(cls.reshape14, (lv527,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape574 = R.call_tir(cls.reshape18, (reshape573,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv106_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape574), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape575 = R.call_tir(cls.reshape16, (lv106_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape576 = R.call_tir(cls.reshape17, (reshape575,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv528 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_out_proj_weight2, reshape576, model_decoder_layers_18_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add444 = R.call_tir(cls.add5, (add441, lv528), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm121 = R.call_tir(cls.layer_norm2, (add444, model_decoder_layers_18_final_layer_norm_weight2, model_decoder_layers_18_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv82_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_18_fc1_weight2, layer_norm121, model_decoder_layers_18_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv529 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_18_fc2_weight2, lv82_1, model_decoder_layers_18_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add447 = R.call_tir(cls.add5, (add444, lv529), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm122 = R.call_tir(cls.layer_norm2, (add447, model_decoder_layers_19_self_attn_layer_norm_weight2, model_decoder_layers_19_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv530 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_q_proj_weight2, layer_norm122, model_decoder_layers_19_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape577 = R.call_tir(cls.reshape14, (lv530,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv117 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_19_self_attn_k_proj_weight2, layer_norm122), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape578 = R.call_tir(cls.reshape14, (lv117,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv531 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_v_proj_weight2, layer_norm122, model_decoder_layers_19_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape579 = R.call_tir(cls.reshape14, (lv531,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat19 = R.call_tir(cls.concatenate1, (reshape577, reshape578, reshape579), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape580 = R.call_tir(cls.reshape15, (concat19,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv107_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape580), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape581 = R.call_tir(cls.reshape16, (lv107_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape582 = R.call_tir(cls.reshape17, (reshape581,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv532 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_out_proj_weight2, reshape582, model_decoder_layers_19_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add451 = R.call_tir(cls.add5, (add447, lv532), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm123 = R.call_tir(cls.layer_norm2, (add451, model_decoder_layers_19_encoder_attn_layer_norm_weight2, model_decoder_layers_19_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv533 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_q_proj_weight2, layer_norm123, model_decoder_layers_19_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape583 = R.call_tir(cls.reshape14, (lv533,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape584 = R.call_tir(cls.reshape18, (reshape583,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv108_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape584), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape585 = R.call_tir(cls.reshape16, (lv108_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape586 = R.call_tir(cls.reshape17, (reshape585,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv534 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_out_proj_weight2, reshape586, model_decoder_layers_19_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add454 = R.call_tir(cls.add5, (add451, lv534), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm124 = R.call_tir(cls.layer_norm2, (add454, model_decoder_layers_19_final_layer_norm_weight2, model_decoder_layers_19_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv83_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_19_fc1_weight2, layer_norm124, model_decoder_layers_19_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv535 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_19_fc2_weight2, lv83_1, model_decoder_layers_19_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add457 = R.call_tir(cls.add5, (add454, lv535), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm125 = R.call_tir(cls.layer_norm2, (add457, model_decoder_layers_20_self_attn_layer_norm_weight2, model_decoder_layers_20_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv536 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_q_proj_weight2, layer_norm125, model_decoder_layers_20_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape587 = R.call_tir(cls.reshape14, (lv536,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv118 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_20_self_attn_k_proj_weight2, layer_norm125), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape588 = R.call_tir(cls.reshape14, (lv118,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv537 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_v_proj_weight2, layer_norm125, model_decoder_layers_20_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape589 = R.call_tir(cls.reshape14, (lv537,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat20 = R.call_tir(cls.concatenate1, (reshape587, reshape588, reshape589), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape590 = R.call_tir(cls.reshape15, (concat20,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv109_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape590), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape591 = R.call_tir(cls.reshape16, (lv109_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape592 = R.call_tir(cls.reshape17, (reshape591,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv538 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_out_proj_weight2, reshape592, model_decoder_layers_20_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add461 = R.call_tir(cls.add5, (add457, lv538), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm126 = R.call_tir(cls.layer_norm2, (add461, model_decoder_layers_20_encoder_attn_layer_norm_weight2, model_decoder_layers_20_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv539 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_q_proj_weight2, layer_norm126, model_decoder_layers_20_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape593 = R.call_tir(cls.reshape14, (lv539,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape594 = R.call_tir(cls.reshape18, (reshape593,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv110_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape594), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape595 = R.call_tir(cls.reshape16, (lv110_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape596 = R.call_tir(cls.reshape17, (reshape595,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv540 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_out_proj_weight2, reshape596, model_decoder_layers_20_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add464 = R.call_tir(cls.add5, (add461, lv540), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm127 = R.call_tir(cls.layer_norm2, (add464, model_decoder_layers_20_final_layer_norm_weight2, model_decoder_layers_20_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv84_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_20_fc1_weight2, layer_norm127, model_decoder_layers_20_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv541 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_20_fc2_weight2, lv84_1, model_decoder_layers_20_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add467 = R.call_tir(cls.add5, (add464, lv541), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm128 = R.call_tir(cls.layer_norm2, (add467, model_decoder_layers_21_self_attn_layer_norm_weight2, model_decoder_layers_21_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv542 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_q_proj_weight2, layer_norm128, model_decoder_layers_21_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape597 = R.call_tir(cls.reshape14, (lv542,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv119 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_21_self_attn_k_proj_weight2, layer_norm128), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape598 = R.call_tir(cls.reshape14, (lv119,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv543 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_v_proj_weight2, layer_norm128, model_decoder_layers_21_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape599 = R.call_tir(cls.reshape14, (lv543,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat21 = R.call_tir(cls.concatenate1, (reshape597, reshape598, reshape599), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape600 = R.call_tir(cls.reshape15, (concat21,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv111_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape600), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape601 = R.call_tir(cls.reshape16, (lv111_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape602 = R.call_tir(cls.reshape17, (reshape601,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv544 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_out_proj_weight2, reshape602, model_decoder_layers_21_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add471 = R.call_tir(cls.add5, (add467, lv544), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm129 = R.call_tir(cls.layer_norm2, (add471, model_decoder_layers_21_encoder_attn_layer_norm_weight2, model_decoder_layers_21_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv545 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_q_proj_weight2, layer_norm129, model_decoder_layers_21_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape603 = R.call_tir(cls.reshape14, (lv545,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape604 = R.call_tir(cls.reshape18, (reshape603,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv112_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape604), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape605 = R.call_tir(cls.reshape16, (lv112_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape606 = R.call_tir(cls.reshape17, (reshape605,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv546 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_out_proj_weight2, reshape606, model_decoder_layers_21_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add474 = R.call_tir(cls.add5, (add471, lv546), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm130 = R.call_tir(cls.layer_norm2, (add474, model_decoder_layers_21_final_layer_norm_weight2, model_decoder_layers_21_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv85_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_21_fc1_weight2, layer_norm130, model_decoder_layers_21_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv547 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_21_fc2_weight2, lv85_1, model_decoder_layers_21_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add477 = R.call_tir(cls.add5, (add474, lv547), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm131 = R.call_tir(cls.layer_norm2, (add477, model_decoder_layers_22_self_attn_layer_norm_weight2, model_decoder_layers_22_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv548 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_q_proj_weight2, layer_norm131, model_decoder_layers_22_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape607 = R.call_tir(cls.reshape14, (lv548,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv120 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_22_self_attn_k_proj_weight2, layer_norm131), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape608 = R.call_tir(cls.reshape14, (lv120,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv549 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_v_proj_weight2, layer_norm131, model_decoder_layers_22_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape609 = R.call_tir(cls.reshape14, (lv549,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat22 = R.call_tir(cls.concatenate1, (reshape607, reshape608, reshape609), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape610 = R.call_tir(cls.reshape15, (concat22,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv113_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape610), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape611 = R.call_tir(cls.reshape16, (lv113_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape612 = R.call_tir(cls.reshape17, (reshape611,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv550 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_out_proj_weight2, reshape612, model_decoder_layers_22_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add481 = R.call_tir(cls.add5, (add477, lv550), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm132 = R.call_tir(cls.layer_norm2, (add481, model_decoder_layers_22_encoder_attn_layer_norm_weight2, model_decoder_layers_22_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv551 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_q_proj_weight2, layer_norm132, model_decoder_layers_22_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape613 = R.call_tir(cls.reshape14, (lv551,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape614 = R.call_tir(cls.reshape18, (reshape613,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv114_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape614), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape615 = R.call_tir(cls.reshape16, (lv114_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape616 = R.call_tir(cls.reshape17, (reshape615,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv552 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_out_proj_weight2, reshape616, model_decoder_layers_22_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add484 = R.call_tir(cls.add5, (add481, lv552), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm133 = R.call_tir(cls.layer_norm2, (add484, model_decoder_layers_22_final_layer_norm_weight2, model_decoder_layers_22_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv86_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_22_fc1_weight2, layer_norm133, model_decoder_layers_22_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv553 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_22_fc2_weight2, lv86_1, model_decoder_layers_22_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add487 = R.call_tir(cls.add5, (add484, lv553), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm134 = R.call_tir(cls.layer_norm2, (add487, model_decoder_layers_23_self_attn_layer_norm_weight2, model_decoder_layers_23_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv554 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_q_proj_weight2, layer_norm134, model_decoder_layers_23_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape617 = R.call_tir(cls.reshape14, (lv554,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv121 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_23_self_attn_k_proj_weight2, layer_norm134), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape618 = R.call_tir(cls.reshape14, (lv121,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv555 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_v_proj_weight2, layer_norm134, model_decoder_layers_23_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape619 = R.call_tir(cls.reshape14, (lv555,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat23 = R.call_tir(cls.concatenate1, (reshape617, reshape618, reshape619), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape620 = R.call_tir(cls.reshape15, (concat23,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv115_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape620), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape621 = R.call_tir(cls.reshape16, (lv115_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape622 = R.call_tir(cls.reshape17, (reshape621,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv556 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_out_proj_weight2, reshape622, model_decoder_layers_23_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add491 = R.call_tir(cls.add5, (add487, lv556), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm135 = R.call_tir(cls.layer_norm2, (add491, model_decoder_layers_23_encoder_attn_layer_norm_weight2, model_decoder_layers_23_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv557 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_q_proj_weight2, layer_norm135, model_decoder_layers_23_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape623 = R.call_tir(cls.reshape14, (lv557,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape624 = R.call_tir(cls.reshape18, (reshape623,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv116_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape624), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape625 = R.call_tir(cls.reshape16, (lv116_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape626 = R.call_tir(cls.reshape17, (reshape625,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv558 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_out_proj_weight2, reshape626, model_decoder_layers_23_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add494 = R.call_tir(cls.add5, (add491, lv558), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm136 = R.call_tir(cls.layer_norm2, (add494, model_decoder_layers_23_final_layer_norm_weight2, model_decoder_layers_23_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv87_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_23_fc1_weight2, layer_norm136, model_decoder_layers_23_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv559 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_23_fc2_weight2, lv87_1, model_decoder_layers_23_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add497 = R.call_tir(cls.add5, (add494, lv559), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm137 = R.call_tir(cls.layer_norm2, (add497, model_decoder_layers_24_self_attn_layer_norm_weight2, model_decoder_layers_24_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv560 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_q_proj_weight2, layer_norm137, model_decoder_layers_24_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape627 = R.call_tir(cls.reshape14, (lv560,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv122 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_24_self_attn_k_proj_weight2, layer_norm137), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape628 = R.call_tir(cls.reshape14, (lv122,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv561 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_v_proj_weight2, layer_norm137, model_decoder_layers_24_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape629 = R.call_tir(cls.reshape14, (lv561,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat24 = R.call_tir(cls.concatenate1, (reshape627, reshape628, reshape629), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape630 = R.call_tir(cls.reshape15, (concat24,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv117_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape630), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape631 = R.call_tir(cls.reshape16, (lv117_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape632 = R.call_tir(cls.reshape17, (reshape631,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv562 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_out_proj_weight2, reshape632, model_decoder_layers_24_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add501 = R.call_tir(cls.add5, (add497, lv562), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm138 = R.call_tir(cls.layer_norm2, (add501, model_decoder_layers_24_encoder_attn_layer_norm_weight2, model_decoder_layers_24_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv563 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_q_proj_weight2, layer_norm138, model_decoder_layers_24_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape633 = R.call_tir(cls.reshape14, (lv563,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape634 = R.call_tir(cls.reshape18, (reshape633,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv118_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape634), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape635 = R.call_tir(cls.reshape16, (lv118_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape636 = R.call_tir(cls.reshape17, (reshape635,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv564 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_out_proj_weight2, reshape636, model_decoder_layers_24_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add504 = R.call_tir(cls.add5, (add501, lv564), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm139 = R.call_tir(cls.layer_norm2, (add504, model_decoder_layers_24_final_layer_norm_weight2, model_decoder_layers_24_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv88_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_24_fc1_weight2, layer_norm139, model_decoder_layers_24_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv565 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_24_fc2_weight2, lv88_1, model_decoder_layers_24_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add507 = R.call_tir(cls.add5, (add504, lv565), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm140 = R.call_tir(cls.layer_norm2, (add507, model_decoder_layers_25_self_attn_layer_norm_weight2, model_decoder_layers_25_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv566 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_q_proj_weight2, layer_norm140, model_decoder_layers_25_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape637 = R.call_tir(cls.reshape14, (lv566,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv123 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_25_self_attn_k_proj_weight2, layer_norm140), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape638 = R.call_tir(cls.reshape14, (lv123,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv567 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_v_proj_weight2, layer_norm140, model_decoder_layers_25_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape639 = R.call_tir(cls.reshape14, (lv567,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat25 = R.call_tir(cls.concatenate1, (reshape637, reshape638, reshape639), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape640 = R.call_tir(cls.reshape15, (concat25,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv119_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape640), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape641 = R.call_tir(cls.reshape16, (lv119_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape642 = R.call_tir(cls.reshape17, (reshape641,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv568 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_out_proj_weight2, reshape642, model_decoder_layers_25_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add511 = R.call_tir(cls.add5, (add507, lv568), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm141 = R.call_tir(cls.layer_norm2, (add511, model_decoder_layers_25_encoder_attn_layer_norm_weight2, model_decoder_layers_25_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv569 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_q_proj_weight2, layer_norm141, model_decoder_layers_25_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape643 = R.call_tir(cls.reshape14, (lv569,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape644 = R.call_tir(cls.reshape18, (reshape643,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv120_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape644), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape645 = R.call_tir(cls.reshape16, (lv120_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape646 = R.call_tir(cls.reshape17, (reshape645,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv570 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_out_proj_weight2, reshape646, model_decoder_layers_25_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add514 = R.call_tir(cls.add5, (add511, lv570), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm142 = R.call_tir(cls.layer_norm2, (add514, model_decoder_layers_25_final_layer_norm_weight2, model_decoder_layers_25_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv89_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_25_fc1_weight2, layer_norm142, model_decoder_layers_25_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv571 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_25_fc2_weight2, lv89_1, model_decoder_layers_25_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add517 = R.call_tir(cls.add5, (add514, lv571), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm143 = R.call_tir(cls.layer_norm2, (add517, model_decoder_layers_26_self_attn_layer_norm_weight2, model_decoder_layers_26_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv572 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_q_proj_weight2, layer_norm143, model_decoder_layers_26_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape647 = R.call_tir(cls.reshape14, (lv572,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv124 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_26_self_attn_k_proj_weight2, layer_norm143), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape648 = R.call_tir(cls.reshape14, (lv124,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv573 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_v_proj_weight2, layer_norm143, model_decoder_layers_26_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape649 = R.call_tir(cls.reshape14, (lv573,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat26 = R.call_tir(cls.concatenate1, (reshape647, reshape648, reshape649), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape650 = R.call_tir(cls.reshape15, (concat26,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv121_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape650), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape651 = R.call_tir(cls.reshape16, (lv121_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape652 = R.call_tir(cls.reshape17, (reshape651,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv574 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_out_proj_weight2, reshape652, model_decoder_layers_26_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add521 = R.call_tir(cls.add5, (add517, lv574), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm144 = R.call_tir(cls.layer_norm2, (add521, model_decoder_layers_26_encoder_attn_layer_norm_weight2, model_decoder_layers_26_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv575 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_q_proj_weight2, layer_norm144, model_decoder_layers_26_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape653 = R.call_tir(cls.reshape14, (lv575,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape654 = R.call_tir(cls.reshape18, (reshape653,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv122_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape654), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape655 = R.call_tir(cls.reshape16, (lv122_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape656 = R.call_tir(cls.reshape17, (reshape655,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv576 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_out_proj_weight2, reshape656, model_decoder_layers_26_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add524 = R.call_tir(cls.add5, (add521, lv576), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm145 = R.call_tir(cls.layer_norm2, (add524, model_decoder_layers_26_final_layer_norm_weight2, model_decoder_layers_26_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv90_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_26_fc1_weight2, layer_norm145, model_decoder_layers_26_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv577 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_26_fc2_weight2, lv90_1, model_decoder_layers_26_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add527 = R.call_tir(cls.add5, (add524, lv577), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm146 = R.call_tir(cls.layer_norm2, (add527, model_decoder_layers_27_self_attn_layer_norm_weight2, model_decoder_layers_27_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv578 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_q_proj_weight2, layer_norm146, model_decoder_layers_27_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape657 = R.call_tir(cls.reshape14, (lv578,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv125 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_27_self_attn_k_proj_weight2, layer_norm146), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape658 = R.call_tir(cls.reshape14, (lv125,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv579 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_v_proj_weight2, layer_norm146, model_decoder_layers_27_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape659 = R.call_tir(cls.reshape14, (lv579,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat27 = R.call_tir(cls.concatenate1, (reshape657, reshape658, reshape659), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape660 = R.call_tir(cls.reshape15, (concat27,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv123_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape660), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape661 = R.call_tir(cls.reshape16, (lv123_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape662 = R.call_tir(cls.reshape17, (reshape661,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv580 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_out_proj_weight2, reshape662, model_decoder_layers_27_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add531 = R.call_tir(cls.add5, (add527, lv580), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm147 = R.call_tir(cls.layer_norm2, (add531, model_decoder_layers_27_encoder_attn_layer_norm_weight2, model_decoder_layers_27_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv581 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_q_proj_weight2, layer_norm147, model_decoder_layers_27_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape663 = R.call_tir(cls.reshape14, (lv581,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape664 = R.call_tir(cls.reshape18, (reshape663,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv124_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape664), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape665 = R.call_tir(cls.reshape16, (lv124_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape666 = R.call_tir(cls.reshape17, (reshape665,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv582 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_out_proj_weight2, reshape666, model_decoder_layers_27_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add534 = R.call_tir(cls.add5, (add531, lv582), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm148 = R.call_tir(cls.layer_norm2, (add534, model_decoder_layers_27_final_layer_norm_weight2, model_decoder_layers_27_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv91_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_27_fc1_weight2, layer_norm148, model_decoder_layers_27_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv583 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_27_fc2_weight2, lv91_1, model_decoder_layers_27_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add537 = R.call_tir(cls.add5, (add534, lv583), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm149 = R.call_tir(cls.layer_norm2, (add537, model_decoder_layers_28_self_attn_layer_norm_weight2, model_decoder_layers_28_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv584 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_q_proj_weight2, layer_norm149, model_decoder_layers_28_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape667 = R.call_tir(cls.reshape14, (lv584,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv126 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_28_self_attn_k_proj_weight2, layer_norm149), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape668 = R.call_tir(cls.reshape14, (lv126,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv585 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_v_proj_weight2, layer_norm149, model_decoder_layers_28_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape669 = R.call_tir(cls.reshape14, (lv585,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat28 = R.call_tir(cls.concatenate1, (reshape667, reshape668, reshape669), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape670 = R.call_tir(cls.reshape15, (concat28,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv125_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape670), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape671 = R.call_tir(cls.reshape16, (lv125_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape672 = R.call_tir(cls.reshape17, (reshape671,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv586 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_out_proj_weight2, reshape672, model_decoder_layers_28_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add541 = R.call_tir(cls.add5, (add537, lv586), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm150 = R.call_tir(cls.layer_norm2, (add541, model_decoder_layers_28_encoder_attn_layer_norm_weight2, model_decoder_layers_28_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv587 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_q_proj_weight2, layer_norm150, model_decoder_layers_28_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape673 = R.call_tir(cls.reshape14, (lv587,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape674 = R.call_tir(cls.reshape18, (reshape673,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv126_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape674), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape675 = R.call_tir(cls.reshape16, (lv126_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape676 = R.call_tir(cls.reshape17, (reshape675,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv588 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_out_proj_weight2, reshape676, model_decoder_layers_28_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add544 = R.call_tir(cls.add5, (add541, lv588), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm151 = R.call_tir(cls.layer_norm2, (add544, model_decoder_layers_28_final_layer_norm_weight2, model_decoder_layers_28_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv92_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_28_fc1_weight2, layer_norm151, model_decoder_layers_28_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv589 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_28_fc2_weight2, lv92_1, model_decoder_layers_28_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add547 = R.call_tir(cls.add5, (add544, lv589), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm152 = R.call_tir(cls.layer_norm2, (add547, model_decoder_layers_29_self_attn_layer_norm_weight2, model_decoder_layers_29_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv590 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_q_proj_weight2, layer_norm152, model_decoder_layers_29_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape677 = R.call_tir(cls.reshape14, (lv590,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv127 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_29_self_attn_k_proj_weight2, layer_norm152), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape678 = R.call_tir(cls.reshape14, (lv127,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv591 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_v_proj_weight2, layer_norm152, model_decoder_layers_29_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape679 = R.call_tir(cls.reshape14, (lv591,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat29 = R.call_tir(cls.concatenate1, (reshape677, reshape678, reshape679), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape680 = R.call_tir(cls.reshape15, (concat29,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv127_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape680), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape681 = R.call_tir(cls.reshape16, (lv127_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape682 = R.call_tir(cls.reshape17, (reshape681,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv592 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_out_proj_weight2, reshape682, model_decoder_layers_29_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add551 = R.call_tir(cls.add5, (add547, lv592), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm153 = R.call_tir(cls.layer_norm2, (add551, model_decoder_layers_29_encoder_attn_layer_norm_weight2, model_decoder_layers_29_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv593 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_q_proj_weight2, layer_norm153, model_decoder_layers_29_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape683 = R.call_tir(cls.reshape14, (lv593,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape684 = R.call_tir(cls.reshape18, (reshape683,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv128 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape684), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape685 = R.call_tir(cls.reshape16, (lv128,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape686 = R.call_tir(cls.reshape17, (reshape685,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv594 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_out_proj_weight2, reshape686, model_decoder_layers_29_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add554 = R.call_tir(cls.add5, (add551, lv594), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm154 = R.call_tir(cls.layer_norm2, (add554, model_decoder_layers_29_final_layer_norm_weight2, model_decoder_layers_29_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv93_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_29_fc1_weight2, layer_norm154, model_decoder_layers_29_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv595 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_29_fc2_weight2, lv93_1, model_decoder_layers_29_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add557 = R.call_tir(cls.add5, (add554, lv595), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm155 = R.call_tir(cls.layer_norm2, (add557, model_decoder_layers_30_self_attn_layer_norm_weight2, model_decoder_layers_30_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv596 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_q_proj_weight2, layer_norm155, model_decoder_layers_30_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape687 = R.call_tir(cls.reshape14, (lv596,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv128_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_30_self_attn_k_proj_weight2, layer_norm155), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape688 = R.call_tir(cls.reshape14, (lv128_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv597 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_v_proj_weight2, layer_norm155, model_decoder_layers_30_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape689 = R.call_tir(cls.reshape14, (lv597,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat30 = R.call_tir(cls.concatenate1, (reshape687, reshape688, reshape689), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape690 = R.call_tir(cls.reshape15, (concat30,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv129 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape690), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape691 = R.call_tir(cls.reshape16, (lv129,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape692 = R.call_tir(cls.reshape17, (reshape691,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv598 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_out_proj_weight2, reshape692, model_decoder_layers_30_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add561 = R.call_tir(cls.add5, (add557, lv598), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm156 = R.call_tir(cls.layer_norm2, (add561, model_decoder_layers_30_encoder_attn_layer_norm_weight2, model_decoder_layers_30_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv599 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_q_proj_weight2, layer_norm156, model_decoder_layers_30_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape693 = R.call_tir(cls.reshape14, (lv599,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape694 = R.call_tir(cls.reshape18, (reshape693,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv130 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape694), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape695 = R.call_tir(cls.reshape16, (lv130,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape696 = R.call_tir(cls.reshape17, (reshape695,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv600 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_out_proj_weight2, reshape696, model_decoder_layers_30_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add564 = R.call_tir(cls.add5, (add561, lv600), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm157 = R.call_tir(cls.layer_norm2, (add564, model_decoder_layers_30_final_layer_norm_weight2, model_decoder_layers_30_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv94_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_30_fc1_weight2, layer_norm157, model_decoder_layers_30_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv601 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_30_fc2_weight2, lv94_1, model_decoder_layers_30_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add567 = R.call_tir(cls.add5, (add564, lv601), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm158 = R.call_tir(cls.layer_norm2, (add567, model_decoder_layers_31_self_attn_layer_norm_weight2, model_decoder_layers_31_self_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv602 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_q_proj_weight2, layer_norm158, model_decoder_layers_31_self_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape697 = R.call_tir(cls.reshape14, (lv602,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv129_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_31_self_attn_k_proj_weight2, layer_norm158), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape698 = R.call_tir(cls.reshape14, (lv129_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv603 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_v_proj_weight2, layer_norm158, model_decoder_layers_31_self_attn_v_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape699 = R.call_tir(cls.reshape14, (lv603,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat31 = R.call_tir(cls.concatenate1, (reshape697, reshape698, reshape699), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape700 = R.call_tir(cls.reshape15, (concat31,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv131 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape700), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape701 = R.call_tir(cls.reshape16, (lv131,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape702 = R.call_tir(cls.reshape17, (reshape701,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv604 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_out_proj_weight2, reshape702, model_decoder_layers_31_self_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add571 = R.call_tir(cls.add5, (add567, lv604), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm159 = R.call_tir(cls.layer_norm2, (add571, model_decoder_layers_31_encoder_attn_layer_norm_weight2, model_decoder_layers_31_encoder_attn_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv605 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_q_proj_weight2, layer_norm159, model_decoder_layers_31_encoder_attn_q_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape703 = R.call_tir(cls.reshape14, (lv605,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape704 = R.call_tir(cls.reshape18, (reshape703,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv132 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape704), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape705 = R.call_tir(cls.reshape16, (lv132,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape706 = R.call_tir(cls.reshape17, (reshape705,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv606 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_out_proj_weight2, reshape706, model_decoder_layers_31_encoder_attn_out_proj_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add574 = R.call_tir(cls.add5, (add571, lv606), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm160 = R.call_tir(cls.layer_norm2, (add574, model_decoder_layers_31_final_layer_norm_weight2, model_decoder_layers_31_final_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv95_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_31_fc1_weight2, layer_norm160, model_decoder_layers_31_fc1_bias2), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv607 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_31_fc2_weight2, lv95_1, model_decoder_layers_31_fc2_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add577 = R.call_tir(cls.add5, (add574, lv607), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm161 = R.call_tir(cls.layer_norm2, (add577, model_decoder_layer_norm_weight2, model_decoder_layer_norm_bias2), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) take2 = R.call_tir(cls.take2, (layer_norm161, logit_positions), out_sinfo=R.Tensor((1, batch_size, 1280), dtype="float16")) gv2 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul5_cublas", (model_decoder_embed_tokens_weight2, take2), out_sinfo=R.Tensor((1, batch_size, 51866), dtype="float32")) R.output(gv2) return gv2 @R.function def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object: max_batch_size = T.int64() max_total_seq_len = T.int64() prefill_chunk_size = T.int64() page_size = T.int64() support_sliding_window = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.prim_value(32), R.prim_value(20), R.prim_value(20), R.prim_value(64), R.prim_value(0), R.prim_value(1), R.prim_value(1), R.const(0, "float16"), cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, sinfo_args=(R.Object,)) return paged_kv_cache @R.function def decode(input_ids: R.Tensor((1, 1), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor((1, 1, 51866), dtype="float32"): R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module with R.dataflow(): model_decoder_embed_tokens_weight5: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] model_decoder_embed_positions_weight5: R.Tensor((448, 1280), dtype="float16") = packed_params[488] model_decoder_layers_0_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] model_decoder_layers_0_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] model_decoder_layers_0_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[491] model_decoder_layers_0_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] model_decoder_layers_0_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[493] model_decoder_layers_0_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] model_decoder_layers_0_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[495] model_decoder_layers_0_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[496] model_decoder_layers_0_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[497] model_decoder_layers_0_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] model_decoder_layers_0_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[502] model_decoder_layers_0_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] model_decoder_layers_0_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[504] model_decoder_layers_0_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[505] model_decoder_layers_0_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[506] model_decoder_layers_0_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] model_decoder_layers_0_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[508] model_decoder_layers_0_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] model_decoder_layers_0_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[510] model_decoder_layers_0_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[511] model_decoder_layers_0_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[512] model_decoder_layers_1_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] model_decoder_layers_1_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] model_decoder_layers_1_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[515] model_decoder_layers_1_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] model_decoder_layers_1_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[517] model_decoder_layers_1_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] model_decoder_layers_1_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[519] model_decoder_layers_1_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[520] model_decoder_layers_1_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[521] model_decoder_layers_1_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] model_decoder_layers_1_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[526] model_decoder_layers_1_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] model_decoder_layers_1_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[528] model_decoder_layers_1_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[529] model_decoder_layers_1_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[530] model_decoder_layers_1_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] model_decoder_layers_1_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[532] model_decoder_layers_1_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] model_decoder_layers_1_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[534] model_decoder_layers_1_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[535] model_decoder_layers_1_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[536] model_decoder_layers_2_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] model_decoder_layers_2_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] model_decoder_layers_2_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[539] model_decoder_layers_2_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] model_decoder_layers_2_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[541] model_decoder_layers_2_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] model_decoder_layers_2_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[543] model_decoder_layers_2_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[544] model_decoder_layers_2_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[545] model_decoder_layers_2_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] model_decoder_layers_2_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[550] model_decoder_layers_2_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] model_decoder_layers_2_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[552] model_decoder_layers_2_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[553] model_decoder_layers_2_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[554] model_decoder_layers_2_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] model_decoder_layers_2_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[556] model_decoder_layers_2_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] model_decoder_layers_2_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[558] model_decoder_layers_2_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[559] model_decoder_layers_2_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[560] model_decoder_layers_3_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] model_decoder_layers_3_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] model_decoder_layers_3_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[563] model_decoder_layers_3_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] model_decoder_layers_3_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[565] model_decoder_layers_3_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] model_decoder_layers_3_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[567] model_decoder_layers_3_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[568] model_decoder_layers_3_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[569] model_decoder_layers_3_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] model_decoder_layers_3_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[574] model_decoder_layers_3_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] model_decoder_layers_3_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[576] model_decoder_layers_3_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[577] model_decoder_layers_3_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[578] model_decoder_layers_3_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] model_decoder_layers_3_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[580] model_decoder_layers_3_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] model_decoder_layers_3_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[582] model_decoder_layers_3_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[583] model_decoder_layers_3_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[584] model_decoder_layers_4_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] model_decoder_layers_4_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] model_decoder_layers_4_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[587] model_decoder_layers_4_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] model_decoder_layers_4_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[589] model_decoder_layers_4_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] model_decoder_layers_4_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[591] model_decoder_layers_4_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[592] model_decoder_layers_4_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[593] model_decoder_layers_4_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] model_decoder_layers_4_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[598] model_decoder_layers_4_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] model_decoder_layers_4_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[600] model_decoder_layers_4_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[601] model_decoder_layers_4_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[602] model_decoder_layers_4_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] model_decoder_layers_4_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[604] model_decoder_layers_4_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] model_decoder_layers_4_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[606] model_decoder_layers_4_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[607] model_decoder_layers_4_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[608] model_decoder_layers_5_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] model_decoder_layers_5_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] model_decoder_layers_5_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[611] model_decoder_layers_5_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] model_decoder_layers_5_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[613] model_decoder_layers_5_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] model_decoder_layers_5_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[615] model_decoder_layers_5_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[616] model_decoder_layers_5_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[617] model_decoder_layers_5_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] model_decoder_layers_5_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[622] model_decoder_layers_5_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] model_decoder_layers_5_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[624] model_decoder_layers_5_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[625] model_decoder_layers_5_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[626] model_decoder_layers_5_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] model_decoder_layers_5_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[628] model_decoder_layers_5_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] model_decoder_layers_5_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[630] model_decoder_layers_5_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[631] model_decoder_layers_5_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[632] model_decoder_layers_6_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] model_decoder_layers_6_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] model_decoder_layers_6_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[635] model_decoder_layers_6_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] model_decoder_layers_6_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[637] model_decoder_layers_6_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] model_decoder_layers_6_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[639] model_decoder_layers_6_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[640] model_decoder_layers_6_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[641] model_decoder_layers_6_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] model_decoder_layers_6_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[646] model_decoder_layers_6_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] model_decoder_layers_6_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[648] model_decoder_layers_6_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[649] model_decoder_layers_6_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[650] model_decoder_layers_6_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] model_decoder_layers_6_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[652] model_decoder_layers_6_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] model_decoder_layers_6_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[654] model_decoder_layers_6_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[655] model_decoder_layers_6_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[656] model_decoder_layers_7_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] model_decoder_layers_7_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] model_decoder_layers_7_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[659] model_decoder_layers_7_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] model_decoder_layers_7_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[661] model_decoder_layers_7_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] model_decoder_layers_7_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[663] model_decoder_layers_7_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[664] model_decoder_layers_7_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[665] model_decoder_layers_7_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] model_decoder_layers_7_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[670] model_decoder_layers_7_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] model_decoder_layers_7_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[672] model_decoder_layers_7_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[673] model_decoder_layers_7_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[674] model_decoder_layers_7_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] model_decoder_layers_7_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[676] model_decoder_layers_7_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] model_decoder_layers_7_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[678] model_decoder_layers_7_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[679] model_decoder_layers_7_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[680] model_decoder_layers_8_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] model_decoder_layers_8_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] model_decoder_layers_8_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[683] model_decoder_layers_8_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] model_decoder_layers_8_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[685] model_decoder_layers_8_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] model_decoder_layers_8_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[687] model_decoder_layers_8_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[688] model_decoder_layers_8_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[689] model_decoder_layers_8_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] model_decoder_layers_8_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[694] model_decoder_layers_8_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] model_decoder_layers_8_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[696] model_decoder_layers_8_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[697] model_decoder_layers_8_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[698] model_decoder_layers_8_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] model_decoder_layers_8_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[700] model_decoder_layers_8_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] model_decoder_layers_8_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[702] model_decoder_layers_8_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[703] model_decoder_layers_8_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[704] model_decoder_layers_9_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] model_decoder_layers_9_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] model_decoder_layers_9_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[707] model_decoder_layers_9_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] model_decoder_layers_9_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[709] model_decoder_layers_9_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] model_decoder_layers_9_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[711] model_decoder_layers_9_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[712] model_decoder_layers_9_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[713] model_decoder_layers_9_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] model_decoder_layers_9_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[718] model_decoder_layers_9_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] model_decoder_layers_9_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[720] model_decoder_layers_9_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[721] model_decoder_layers_9_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[722] model_decoder_layers_9_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] model_decoder_layers_9_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[724] model_decoder_layers_9_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] model_decoder_layers_9_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[726] model_decoder_layers_9_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[727] model_decoder_layers_9_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[728] model_decoder_layers_10_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] model_decoder_layers_10_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] model_decoder_layers_10_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[731] model_decoder_layers_10_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] model_decoder_layers_10_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[733] model_decoder_layers_10_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] model_decoder_layers_10_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[735] model_decoder_layers_10_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[736] model_decoder_layers_10_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[737] model_decoder_layers_10_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] model_decoder_layers_10_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[742] model_decoder_layers_10_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] model_decoder_layers_10_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[744] model_decoder_layers_10_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[745] model_decoder_layers_10_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[746] model_decoder_layers_10_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] model_decoder_layers_10_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[748] model_decoder_layers_10_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] model_decoder_layers_10_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[750] model_decoder_layers_10_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[751] model_decoder_layers_10_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[752] model_decoder_layers_11_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] model_decoder_layers_11_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] model_decoder_layers_11_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[755] model_decoder_layers_11_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] model_decoder_layers_11_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[757] model_decoder_layers_11_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] model_decoder_layers_11_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[759] model_decoder_layers_11_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[760] model_decoder_layers_11_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[761] model_decoder_layers_11_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] model_decoder_layers_11_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[766] model_decoder_layers_11_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] model_decoder_layers_11_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[768] model_decoder_layers_11_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[769] model_decoder_layers_11_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[770] model_decoder_layers_11_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] model_decoder_layers_11_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[772] model_decoder_layers_11_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] model_decoder_layers_11_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[774] model_decoder_layers_11_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[775] model_decoder_layers_11_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[776] model_decoder_layers_12_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] model_decoder_layers_12_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] model_decoder_layers_12_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[779] model_decoder_layers_12_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] model_decoder_layers_12_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[781] model_decoder_layers_12_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] model_decoder_layers_12_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[783] model_decoder_layers_12_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[784] model_decoder_layers_12_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[785] model_decoder_layers_12_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] model_decoder_layers_12_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[790] model_decoder_layers_12_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] model_decoder_layers_12_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[792] model_decoder_layers_12_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[793] model_decoder_layers_12_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[794] model_decoder_layers_12_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] model_decoder_layers_12_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[796] model_decoder_layers_12_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] model_decoder_layers_12_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[798] model_decoder_layers_12_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[799] model_decoder_layers_12_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[800] model_decoder_layers_13_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] model_decoder_layers_13_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] model_decoder_layers_13_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[803] model_decoder_layers_13_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] model_decoder_layers_13_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[805] model_decoder_layers_13_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] model_decoder_layers_13_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[807] model_decoder_layers_13_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[808] model_decoder_layers_13_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[809] model_decoder_layers_13_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] model_decoder_layers_13_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[814] model_decoder_layers_13_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] model_decoder_layers_13_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[816] model_decoder_layers_13_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[817] model_decoder_layers_13_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[818] model_decoder_layers_13_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] model_decoder_layers_13_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[820] model_decoder_layers_13_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] model_decoder_layers_13_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[822] model_decoder_layers_13_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[823] model_decoder_layers_13_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[824] model_decoder_layers_14_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] model_decoder_layers_14_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] model_decoder_layers_14_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[827] model_decoder_layers_14_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] model_decoder_layers_14_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[829] model_decoder_layers_14_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] model_decoder_layers_14_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[831] model_decoder_layers_14_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[832] model_decoder_layers_14_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[833] model_decoder_layers_14_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] model_decoder_layers_14_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[838] model_decoder_layers_14_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] model_decoder_layers_14_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[840] model_decoder_layers_14_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[841] model_decoder_layers_14_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[842] model_decoder_layers_14_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] model_decoder_layers_14_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[844] model_decoder_layers_14_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] model_decoder_layers_14_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[846] model_decoder_layers_14_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[847] model_decoder_layers_14_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[848] model_decoder_layers_15_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] model_decoder_layers_15_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] model_decoder_layers_15_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[851] model_decoder_layers_15_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] model_decoder_layers_15_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[853] model_decoder_layers_15_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] model_decoder_layers_15_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[855] model_decoder_layers_15_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[856] model_decoder_layers_15_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[857] model_decoder_layers_15_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] model_decoder_layers_15_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[862] model_decoder_layers_15_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] model_decoder_layers_15_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[864] model_decoder_layers_15_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[865] model_decoder_layers_15_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[866] model_decoder_layers_15_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] model_decoder_layers_15_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[868] model_decoder_layers_15_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] model_decoder_layers_15_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[870] model_decoder_layers_15_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[871] model_decoder_layers_15_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[872] model_decoder_layers_16_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] model_decoder_layers_16_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] model_decoder_layers_16_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[875] model_decoder_layers_16_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] model_decoder_layers_16_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[877] model_decoder_layers_16_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] model_decoder_layers_16_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[879] model_decoder_layers_16_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[880] model_decoder_layers_16_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[881] model_decoder_layers_16_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] model_decoder_layers_16_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[886] model_decoder_layers_16_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] model_decoder_layers_16_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[888] model_decoder_layers_16_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[889] model_decoder_layers_16_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[890] model_decoder_layers_16_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] model_decoder_layers_16_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[892] model_decoder_layers_16_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] model_decoder_layers_16_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[894] model_decoder_layers_16_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[895] model_decoder_layers_16_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[896] model_decoder_layers_17_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] model_decoder_layers_17_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] model_decoder_layers_17_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[899] model_decoder_layers_17_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] model_decoder_layers_17_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[901] model_decoder_layers_17_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] model_decoder_layers_17_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[903] model_decoder_layers_17_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[904] model_decoder_layers_17_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[905] model_decoder_layers_17_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] model_decoder_layers_17_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[910] model_decoder_layers_17_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] model_decoder_layers_17_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[912] model_decoder_layers_17_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[913] model_decoder_layers_17_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[914] model_decoder_layers_17_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] model_decoder_layers_17_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[916] model_decoder_layers_17_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] model_decoder_layers_17_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[918] model_decoder_layers_17_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[919] model_decoder_layers_17_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[920] model_decoder_layers_18_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] model_decoder_layers_18_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] model_decoder_layers_18_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[923] model_decoder_layers_18_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] model_decoder_layers_18_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[925] model_decoder_layers_18_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] model_decoder_layers_18_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[927] model_decoder_layers_18_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[928] model_decoder_layers_18_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[929] model_decoder_layers_18_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] model_decoder_layers_18_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[934] model_decoder_layers_18_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] model_decoder_layers_18_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[936] model_decoder_layers_18_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[937] model_decoder_layers_18_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[938] model_decoder_layers_18_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] model_decoder_layers_18_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[940] model_decoder_layers_18_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] model_decoder_layers_18_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[942] model_decoder_layers_18_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[943] model_decoder_layers_18_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[944] model_decoder_layers_19_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] model_decoder_layers_19_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] model_decoder_layers_19_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[947] model_decoder_layers_19_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] model_decoder_layers_19_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[949] model_decoder_layers_19_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] model_decoder_layers_19_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[951] model_decoder_layers_19_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[952] model_decoder_layers_19_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[953] model_decoder_layers_19_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] model_decoder_layers_19_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[958] model_decoder_layers_19_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] model_decoder_layers_19_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[960] model_decoder_layers_19_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[961] model_decoder_layers_19_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[962] model_decoder_layers_19_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] model_decoder_layers_19_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[964] model_decoder_layers_19_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] model_decoder_layers_19_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[966] model_decoder_layers_19_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[967] model_decoder_layers_19_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[968] model_decoder_layers_20_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] model_decoder_layers_20_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] model_decoder_layers_20_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[971] model_decoder_layers_20_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] model_decoder_layers_20_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[973] model_decoder_layers_20_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] model_decoder_layers_20_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[975] model_decoder_layers_20_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[976] model_decoder_layers_20_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[977] model_decoder_layers_20_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] model_decoder_layers_20_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[982] model_decoder_layers_20_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] model_decoder_layers_20_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[984] model_decoder_layers_20_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[985] model_decoder_layers_20_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[986] model_decoder_layers_20_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] model_decoder_layers_20_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[988] model_decoder_layers_20_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] model_decoder_layers_20_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[990] model_decoder_layers_20_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[991] model_decoder_layers_20_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[992] model_decoder_layers_21_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] model_decoder_layers_21_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] model_decoder_layers_21_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[995] model_decoder_layers_21_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] model_decoder_layers_21_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[997] model_decoder_layers_21_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] model_decoder_layers_21_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[999] model_decoder_layers_21_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1000] model_decoder_layers_21_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1001] model_decoder_layers_21_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] model_decoder_layers_21_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1006] model_decoder_layers_21_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] model_decoder_layers_21_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1008] model_decoder_layers_21_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1009] model_decoder_layers_21_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1010] model_decoder_layers_21_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] model_decoder_layers_21_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1012] model_decoder_layers_21_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] model_decoder_layers_21_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1014] model_decoder_layers_21_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1015] model_decoder_layers_21_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1016] model_decoder_layers_22_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] model_decoder_layers_22_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] model_decoder_layers_22_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1019] model_decoder_layers_22_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] model_decoder_layers_22_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1021] model_decoder_layers_22_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] model_decoder_layers_22_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1023] model_decoder_layers_22_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1024] model_decoder_layers_22_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1025] model_decoder_layers_22_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] model_decoder_layers_22_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1030] model_decoder_layers_22_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] model_decoder_layers_22_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1032] model_decoder_layers_22_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1033] model_decoder_layers_22_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1034] model_decoder_layers_22_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] model_decoder_layers_22_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1036] model_decoder_layers_22_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] model_decoder_layers_22_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1038] model_decoder_layers_22_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1039] model_decoder_layers_22_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1040] model_decoder_layers_23_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] model_decoder_layers_23_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] model_decoder_layers_23_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1043] model_decoder_layers_23_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] model_decoder_layers_23_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1045] model_decoder_layers_23_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] model_decoder_layers_23_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1047] model_decoder_layers_23_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1048] model_decoder_layers_23_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1049] model_decoder_layers_23_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] model_decoder_layers_23_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1054] model_decoder_layers_23_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] model_decoder_layers_23_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1056] model_decoder_layers_23_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1057] model_decoder_layers_23_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1058] model_decoder_layers_23_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] model_decoder_layers_23_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1060] model_decoder_layers_23_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] model_decoder_layers_23_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1062] model_decoder_layers_23_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1063] model_decoder_layers_23_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1064] model_decoder_layers_24_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] model_decoder_layers_24_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] model_decoder_layers_24_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1067] model_decoder_layers_24_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] model_decoder_layers_24_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1069] model_decoder_layers_24_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] model_decoder_layers_24_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1071] model_decoder_layers_24_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1072] model_decoder_layers_24_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1073] model_decoder_layers_24_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] model_decoder_layers_24_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1078] model_decoder_layers_24_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] model_decoder_layers_24_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1080] model_decoder_layers_24_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1081] model_decoder_layers_24_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1082] model_decoder_layers_24_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] model_decoder_layers_24_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1084] model_decoder_layers_24_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] model_decoder_layers_24_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1086] model_decoder_layers_24_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1087] model_decoder_layers_24_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1088] model_decoder_layers_25_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] model_decoder_layers_25_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] model_decoder_layers_25_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1091] model_decoder_layers_25_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] model_decoder_layers_25_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1093] model_decoder_layers_25_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] model_decoder_layers_25_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1095] model_decoder_layers_25_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1096] model_decoder_layers_25_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1097] model_decoder_layers_25_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] model_decoder_layers_25_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1102] model_decoder_layers_25_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] model_decoder_layers_25_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1104] model_decoder_layers_25_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1105] model_decoder_layers_25_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1106] model_decoder_layers_25_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] model_decoder_layers_25_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1108] model_decoder_layers_25_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] model_decoder_layers_25_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1110] model_decoder_layers_25_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1111] model_decoder_layers_25_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1112] model_decoder_layers_26_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] model_decoder_layers_26_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] model_decoder_layers_26_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1115] model_decoder_layers_26_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] model_decoder_layers_26_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1117] model_decoder_layers_26_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] model_decoder_layers_26_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1119] model_decoder_layers_26_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1120] model_decoder_layers_26_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1121] model_decoder_layers_26_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] model_decoder_layers_26_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1126] model_decoder_layers_26_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] model_decoder_layers_26_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1128] model_decoder_layers_26_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1129] model_decoder_layers_26_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1130] model_decoder_layers_26_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] model_decoder_layers_26_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1132] model_decoder_layers_26_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] model_decoder_layers_26_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1134] model_decoder_layers_26_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1135] model_decoder_layers_26_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1136] model_decoder_layers_27_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] model_decoder_layers_27_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] model_decoder_layers_27_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1139] model_decoder_layers_27_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] model_decoder_layers_27_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1141] model_decoder_layers_27_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] model_decoder_layers_27_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1143] model_decoder_layers_27_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1144] model_decoder_layers_27_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1145] model_decoder_layers_27_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] model_decoder_layers_27_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1150] model_decoder_layers_27_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] model_decoder_layers_27_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1152] model_decoder_layers_27_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1153] model_decoder_layers_27_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1154] model_decoder_layers_27_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] model_decoder_layers_27_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1156] model_decoder_layers_27_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] model_decoder_layers_27_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1158] model_decoder_layers_27_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1159] model_decoder_layers_27_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1160] model_decoder_layers_28_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] model_decoder_layers_28_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] model_decoder_layers_28_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1163] model_decoder_layers_28_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] model_decoder_layers_28_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1165] model_decoder_layers_28_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] model_decoder_layers_28_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1167] model_decoder_layers_28_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1168] model_decoder_layers_28_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1169] model_decoder_layers_28_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] model_decoder_layers_28_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1174] model_decoder_layers_28_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] model_decoder_layers_28_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1176] model_decoder_layers_28_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1177] model_decoder_layers_28_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1178] model_decoder_layers_28_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] model_decoder_layers_28_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1180] model_decoder_layers_28_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] model_decoder_layers_28_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1182] model_decoder_layers_28_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1183] model_decoder_layers_28_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1184] model_decoder_layers_29_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] model_decoder_layers_29_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] model_decoder_layers_29_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1187] model_decoder_layers_29_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] model_decoder_layers_29_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1189] model_decoder_layers_29_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] model_decoder_layers_29_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1191] model_decoder_layers_29_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1192] model_decoder_layers_29_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1193] model_decoder_layers_29_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] model_decoder_layers_29_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1198] model_decoder_layers_29_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] model_decoder_layers_29_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1200] model_decoder_layers_29_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1201] model_decoder_layers_29_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1202] model_decoder_layers_29_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] model_decoder_layers_29_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1204] model_decoder_layers_29_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] model_decoder_layers_29_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1206] model_decoder_layers_29_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1207] model_decoder_layers_29_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1208] model_decoder_layers_30_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] model_decoder_layers_30_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] model_decoder_layers_30_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1211] model_decoder_layers_30_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] model_decoder_layers_30_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1213] model_decoder_layers_30_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] model_decoder_layers_30_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1215] model_decoder_layers_30_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1216] model_decoder_layers_30_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1217] model_decoder_layers_30_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] model_decoder_layers_30_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1222] model_decoder_layers_30_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] model_decoder_layers_30_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1224] model_decoder_layers_30_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1225] model_decoder_layers_30_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1226] model_decoder_layers_30_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] model_decoder_layers_30_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1228] model_decoder_layers_30_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] model_decoder_layers_30_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1230] model_decoder_layers_30_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1231] model_decoder_layers_30_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1232] model_decoder_layers_31_self_attn_k_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] model_decoder_layers_31_self_attn_v_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] model_decoder_layers_31_self_attn_v_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1235] model_decoder_layers_31_self_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] model_decoder_layers_31_self_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1237] model_decoder_layers_31_self_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] model_decoder_layers_31_self_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1239] model_decoder_layers_31_self_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1240] model_decoder_layers_31_self_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1241] model_decoder_layers_31_encoder_attn_q_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] model_decoder_layers_31_encoder_attn_q_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1246] model_decoder_layers_31_encoder_attn_out_proj_weight5: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] model_decoder_layers_31_encoder_attn_out_proj_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1248] model_decoder_layers_31_encoder_attn_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1249] model_decoder_layers_31_encoder_attn_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1250] model_decoder_layers_31_fc1_weight5: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] model_decoder_layers_31_fc1_bias5: R.Tensor((5120,), dtype="float16") = packed_params[1252] model_decoder_layers_31_fc2_weight5: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] model_decoder_layers_31_fc2_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1254] model_decoder_layers_31_final_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1255] model_decoder_layers_31_final_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1256] model_decoder_layer_norm_weight5: R.Tensor((1280,), dtype="float16") = packed_params[1257] model_decoder_layer_norm_bias5: R.Tensor((1280,), dtype="float16") = packed_params[1258] reshape1353 = R.call_tir(cls.reshape19, (input_ids,), out_sinfo=R.Tensor((1,), dtype="int32")) take7 = R.call_tir(cls.take3, (model_decoder_embed_tokens_weight5, reshape1353), out_sinfo=R.Tensor((1, 1280), dtype="float16")) lv264: R.Tensor((1,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((1,), dtype="int32"),)) take8 = R.call_tir(cls.take4, (model_decoder_embed_positions_weight5, lv264), out_sinfo=R.Tensor((1, 1280), dtype="float16")) lv40 = R.call_tir(cls.fused_reshape20_reshape20_add6, (take7, take8), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm356 = R.call_tir(cls.layer_norm3, (lv40, model_decoder_layers_0_self_attn_layer_norm_weight5, model_decoder_layers_0_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv41 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm356, model_decoder_layers_0_self_attn_q_proj_weight5, model_decoder_layers_0_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv1 = R.call_tir(cls.NT_matmul, (layer_norm356, model_decoder_layers_0_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv42 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm356, model_decoder_layers_0_self_attn_v_proj_weight5, model_decoder_layers_0_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv43 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv41, lv1, lv42), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv265 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), lv43), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv44 = R.call_tir(cls.fused_reshape23_reshape24, (lv265,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv45 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv44, model_decoder_layers_0_self_attn_out_proj_weight5, model_decoder_layers_0_self_attn_out_proj_bias5, lv40), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm357 = R.call_tir(cls.layer_norm3, (lv45, model_decoder_layers_0_encoder_attn_layer_norm_weight5, model_decoder_layers_0_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv46 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm357, model_decoder_layers_0_encoder_attn_q_proj_weight5, model_decoder_layers_0_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv47 = R.call_tir(cls.fused_reshape21_reshape25, (lv46,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv266 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), lv47), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv48 = R.call_tir(cls.fused_reshape23_reshape24, (lv266,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv49 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv48, model_decoder_layers_0_encoder_attn_out_proj_weight5, model_decoder_layers_0_encoder_attn_out_proj_bias5, lv45), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm358 = R.call_tir(cls.layer_norm3, (lv49, model_decoder_layers_0_final_layer_norm_weight5, model_decoder_layers_0_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv50 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm358, model_decoder_layers_0_fc1_weight5, model_decoder_layers_0_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv51 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv50, model_decoder_layers_0_fc2_weight5, model_decoder_layers_0_fc2_bias5, lv49), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm359 = R.call_tir(cls.layer_norm3, (lv51, model_decoder_layers_1_self_attn_layer_norm_weight5, model_decoder_layers_1_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv52 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm359, model_decoder_layers_1_self_attn_q_proj_weight5, model_decoder_layers_1_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv9 = R.call_tir(cls.NT_matmul, (layer_norm359, model_decoder_layers_1_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv53 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm359, model_decoder_layers_1_self_attn_v_proj_weight5, model_decoder_layers_1_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv54 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv52, lv9, lv53), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv267 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), lv54), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv55 = R.call_tir(cls.fused_reshape23_reshape24, (lv267,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv56 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv55, model_decoder_layers_1_self_attn_out_proj_weight5, model_decoder_layers_1_self_attn_out_proj_bias5, lv51), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm360 = R.call_tir(cls.layer_norm3, (lv56, model_decoder_layers_1_encoder_attn_layer_norm_weight5, model_decoder_layers_1_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv57 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm360, model_decoder_layers_1_encoder_attn_q_proj_weight5, model_decoder_layers_1_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv58 = R.call_tir(cls.fused_reshape21_reshape25, (lv57,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv268 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), lv58), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv59 = R.call_tir(cls.fused_reshape23_reshape24, (lv268,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv60 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv59, model_decoder_layers_1_encoder_attn_out_proj_weight5, model_decoder_layers_1_encoder_attn_out_proj_bias5, lv56), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm361 = R.call_tir(cls.layer_norm3, (lv60, model_decoder_layers_1_final_layer_norm_weight5, model_decoder_layers_1_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv61 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm361, model_decoder_layers_1_fc1_weight5, model_decoder_layers_1_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv62 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv61, model_decoder_layers_1_fc2_weight5, model_decoder_layers_1_fc2_bias5, lv60), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm362 = R.call_tir(cls.layer_norm3, (lv62, model_decoder_layers_2_self_attn_layer_norm_weight5, model_decoder_layers_2_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv63 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm362, model_decoder_layers_2_self_attn_q_proj_weight5, model_decoder_layers_2_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv17 = R.call_tir(cls.NT_matmul, (layer_norm362, model_decoder_layers_2_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv64 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm362, model_decoder_layers_2_self_attn_v_proj_weight5, model_decoder_layers_2_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv65 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv63, lv17, lv64), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv269 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), lv65), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv66 = R.call_tir(cls.fused_reshape23_reshape24, (lv269,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv67 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv66, model_decoder_layers_2_self_attn_out_proj_weight5, model_decoder_layers_2_self_attn_out_proj_bias5, lv62), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm363 = R.call_tir(cls.layer_norm3, (lv67, model_decoder_layers_2_encoder_attn_layer_norm_weight5, model_decoder_layers_2_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv68 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm363, model_decoder_layers_2_encoder_attn_q_proj_weight5, model_decoder_layers_2_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv69 = R.call_tir(cls.fused_reshape21_reshape25, (lv68,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv270 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), lv69), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv70 = R.call_tir(cls.fused_reshape23_reshape24, (lv270,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv71 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv70, model_decoder_layers_2_encoder_attn_out_proj_weight5, model_decoder_layers_2_encoder_attn_out_proj_bias5, lv67), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm364 = R.call_tir(cls.layer_norm3, (lv71, model_decoder_layers_2_final_layer_norm_weight5, model_decoder_layers_2_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv72 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm364, model_decoder_layers_2_fc1_weight5, model_decoder_layers_2_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv73 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv72, model_decoder_layers_2_fc2_weight5, model_decoder_layers_2_fc2_bias5, lv71), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm365 = R.call_tir(cls.layer_norm3, (lv73, model_decoder_layers_3_self_attn_layer_norm_weight5, model_decoder_layers_3_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv74 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm365, model_decoder_layers_3_self_attn_q_proj_weight5, model_decoder_layers_3_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv25 = R.call_tir(cls.NT_matmul, (layer_norm365, model_decoder_layers_3_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv75 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm365, model_decoder_layers_3_self_attn_v_proj_weight5, model_decoder_layers_3_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv76 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv74, lv25, lv75), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv271 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), lv76), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv77 = R.call_tir(cls.fused_reshape23_reshape24, (lv271,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv78 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv77, model_decoder_layers_3_self_attn_out_proj_weight5, model_decoder_layers_3_self_attn_out_proj_bias5, lv73), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm366 = R.call_tir(cls.layer_norm3, (lv78, model_decoder_layers_3_encoder_attn_layer_norm_weight5, model_decoder_layers_3_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv79 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm366, model_decoder_layers_3_encoder_attn_q_proj_weight5, model_decoder_layers_3_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv80 = R.call_tir(cls.fused_reshape21_reshape25, (lv79,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv272 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), lv80), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv81 = R.call_tir(cls.fused_reshape23_reshape24, (lv272,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv82 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv81, model_decoder_layers_3_encoder_attn_out_proj_weight5, model_decoder_layers_3_encoder_attn_out_proj_bias5, lv78), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm367 = R.call_tir(cls.layer_norm3, (lv82, model_decoder_layers_3_final_layer_norm_weight5, model_decoder_layers_3_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv83 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm367, model_decoder_layers_3_fc1_weight5, model_decoder_layers_3_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv84 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv83, model_decoder_layers_3_fc2_weight5, model_decoder_layers_3_fc2_bias5, lv82), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm368 = R.call_tir(cls.layer_norm3, (lv84, model_decoder_layers_4_self_attn_layer_norm_weight5, model_decoder_layers_4_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv85 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm368, model_decoder_layers_4_self_attn_q_proj_weight5, model_decoder_layers_4_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv33 = R.call_tir(cls.NT_matmul, (layer_norm368, model_decoder_layers_4_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv86 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm368, model_decoder_layers_4_self_attn_v_proj_weight5, model_decoder_layers_4_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv87 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv85, lv33, lv86), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv273 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), lv87), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv88 = R.call_tir(cls.fused_reshape23_reshape24, (lv273,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv89 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv88, model_decoder_layers_4_self_attn_out_proj_weight5, model_decoder_layers_4_self_attn_out_proj_bias5, lv84), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm369 = R.call_tir(cls.layer_norm3, (lv89, model_decoder_layers_4_encoder_attn_layer_norm_weight5, model_decoder_layers_4_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv90 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm369, model_decoder_layers_4_encoder_attn_q_proj_weight5, model_decoder_layers_4_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv91 = R.call_tir(cls.fused_reshape21_reshape25, (lv90,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv274 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), lv91), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv92 = R.call_tir(cls.fused_reshape23_reshape24, (lv274,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv93 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv92, model_decoder_layers_4_encoder_attn_out_proj_weight5, model_decoder_layers_4_encoder_attn_out_proj_bias5, lv89), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm370 = R.call_tir(cls.layer_norm3, (lv93, model_decoder_layers_4_final_layer_norm_weight5, model_decoder_layers_4_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv94 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm370, model_decoder_layers_4_fc1_weight5, model_decoder_layers_4_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv95 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv94, model_decoder_layers_4_fc2_weight5, model_decoder_layers_4_fc2_bias5, lv93), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm371 = R.call_tir(cls.layer_norm3, (lv95, model_decoder_layers_5_self_attn_layer_norm_weight5, model_decoder_layers_5_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv96 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm371, model_decoder_layers_5_self_attn_q_proj_weight5, model_decoder_layers_5_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv41_1 = R.call_tir(cls.NT_matmul, (layer_norm371, model_decoder_layers_5_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv97 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm371, model_decoder_layers_5_self_attn_v_proj_weight5, model_decoder_layers_5_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv98 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv96, lv41_1, lv97), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv275 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), lv98), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv99 = R.call_tir(cls.fused_reshape23_reshape24, (lv275,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv100 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv99, model_decoder_layers_5_self_attn_out_proj_weight5, model_decoder_layers_5_self_attn_out_proj_bias5, lv95), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm372 = R.call_tir(cls.layer_norm3, (lv100, model_decoder_layers_5_encoder_attn_layer_norm_weight5, model_decoder_layers_5_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv101 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm372, model_decoder_layers_5_encoder_attn_q_proj_weight5, model_decoder_layers_5_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv102 = R.call_tir(cls.fused_reshape21_reshape25, (lv101,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv276 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), lv102), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv103 = R.call_tir(cls.fused_reshape23_reshape24, (lv276,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv104 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv103, model_decoder_layers_5_encoder_attn_out_proj_weight5, model_decoder_layers_5_encoder_attn_out_proj_bias5, lv100), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm373 = R.call_tir(cls.layer_norm3, (lv104, model_decoder_layers_5_final_layer_norm_weight5, model_decoder_layers_5_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv105 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm373, model_decoder_layers_5_fc1_weight5, model_decoder_layers_5_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv106 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv105, model_decoder_layers_5_fc2_weight5, model_decoder_layers_5_fc2_bias5, lv104), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm374 = R.call_tir(cls.layer_norm3, (lv106, model_decoder_layers_6_self_attn_layer_norm_weight5, model_decoder_layers_6_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv107 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm374, model_decoder_layers_6_self_attn_q_proj_weight5, model_decoder_layers_6_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv49_1 = R.call_tir(cls.NT_matmul, (layer_norm374, model_decoder_layers_6_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv108 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm374, model_decoder_layers_6_self_attn_v_proj_weight5, model_decoder_layers_6_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv109 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv107, lv49_1, lv108), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv277 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), lv109), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv110 = R.call_tir(cls.fused_reshape23_reshape24, (lv277,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv111 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv110, model_decoder_layers_6_self_attn_out_proj_weight5, model_decoder_layers_6_self_attn_out_proj_bias5, lv106), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm375 = R.call_tir(cls.layer_norm3, (lv111, model_decoder_layers_6_encoder_attn_layer_norm_weight5, model_decoder_layers_6_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv112 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm375, model_decoder_layers_6_encoder_attn_q_proj_weight5, model_decoder_layers_6_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv113 = R.call_tir(cls.fused_reshape21_reshape25, (lv112,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv278 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), lv113), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv114 = R.call_tir(cls.fused_reshape23_reshape24, (lv278,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv115 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv114, model_decoder_layers_6_encoder_attn_out_proj_weight5, model_decoder_layers_6_encoder_attn_out_proj_bias5, lv111), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm376 = R.call_tir(cls.layer_norm3, (lv115, model_decoder_layers_6_final_layer_norm_weight5, model_decoder_layers_6_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv116 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm376, model_decoder_layers_6_fc1_weight5, model_decoder_layers_6_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv117 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv116, model_decoder_layers_6_fc2_weight5, model_decoder_layers_6_fc2_bias5, lv115), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm377 = R.call_tir(cls.layer_norm3, (lv117, model_decoder_layers_7_self_attn_layer_norm_weight5, model_decoder_layers_7_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv118 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm377, model_decoder_layers_7_self_attn_q_proj_weight5, model_decoder_layers_7_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv57_1 = R.call_tir(cls.NT_matmul, (layer_norm377, model_decoder_layers_7_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv119 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm377, model_decoder_layers_7_self_attn_v_proj_weight5, model_decoder_layers_7_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv120 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv118, lv57_1, lv119), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv279 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), lv120), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv121 = R.call_tir(cls.fused_reshape23_reshape24, (lv279,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv122 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv121, model_decoder_layers_7_self_attn_out_proj_weight5, model_decoder_layers_7_self_attn_out_proj_bias5, lv117), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm378 = R.call_tir(cls.layer_norm3, (lv122, model_decoder_layers_7_encoder_attn_layer_norm_weight5, model_decoder_layers_7_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv123 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm378, model_decoder_layers_7_encoder_attn_q_proj_weight5, model_decoder_layers_7_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv124 = R.call_tir(cls.fused_reshape21_reshape25, (lv123,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv280 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), lv124), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv125 = R.call_tir(cls.fused_reshape23_reshape24, (lv280,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv126 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv125, model_decoder_layers_7_encoder_attn_out_proj_weight5, model_decoder_layers_7_encoder_attn_out_proj_bias5, lv122), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm379 = R.call_tir(cls.layer_norm3, (lv126, model_decoder_layers_7_final_layer_norm_weight5, model_decoder_layers_7_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv127 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm379, model_decoder_layers_7_fc1_weight5, model_decoder_layers_7_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv128 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv127, model_decoder_layers_7_fc2_weight5, model_decoder_layers_7_fc2_bias5, lv126), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm380 = R.call_tir(cls.layer_norm3, (lv128, model_decoder_layers_8_self_attn_layer_norm_weight5, model_decoder_layers_8_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv129 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm380, model_decoder_layers_8_self_attn_q_proj_weight5, model_decoder_layers_8_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv65_1 = R.call_tir(cls.NT_matmul, (layer_norm380, model_decoder_layers_8_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv130 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm380, model_decoder_layers_8_self_attn_v_proj_weight5, model_decoder_layers_8_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv131 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv129, lv65_1, lv130), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv281 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), lv131), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv132 = R.call_tir(cls.fused_reshape23_reshape24, (lv281,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv133 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv132, model_decoder_layers_8_self_attn_out_proj_weight5, model_decoder_layers_8_self_attn_out_proj_bias5, lv128), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm381 = R.call_tir(cls.layer_norm3, (lv133, model_decoder_layers_8_encoder_attn_layer_norm_weight5, model_decoder_layers_8_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv134 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm381, model_decoder_layers_8_encoder_attn_q_proj_weight5, model_decoder_layers_8_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv135 = R.call_tir(cls.fused_reshape21_reshape25, (lv134,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv282 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), lv135), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv136 = R.call_tir(cls.fused_reshape23_reshape24, (lv282,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv137 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv136, model_decoder_layers_8_encoder_attn_out_proj_weight5, model_decoder_layers_8_encoder_attn_out_proj_bias5, lv133), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm382 = R.call_tir(cls.layer_norm3, (lv137, model_decoder_layers_8_final_layer_norm_weight5, model_decoder_layers_8_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv138 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm382, model_decoder_layers_8_fc1_weight5, model_decoder_layers_8_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv139 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv138, model_decoder_layers_8_fc2_weight5, model_decoder_layers_8_fc2_bias5, lv137), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm383 = R.call_tir(cls.layer_norm3, (lv139, model_decoder_layers_9_self_attn_layer_norm_weight5, model_decoder_layers_9_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv140 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm383, model_decoder_layers_9_self_attn_q_proj_weight5, model_decoder_layers_9_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv73_1 = R.call_tir(cls.NT_matmul, (layer_norm383, model_decoder_layers_9_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv141 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm383, model_decoder_layers_9_self_attn_v_proj_weight5, model_decoder_layers_9_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv142 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv140, lv73_1, lv141), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv283 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), lv142), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv143 = R.call_tir(cls.fused_reshape23_reshape24, (lv283,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv144 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv143, model_decoder_layers_9_self_attn_out_proj_weight5, model_decoder_layers_9_self_attn_out_proj_bias5, lv139), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm384 = R.call_tir(cls.layer_norm3, (lv144, model_decoder_layers_9_encoder_attn_layer_norm_weight5, model_decoder_layers_9_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv145 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm384, model_decoder_layers_9_encoder_attn_q_proj_weight5, model_decoder_layers_9_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv146 = R.call_tir(cls.fused_reshape21_reshape25, (lv145,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv284 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), lv146), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv147 = R.call_tir(cls.fused_reshape23_reshape24, (lv284,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv148 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv147, model_decoder_layers_9_encoder_attn_out_proj_weight5, model_decoder_layers_9_encoder_attn_out_proj_bias5, lv144), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm385 = R.call_tir(cls.layer_norm3, (lv148, model_decoder_layers_9_final_layer_norm_weight5, model_decoder_layers_9_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv149 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm385, model_decoder_layers_9_fc1_weight5, model_decoder_layers_9_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv150 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv149, model_decoder_layers_9_fc2_weight5, model_decoder_layers_9_fc2_bias5, lv148), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm386 = R.call_tir(cls.layer_norm3, (lv150, model_decoder_layers_10_self_attn_layer_norm_weight5, model_decoder_layers_10_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv151 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm386, model_decoder_layers_10_self_attn_q_proj_weight5, model_decoder_layers_10_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv81_1 = R.call_tir(cls.NT_matmul, (layer_norm386, model_decoder_layers_10_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv152 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm386, model_decoder_layers_10_self_attn_v_proj_weight5, model_decoder_layers_10_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv153 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv151, lv81_1, lv152), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv285 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), lv153), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv154 = R.call_tir(cls.fused_reshape23_reshape24, (lv285,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv155 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv154, model_decoder_layers_10_self_attn_out_proj_weight5, model_decoder_layers_10_self_attn_out_proj_bias5, lv150), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm387 = R.call_tir(cls.layer_norm3, (lv155, model_decoder_layers_10_encoder_attn_layer_norm_weight5, model_decoder_layers_10_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv156 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm387, model_decoder_layers_10_encoder_attn_q_proj_weight5, model_decoder_layers_10_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv157 = R.call_tir(cls.fused_reshape21_reshape25, (lv156,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv286 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), lv157), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv158 = R.call_tir(cls.fused_reshape23_reshape24, (lv286,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv159 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv158, model_decoder_layers_10_encoder_attn_out_proj_weight5, model_decoder_layers_10_encoder_attn_out_proj_bias5, lv155), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm388 = R.call_tir(cls.layer_norm3, (lv159, model_decoder_layers_10_final_layer_norm_weight5, model_decoder_layers_10_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv160 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm388, model_decoder_layers_10_fc1_weight5, model_decoder_layers_10_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv161 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv160, model_decoder_layers_10_fc2_weight5, model_decoder_layers_10_fc2_bias5, lv159), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm389 = R.call_tir(cls.layer_norm3, (lv161, model_decoder_layers_11_self_attn_layer_norm_weight5, model_decoder_layers_11_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv162 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm389, model_decoder_layers_11_self_attn_q_proj_weight5, model_decoder_layers_11_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv89_1 = R.call_tir(cls.NT_matmul, (layer_norm389, model_decoder_layers_11_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv163 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm389, model_decoder_layers_11_self_attn_v_proj_weight5, model_decoder_layers_11_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv164 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv162, lv89_1, lv163), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv287 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), lv164), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv165 = R.call_tir(cls.fused_reshape23_reshape24, (lv287,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv166 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv165, model_decoder_layers_11_self_attn_out_proj_weight5, model_decoder_layers_11_self_attn_out_proj_bias5, lv161), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm390 = R.call_tir(cls.layer_norm3, (lv166, model_decoder_layers_11_encoder_attn_layer_norm_weight5, model_decoder_layers_11_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv167 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm390, model_decoder_layers_11_encoder_attn_q_proj_weight5, model_decoder_layers_11_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv168 = R.call_tir(cls.fused_reshape21_reshape25, (lv167,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv288 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), lv168), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv169 = R.call_tir(cls.fused_reshape23_reshape24, (lv288,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv170 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv169, model_decoder_layers_11_encoder_attn_out_proj_weight5, model_decoder_layers_11_encoder_attn_out_proj_bias5, lv166), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm391 = R.call_tir(cls.layer_norm3, (lv170, model_decoder_layers_11_final_layer_norm_weight5, model_decoder_layers_11_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv171 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm391, model_decoder_layers_11_fc1_weight5, model_decoder_layers_11_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv172 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv171, model_decoder_layers_11_fc2_weight5, model_decoder_layers_11_fc2_bias5, lv170), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm392 = R.call_tir(cls.layer_norm3, (lv172, model_decoder_layers_12_self_attn_layer_norm_weight5, model_decoder_layers_12_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv173 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm392, model_decoder_layers_12_self_attn_q_proj_weight5, model_decoder_layers_12_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv97_1 = R.call_tir(cls.NT_matmul, (layer_norm392, model_decoder_layers_12_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv174 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm392, model_decoder_layers_12_self_attn_v_proj_weight5, model_decoder_layers_12_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv175 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv173, lv97_1, lv174), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv289 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), lv175), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv176 = R.call_tir(cls.fused_reshape23_reshape24, (lv289,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv177 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv176, model_decoder_layers_12_self_attn_out_proj_weight5, model_decoder_layers_12_self_attn_out_proj_bias5, lv172), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm393 = R.call_tir(cls.layer_norm3, (lv177, model_decoder_layers_12_encoder_attn_layer_norm_weight5, model_decoder_layers_12_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv178 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm393, model_decoder_layers_12_encoder_attn_q_proj_weight5, model_decoder_layers_12_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv179 = R.call_tir(cls.fused_reshape21_reshape25, (lv178,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv290 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), lv179), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv180 = R.call_tir(cls.fused_reshape23_reshape24, (lv290,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv181 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv180, model_decoder_layers_12_encoder_attn_out_proj_weight5, model_decoder_layers_12_encoder_attn_out_proj_bias5, lv177), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm394 = R.call_tir(cls.layer_norm3, (lv181, model_decoder_layers_12_final_layer_norm_weight5, model_decoder_layers_12_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv182 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm394, model_decoder_layers_12_fc1_weight5, model_decoder_layers_12_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv183 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv182, model_decoder_layers_12_fc2_weight5, model_decoder_layers_12_fc2_bias5, lv181), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm395 = R.call_tir(cls.layer_norm3, (lv183, model_decoder_layers_13_self_attn_layer_norm_weight5, model_decoder_layers_13_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv184 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm395, model_decoder_layers_13_self_attn_q_proj_weight5, model_decoder_layers_13_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv105_1 = R.call_tir(cls.NT_matmul, (layer_norm395, model_decoder_layers_13_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv185 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm395, model_decoder_layers_13_self_attn_v_proj_weight5, model_decoder_layers_13_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv186 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv184, lv105_1, lv185), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv291 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), lv186), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv187 = R.call_tir(cls.fused_reshape23_reshape24, (lv291,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv188 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv187, model_decoder_layers_13_self_attn_out_proj_weight5, model_decoder_layers_13_self_attn_out_proj_bias5, lv183), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm396 = R.call_tir(cls.layer_norm3, (lv188, model_decoder_layers_13_encoder_attn_layer_norm_weight5, model_decoder_layers_13_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv189 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm396, model_decoder_layers_13_encoder_attn_q_proj_weight5, model_decoder_layers_13_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv190 = R.call_tir(cls.fused_reshape21_reshape25, (lv189,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv292 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), lv190), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv191 = R.call_tir(cls.fused_reshape23_reshape24, (lv292,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv192 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv191, model_decoder_layers_13_encoder_attn_out_proj_weight5, model_decoder_layers_13_encoder_attn_out_proj_bias5, lv188), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm397 = R.call_tir(cls.layer_norm3, (lv192, model_decoder_layers_13_final_layer_norm_weight5, model_decoder_layers_13_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv193 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm397, model_decoder_layers_13_fc1_weight5, model_decoder_layers_13_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv194 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv193, model_decoder_layers_13_fc2_weight5, model_decoder_layers_13_fc2_bias5, lv192), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm398 = R.call_tir(cls.layer_norm3, (lv194, model_decoder_layers_14_self_attn_layer_norm_weight5, model_decoder_layers_14_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv195 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm398, model_decoder_layers_14_self_attn_q_proj_weight5, model_decoder_layers_14_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv113_1 = R.call_tir(cls.NT_matmul, (layer_norm398, model_decoder_layers_14_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv196 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm398, model_decoder_layers_14_self_attn_v_proj_weight5, model_decoder_layers_14_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv197 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv195, lv113_1, lv196), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv293 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), lv197), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv198 = R.call_tir(cls.fused_reshape23_reshape24, (lv293,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv199 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv198, model_decoder_layers_14_self_attn_out_proj_weight5, model_decoder_layers_14_self_attn_out_proj_bias5, lv194), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm399 = R.call_tir(cls.layer_norm3, (lv199, model_decoder_layers_14_encoder_attn_layer_norm_weight5, model_decoder_layers_14_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv200 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm399, model_decoder_layers_14_encoder_attn_q_proj_weight5, model_decoder_layers_14_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv201 = R.call_tir(cls.fused_reshape21_reshape25, (lv200,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv294 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), lv201), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv202 = R.call_tir(cls.fused_reshape23_reshape24, (lv294,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv203 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv202, model_decoder_layers_14_encoder_attn_out_proj_weight5, model_decoder_layers_14_encoder_attn_out_proj_bias5, lv199), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm400 = R.call_tir(cls.layer_norm3, (lv203, model_decoder_layers_14_final_layer_norm_weight5, model_decoder_layers_14_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv204 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm400, model_decoder_layers_14_fc1_weight5, model_decoder_layers_14_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv205 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv204, model_decoder_layers_14_fc2_weight5, model_decoder_layers_14_fc2_bias5, lv203), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm401 = R.call_tir(cls.layer_norm3, (lv205, model_decoder_layers_15_self_attn_layer_norm_weight5, model_decoder_layers_15_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv206 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm401, model_decoder_layers_15_self_attn_q_proj_weight5, model_decoder_layers_15_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv121_1 = R.call_tir(cls.NT_matmul, (layer_norm401, model_decoder_layers_15_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv207 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm401, model_decoder_layers_15_self_attn_v_proj_weight5, model_decoder_layers_15_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv208 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv206, lv121_1, lv207), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv295 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), lv208), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv209 = R.call_tir(cls.fused_reshape23_reshape24, (lv295,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv210 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv209, model_decoder_layers_15_self_attn_out_proj_weight5, model_decoder_layers_15_self_attn_out_proj_bias5, lv205), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm402 = R.call_tir(cls.layer_norm3, (lv210, model_decoder_layers_15_encoder_attn_layer_norm_weight5, model_decoder_layers_15_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv211 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm402, model_decoder_layers_15_encoder_attn_q_proj_weight5, model_decoder_layers_15_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv212 = R.call_tir(cls.fused_reshape21_reshape25, (lv211,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv296 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), lv212), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv213 = R.call_tir(cls.fused_reshape23_reshape24, (lv296,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv214 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv213, model_decoder_layers_15_encoder_attn_out_proj_weight5, model_decoder_layers_15_encoder_attn_out_proj_bias5, lv210), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm403 = R.call_tir(cls.layer_norm3, (lv214, model_decoder_layers_15_final_layer_norm_weight5, model_decoder_layers_15_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv215 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm403, model_decoder_layers_15_fc1_weight5, model_decoder_layers_15_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv216 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv215, model_decoder_layers_15_fc2_weight5, model_decoder_layers_15_fc2_bias5, lv214), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm404 = R.call_tir(cls.layer_norm3, (lv216, model_decoder_layers_16_self_attn_layer_norm_weight5, model_decoder_layers_16_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv217 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm404, model_decoder_layers_16_self_attn_q_proj_weight5, model_decoder_layers_16_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv129_1 = R.call_tir(cls.NT_matmul, (layer_norm404, model_decoder_layers_16_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv218 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm404, model_decoder_layers_16_self_attn_v_proj_weight5, model_decoder_layers_16_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv219 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv217, lv129_1, lv218), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv297 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), lv219), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv220 = R.call_tir(cls.fused_reshape23_reshape24, (lv297,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv221 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv220, model_decoder_layers_16_self_attn_out_proj_weight5, model_decoder_layers_16_self_attn_out_proj_bias5, lv216), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm405 = R.call_tir(cls.layer_norm3, (lv221, model_decoder_layers_16_encoder_attn_layer_norm_weight5, model_decoder_layers_16_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv222 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm405, model_decoder_layers_16_encoder_attn_q_proj_weight5, model_decoder_layers_16_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv223 = R.call_tir(cls.fused_reshape21_reshape25, (lv222,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv298 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), lv223), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv224 = R.call_tir(cls.fused_reshape23_reshape24, (lv298,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv225 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv224, model_decoder_layers_16_encoder_attn_out_proj_weight5, model_decoder_layers_16_encoder_attn_out_proj_bias5, lv221), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm406 = R.call_tir(cls.layer_norm3, (lv225, model_decoder_layers_16_final_layer_norm_weight5, model_decoder_layers_16_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv226 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm406, model_decoder_layers_16_fc1_weight5, model_decoder_layers_16_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv227 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv226, model_decoder_layers_16_fc2_weight5, model_decoder_layers_16_fc2_bias5, lv225), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm407 = R.call_tir(cls.layer_norm3, (lv227, model_decoder_layers_17_self_attn_layer_norm_weight5, model_decoder_layers_17_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv228 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm407, model_decoder_layers_17_self_attn_q_proj_weight5, model_decoder_layers_17_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv137_1 = R.call_tir(cls.NT_matmul, (layer_norm407, model_decoder_layers_17_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv229 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm407, model_decoder_layers_17_self_attn_v_proj_weight5, model_decoder_layers_17_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv230 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv228, lv137_1, lv229), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv299 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), lv230), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv231 = R.call_tir(cls.fused_reshape23_reshape24, (lv299,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv232 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv231, model_decoder_layers_17_self_attn_out_proj_weight5, model_decoder_layers_17_self_attn_out_proj_bias5, lv227), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm408 = R.call_tir(cls.layer_norm3, (lv232, model_decoder_layers_17_encoder_attn_layer_norm_weight5, model_decoder_layers_17_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv233 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm408, model_decoder_layers_17_encoder_attn_q_proj_weight5, model_decoder_layers_17_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv234 = R.call_tir(cls.fused_reshape21_reshape25, (lv233,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv300 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), lv234), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv235 = R.call_tir(cls.fused_reshape23_reshape24, (lv300,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv236 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv235, model_decoder_layers_17_encoder_attn_out_proj_weight5, model_decoder_layers_17_encoder_attn_out_proj_bias5, lv232), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm409 = R.call_tir(cls.layer_norm3, (lv236, model_decoder_layers_17_final_layer_norm_weight5, model_decoder_layers_17_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv237 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm409, model_decoder_layers_17_fc1_weight5, model_decoder_layers_17_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv238 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv237, model_decoder_layers_17_fc2_weight5, model_decoder_layers_17_fc2_bias5, lv236), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm410 = R.call_tir(cls.layer_norm3, (lv238, model_decoder_layers_18_self_attn_layer_norm_weight5, model_decoder_layers_18_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv239 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm410, model_decoder_layers_18_self_attn_q_proj_weight5, model_decoder_layers_18_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv145_1 = R.call_tir(cls.NT_matmul, (layer_norm410, model_decoder_layers_18_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv240 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm410, model_decoder_layers_18_self_attn_v_proj_weight5, model_decoder_layers_18_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv241 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv239, lv145_1, lv240), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv301 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), lv241), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv242 = R.call_tir(cls.fused_reshape23_reshape24, (lv301,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv243 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv242, model_decoder_layers_18_self_attn_out_proj_weight5, model_decoder_layers_18_self_attn_out_proj_bias5, lv238), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm411 = R.call_tir(cls.layer_norm3, (lv243, model_decoder_layers_18_encoder_attn_layer_norm_weight5, model_decoder_layers_18_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv244 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm411, model_decoder_layers_18_encoder_attn_q_proj_weight5, model_decoder_layers_18_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv245 = R.call_tir(cls.fused_reshape21_reshape25, (lv244,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv302 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), lv245), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv246 = R.call_tir(cls.fused_reshape23_reshape24, (lv302,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv247 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv246, model_decoder_layers_18_encoder_attn_out_proj_weight5, model_decoder_layers_18_encoder_attn_out_proj_bias5, lv243), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm412 = R.call_tir(cls.layer_norm3, (lv247, model_decoder_layers_18_final_layer_norm_weight5, model_decoder_layers_18_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv248 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm412, model_decoder_layers_18_fc1_weight5, model_decoder_layers_18_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv249 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv248, model_decoder_layers_18_fc2_weight5, model_decoder_layers_18_fc2_bias5, lv247), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm413 = R.call_tir(cls.layer_norm3, (lv249, model_decoder_layers_19_self_attn_layer_norm_weight5, model_decoder_layers_19_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv250 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm413, model_decoder_layers_19_self_attn_q_proj_weight5, model_decoder_layers_19_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv153_1 = R.call_tir(cls.NT_matmul, (layer_norm413, model_decoder_layers_19_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv251 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm413, model_decoder_layers_19_self_attn_v_proj_weight5, model_decoder_layers_19_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv252 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv250, lv153_1, lv251), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv303 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), lv252), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv253 = R.call_tir(cls.fused_reshape23_reshape24, (lv303,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv254 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv253, model_decoder_layers_19_self_attn_out_proj_weight5, model_decoder_layers_19_self_attn_out_proj_bias5, lv249), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm414 = R.call_tir(cls.layer_norm3, (lv254, model_decoder_layers_19_encoder_attn_layer_norm_weight5, model_decoder_layers_19_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv255 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm414, model_decoder_layers_19_encoder_attn_q_proj_weight5, model_decoder_layers_19_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv256 = R.call_tir(cls.fused_reshape21_reshape25, (lv255,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv304 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), lv256), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv257 = R.call_tir(cls.fused_reshape23_reshape24, (lv304,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv258 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv257, model_decoder_layers_19_encoder_attn_out_proj_weight5, model_decoder_layers_19_encoder_attn_out_proj_bias5, lv254), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm415 = R.call_tir(cls.layer_norm3, (lv258, model_decoder_layers_19_final_layer_norm_weight5, model_decoder_layers_19_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv259 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm415, model_decoder_layers_19_fc1_weight5, model_decoder_layers_19_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv260 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv259, model_decoder_layers_19_fc2_weight5, model_decoder_layers_19_fc2_bias5, lv258), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm416 = R.call_tir(cls.layer_norm3, (lv260, model_decoder_layers_20_self_attn_layer_norm_weight5, model_decoder_layers_20_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv261 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm416, model_decoder_layers_20_self_attn_q_proj_weight5, model_decoder_layers_20_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv161_1 = R.call_tir(cls.NT_matmul, (layer_norm416, model_decoder_layers_20_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv262 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm416, model_decoder_layers_20_self_attn_v_proj_weight5, model_decoder_layers_20_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv263 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv261, lv161_1, lv262), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv305 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), lv263), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv264_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv305,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv265_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv264_1, model_decoder_layers_20_self_attn_out_proj_weight5, model_decoder_layers_20_self_attn_out_proj_bias5, lv260), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm417 = R.call_tir(cls.layer_norm3, (lv265_1, model_decoder_layers_20_encoder_attn_layer_norm_weight5, model_decoder_layers_20_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv266_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm417, model_decoder_layers_20_encoder_attn_q_proj_weight5, model_decoder_layers_20_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv267_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv266_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv306 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), lv267_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv268_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv306,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv269_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv268_1, model_decoder_layers_20_encoder_attn_out_proj_weight5, model_decoder_layers_20_encoder_attn_out_proj_bias5, lv265_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm418 = R.call_tir(cls.layer_norm3, (lv269_1, model_decoder_layers_20_final_layer_norm_weight5, model_decoder_layers_20_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv270_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm418, model_decoder_layers_20_fc1_weight5, model_decoder_layers_20_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv271_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv270_1, model_decoder_layers_20_fc2_weight5, model_decoder_layers_20_fc2_bias5, lv269_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm419 = R.call_tir(cls.layer_norm3, (lv271_1, model_decoder_layers_21_self_attn_layer_norm_weight5, model_decoder_layers_21_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv272_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm419, model_decoder_layers_21_self_attn_q_proj_weight5, model_decoder_layers_21_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv169_1 = R.call_tir(cls.NT_matmul, (layer_norm419, model_decoder_layers_21_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv273_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm419, model_decoder_layers_21_self_attn_v_proj_weight5, model_decoder_layers_21_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv274_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv272_1, lv169_1, lv273_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv307 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), lv274_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv275_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv307,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv276_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv275_1, model_decoder_layers_21_self_attn_out_proj_weight5, model_decoder_layers_21_self_attn_out_proj_bias5, lv271_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm420 = R.call_tir(cls.layer_norm3, (lv276_1, model_decoder_layers_21_encoder_attn_layer_norm_weight5, model_decoder_layers_21_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv277_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm420, model_decoder_layers_21_encoder_attn_q_proj_weight5, model_decoder_layers_21_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv278_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv277_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv308 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), lv278_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv279_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv308,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv280_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv279_1, model_decoder_layers_21_encoder_attn_out_proj_weight5, model_decoder_layers_21_encoder_attn_out_proj_bias5, lv276_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm421 = R.call_tir(cls.layer_norm3, (lv280_1, model_decoder_layers_21_final_layer_norm_weight5, model_decoder_layers_21_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv281_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm421, model_decoder_layers_21_fc1_weight5, model_decoder_layers_21_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv282_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv281_1, model_decoder_layers_21_fc2_weight5, model_decoder_layers_21_fc2_bias5, lv280_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm422 = R.call_tir(cls.layer_norm3, (lv282_1, model_decoder_layers_22_self_attn_layer_norm_weight5, model_decoder_layers_22_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv283_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm422, model_decoder_layers_22_self_attn_q_proj_weight5, model_decoder_layers_22_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv177_1 = R.call_tir(cls.NT_matmul, (layer_norm422, model_decoder_layers_22_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv284_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm422, model_decoder_layers_22_self_attn_v_proj_weight5, model_decoder_layers_22_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv285_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv283_1, lv177_1, lv284_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv309 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), lv285_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv286_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv309,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv287_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv286_1, model_decoder_layers_22_self_attn_out_proj_weight5, model_decoder_layers_22_self_attn_out_proj_bias5, lv282_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm423 = R.call_tir(cls.layer_norm3, (lv287_1, model_decoder_layers_22_encoder_attn_layer_norm_weight5, model_decoder_layers_22_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv288_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm423, model_decoder_layers_22_encoder_attn_q_proj_weight5, model_decoder_layers_22_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv289_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv288_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv310 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), lv289_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv290_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv310,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv291_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv290_1, model_decoder_layers_22_encoder_attn_out_proj_weight5, model_decoder_layers_22_encoder_attn_out_proj_bias5, lv287_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm424 = R.call_tir(cls.layer_norm3, (lv291_1, model_decoder_layers_22_final_layer_norm_weight5, model_decoder_layers_22_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv292_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm424, model_decoder_layers_22_fc1_weight5, model_decoder_layers_22_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv293_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv292_1, model_decoder_layers_22_fc2_weight5, model_decoder_layers_22_fc2_bias5, lv291_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm425 = R.call_tir(cls.layer_norm3, (lv293_1, model_decoder_layers_23_self_attn_layer_norm_weight5, model_decoder_layers_23_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv294_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm425, model_decoder_layers_23_self_attn_q_proj_weight5, model_decoder_layers_23_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv185_1 = R.call_tir(cls.NT_matmul, (layer_norm425, model_decoder_layers_23_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv295_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm425, model_decoder_layers_23_self_attn_v_proj_weight5, model_decoder_layers_23_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv296_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv294_1, lv185_1, lv295_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv311 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), lv296_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv297_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv311,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv298_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv297_1, model_decoder_layers_23_self_attn_out_proj_weight5, model_decoder_layers_23_self_attn_out_proj_bias5, lv293_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm426 = R.call_tir(cls.layer_norm3, (lv298_1, model_decoder_layers_23_encoder_attn_layer_norm_weight5, model_decoder_layers_23_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv299_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm426, model_decoder_layers_23_encoder_attn_q_proj_weight5, model_decoder_layers_23_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv300_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv299_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv312 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), lv300_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv301_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv312,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv302_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv301_1, model_decoder_layers_23_encoder_attn_out_proj_weight5, model_decoder_layers_23_encoder_attn_out_proj_bias5, lv298_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm427 = R.call_tir(cls.layer_norm3, (lv302_1, model_decoder_layers_23_final_layer_norm_weight5, model_decoder_layers_23_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv303_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm427, model_decoder_layers_23_fc1_weight5, model_decoder_layers_23_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv304_1 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv303_1, model_decoder_layers_23_fc2_weight5, model_decoder_layers_23_fc2_bias5, lv302_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm428 = R.call_tir(cls.layer_norm3, (lv304_1, model_decoder_layers_24_self_attn_layer_norm_weight5, model_decoder_layers_24_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv305_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm428, model_decoder_layers_24_self_attn_q_proj_weight5, model_decoder_layers_24_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv193_1 = R.call_tir(cls.NT_matmul, (layer_norm428, model_decoder_layers_24_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv306_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm428, model_decoder_layers_24_self_attn_v_proj_weight5, model_decoder_layers_24_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv307_1 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv305_1, lv193_1, lv306_1), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv313 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), lv307_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv308_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv313,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv309_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv308_1, model_decoder_layers_24_self_attn_out_proj_weight5, model_decoder_layers_24_self_attn_out_proj_bias5, lv304_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm429 = R.call_tir(cls.layer_norm3, (lv309_1, model_decoder_layers_24_encoder_attn_layer_norm_weight5, model_decoder_layers_24_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv310_1 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm429, model_decoder_layers_24_encoder_attn_q_proj_weight5, model_decoder_layers_24_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv311_1 = R.call_tir(cls.fused_reshape21_reshape25, (lv310_1,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv314 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), lv311_1), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv312_1 = R.call_tir(cls.fused_reshape23_reshape24, (lv314,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv313_1 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv312_1, model_decoder_layers_24_encoder_attn_out_proj_weight5, model_decoder_layers_24_encoder_attn_out_proj_bias5, lv309_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm430 = R.call_tir(cls.layer_norm3, (lv313_1, model_decoder_layers_24_final_layer_norm_weight5, model_decoder_layers_24_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv314_1 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm430, model_decoder_layers_24_fc1_weight5, model_decoder_layers_24_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv315 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv314_1, model_decoder_layers_24_fc2_weight5, model_decoder_layers_24_fc2_bias5, lv313_1), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm431 = R.call_tir(cls.layer_norm3, (lv315, model_decoder_layers_25_self_attn_layer_norm_weight5, model_decoder_layers_25_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv316 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm431, model_decoder_layers_25_self_attn_q_proj_weight5, model_decoder_layers_25_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv201_1 = R.call_tir(cls.NT_matmul, (layer_norm431, model_decoder_layers_25_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv317 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm431, model_decoder_layers_25_self_attn_v_proj_weight5, model_decoder_layers_25_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv318 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv316, lv201_1, lv317), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv315_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), lv318), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv319 = R.call_tir(cls.fused_reshape23_reshape24, (lv315_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv320 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv319, model_decoder_layers_25_self_attn_out_proj_weight5, model_decoder_layers_25_self_attn_out_proj_bias5, lv315), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm432 = R.call_tir(cls.layer_norm3, (lv320, model_decoder_layers_25_encoder_attn_layer_norm_weight5, model_decoder_layers_25_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv321 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm432, model_decoder_layers_25_encoder_attn_q_proj_weight5, model_decoder_layers_25_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv322 = R.call_tir(cls.fused_reshape21_reshape25, (lv321,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv316_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), lv322), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv323 = R.call_tir(cls.fused_reshape23_reshape24, (lv316_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv324 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv323, model_decoder_layers_25_encoder_attn_out_proj_weight5, model_decoder_layers_25_encoder_attn_out_proj_bias5, lv320), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm433 = R.call_tir(cls.layer_norm3, (lv324, model_decoder_layers_25_final_layer_norm_weight5, model_decoder_layers_25_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv325 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm433, model_decoder_layers_25_fc1_weight5, model_decoder_layers_25_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv326 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv325, model_decoder_layers_25_fc2_weight5, model_decoder_layers_25_fc2_bias5, lv324), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm434 = R.call_tir(cls.layer_norm3, (lv326, model_decoder_layers_26_self_attn_layer_norm_weight5, model_decoder_layers_26_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv327 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm434, model_decoder_layers_26_self_attn_q_proj_weight5, model_decoder_layers_26_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv209_1 = R.call_tir(cls.NT_matmul, (layer_norm434, model_decoder_layers_26_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv328 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm434, model_decoder_layers_26_self_attn_v_proj_weight5, model_decoder_layers_26_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv329 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv327, lv209_1, lv328), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv317_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), lv329), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv330 = R.call_tir(cls.fused_reshape23_reshape24, (lv317_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv331 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv330, model_decoder_layers_26_self_attn_out_proj_weight5, model_decoder_layers_26_self_attn_out_proj_bias5, lv326), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm435 = R.call_tir(cls.layer_norm3, (lv331, model_decoder_layers_26_encoder_attn_layer_norm_weight5, model_decoder_layers_26_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv332 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm435, model_decoder_layers_26_encoder_attn_q_proj_weight5, model_decoder_layers_26_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv333 = R.call_tir(cls.fused_reshape21_reshape25, (lv332,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv318_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), lv333), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv334 = R.call_tir(cls.fused_reshape23_reshape24, (lv318_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv335 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv334, model_decoder_layers_26_encoder_attn_out_proj_weight5, model_decoder_layers_26_encoder_attn_out_proj_bias5, lv331), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm436 = R.call_tir(cls.layer_norm3, (lv335, model_decoder_layers_26_final_layer_norm_weight5, model_decoder_layers_26_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv336 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm436, model_decoder_layers_26_fc1_weight5, model_decoder_layers_26_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv337 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv336, model_decoder_layers_26_fc2_weight5, model_decoder_layers_26_fc2_bias5, lv335), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm437 = R.call_tir(cls.layer_norm3, (lv337, model_decoder_layers_27_self_attn_layer_norm_weight5, model_decoder_layers_27_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv338 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm437, model_decoder_layers_27_self_attn_q_proj_weight5, model_decoder_layers_27_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv217_1 = R.call_tir(cls.NT_matmul, (layer_norm437, model_decoder_layers_27_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv339 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm437, model_decoder_layers_27_self_attn_v_proj_weight5, model_decoder_layers_27_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv340 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv338, lv217_1, lv339), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv319_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), lv340), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv341 = R.call_tir(cls.fused_reshape23_reshape24, (lv319_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv342 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv341, model_decoder_layers_27_self_attn_out_proj_weight5, model_decoder_layers_27_self_attn_out_proj_bias5, lv337), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm438 = R.call_tir(cls.layer_norm3, (lv342, model_decoder_layers_27_encoder_attn_layer_norm_weight5, model_decoder_layers_27_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv343 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm438, model_decoder_layers_27_encoder_attn_q_proj_weight5, model_decoder_layers_27_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv344 = R.call_tir(cls.fused_reshape21_reshape25, (lv343,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv320_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), lv344), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv345 = R.call_tir(cls.fused_reshape23_reshape24, (lv320_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv346 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv345, model_decoder_layers_27_encoder_attn_out_proj_weight5, model_decoder_layers_27_encoder_attn_out_proj_bias5, lv342), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm439 = R.call_tir(cls.layer_norm3, (lv346, model_decoder_layers_27_final_layer_norm_weight5, model_decoder_layers_27_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv347 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm439, model_decoder_layers_27_fc1_weight5, model_decoder_layers_27_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv348 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv347, model_decoder_layers_27_fc2_weight5, model_decoder_layers_27_fc2_bias5, lv346), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm440 = R.call_tir(cls.layer_norm3, (lv348, model_decoder_layers_28_self_attn_layer_norm_weight5, model_decoder_layers_28_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv349 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm440, model_decoder_layers_28_self_attn_q_proj_weight5, model_decoder_layers_28_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv225_1 = R.call_tir(cls.NT_matmul, (layer_norm440, model_decoder_layers_28_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv350 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm440, model_decoder_layers_28_self_attn_v_proj_weight5, model_decoder_layers_28_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv351 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv349, lv225_1, lv350), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv321_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), lv351), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv352 = R.call_tir(cls.fused_reshape23_reshape24, (lv321_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv353 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv352, model_decoder_layers_28_self_attn_out_proj_weight5, model_decoder_layers_28_self_attn_out_proj_bias5, lv348), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm441 = R.call_tir(cls.layer_norm3, (lv353, model_decoder_layers_28_encoder_attn_layer_norm_weight5, model_decoder_layers_28_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv354 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm441, model_decoder_layers_28_encoder_attn_q_proj_weight5, model_decoder_layers_28_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv355 = R.call_tir(cls.fused_reshape21_reshape25, (lv354,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv322_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), lv355), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv356 = R.call_tir(cls.fused_reshape23_reshape24, (lv322_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv357 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv356, model_decoder_layers_28_encoder_attn_out_proj_weight5, model_decoder_layers_28_encoder_attn_out_proj_bias5, lv353), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm442 = R.call_tir(cls.layer_norm3, (lv357, model_decoder_layers_28_final_layer_norm_weight5, model_decoder_layers_28_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv358 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm442, model_decoder_layers_28_fc1_weight5, model_decoder_layers_28_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv359 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv358, model_decoder_layers_28_fc2_weight5, model_decoder_layers_28_fc2_bias5, lv357), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm443 = R.call_tir(cls.layer_norm3, (lv359, model_decoder_layers_29_self_attn_layer_norm_weight5, model_decoder_layers_29_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv360 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm443, model_decoder_layers_29_self_attn_q_proj_weight5, model_decoder_layers_29_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv233_1 = R.call_tir(cls.NT_matmul, (layer_norm443, model_decoder_layers_29_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv361 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm443, model_decoder_layers_29_self_attn_v_proj_weight5, model_decoder_layers_29_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv362 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv360, lv233_1, lv361), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv323_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), lv362), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv363 = R.call_tir(cls.fused_reshape23_reshape24, (lv323_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv364 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv363, model_decoder_layers_29_self_attn_out_proj_weight5, model_decoder_layers_29_self_attn_out_proj_bias5, lv359), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm444 = R.call_tir(cls.layer_norm3, (lv364, model_decoder_layers_29_encoder_attn_layer_norm_weight5, model_decoder_layers_29_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv365 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm444, model_decoder_layers_29_encoder_attn_q_proj_weight5, model_decoder_layers_29_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv366 = R.call_tir(cls.fused_reshape21_reshape25, (lv365,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv324_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), lv366), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv367 = R.call_tir(cls.fused_reshape23_reshape24, (lv324_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv368 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv367, model_decoder_layers_29_encoder_attn_out_proj_weight5, model_decoder_layers_29_encoder_attn_out_proj_bias5, lv364), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm445 = R.call_tir(cls.layer_norm3, (lv368, model_decoder_layers_29_final_layer_norm_weight5, model_decoder_layers_29_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv369 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm445, model_decoder_layers_29_fc1_weight5, model_decoder_layers_29_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv370 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv369, model_decoder_layers_29_fc2_weight5, model_decoder_layers_29_fc2_bias5, lv368), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm446 = R.call_tir(cls.layer_norm3, (lv370, model_decoder_layers_30_self_attn_layer_norm_weight5, model_decoder_layers_30_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv371 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm446, model_decoder_layers_30_self_attn_q_proj_weight5, model_decoder_layers_30_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv241_1 = R.call_tir(cls.NT_matmul, (layer_norm446, model_decoder_layers_30_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv372 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm446, model_decoder_layers_30_self_attn_v_proj_weight5, model_decoder_layers_30_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv373 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv371, lv241_1, lv372), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv325_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), lv373), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv374 = R.call_tir(cls.fused_reshape23_reshape24, (lv325_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv375 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv374, model_decoder_layers_30_self_attn_out_proj_weight5, model_decoder_layers_30_self_attn_out_proj_bias5, lv370), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm447 = R.call_tir(cls.layer_norm3, (lv375, model_decoder_layers_30_encoder_attn_layer_norm_weight5, model_decoder_layers_30_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv376 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm447, model_decoder_layers_30_encoder_attn_q_proj_weight5, model_decoder_layers_30_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv377 = R.call_tir(cls.fused_reshape21_reshape25, (lv376,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv326_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), lv377), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv378 = R.call_tir(cls.fused_reshape23_reshape24, (lv326_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv379 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv378, model_decoder_layers_30_encoder_attn_out_proj_weight5, model_decoder_layers_30_encoder_attn_out_proj_bias5, lv375), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm448 = R.call_tir(cls.layer_norm3, (lv379, model_decoder_layers_30_final_layer_norm_weight5, model_decoder_layers_30_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv380 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm448, model_decoder_layers_30_fc1_weight5, model_decoder_layers_30_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv381 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv380, model_decoder_layers_30_fc2_weight5, model_decoder_layers_30_fc2_bias5, lv379), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm449 = R.call_tir(cls.layer_norm3, (lv381, model_decoder_layers_31_self_attn_layer_norm_weight5, model_decoder_layers_31_self_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv382 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm449, model_decoder_layers_31_self_attn_q_proj_weight5, model_decoder_layers_31_self_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv249_1 = R.call_tir(cls.NT_matmul, (layer_norm449, model_decoder_layers_31_self_attn_k_proj_weight5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv383 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm449, model_decoder_layers_31_self_attn_v_proj_weight5, model_decoder_layers_31_self_attn_v_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv384 = R.call_tir(cls.fused_reshape21_reshape21_reshape21_concatenate2_reshape22, (lv382, lv249_1, lv383), out_sinfo=R.Tensor((1, 60, 64), dtype="float16")) lv327_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), lv384), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv385 = R.call_tir(cls.fused_reshape23_reshape24, (lv327_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv386 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv385, model_decoder_layers_31_self_attn_out_proj_weight5, model_decoder_layers_31_self_attn_out_proj_bias5, lv381), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm450 = R.call_tir(cls.layer_norm3, (lv386, model_decoder_layers_31_encoder_attn_layer_norm_weight5, model_decoder_layers_31_encoder_attn_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv387 = R.call_tir(cls.fused_NT_matmul_add7, (layer_norm450, model_decoder_layers_31_encoder_attn_q_proj_weight5, model_decoder_layers_31_encoder_attn_q_proj_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv388 = R.call_tir(cls.fused_reshape21_reshape25, (lv387,), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv328_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), lv388), out_sinfo=R.Tensor((1, 20, 64), dtype="float16")) lv389 = R.call_tir(cls.fused_reshape23_reshape24, (lv328_1,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv390 = R.call_tir(cls.fused_NT_matmul_add7_add6, (lv389, model_decoder_layers_31_encoder_attn_out_proj_weight5, model_decoder_layers_31_encoder_attn_out_proj_bias5, lv386), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm451 = R.call_tir(cls.layer_norm3, (lv390, model_decoder_layers_31_final_layer_norm_weight5, model_decoder_layers_31_final_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) lv391 = R.call_tir(cls.fused_NT_matmul1_add8_gelu2, (layer_norm451, model_decoder_layers_31_fc1_weight5, model_decoder_layers_31_fc1_bias5), out_sinfo=R.Tensor((1, 1, 5120), dtype="float16")) lv392 = R.call_tir(cls.fused_NT_matmul2_add7_add6, (lv391, model_decoder_layers_31_fc2_weight5, model_decoder_layers_31_fc2_bias5, lv390), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) layer_norm452 = R.call_tir(cls.layer_norm3, (lv392, model_decoder_layer_norm_weight5, model_decoder_layer_norm_bias5), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) gv5 = R.call_tir(cls.NT_matmul3, (layer_norm452, model_decoder_embed_tokens_weight5), out_sinfo=R.Tensor((1, 1, 51866), dtype="float32")) R.output(gv5) return gv5 @R.function def multinomial_from_uniform(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), uniform_samples: R.Tensor(("num_samples",), dtype="float32"), sample_indices: R.Tensor(("num_samples",), dtype="int32")) -> R.Tensor(("num_samples",), dtype="int32"): num_samples = T.int64() batch_size = T.int64() vocab_size = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) cls = Module with R.dataflow(): uniform_samples_1: R.Tensor((num_samples, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", uniform_samples, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="float32"),)) sample_indices_1: R.Tensor((num_samples, 1), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", sample_indices, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="int32"),)) nn_multinomial_from_uniform = R.call_tir(cls.parallel_sampling_from_prob, (probs, uniform_samples_1, sample_indices_1), out_sinfo=R.Tensor((num_samples, 1), dtype="int32")) gv: R.Tensor((num_samples,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", nn_multinomial_from_uniform, R.shape([num_samples]), sinfo_args=(R.Tensor((num_samples,), dtype="int32"),)) R.output(gv) return gv @R.function def prefill(input_ids: R.Tensor((1, "seq_len"), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((1280, 128, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280, 3), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1500, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((51866, 1280), dtype="float16"), R.Tensor((448, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280, 1280), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((5120, 1280), dtype="float16"), R.Tensor((5120,), dtype="float16"), R.Tensor((1280, 5120), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"), R.Tensor((1280,), dtype="float16"))) -> R.Tensor((1, 1, 51866), dtype="float32"): seq_len = T.int64() R.func_attr({"num_input": 2, "relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module with R.dataflow(): model_decoder_embed_tokens_weight4: R.Tensor((51866, 1280), dtype="float16") = packed_params[487] model_decoder_embed_positions_weight4: R.Tensor((448, 1280), dtype="float16") = packed_params[488] model_decoder_layers_0_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[489] model_decoder_layers_0_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[490] model_decoder_layers_0_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[491] model_decoder_layers_0_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[492] model_decoder_layers_0_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[493] model_decoder_layers_0_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[494] model_decoder_layers_0_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[495] model_decoder_layers_0_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[496] model_decoder_layers_0_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[497] model_decoder_layers_0_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[501] model_decoder_layers_0_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[502] model_decoder_layers_0_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[503] model_decoder_layers_0_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[504] model_decoder_layers_0_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[505] model_decoder_layers_0_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[506] model_decoder_layers_0_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[507] model_decoder_layers_0_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[508] model_decoder_layers_0_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[509] model_decoder_layers_0_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[510] model_decoder_layers_0_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[511] model_decoder_layers_0_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[512] model_decoder_layers_1_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[513] model_decoder_layers_1_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[514] model_decoder_layers_1_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[515] model_decoder_layers_1_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[516] model_decoder_layers_1_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[517] model_decoder_layers_1_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[518] model_decoder_layers_1_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[519] model_decoder_layers_1_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[520] model_decoder_layers_1_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[521] model_decoder_layers_1_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[525] model_decoder_layers_1_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[526] model_decoder_layers_1_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[527] model_decoder_layers_1_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[528] model_decoder_layers_1_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[529] model_decoder_layers_1_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[530] model_decoder_layers_1_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[531] model_decoder_layers_1_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[532] model_decoder_layers_1_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[533] model_decoder_layers_1_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[534] model_decoder_layers_1_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[535] model_decoder_layers_1_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[536] model_decoder_layers_2_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[537] model_decoder_layers_2_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[538] model_decoder_layers_2_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[539] model_decoder_layers_2_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[540] model_decoder_layers_2_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[541] model_decoder_layers_2_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[542] model_decoder_layers_2_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[543] model_decoder_layers_2_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[544] model_decoder_layers_2_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[545] model_decoder_layers_2_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[549] model_decoder_layers_2_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[550] model_decoder_layers_2_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[551] model_decoder_layers_2_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[552] model_decoder_layers_2_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[553] model_decoder_layers_2_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[554] model_decoder_layers_2_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[555] model_decoder_layers_2_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[556] model_decoder_layers_2_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[557] model_decoder_layers_2_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[558] model_decoder_layers_2_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[559] model_decoder_layers_2_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[560] model_decoder_layers_3_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[561] model_decoder_layers_3_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[562] model_decoder_layers_3_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[563] model_decoder_layers_3_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[564] model_decoder_layers_3_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[565] model_decoder_layers_3_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[566] model_decoder_layers_3_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[567] model_decoder_layers_3_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[568] model_decoder_layers_3_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[569] model_decoder_layers_3_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[573] model_decoder_layers_3_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[574] model_decoder_layers_3_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[575] model_decoder_layers_3_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[576] model_decoder_layers_3_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[577] model_decoder_layers_3_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[578] model_decoder_layers_3_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[579] model_decoder_layers_3_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[580] model_decoder_layers_3_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[581] model_decoder_layers_3_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[582] model_decoder_layers_3_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[583] model_decoder_layers_3_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[584] model_decoder_layers_4_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[585] model_decoder_layers_4_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[586] model_decoder_layers_4_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[587] model_decoder_layers_4_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[588] model_decoder_layers_4_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[589] model_decoder_layers_4_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[590] model_decoder_layers_4_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[591] model_decoder_layers_4_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[592] model_decoder_layers_4_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[593] model_decoder_layers_4_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[597] model_decoder_layers_4_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[598] model_decoder_layers_4_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[599] model_decoder_layers_4_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[600] model_decoder_layers_4_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[601] model_decoder_layers_4_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[602] model_decoder_layers_4_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[603] model_decoder_layers_4_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[604] model_decoder_layers_4_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[605] model_decoder_layers_4_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[606] model_decoder_layers_4_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[607] model_decoder_layers_4_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[608] model_decoder_layers_5_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[609] model_decoder_layers_5_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[610] model_decoder_layers_5_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[611] model_decoder_layers_5_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[612] model_decoder_layers_5_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[613] model_decoder_layers_5_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[614] model_decoder_layers_5_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[615] model_decoder_layers_5_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[616] model_decoder_layers_5_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[617] model_decoder_layers_5_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[621] model_decoder_layers_5_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[622] model_decoder_layers_5_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[623] model_decoder_layers_5_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[624] model_decoder_layers_5_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[625] model_decoder_layers_5_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[626] model_decoder_layers_5_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[627] model_decoder_layers_5_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[628] model_decoder_layers_5_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[629] model_decoder_layers_5_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[630] model_decoder_layers_5_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[631] model_decoder_layers_5_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[632] model_decoder_layers_6_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[633] model_decoder_layers_6_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[634] model_decoder_layers_6_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[635] model_decoder_layers_6_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[636] model_decoder_layers_6_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[637] model_decoder_layers_6_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[638] model_decoder_layers_6_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[639] model_decoder_layers_6_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[640] model_decoder_layers_6_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[641] model_decoder_layers_6_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[645] model_decoder_layers_6_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[646] model_decoder_layers_6_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[647] model_decoder_layers_6_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[648] model_decoder_layers_6_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[649] model_decoder_layers_6_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[650] model_decoder_layers_6_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[651] model_decoder_layers_6_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[652] model_decoder_layers_6_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[653] model_decoder_layers_6_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[654] model_decoder_layers_6_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[655] model_decoder_layers_6_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[656] model_decoder_layers_7_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[657] model_decoder_layers_7_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[658] model_decoder_layers_7_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[659] model_decoder_layers_7_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[660] model_decoder_layers_7_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[661] model_decoder_layers_7_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[662] model_decoder_layers_7_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[663] model_decoder_layers_7_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[664] model_decoder_layers_7_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[665] model_decoder_layers_7_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[669] model_decoder_layers_7_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[670] model_decoder_layers_7_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[671] model_decoder_layers_7_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[672] model_decoder_layers_7_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[673] model_decoder_layers_7_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[674] model_decoder_layers_7_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[675] model_decoder_layers_7_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[676] model_decoder_layers_7_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[677] model_decoder_layers_7_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[678] model_decoder_layers_7_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[679] model_decoder_layers_7_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[680] model_decoder_layers_8_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[681] model_decoder_layers_8_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[682] model_decoder_layers_8_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[683] model_decoder_layers_8_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[684] model_decoder_layers_8_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[685] model_decoder_layers_8_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[686] model_decoder_layers_8_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[687] model_decoder_layers_8_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[688] model_decoder_layers_8_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[689] model_decoder_layers_8_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[693] model_decoder_layers_8_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[694] model_decoder_layers_8_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[695] model_decoder_layers_8_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[696] model_decoder_layers_8_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[697] model_decoder_layers_8_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[698] model_decoder_layers_8_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[699] model_decoder_layers_8_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[700] model_decoder_layers_8_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[701] model_decoder_layers_8_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[702] model_decoder_layers_8_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[703] model_decoder_layers_8_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[704] model_decoder_layers_9_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[705] model_decoder_layers_9_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[706] model_decoder_layers_9_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[707] model_decoder_layers_9_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[708] model_decoder_layers_9_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[709] model_decoder_layers_9_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[710] model_decoder_layers_9_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[711] model_decoder_layers_9_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[712] model_decoder_layers_9_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[713] model_decoder_layers_9_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[717] model_decoder_layers_9_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[718] model_decoder_layers_9_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[719] model_decoder_layers_9_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[720] model_decoder_layers_9_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[721] model_decoder_layers_9_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[722] model_decoder_layers_9_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[723] model_decoder_layers_9_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[724] model_decoder_layers_9_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[725] model_decoder_layers_9_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[726] model_decoder_layers_9_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[727] model_decoder_layers_9_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[728] model_decoder_layers_10_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[729] model_decoder_layers_10_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[730] model_decoder_layers_10_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[731] model_decoder_layers_10_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[732] model_decoder_layers_10_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[733] model_decoder_layers_10_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[734] model_decoder_layers_10_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[735] model_decoder_layers_10_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[736] model_decoder_layers_10_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[737] model_decoder_layers_10_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[741] model_decoder_layers_10_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[742] model_decoder_layers_10_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[743] model_decoder_layers_10_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[744] model_decoder_layers_10_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[745] model_decoder_layers_10_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[746] model_decoder_layers_10_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[747] model_decoder_layers_10_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[748] model_decoder_layers_10_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[749] model_decoder_layers_10_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[750] model_decoder_layers_10_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[751] model_decoder_layers_10_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[752] model_decoder_layers_11_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[753] model_decoder_layers_11_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[754] model_decoder_layers_11_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[755] model_decoder_layers_11_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[756] model_decoder_layers_11_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[757] model_decoder_layers_11_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[758] model_decoder_layers_11_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[759] model_decoder_layers_11_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[760] model_decoder_layers_11_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[761] model_decoder_layers_11_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[765] model_decoder_layers_11_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[766] model_decoder_layers_11_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[767] model_decoder_layers_11_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[768] model_decoder_layers_11_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[769] model_decoder_layers_11_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[770] model_decoder_layers_11_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[771] model_decoder_layers_11_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[772] model_decoder_layers_11_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[773] model_decoder_layers_11_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[774] model_decoder_layers_11_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[775] model_decoder_layers_11_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[776] model_decoder_layers_12_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[777] model_decoder_layers_12_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[778] model_decoder_layers_12_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[779] model_decoder_layers_12_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[780] model_decoder_layers_12_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[781] model_decoder_layers_12_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[782] model_decoder_layers_12_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[783] model_decoder_layers_12_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[784] model_decoder_layers_12_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[785] model_decoder_layers_12_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[789] model_decoder_layers_12_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[790] model_decoder_layers_12_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[791] model_decoder_layers_12_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[792] model_decoder_layers_12_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[793] model_decoder_layers_12_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[794] model_decoder_layers_12_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[795] model_decoder_layers_12_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[796] model_decoder_layers_12_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[797] model_decoder_layers_12_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[798] model_decoder_layers_12_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[799] model_decoder_layers_12_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[800] model_decoder_layers_13_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[801] model_decoder_layers_13_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[802] model_decoder_layers_13_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[803] model_decoder_layers_13_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[804] model_decoder_layers_13_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[805] model_decoder_layers_13_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[806] model_decoder_layers_13_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[807] model_decoder_layers_13_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[808] model_decoder_layers_13_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[809] model_decoder_layers_13_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[813] model_decoder_layers_13_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[814] model_decoder_layers_13_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[815] model_decoder_layers_13_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[816] model_decoder_layers_13_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[817] model_decoder_layers_13_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[818] model_decoder_layers_13_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[819] model_decoder_layers_13_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[820] model_decoder_layers_13_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[821] model_decoder_layers_13_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[822] model_decoder_layers_13_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[823] model_decoder_layers_13_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[824] model_decoder_layers_14_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[825] model_decoder_layers_14_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[826] model_decoder_layers_14_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[827] model_decoder_layers_14_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[828] model_decoder_layers_14_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[829] model_decoder_layers_14_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[830] model_decoder_layers_14_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[831] model_decoder_layers_14_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[832] model_decoder_layers_14_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[833] model_decoder_layers_14_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[837] model_decoder_layers_14_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[838] model_decoder_layers_14_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[839] model_decoder_layers_14_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[840] model_decoder_layers_14_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[841] model_decoder_layers_14_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[842] model_decoder_layers_14_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[843] model_decoder_layers_14_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[844] model_decoder_layers_14_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[845] model_decoder_layers_14_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[846] model_decoder_layers_14_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[847] model_decoder_layers_14_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[848] model_decoder_layers_15_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[849] model_decoder_layers_15_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[850] model_decoder_layers_15_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[851] model_decoder_layers_15_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[852] model_decoder_layers_15_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[853] model_decoder_layers_15_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[854] model_decoder_layers_15_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[855] model_decoder_layers_15_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[856] model_decoder_layers_15_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[857] model_decoder_layers_15_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[861] model_decoder_layers_15_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[862] model_decoder_layers_15_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[863] model_decoder_layers_15_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[864] model_decoder_layers_15_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[865] model_decoder_layers_15_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[866] model_decoder_layers_15_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[867] model_decoder_layers_15_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[868] model_decoder_layers_15_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[869] model_decoder_layers_15_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[870] model_decoder_layers_15_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[871] model_decoder_layers_15_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[872] model_decoder_layers_16_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[873] model_decoder_layers_16_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[874] model_decoder_layers_16_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[875] model_decoder_layers_16_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[876] model_decoder_layers_16_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[877] model_decoder_layers_16_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[878] model_decoder_layers_16_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[879] model_decoder_layers_16_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[880] model_decoder_layers_16_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[881] model_decoder_layers_16_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[885] model_decoder_layers_16_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[886] model_decoder_layers_16_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[887] model_decoder_layers_16_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[888] model_decoder_layers_16_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[889] model_decoder_layers_16_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[890] model_decoder_layers_16_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[891] model_decoder_layers_16_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[892] model_decoder_layers_16_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[893] model_decoder_layers_16_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[894] model_decoder_layers_16_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[895] model_decoder_layers_16_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[896] model_decoder_layers_17_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[897] model_decoder_layers_17_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[898] model_decoder_layers_17_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[899] model_decoder_layers_17_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[900] model_decoder_layers_17_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[901] model_decoder_layers_17_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[902] model_decoder_layers_17_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[903] model_decoder_layers_17_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[904] model_decoder_layers_17_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[905] model_decoder_layers_17_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[909] model_decoder_layers_17_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[910] model_decoder_layers_17_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[911] model_decoder_layers_17_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[912] model_decoder_layers_17_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[913] model_decoder_layers_17_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[914] model_decoder_layers_17_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[915] model_decoder_layers_17_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[916] model_decoder_layers_17_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[917] model_decoder_layers_17_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[918] model_decoder_layers_17_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[919] model_decoder_layers_17_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[920] model_decoder_layers_18_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[921] model_decoder_layers_18_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[922] model_decoder_layers_18_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[923] model_decoder_layers_18_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[924] model_decoder_layers_18_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[925] model_decoder_layers_18_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[926] model_decoder_layers_18_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[927] model_decoder_layers_18_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[928] model_decoder_layers_18_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[929] model_decoder_layers_18_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[933] model_decoder_layers_18_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[934] model_decoder_layers_18_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[935] model_decoder_layers_18_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[936] model_decoder_layers_18_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[937] model_decoder_layers_18_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[938] model_decoder_layers_18_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[939] model_decoder_layers_18_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[940] model_decoder_layers_18_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[941] model_decoder_layers_18_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[942] model_decoder_layers_18_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[943] model_decoder_layers_18_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[944] model_decoder_layers_19_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[945] model_decoder_layers_19_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[946] model_decoder_layers_19_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[947] model_decoder_layers_19_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[948] model_decoder_layers_19_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[949] model_decoder_layers_19_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[950] model_decoder_layers_19_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[951] model_decoder_layers_19_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[952] model_decoder_layers_19_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[953] model_decoder_layers_19_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[957] model_decoder_layers_19_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[958] model_decoder_layers_19_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[959] model_decoder_layers_19_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[960] model_decoder_layers_19_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[961] model_decoder_layers_19_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[962] model_decoder_layers_19_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[963] model_decoder_layers_19_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[964] model_decoder_layers_19_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[965] model_decoder_layers_19_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[966] model_decoder_layers_19_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[967] model_decoder_layers_19_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[968] model_decoder_layers_20_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[969] model_decoder_layers_20_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[970] model_decoder_layers_20_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[971] model_decoder_layers_20_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[972] model_decoder_layers_20_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[973] model_decoder_layers_20_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[974] model_decoder_layers_20_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[975] model_decoder_layers_20_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[976] model_decoder_layers_20_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[977] model_decoder_layers_20_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[981] model_decoder_layers_20_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[982] model_decoder_layers_20_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[983] model_decoder_layers_20_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[984] model_decoder_layers_20_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[985] model_decoder_layers_20_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[986] model_decoder_layers_20_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[987] model_decoder_layers_20_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[988] model_decoder_layers_20_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[989] model_decoder_layers_20_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[990] model_decoder_layers_20_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[991] model_decoder_layers_20_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[992] model_decoder_layers_21_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[993] model_decoder_layers_21_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[994] model_decoder_layers_21_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[995] model_decoder_layers_21_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[996] model_decoder_layers_21_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[997] model_decoder_layers_21_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[998] model_decoder_layers_21_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[999] model_decoder_layers_21_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1000] model_decoder_layers_21_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1001] model_decoder_layers_21_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1005] model_decoder_layers_21_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1006] model_decoder_layers_21_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1007] model_decoder_layers_21_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1008] model_decoder_layers_21_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1009] model_decoder_layers_21_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1010] model_decoder_layers_21_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1011] model_decoder_layers_21_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1012] model_decoder_layers_21_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1013] model_decoder_layers_21_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1014] model_decoder_layers_21_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1015] model_decoder_layers_21_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1016] model_decoder_layers_22_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1017] model_decoder_layers_22_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1018] model_decoder_layers_22_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1019] model_decoder_layers_22_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1020] model_decoder_layers_22_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1021] model_decoder_layers_22_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1022] model_decoder_layers_22_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1023] model_decoder_layers_22_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1024] model_decoder_layers_22_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1025] model_decoder_layers_22_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1029] model_decoder_layers_22_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1030] model_decoder_layers_22_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1031] model_decoder_layers_22_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1032] model_decoder_layers_22_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1033] model_decoder_layers_22_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1034] model_decoder_layers_22_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1035] model_decoder_layers_22_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1036] model_decoder_layers_22_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1037] model_decoder_layers_22_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1038] model_decoder_layers_22_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1039] model_decoder_layers_22_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1040] model_decoder_layers_23_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1041] model_decoder_layers_23_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1042] model_decoder_layers_23_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1043] model_decoder_layers_23_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1044] model_decoder_layers_23_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1045] model_decoder_layers_23_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1046] model_decoder_layers_23_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1047] model_decoder_layers_23_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1048] model_decoder_layers_23_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1049] model_decoder_layers_23_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1053] model_decoder_layers_23_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1054] model_decoder_layers_23_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1055] model_decoder_layers_23_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1056] model_decoder_layers_23_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1057] model_decoder_layers_23_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1058] model_decoder_layers_23_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1059] model_decoder_layers_23_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1060] model_decoder_layers_23_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1061] model_decoder_layers_23_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1062] model_decoder_layers_23_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1063] model_decoder_layers_23_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1064] model_decoder_layers_24_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1065] model_decoder_layers_24_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1066] model_decoder_layers_24_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1067] model_decoder_layers_24_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1068] model_decoder_layers_24_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1069] model_decoder_layers_24_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1070] model_decoder_layers_24_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1071] model_decoder_layers_24_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1072] model_decoder_layers_24_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1073] model_decoder_layers_24_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1077] model_decoder_layers_24_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1078] model_decoder_layers_24_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1079] model_decoder_layers_24_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1080] model_decoder_layers_24_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1081] model_decoder_layers_24_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1082] model_decoder_layers_24_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1083] model_decoder_layers_24_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1084] model_decoder_layers_24_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1085] model_decoder_layers_24_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1086] model_decoder_layers_24_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1087] model_decoder_layers_24_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1088] model_decoder_layers_25_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1089] model_decoder_layers_25_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1090] model_decoder_layers_25_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1091] model_decoder_layers_25_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1092] model_decoder_layers_25_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1093] model_decoder_layers_25_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1094] model_decoder_layers_25_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1095] model_decoder_layers_25_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1096] model_decoder_layers_25_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1097] model_decoder_layers_25_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1101] model_decoder_layers_25_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1102] model_decoder_layers_25_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1103] model_decoder_layers_25_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1104] model_decoder_layers_25_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1105] model_decoder_layers_25_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1106] model_decoder_layers_25_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1107] model_decoder_layers_25_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1108] model_decoder_layers_25_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1109] model_decoder_layers_25_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1110] model_decoder_layers_25_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1111] model_decoder_layers_25_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1112] model_decoder_layers_26_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1113] model_decoder_layers_26_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1114] model_decoder_layers_26_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1115] model_decoder_layers_26_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1116] model_decoder_layers_26_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1117] model_decoder_layers_26_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1118] model_decoder_layers_26_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1119] model_decoder_layers_26_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1120] model_decoder_layers_26_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1121] model_decoder_layers_26_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1125] model_decoder_layers_26_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1126] model_decoder_layers_26_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1127] model_decoder_layers_26_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1128] model_decoder_layers_26_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1129] model_decoder_layers_26_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1130] model_decoder_layers_26_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1131] model_decoder_layers_26_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1132] model_decoder_layers_26_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1133] model_decoder_layers_26_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1134] model_decoder_layers_26_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1135] model_decoder_layers_26_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1136] model_decoder_layers_27_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1137] model_decoder_layers_27_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1138] model_decoder_layers_27_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1139] model_decoder_layers_27_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1140] model_decoder_layers_27_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1141] model_decoder_layers_27_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1142] model_decoder_layers_27_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1143] model_decoder_layers_27_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1144] model_decoder_layers_27_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1145] model_decoder_layers_27_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1149] model_decoder_layers_27_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1150] model_decoder_layers_27_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1151] model_decoder_layers_27_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1152] model_decoder_layers_27_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1153] model_decoder_layers_27_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1154] model_decoder_layers_27_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1155] model_decoder_layers_27_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1156] model_decoder_layers_27_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1157] model_decoder_layers_27_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1158] model_decoder_layers_27_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1159] model_decoder_layers_27_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1160] model_decoder_layers_28_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1161] model_decoder_layers_28_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1162] model_decoder_layers_28_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1163] model_decoder_layers_28_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1164] model_decoder_layers_28_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1165] model_decoder_layers_28_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1166] model_decoder_layers_28_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1167] model_decoder_layers_28_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1168] model_decoder_layers_28_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1169] model_decoder_layers_28_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1173] model_decoder_layers_28_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1174] model_decoder_layers_28_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1175] model_decoder_layers_28_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1176] model_decoder_layers_28_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1177] model_decoder_layers_28_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1178] model_decoder_layers_28_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1179] model_decoder_layers_28_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1180] model_decoder_layers_28_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1181] model_decoder_layers_28_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1182] model_decoder_layers_28_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1183] model_decoder_layers_28_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1184] model_decoder_layers_29_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1185] model_decoder_layers_29_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1186] model_decoder_layers_29_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1187] model_decoder_layers_29_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1188] model_decoder_layers_29_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1189] model_decoder_layers_29_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1190] model_decoder_layers_29_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1191] model_decoder_layers_29_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1192] model_decoder_layers_29_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1193] model_decoder_layers_29_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1197] model_decoder_layers_29_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1198] model_decoder_layers_29_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1199] model_decoder_layers_29_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1200] model_decoder_layers_29_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1201] model_decoder_layers_29_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1202] model_decoder_layers_29_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1203] model_decoder_layers_29_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1204] model_decoder_layers_29_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1205] model_decoder_layers_29_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1206] model_decoder_layers_29_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1207] model_decoder_layers_29_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1208] model_decoder_layers_30_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1209] model_decoder_layers_30_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1210] model_decoder_layers_30_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1211] model_decoder_layers_30_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1212] model_decoder_layers_30_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1213] model_decoder_layers_30_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1214] model_decoder_layers_30_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1215] model_decoder_layers_30_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1216] model_decoder_layers_30_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1217] model_decoder_layers_30_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1221] model_decoder_layers_30_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1222] model_decoder_layers_30_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1223] model_decoder_layers_30_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1224] model_decoder_layers_30_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1225] model_decoder_layers_30_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1226] model_decoder_layers_30_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1227] model_decoder_layers_30_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1228] model_decoder_layers_30_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1229] model_decoder_layers_30_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1230] model_decoder_layers_30_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1231] model_decoder_layers_30_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1232] model_decoder_layers_31_self_attn_k_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1233] model_decoder_layers_31_self_attn_v_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1234] model_decoder_layers_31_self_attn_v_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1235] model_decoder_layers_31_self_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1236] model_decoder_layers_31_self_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1237] model_decoder_layers_31_self_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1238] model_decoder_layers_31_self_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1239] model_decoder_layers_31_self_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1240] model_decoder_layers_31_self_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1241] model_decoder_layers_31_encoder_attn_q_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1245] model_decoder_layers_31_encoder_attn_q_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1246] model_decoder_layers_31_encoder_attn_out_proj_weight4: R.Tensor((1280, 1280), dtype="float16") = packed_params[1247] model_decoder_layers_31_encoder_attn_out_proj_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1248] model_decoder_layers_31_encoder_attn_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1249] model_decoder_layers_31_encoder_attn_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1250] model_decoder_layers_31_fc1_weight4: R.Tensor((5120, 1280), dtype="float16") = packed_params[1251] model_decoder_layers_31_fc1_bias4: R.Tensor((5120,), dtype="float16") = packed_params[1252] model_decoder_layers_31_fc2_weight4: R.Tensor((1280, 5120), dtype="float16") = packed_params[1253] model_decoder_layers_31_fc2_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1254] model_decoder_layers_31_final_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1255] model_decoder_layers_31_final_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1256] model_decoder_layer_norm_weight4: R.Tensor((1280,), dtype="float16") = packed_params[1257] model_decoder_layer_norm_bias4: R.Tensor((1280,), dtype="float16") = packed_params[1258] reshape1030 = R.call_tir(cls.reshape12, (input_ids,), out_sinfo=R.Tensor((seq_len,), dtype="int32")) take5 = R.call_tir(cls.take, (model_decoder_embed_tokens_weight4, reshape1030), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) reshape1031 = R.call_tir(cls.reshape13, (take5,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv198: R.Tensor((seq_len,), dtype="int32") = R.call_pure_packed("vm.builtin.attention_kv_cache_get_query_positions", paged_kv_cache, sinfo_args=(R.Tensor((seq_len,), dtype="int32"),)) take6 = R.call_tir(cls.take1, (model_decoder_embed_positions_weight4, lv198), out_sinfo=R.Tensor((seq_len, 1280), dtype="float16")) reshape1032 = R.call_tir(cls.reshape13, (take6,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add899 = R.call_tir(cls.add5, (reshape1031, reshape1032), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm259 = R.call_tir(cls.layer_norm2, (add899, model_decoder_layers_0_self_attn_layer_norm_weight4, model_decoder_layers_0_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv32 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_q_proj_weight4, layer_norm259, model_decoder_layers_0_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1033 = R.call_tir(cls.reshape14, (lv32,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv32_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_0_self_attn_k_proj_weight4, layer_norm259), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1034 = R.call_tir(cls.reshape14, (lv32_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv33 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_v_proj_weight4, layer_norm259, model_decoder_layers_0_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1035 = R.call_tir(cls.reshape14, (lv33,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat64 = R.call_tir(cls.concatenate1, (reshape1033, reshape1034, reshape1035), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1036 = R.call_tir(cls.reshape15, (concat64,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv199 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape1036), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1037 = R.call_tir(cls.reshape16, (lv199,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1038 = R.call_tir(cls.reshape17, (reshape1037,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv34 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_self_attn_out_proj_weight4, reshape1038, model_decoder_layers_0_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add903 = R.call_tir(cls.add5, (add899, lv34), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm260 = R.call_tir(cls.layer_norm2, (add903, model_decoder_layers_0_encoder_attn_layer_norm_weight4, model_decoder_layers_0_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv35 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_q_proj_weight4, layer_norm260, model_decoder_layers_0_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1039 = R.call_tir(cls.reshape14, (lv35,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1040 = R.call_tir(cls.reshape18, (reshape1039,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv200 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape1040), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1041 = R.call_tir(cls.reshape16, (lv200,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1042 = R.call_tir(cls.reshape17, (reshape1041,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv36 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_0_encoder_attn_out_proj_weight4, reshape1042, model_decoder_layers_0_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add906 = R.call_tir(cls.add5, (add903, lv36), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm261 = R.call_tir(cls.layer_norm2, (add906, model_decoder_layers_0_final_layer_norm_weight4, model_decoder_layers_0_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_0_fc1_weight4, layer_norm261, model_decoder_layers_0_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv37 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_0_fc2_weight4, lv, model_decoder_layers_0_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add909 = R.call_tir(cls.add5, (add906, lv37), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm262 = R.call_tir(cls.layer_norm2, (add909, model_decoder_layers_1_self_attn_layer_norm_weight4, model_decoder_layers_1_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv38 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_q_proj_weight4, layer_norm262, model_decoder_layers_1_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1043 = R.call_tir(cls.reshape14, (lv38,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv33_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_1_self_attn_k_proj_weight4, layer_norm262), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1044 = R.call_tir(cls.reshape14, (lv33_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv39 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_v_proj_weight4, layer_norm262, model_decoder_layers_1_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1045 = R.call_tir(cls.reshape14, (lv39,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat65 = R.call_tir(cls.concatenate1, (reshape1043, reshape1044, reshape1045), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1046 = R.call_tir(cls.reshape15, (concat65,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv201 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape1046), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1047 = R.call_tir(cls.reshape16, (lv201,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1048 = R.call_tir(cls.reshape17, (reshape1047,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv40 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_self_attn_out_proj_weight4, reshape1048, model_decoder_layers_1_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add913 = R.call_tir(cls.add5, (add909, lv40), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm263 = R.call_tir(cls.layer_norm2, (add913, model_decoder_layers_1_encoder_attn_layer_norm_weight4, model_decoder_layers_1_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv41 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_q_proj_weight4, layer_norm263, model_decoder_layers_1_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1049 = R.call_tir(cls.reshape14, (lv41,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1050 = R.call_tir(cls.reshape18, (reshape1049,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv202 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1)), reshape1050), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1051 = R.call_tir(cls.reshape16, (lv202,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1052 = R.call_tir(cls.reshape17, (reshape1051,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv42 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_1_encoder_attn_out_proj_weight4, reshape1052, model_decoder_layers_1_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add916 = R.call_tir(cls.add5, (add913, lv42), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm264 = R.call_tir(cls.layer_norm2, (add916, model_decoder_layers_1_final_layer_norm_weight4, model_decoder_layers_1_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_1_fc1_weight4, layer_norm264, model_decoder_layers_1_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv43 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_1_fc2_weight4, lv1, model_decoder_layers_1_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add919 = R.call_tir(cls.add5, (add916, lv43), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm265 = R.call_tir(cls.layer_norm2, (add919, model_decoder_layers_2_self_attn_layer_norm_weight4, model_decoder_layers_2_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv44 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_q_proj_weight4, layer_norm265, model_decoder_layers_2_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1053 = R.call_tir(cls.reshape14, (lv44,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv34_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_2_self_attn_k_proj_weight4, layer_norm265), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1054 = R.call_tir(cls.reshape14, (lv34_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv45 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_v_proj_weight4, layer_norm265, model_decoder_layers_2_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1055 = R.call_tir(cls.reshape14, (lv45,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat66 = R.call_tir(cls.concatenate1, (reshape1053, reshape1054, reshape1055), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1056 = R.call_tir(cls.reshape15, (concat66,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv203 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape1056), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1057 = R.call_tir(cls.reshape16, (lv203,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1058 = R.call_tir(cls.reshape17, (reshape1057,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv46 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_self_attn_out_proj_weight4, reshape1058, model_decoder_layers_2_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add923 = R.call_tir(cls.add5, (add919, lv46), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm266 = R.call_tir(cls.layer_norm2, (add923, model_decoder_layers_2_encoder_attn_layer_norm_weight4, model_decoder_layers_2_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv47 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_q_proj_weight4, layer_norm266, model_decoder_layers_2_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1059 = R.call_tir(cls.reshape14, (lv47,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1060 = R.call_tir(cls.reshape18, (reshape1059,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv204 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1)), reshape1060), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1061 = R.call_tir(cls.reshape16, (lv204,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1062 = R.call_tir(cls.reshape17, (reshape1061,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv48 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_2_encoder_attn_out_proj_weight4, reshape1062, model_decoder_layers_2_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add926 = R.call_tir(cls.add5, (add923, lv48), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm267 = R.call_tir(cls.layer_norm2, (add926, model_decoder_layers_2_final_layer_norm_weight4, model_decoder_layers_2_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv2 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_2_fc1_weight4, layer_norm267, model_decoder_layers_2_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv49 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_2_fc2_weight4, lv2, model_decoder_layers_2_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add929 = R.call_tir(cls.add5, (add926, lv49), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm268 = R.call_tir(cls.layer_norm2, (add929, model_decoder_layers_3_self_attn_layer_norm_weight4, model_decoder_layers_3_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv50 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_q_proj_weight4, layer_norm268, model_decoder_layers_3_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1063 = R.call_tir(cls.reshape14, (lv50,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv35_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_3_self_attn_k_proj_weight4, layer_norm268), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1064 = R.call_tir(cls.reshape14, (lv35_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv51 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_v_proj_weight4, layer_norm268, model_decoder_layers_3_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1065 = R.call_tir(cls.reshape14, (lv51,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat67 = R.call_tir(cls.concatenate1, (reshape1063, reshape1064, reshape1065), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1066 = R.call_tir(cls.reshape15, (concat67,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv205 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape1066), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1067 = R.call_tir(cls.reshape16, (lv205,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1068 = R.call_tir(cls.reshape17, (reshape1067,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv52 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_self_attn_out_proj_weight4, reshape1068, model_decoder_layers_3_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add933 = R.call_tir(cls.add5, (add929, lv52), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm269 = R.call_tir(cls.layer_norm2, (add933, model_decoder_layers_3_encoder_attn_layer_norm_weight4, model_decoder_layers_3_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv53 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_q_proj_weight4, layer_norm269, model_decoder_layers_3_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1069 = R.call_tir(cls.reshape14, (lv53,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1070 = R.call_tir(cls.reshape18, (reshape1069,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv206 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1)), reshape1070), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1071 = R.call_tir(cls.reshape16, (lv206,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1072 = R.call_tir(cls.reshape17, (reshape1071,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv54 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_3_encoder_attn_out_proj_weight4, reshape1072, model_decoder_layers_3_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add936 = R.call_tir(cls.add5, (add933, lv54), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm270 = R.call_tir(cls.layer_norm2, (add936, model_decoder_layers_3_final_layer_norm_weight4, model_decoder_layers_3_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv3 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_3_fc1_weight4, layer_norm270, model_decoder_layers_3_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv55 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_3_fc2_weight4, lv3, model_decoder_layers_3_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add939 = R.call_tir(cls.add5, (add936, lv55), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm271 = R.call_tir(cls.layer_norm2, (add939, model_decoder_layers_4_self_attn_layer_norm_weight4, model_decoder_layers_4_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv56 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_q_proj_weight4, layer_norm271, model_decoder_layers_4_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1073 = R.call_tir(cls.reshape14, (lv56,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv36_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_4_self_attn_k_proj_weight4, layer_norm271), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1074 = R.call_tir(cls.reshape14, (lv36_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv57 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_v_proj_weight4, layer_norm271, model_decoder_layers_4_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1075 = R.call_tir(cls.reshape14, (lv57,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat68 = R.call_tir(cls.concatenate1, (reshape1073, reshape1074, reshape1075), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1076 = R.call_tir(cls.reshape15, (concat68,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv207 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape1076), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1077 = R.call_tir(cls.reshape16, (lv207,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1078 = R.call_tir(cls.reshape17, (reshape1077,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv58 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_self_attn_out_proj_weight4, reshape1078, model_decoder_layers_4_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add943 = R.call_tir(cls.add5, (add939, lv58), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm272 = R.call_tir(cls.layer_norm2, (add943, model_decoder_layers_4_encoder_attn_layer_norm_weight4, model_decoder_layers_4_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv59 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_q_proj_weight4, layer_norm272, model_decoder_layers_4_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1079 = R.call_tir(cls.reshape14, (lv59,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1080 = R.call_tir(cls.reshape18, (reshape1079,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv208 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1)), reshape1080), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1081 = R.call_tir(cls.reshape16, (lv208,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1082 = R.call_tir(cls.reshape17, (reshape1081,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv60 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_4_encoder_attn_out_proj_weight4, reshape1082, model_decoder_layers_4_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add946 = R.call_tir(cls.add5, (add943, lv60), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm273 = R.call_tir(cls.layer_norm2, (add946, model_decoder_layers_4_final_layer_norm_weight4, model_decoder_layers_4_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv4 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_4_fc1_weight4, layer_norm273, model_decoder_layers_4_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv61 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_4_fc2_weight4, lv4, model_decoder_layers_4_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add949 = R.call_tir(cls.add5, (add946, lv61), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm274 = R.call_tir(cls.layer_norm2, (add949, model_decoder_layers_5_self_attn_layer_norm_weight4, model_decoder_layers_5_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv62 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_q_proj_weight4, layer_norm274, model_decoder_layers_5_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1083 = R.call_tir(cls.reshape14, (lv62,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv37_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_5_self_attn_k_proj_weight4, layer_norm274), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1084 = R.call_tir(cls.reshape14, (lv37_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv63 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_v_proj_weight4, layer_norm274, model_decoder_layers_5_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1085 = R.call_tir(cls.reshape14, (lv63,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat69 = R.call_tir(cls.concatenate1, (reshape1083, reshape1084, reshape1085), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1086 = R.call_tir(cls.reshape15, (concat69,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv209 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape1086), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1087 = R.call_tir(cls.reshape16, (lv209,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1088 = R.call_tir(cls.reshape17, (reshape1087,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv64 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_self_attn_out_proj_weight4, reshape1088, model_decoder_layers_5_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add953 = R.call_tir(cls.add5, (add949, lv64), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm275 = R.call_tir(cls.layer_norm2, (add953, model_decoder_layers_5_encoder_attn_layer_norm_weight4, model_decoder_layers_5_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv65 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_q_proj_weight4, layer_norm275, model_decoder_layers_5_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1089 = R.call_tir(cls.reshape14, (lv65,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1090 = R.call_tir(cls.reshape18, (reshape1089,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv210 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1)), reshape1090), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1091 = R.call_tir(cls.reshape16, (lv210,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1092 = R.call_tir(cls.reshape17, (reshape1091,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv66 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_5_encoder_attn_out_proj_weight4, reshape1092, model_decoder_layers_5_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add956 = R.call_tir(cls.add5, (add953, lv66), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm276 = R.call_tir(cls.layer_norm2, (add956, model_decoder_layers_5_final_layer_norm_weight4, model_decoder_layers_5_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv5 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_5_fc1_weight4, layer_norm276, model_decoder_layers_5_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv67 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_5_fc2_weight4, lv5, model_decoder_layers_5_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add959 = R.call_tir(cls.add5, (add956, lv67), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm277 = R.call_tir(cls.layer_norm2, (add959, model_decoder_layers_6_self_attn_layer_norm_weight4, model_decoder_layers_6_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv68 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_q_proj_weight4, layer_norm277, model_decoder_layers_6_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1093 = R.call_tir(cls.reshape14, (lv68,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv38_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_6_self_attn_k_proj_weight4, layer_norm277), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1094 = R.call_tir(cls.reshape14, (lv38_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv69 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_v_proj_weight4, layer_norm277, model_decoder_layers_6_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1095 = R.call_tir(cls.reshape14, (lv69,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat70 = R.call_tir(cls.concatenate1, (reshape1093, reshape1094, reshape1095), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1096 = R.call_tir(cls.reshape15, (concat70,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv211 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape1096), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1097 = R.call_tir(cls.reshape16, (lv211,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1098 = R.call_tir(cls.reshape17, (reshape1097,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv70 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_self_attn_out_proj_weight4, reshape1098, model_decoder_layers_6_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add963 = R.call_tir(cls.add5, (add959, lv70), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm278 = R.call_tir(cls.layer_norm2, (add963, model_decoder_layers_6_encoder_attn_layer_norm_weight4, model_decoder_layers_6_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv71 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_q_proj_weight4, layer_norm278, model_decoder_layers_6_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1099 = R.call_tir(cls.reshape14, (lv71,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1100 = R.call_tir(cls.reshape18, (reshape1099,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv212 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1)), reshape1100), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1101 = R.call_tir(cls.reshape16, (lv212,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1102 = R.call_tir(cls.reshape17, (reshape1101,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv72 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_6_encoder_attn_out_proj_weight4, reshape1102, model_decoder_layers_6_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add966 = R.call_tir(cls.add5, (add963, lv72), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm279 = R.call_tir(cls.layer_norm2, (add966, model_decoder_layers_6_final_layer_norm_weight4, model_decoder_layers_6_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv6 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_6_fc1_weight4, layer_norm279, model_decoder_layers_6_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv73 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_6_fc2_weight4, lv6, model_decoder_layers_6_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add969 = R.call_tir(cls.add5, (add966, lv73), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm280 = R.call_tir(cls.layer_norm2, (add969, model_decoder_layers_7_self_attn_layer_norm_weight4, model_decoder_layers_7_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv74 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_q_proj_weight4, layer_norm280, model_decoder_layers_7_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1103 = R.call_tir(cls.reshape14, (lv74,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv39_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_7_self_attn_k_proj_weight4, layer_norm280), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1104 = R.call_tir(cls.reshape14, (lv39_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv75 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_v_proj_weight4, layer_norm280, model_decoder_layers_7_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1105 = R.call_tir(cls.reshape14, (lv75,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat71 = R.call_tir(cls.concatenate1, (reshape1103, reshape1104, reshape1105), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1106 = R.call_tir(cls.reshape15, (concat71,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv213 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape1106), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1107 = R.call_tir(cls.reshape16, (lv213,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1108 = R.call_tir(cls.reshape17, (reshape1107,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv76 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_self_attn_out_proj_weight4, reshape1108, model_decoder_layers_7_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add973 = R.call_tir(cls.add5, (add969, lv76), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm281 = R.call_tir(cls.layer_norm2, (add973, model_decoder_layers_7_encoder_attn_layer_norm_weight4, model_decoder_layers_7_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv77 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_q_proj_weight4, layer_norm281, model_decoder_layers_7_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1109 = R.call_tir(cls.reshape14, (lv77,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1110 = R.call_tir(cls.reshape18, (reshape1109,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv214 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1)), reshape1110), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1111 = R.call_tir(cls.reshape16, (lv214,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1112 = R.call_tir(cls.reshape17, (reshape1111,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv78 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_7_encoder_attn_out_proj_weight4, reshape1112, model_decoder_layers_7_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add976 = R.call_tir(cls.add5, (add973, lv78), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm282 = R.call_tir(cls.layer_norm2, (add976, model_decoder_layers_7_final_layer_norm_weight4, model_decoder_layers_7_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv7 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_7_fc1_weight4, layer_norm282, model_decoder_layers_7_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv79 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_7_fc2_weight4, lv7, model_decoder_layers_7_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add979 = R.call_tir(cls.add5, (add976, lv79), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm283 = R.call_tir(cls.layer_norm2, (add979, model_decoder_layers_8_self_attn_layer_norm_weight4, model_decoder_layers_8_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv80 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_q_proj_weight4, layer_norm283, model_decoder_layers_8_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1113 = R.call_tir(cls.reshape14, (lv80,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv40_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_8_self_attn_k_proj_weight4, layer_norm283), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1114 = R.call_tir(cls.reshape14, (lv40_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv81 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_v_proj_weight4, layer_norm283, model_decoder_layers_8_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1115 = R.call_tir(cls.reshape14, (lv81,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat72 = R.call_tir(cls.concatenate1, (reshape1113, reshape1114, reshape1115), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1116 = R.call_tir(cls.reshape15, (concat72,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv215 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape1116), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1117 = R.call_tir(cls.reshape16, (lv215,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1118 = R.call_tir(cls.reshape17, (reshape1117,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv82 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_self_attn_out_proj_weight4, reshape1118, model_decoder_layers_8_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add983 = R.call_tir(cls.add5, (add979, lv82), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm284 = R.call_tir(cls.layer_norm2, (add983, model_decoder_layers_8_encoder_attn_layer_norm_weight4, model_decoder_layers_8_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv83 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_q_proj_weight4, layer_norm284, model_decoder_layers_8_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1119 = R.call_tir(cls.reshape14, (lv83,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1120 = R.call_tir(cls.reshape18, (reshape1119,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv216 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1)), reshape1120), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1121 = R.call_tir(cls.reshape16, (lv216,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1122 = R.call_tir(cls.reshape17, (reshape1121,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv84 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_8_encoder_attn_out_proj_weight4, reshape1122, model_decoder_layers_8_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add986 = R.call_tir(cls.add5, (add983, lv84), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm285 = R.call_tir(cls.layer_norm2, (add986, model_decoder_layers_8_final_layer_norm_weight4, model_decoder_layers_8_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv8 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_8_fc1_weight4, layer_norm285, model_decoder_layers_8_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv85 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_8_fc2_weight4, lv8, model_decoder_layers_8_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add989 = R.call_tir(cls.add5, (add986, lv85), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm286 = R.call_tir(cls.layer_norm2, (add989, model_decoder_layers_9_self_attn_layer_norm_weight4, model_decoder_layers_9_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv86 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_q_proj_weight4, layer_norm286, model_decoder_layers_9_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1123 = R.call_tir(cls.reshape14, (lv86,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv41_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_9_self_attn_k_proj_weight4, layer_norm286), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1124 = R.call_tir(cls.reshape14, (lv41_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv87 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_v_proj_weight4, layer_norm286, model_decoder_layers_9_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1125 = R.call_tir(cls.reshape14, (lv87,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat73 = R.call_tir(cls.concatenate1, (reshape1123, reshape1124, reshape1125), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1126 = R.call_tir(cls.reshape15, (concat73,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv217 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape1126), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1127 = R.call_tir(cls.reshape16, (lv217,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1128 = R.call_tir(cls.reshape17, (reshape1127,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv88 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_self_attn_out_proj_weight4, reshape1128, model_decoder_layers_9_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add993 = R.call_tir(cls.add5, (add989, lv88), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm287 = R.call_tir(cls.layer_norm2, (add993, model_decoder_layers_9_encoder_attn_layer_norm_weight4, model_decoder_layers_9_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv89 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_q_proj_weight4, layer_norm287, model_decoder_layers_9_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1129 = R.call_tir(cls.reshape14, (lv89,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1130 = R.call_tir(cls.reshape18, (reshape1129,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv218 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1)), reshape1130), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1131 = R.call_tir(cls.reshape16, (lv218,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1132 = R.call_tir(cls.reshape17, (reshape1131,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv90 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_9_encoder_attn_out_proj_weight4, reshape1132, model_decoder_layers_9_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add996 = R.call_tir(cls.add5, (add993, lv90), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm288 = R.call_tir(cls.layer_norm2, (add996, model_decoder_layers_9_final_layer_norm_weight4, model_decoder_layers_9_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv9 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_9_fc1_weight4, layer_norm288, model_decoder_layers_9_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv91 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_9_fc2_weight4, lv9, model_decoder_layers_9_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add999 = R.call_tir(cls.add5, (add996, lv91), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm289 = R.call_tir(cls.layer_norm2, (add999, model_decoder_layers_10_self_attn_layer_norm_weight4, model_decoder_layers_10_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv92 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_q_proj_weight4, layer_norm289, model_decoder_layers_10_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1133 = R.call_tir(cls.reshape14, (lv92,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv42_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_10_self_attn_k_proj_weight4, layer_norm289), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1134 = R.call_tir(cls.reshape14, (lv42_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv93 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_v_proj_weight4, layer_norm289, model_decoder_layers_10_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1135 = R.call_tir(cls.reshape14, (lv93,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat74 = R.call_tir(cls.concatenate1, (reshape1133, reshape1134, reshape1135), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1136 = R.call_tir(cls.reshape15, (concat74,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv219 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape1136), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1137 = R.call_tir(cls.reshape16, (lv219,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1138 = R.call_tir(cls.reshape17, (reshape1137,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv94 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_self_attn_out_proj_weight4, reshape1138, model_decoder_layers_10_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1003 = R.call_tir(cls.add5, (add999, lv94), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm290 = R.call_tir(cls.layer_norm2, (add1003, model_decoder_layers_10_encoder_attn_layer_norm_weight4, model_decoder_layers_10_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv95 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_q_proj_weight4, layer_norm290, model_decoder_layers_10_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1139 = R.call_tir(cls.reshape14, (lv95,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1140 = R.call_tir(cls.reshape18, (reshape1139,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv220 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1)), reshape1140), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1141 = R.call_tir(cls.reshape16, (lv220,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1142 = R.call_tir(cls.reshape17, (reshape1141,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv96 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_10_encoder_attn_out_proj_weight4, reshape1142, model_decoder_layers_10_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1006 = R.call_tir(cls.add5, (add1003, lv96), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm291 = R.call_tir(cls.layer_norm2, (add1006, model_decoder_layers_10_final_layer_norm_weight4, model_decoder_layers_10_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv10 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_10_fc1_weight4, layer_norm291, model_decoder_layers_10_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv97 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_10_fc2_weight4, lv10, model_decoder_layers_10_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1009 = R.call_tir(cls.add5, (add1006, lv97), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm292 = R.call_tir(cls.layer_norm2, (add1009, model_decoder_layers_11_self_attn_layer_norm_weight4, model_decoder_layers_11_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv98 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_q_proj_weight4, layer_norm292, model_decoder_layers_11_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1143 = R.call_tir(cls.reshape14, (lv98,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv43_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_11_self_attn_k_proj_weight4, layer_norm292), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1144 = R.call_tir(cls.reshape14, (lv43_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv99 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_v_proj_weight4, layer_norm292, model_decoder_layers_11_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1145 = R.call_tir(cls.reshape14, (lv99,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat75 = R.call_tir(cls.concatenate1, (reshape1143, reshape1144, reshape1145), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1146 = R.call_tir(cls.reshape15, (concat75,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv221 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape1146), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1147 = R.call_tir(cls.reshape16, (lv221,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1148 = R.call_tir(cls.reshape17, (reshape1147,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv100 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_self_attn_out_proj_weight4, reshape1148, model_decoder_layers_11_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1013 = R.call_tir(cls.add5, (add1009, lv100), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm293 = R.call_tir(cls.layer_norm2, (add1013, model_decoder_layers_11_encoder_attn_layer_norm_weight4, model_decoder_layers_11_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv101 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_q_proj_weight4, layer_norm293, model_decoder_layers_11_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1149 = R.call_tir(cls.reshape14, (lv101,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1150 = R.call_tir(cls.reshape18, (reshape1149,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv222 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1)), reshape1150), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1151 = R.call_tir(cls.reshape16, (lv222,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1152 = R.call_tir(cls.reshape17, (reshape1151,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv102 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_11_encoder_attn_out_proj_weight4, reshape1152, model_decoder_layers_11_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1016 = R.call_tir(cls.add5, (add1013, lv102), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm294 = R.call_tir(cls.layer_norm2, (add1016, model_decoder_layers_11_final_layer_norm_weight4, model_decoder_layers_11_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv11 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_11_fc1_weight4, layer_norm294, model_decoder_layers_11_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv103 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_11_fc2_weight4, lv11, model_decoder_layers_11_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1019 = R.call_tir(cls.add5, (add1016, lv103), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm295 = R.call_tir(cls.layer_norm2, (add1019, model_decoder_layers_12_self_attn_layer_norm_weight4, model_decoder_layers_12_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv104 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_q_proj_weight4, layer_norm295, model_decoder_layers_12_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1153 = R.call_tir(cls.reshape14, (lv104,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv44_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_12_self_attn_k_proj_weight4, layer_norm295), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1154 = R.call_tir(cls.reshape14, (lv44_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv105 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_v_proj_weight4, layer_norm295, model_decoder_layers_12_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1155 = R.call_tir(cls.reshape14, (lv105,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat76 = R.call_tir(cls.concatenate1, (reshape1153, reshape1154, reshape1155), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1156 = R.call_tir(cls.reshape15, (concat76,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv223 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape1156), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1157 = R.call_tir(cls.reshape16, (lv223,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1158 = R.call_tir(cls.reshape17, (reshape1157,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv106 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_self_attn_out_proj_weight4, reshape1158, model_decoder_layers_12_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1023 = R.call_tir(cls.add5, (add1019, lv106), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm296 = R.call_tir(cls.layer_norm2, (add1023, model_decoder_layers_12_encoder_attn_layer_norm_weight4, model_decoder_layers_12_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv107 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_q_proj_weight4, layer_norm296, model_decoder_layers_12_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1159 = R.call_tir(cls.reshape14, (lv107,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1160 = R.call_tir(cls.reshape18, (reshape1159,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv224 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1)), reshape1160), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1161 = R.call_tir(cls.reshape16, (lv224,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1162 = R.call_tir(cls.reshape17, (reshape1161,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv108 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_12_encoder_attn_out_proj_weight4, reshape1162, model_decoder_layers_12_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1026 = R.call_tir(cls.add5, (add1023, lv108), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm297 = R.call_tir(cls.layer_norm2, (add1026, model_decoder_layers_12_final_layer_norm_weight4, model_decoder_layers_12_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv12 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_12_fc1_weight4, layer_norm297, model_decoder_layers_12_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv109 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_12_fc2_weight4, lv12, model_decoder_layers_12_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1029 = R.call_tir(cls.add5, (add1026, lv109), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm298 = R.call_tir(cls.layer_norm2, (add1029, model_decoder_layers_13_self_attn_layer_norm_weight4, model_decoder_layers_13_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv110 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_q_proj_weight4, layer_norm298, model_decoder_layers_13_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1163 = R.call_tir(cls.reshape14, (lv110,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv45_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_13_self_attn_k_proj_weight4, layer_norm298), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1164 = R.call_tir(cls.reshape14, (lv45_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv111 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_v_proj_weight4, layer_norm298, model_decoder_layers_13_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1165 = R.call_tir(cls.reshape14, (lv111,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat77 = R.call_tir(cls.concatenate1, (reshape1163, reshape1164, reshape1165), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1166 = R.call_tir(cls.reshape15, (concat77,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv225 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape1166), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1167 = R.call_tir(cls.reshape16, (lv225,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1168 = R.call_tir(cls.reshape17, (reshape1167,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv112 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_self_attn_out_proj_weight4, reshape1168, model_decoder_layers_13_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1033 = R.call_tir(cls.add5, (add1029, lv112), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm299 = R.call_tir(cls.layer_norm2, (add1033, model_decoder_layers_13_encoder_attn_layer_norm_weight4, model_decoder_layers_13_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv113 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_q_proj_weight4, layer_norm299, model_decoder_layers_13_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1169 = R.call_tir(cls.reshape14, (lv113,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1170 = R.call_tir(cls.reshape18, (reshape1169,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv226 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1)), reshape1170), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1171 = R.call_tir(cls.reshape16, (lv226,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1172 = R.call_tir(cls.reshape17, (reshape1171,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv114 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_13_encoder_attn_out_proj_weight4, reshape1172, model_decoder_layers_13_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1036 = R.call_tir(cls.add5, (add1033, lv114), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm300 = R.call_tir(cls.layer_norm2, (add1036, model_decoder_layers_13_final_layer_norm_weight4, model_decoder_layers_13_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv13 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_13_fc1_weight4, layer_norm300, model_decoder_layers_13_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv115 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_13_fc2_weight4, lv13, model_decoder_layers_13_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1039 = R.call_tir(cls.add5, (add1036, lv115), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm301 = R.call_tir(cls.layer_norm2, (add1039, model_decoder_layers_14_self_attn_layer_norm_weight4, model_decoder_layers_14_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv116 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_q_proj_weight4, layer_norm301, model_decoder_layers_14_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1173 = R.call_tir(cls.reshape14, (lv116,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv46_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_14_self_attn_k_proj_weight4, layer_norm301), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1174 = R.call_tir(cls.reshape14, (lv46_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv117 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_v_proj_weight4, layer_norm301, model_decoder_layers_14_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1175 = R.call_tir(cls.reshape14, (lv117,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat78 = R.call_tir(cls.concatenate1, (reshape1173, reshape1174, reshape1175), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1176 = R.call_tir(cls.reshape15, (concat78,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv227 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape1176), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1177 = R.call_tir(cls.reshape16, (lv227,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1178 = R.call_tir(cls.reshape17, (reshape1177,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv118 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_self_attn_out_proj_weight4, reshape1178, model_decoder_layers_14_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1043 = R.call_tir(cls.add5, (add1039, lv118), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm302 = R.call_tir(cls.layer_norm2, (add1043, model_decoder_layers_14_encoder_attn_layer_norm_weight4, model_decoder_layers_14_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv119 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_q_proj_weight4, layer_norm302, model_decoder_layers_14_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1179 = R.call_tir(cls.reshape14, (lv119,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1180 = R.call_tir(cls.reshape18, (reshape1179,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv228 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1)), reshape1180), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1181 = R.call_tir(cls.reshape16, (lv228,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1182 = R.call_tir(cls.reshape17, (reshape1181,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv120 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_14_encoder_attn_out_proj_weight4, reshape1182, model_decoder_layers_14_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1046 = R.call_tir(cls.add5, (add1043, lv120), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm303 = R.call_tir(cls.layer_norm2, (add1046, model_decoder_layers_14_final_layer_norm_weight4, model_decoder_layers_14_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv14 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_14_fc1_weight4, layer_norm303, model_decoder_layers_14_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv121 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_14_fc2_weight4, lv14, model_decoder_layers_14_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1049 = R.call_tir(cls.add5, (add1046, lv121), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm304 = R.call_tir(cls.layer_norm2, (add1049, model_decoder_layers_15_self_attn_layer_norm_weight4, model_decoder_layers_15_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv122 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_q_proj_weight4, layer_norm304, model_decoder_layers_15_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1183 = R.call_tir(cls.reshape14, (lv122,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv47_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_15_self_attn_k_proj_weight4, layer_norm304), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1184 = R.call_tir(cls.reshape14, (lv47_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv123 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_v_proj_weight4, layer_norm304, model_decoder_layers_15_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1185 = R.call_tir(cls.reshape14, (lv123,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat79 = R.call_tir(cls.concatenate1, (reshape1183, reshape1184, reshape1185), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1186 = R.call_tir(cls.reshape15, (concat79,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv229 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape1186), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1187 = R.call_tir(cls.reshape16, (lv229,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1188 = R.call_tir(cls.reshape17, (reshape1187,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv124 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_self_attn_out_proj_weight4, reshape1188, model_decoder_layers_15_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1053 = R.call_tir(cls.add5, (add1049, lv124), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm305 = R.call_tir(cls.layer_norm2, (add1053, model_decoder_layers_15_encoder_attn_layer_norm_weight4, model_decoder_layers_15_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv125 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_q_proj_weight4, layer_norm305, model_decoder_layers_15_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1189 = R.call_tir(cls.reshape14, (lv125,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1190 = R.call_tir(cls.reshape18, (reshape1189,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv230 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1)), reshape1190), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1191 = R.call_tir(cls.reshape16, (lv230,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1192 = R.call_tir(cls.reshape17, (reshape1191,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv126 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_15_encoder_attn_out_proj_weight4, reshape1192, model_decoder_layers_15_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1056 = R.call_tir(cls.add5, (add1053, lv126), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm306 = R.call_tir(cls.layer_norm2, (add1056, model_decoder_layers_15_final_layer_norm_weight4, model_decoder_layers_15_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv15 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_15_fc1_weight4, layer_norm306, model_decoder_layers_15_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv127 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_15_fc2_weight4, lv15, model_decoder_layers_15_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1059 = R.call_tir(cls.add5, (add1056, lv127), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm307 = R.call_tir(cls.layer_norm2, (add1059, model_decoder_layers_16_self_attn_layer_norm_weight4, model_decoder_layers_16_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv128 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_q_proj_weight4, layer_norm307, model_decoder_layers_16_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1193 = R.call_tir(cls.reshape14, (lv128,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv48_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_16_self_attn_k_proj_weight4, layer_norm307), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1194 = R.call_tir(cls.reshape14, (lv48_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv129 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_v_proj_weight4, layer_norm307, model_decoder_layers_16_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1195 = R.call_tir(cls.reshape14, (lv129,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat80 = R.call_tir(cls.concatenate1, (reshape1193, reshape1194, reshape1195), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1196 = R.call_tir(cls.reshape15, (concat80,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv231 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape1196), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1197 = R.call_tir(cls.reshape16, (lv231,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1198 = R.call_tir(cls.reshape17, (reshape1197,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv130 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_self_attn_out_proj_weight4, reshape1198, model_decoder_layers_16_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1063 = R.call_tir(cls.add5, (add1059, lv130), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm308 = R.call_tir(cls.layer_norm2, (add1063, model_decoder_layers_16_encoder_attn_layer_norm_weight4, model_decoder_layers_16_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv131 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_q_proj_weight4, layer_norm308, model_decoder_layers_16_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1199 = R.call_tir(cls.reshape14, (lv131,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1200 = R.call_tir(cls.reshape18, (reshape1199,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv232 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1)), reshape1200), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1201 = R.call_tir(cls.reshape16, (lv232,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1202 = R.call_tir(cls.reshape17, (reshape1201,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv132 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_16_encoder_attn_out_proj_weight4, reshape1202, model_decoder_layers_16_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1066 = R.call_tir(cls.add5, (add1063, lv132), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm309 = R.call_tir(cls.layer_norm2, (add1066, model_decoder_layers_16_final_layer_norm_weight4, model_decoder_layers_16_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv16 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_16_fc1_weight4, layer_norm309, model_decoder_layers_16_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv133 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_16_fc2_weight4, lv16, model_decoder_layers_16_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1069 = R.call_tir(cls.add5, (add1066, lv133), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm310 = R.call_tir(cls.layer_norm2, (add1069, model_decoder_layers_17_self_attn_layer_norm_weight4, model_decoder_layers_17_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv134 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_q_proj_weight4, layer_norm310, model_decoder_layers_17_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1203 = R.call_tir(cls.reshape14, (lv134,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv49_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_17_self_attn_k_proj_weight4, layer_norm310), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1204 = R.call_tir(cls.reshape14, (lv49_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv135 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_v_proj_weight4, layer_norm310, model_decoder_layers_17_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1205 = R.call_tir(cls.reshape14, (lv135,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat81 = R.call_tir(cls.concatenate1, (reshape1203, reshape1204, reshape1205), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1206 = R.call_tir(cls.reshape15, (concat81,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv233 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape1206), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1207 = R.call_tir(cls.reshape16, (lv233,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1208 = R.call_tir(cls.reshape17, (reshape1207,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv136 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_self_attn_out_proj_weight4, reshape1208, model_decoder_layers_17_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1073 = R.call_tir(cls.add5, (add1069, lv136), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm311 = R.call_tir(cls.layer_norm2, (add1073, model_decoder_layers_17_encoder_attn_layer_norm_weight4, model_decoder_layers_17_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv137 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_q_proj_weight4, layer_norm311, model_decoder_layers_17_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1209 = R.call_tir(cls.reshape14, (lv137,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1210 = R.call_tir(cls.reshape18, (reshape1209,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv234 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1)), reshape1210), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1211 = R.call_tir(cls.reshape16, (lv234,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1212 = R.call_tir(cls.reshape17, (reshape1211,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv138 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_17_encoder_attn_out_proj_weight4, reshape1212, model_decoder_layers_17_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1076 = R.call_tir(cls.add5, (add1073, lv138), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm312 = R.call_tir(cls.layer_norm2, (add1076, model_decoder_layers_17_final_layer_norm_weight4, model_decoder_layers_17_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv17 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_17_fc1_weight4, layer_norm312, model_decoder_layers_17_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv139 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_17_fc2_weight4, lv17, model_decoder_layers_17_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1079 = R.call_tir(cls.add5, (add1076, lv139), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm313 = R.call_tir(cls.layer_norm2, (add1079, model_decoder_layers_18_self_attn_layer_norm_weight4, model_decoder_layers_18_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv140 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_q_proj_weight4, layer_norm313, model_decoder_layers_18_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1213 = R.call_tir(cls.reshape14, (lv140,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv50_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_18_self_attn_k_proj_weight4, layer_norm313), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1214 = R.call_tir(cls.reshape14, (lv50_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv141 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_v_proj_weight4, layer_norm313, model_decoder_layers_18_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1215 = R.call_tir(cls.reshape14, (lv141,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat82 = R.call_tir(cls.concatenate1, (reshape1213, reshape1214, reshape1215), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1216 = R.call_tir(cls.reshape15, (concat82,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv235 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape1216), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1217 = R.call_tir(cls.reshape16, (lv235,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1218 = R.call_tir(cls.reshape17, (reshape1217,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv142 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_self_attn_out_proj_weight4, reshape1218, model_decoder_layers_18_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1083 = R.call_tir(cls.add5, (add1079, lv142), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm314 = R.call_tir(cls.layer_norm2, (add1083, model_decoder_layers_18_encoder_attn_layer_norm_weight4, model_decoder_layers_18_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv143 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_q_proj_weight4, layer_norm314, model_decoder_layers_18_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1219 = R.call_tir(cls.reshape14, (lv143,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1220 = R.call_tir(cls.reshape18, (reshape1219,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv236 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1)), reshape1220), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1221 = R.call_tir(cls.reshape16, (lv236,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1222 = R.call_tir(cls.reshape17, (reshape1221,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv144 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_18_encoder_attn_out_proj_weight4, reshape1222, model_decoder_layers_18_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1086 = R.call_tir(cls.add5, (add1083, lv144), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm315 = R.call_tir(cls.layer_norm2, (add1086, model_decoder_layers_18_final_layer_norm_weight4, model_decoder_layers_18_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv18 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_18_fc1_weight4, layer_norm315, model_decoder_layers_18_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv145 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_18_fc2_weight4, lv18, model_decoder_layers_18_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1089 = R.call_tir(cls.add5, (add1086, lv145), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm316 = R.call_tir(cls.layer_norm2, (add1089, model_decoder_layers_19_self_attn_layer_norm_weight4, model_decoder_layers_19_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv146 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_q_proj_weight4, layer_norm316, model_decoder_layers_19_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1223 = R.call_tir(cls.reshape14, (lv146,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv51_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_19_self_attn_k_proj_weight4, layer_norm316), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1224 = R.call_tir(cls.reshape14, (lv51_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv147 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_v_proj_weight4, layer_norm316, model_decoder_layers_19_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1225 = R.call_tir(cls.reshape14, (lv147,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat83 = R.call_tir(cls.concatenate1, (reshape1223, reshape1224, reshape1225), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1226 = R.call_tir(cls.reshape15, (concat83,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv237 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape1226), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1227 = R.call_tir(cls.reshape16, (lv237,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1228 = R.call_tir(cls.reshape17, (reshape1227,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv148 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_self_attn_out_proj_weight4, reshape1228, model_decoder_layers_19_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1093 = R.call_tir(cls.add5, (add1089, lv148), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm317 = R.call_tir(cls.layer_norm2, (add1093, model_decoder_layers_19_encoder_attn_layer_norm_weight4, model_decoder_layers_19_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv149 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_q_proj_weight4, layer_norm317, model_decoder_layers_19_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1229 = R.call_tir(cls.reshape14, (lv149,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1230 = R.call_tir(cls.reshape18, (reshape1229,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv238 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1)), reshape1230), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1231 = R.call_tir(cls.reshape16, (lv238,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1232 = R.call_tir(cls.reshape17, (reshape1231,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv150 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_19_encoder_attn_out_proj_weight4, reshape1232, model_decoder_layers_19_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1096 = R.call_tir(cls.add5, (add1093, lv150), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm318 = R.call_tir(cls.layer_norm2, (add1096, model_decoder_layers_19_final_layer_norm_weight4, model_decoder_layers_19_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv19 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_19_fc1_weight4, layer_norm318, model_decoder_layers_19_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv151 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_19_fc2_weight4, lv19, model_decoder_layers_19_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1099 = R.call_tir(cls.add5, (add1096, lv151), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm319 = R.call_tir(cls.layer_norm2, (add1099, model_decoder_layers_20_self_attn_layer_norm_weight4, model_decoder_layers_20_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv152 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_q_proj_weight4, layer_norm319, model_decoder_layers_20_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1233 = R.call_tir(cls.reshape14, (lv152,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv52_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_20_self_attn_k_proj_weight4, layer_norm319), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1234 = R.call_tir(cls.reshape14, (lv52_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv153 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_v_proj_weight4, layer_norm319, model_decoder_layers_20_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1235 = R.call_tir(cls.reshape14, (lv153,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat84 = R.call_tir(cls.concatenate1, (reshape1233, reshape1234, reshape1235), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1236 = R.call_tir(cls.reshape15, (concat84,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv239 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape1236), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1237 = R.call_tir(cls.reshape16, (lv239,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1238 = R.call_tir(cls.reshape17, (reshape1237,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv154 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_self_attn_out_proj_weight4, reshape1238, model_decoder_layers_20_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1103 = R.call_tir(cls.add5, (add1099, lv154), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm320 = R.call_tir(cls.layer_norm2, (add1103, model_decoder_layers_20_encoder_attn_layer_norm_weight4, model_decoder_layers_20_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv155 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_q_proj_weight4, layer_norm320, model_decoder_layers_20_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1239 = R.call_tir(cls.reshape14, (lv155,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1240 = R.call_tir(cls.reshape18, (reshape1239,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv240 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1)), reshape1240), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1241 = R.call_tir(cls.reshape16, (lv240,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1242 = R.call_tir(cls.reshape17, (reshape1241,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv156 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_20_encoder_attn_out_proj_weight4, reshape1242, model_decoder_layers_20_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1106 = R.call_tir(cls.add5, (add1103, lv156), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm321 = R.call_tir(cls.layer_norm2, (add1106, model_decoder_layers_20_final_layer_norm_weight4, model_decoder_layers_20_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv20 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_20_fc1_weight4, layer_norm321, model_decoder_layers_20_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv157 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_20_fc2_weight4, lv20, model_decoder_layers_20_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1109 = R.call_tir(cls.add5, (add1106, lv157), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm322 = R.call_tir(cls.layer_norm2, (add1109, model_decoder_layers_21_self_attn_layer_norm_weight4, model_decoder_layers_21_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv158 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_q_proj_weight4, layer_norm322, model_decoder_layers_21_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1243 = R.call_tir(cls.reshape14, (lv158,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv53_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_21_self_attn_k_proj_weight4, layer_norm322), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1244 = R.call_tir(cls.reshape14, (lv53_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv159 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_v_proj_weight4, layer_norm322, model_decoder_layers_21_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1245 = R.call_tir(cls.reshape14, (lv159,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat85 = R.call_tir(cls.concatenate1, (reshape1243, reshape1244, reshape1245), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1246 = R.call_tir(cls.reshape15, (concat85,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv241 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape1246), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1247 = R.call_tir(cls.reshape16, (lv241,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1248 = R.call_tir(cls.reshape17, (reshape1247,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv160 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_self_attn_out_proj_weight4, reshape1248, model_decoder_layers_21_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1113 = R.call_tir(cls.add5, (add1109, lv160), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm323 = R.call_tir(cls.layer_norm2, (add1113, model_decoder_layers_21_encoder_attn_layer_norm_weight4, model_decoder_layers_21_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv161 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_q_proj_weight4, layer_norm323, model_decoder_layers_21_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1249 = R.call_tir(cls.reshape14, (lv161,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1250 = R.call_tir(cls.reshape18, (reshape1249,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv242 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1)), reshape1250), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1251 = R.call_tir(cls.reshape16, (lv242,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1252 = R.call_tir(cls.reshape17, (reshape1251,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv162 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_21_encoder_attn_out_proj_weight4, reshape1252, model_decoder_layers_21_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1116 = R.call_tir(cls.add5, (add1113, lv162), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm324 = R.call_tir(cls.layer_norm2, (add1116, model_decoder_layers_21_final_layer_norm_weight4, model_decoder_layers_21_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv21 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_21_fc1_weight4, layer_norm324, model_decoder_layers_21_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv163 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_21_fc2_weight4, lv21, model_decoder_layers_21_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1119 = R.call_tir(cls.add5, (add1116, lv163), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm325 = R.call_tir(cls.layer_norm2, (add1119, model_decoder_layers_22_self_attn_layer_norm_weight4, model_decoder_layers_22_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv164 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_q_proj_weight4, layer_norm325, model_decoder_layers_22_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1253 = R.call_tir(cls.reshape14, (lv164,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv54_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_22_self_attn_k_proj_weight4, layer_norm325), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1254 = R.call_tir(cls.reshape14, (lv54_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv165 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_v_proj_weight4, layer_norm325, model_decoder_layers_22_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1255 = R.call_tir(cls.reshape14, (lv165,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat86 = R.call_tir(cls.concatenate1, (reshape1253, reshape1254, reshape1255), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1256 = R.call_tir(cls.reshape15, (concat86,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv243 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape1256), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1257 = R.call_tir(cls.reshape16, (lv243,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1258 = R.call_tir(cls.reshape17, (reshape1257,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv166 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_self_attn_out_proj_weight4, reshape1258, model_decoder_layers_22_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1123 = R.call_tir(cls.add5, (add1119, lv166), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm326 = R.call_tir(cls.layer_norm2, (add1123, model_decoder_layers_22_encoder_attn_layer_norm_weight4, model_decoder_layers_22_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv167 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_q_proj_weight4, layer_norm326, model_decoder_layers_22_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1259 = R.call_tir(cls.reshape14, (lv167,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1260 = R.call_tir(cls.reshape18, (reshape1259,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv244 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1)), reshape1260), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1261 = R.call_tir(cls.reshape16, (lv244,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1262 = R.call_tir(cls.reshape17, (reshape1261,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv168 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_22_encoder_attn_out_proj_weight4, reshape1262, model_decoder_layers_22_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1126 = R.call_tir(cls.add5, (add1123, lv168), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm327 = R.call_tir(cls.layer_norm2, (add1126, model_decoder_layers_22_final_layer_norm_weight4, model_decoder_layers_22_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv22 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_22_fc1_weight4, layer_norm327, model_decoder_layers_22_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv169 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_22_fc2_weight4, lv22, model_decoder_layers_22_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1129 = R.call_tir(cls.add5, (add1126, lv169), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm328 = R.call_tir(cls.layer_norm2, (add1129, model_decoder_layers_23_self_attn_layer_norm_weight4, model_decoder_layers_23_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv170 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_q_proj_weight4, layer_norm328, model_decoder_layers_23_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1263 = R.call_tir(cls.reshape14, (lv170,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv55_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_23_self_attn_k_proj_weight4, layer_norm328), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1264 = R.call_tir(cls.reshape14, (lv55_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv171 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_v_proj_weight4, layer_norm328, model_decoder_layers_23_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1265 = R.call_tir(cls.reshape14, (lv171,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat87 = R.call_tir(cls.concatenate1, (reshape1263, reshape1264, reshape1265), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1266 = R.call_tir(cls.reshape15, (concat87,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv245 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape1266), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1267 = R.call_tir(cls.reshape16, (lv245,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1268 = R.call_tir(cls.reshape17, (reshape1267,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv172 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_self_attn_out_proj_weight4, reshape1268, model_decoder_layers_23_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1133 = R.call_tir(cls.add5, (add1129, lv172), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm329 = R.call_tir(cls.layer_norm2, (add1133, model_decoder_layers_23_encoder_attn_layer_norm_weight4, model_decoder_layers_23_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv173 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_q_proj_weight4, layer_norm329, model_decoder_layers_23_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1269 = R.call_tir(cls.reshape14, (lv173,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1270 = R.call_tir(cls.reshape18, (reshape1269,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv246 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1)), reshape1270), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1271 = R.call_tir(cls.reshape16, (lv246,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1272 = R.call_tir(cls.reshape17, (reshape1271,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv174 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_23_encoder_attn_out_proj_weight4, reshape1272, model_decoder_layers_23_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1136 = R.call_tir(cls.add5, (add1133, lv174), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm330 = R.call_tir(cls.layer_norm2, (add1136, model_decoder_layers_23_final_layer_norm_weight4, model_decoder_layers_23_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv23 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_23_fc1_weight4, layer_norm330, model_decoder_layers_23_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv175 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_23_fc2_weight4, lv23, model_decoder_layers_23_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1139 = R.call_tir(cls.add5, (add1136, lv175), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm331 = R.call_tir(cls.layer_norm2, (add1139, model_decoder_layers_24_self_attn_layer_norm_weight4, model_decoder_layers_24_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv176 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_q_proj_weight4, layer_norm331, model_decoder_layers_24_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1273 = R.call_tir(cls.reshape14, (lv176,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv56_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_24_self_attn_k_proj_weight4, layer_norm331), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1274 = R.call_tir(cls.reshape14, (lv56_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv177 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_v_proj_weight4, layer_norm331, model_decoder_layers_24_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1275 = R.call_tir(cls.reshape14, (lv177,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat88 = R.call_tir(cls.concatenate1, (reshape1273, reshape1274, reshape1275), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1276 = R.call_tir(cls.reshape15, (concat88,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv247 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape1276), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1277 = R.call_tir(cls.reshape16, (lv247,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1278 = R.call_tir(cls.reshape17, (reshape1277,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv178 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_self_attn_out_proj_weight4, reshape1278, model_decoder_layers_24_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1143 = R.call_tir(cls.add5, (add1139, lv178), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm332 = R.call_tir(cls.layer_norm2, (add1143, model_decoder_layers_24_encoder_attn_layer_norm_weight4, model_decoder_layers_24_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv179 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_q_proj_weight4, layer_norm332, model_decoder_layers_24_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1279 = R.call_tir(cls.reshape14, (lv179,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1280 = R.call_tir(cls.reshape18, (reshape1279,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv248 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1)), reshape1280), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1281 = R.call_tir(cls.reshape16, (lv248,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1282 = R.call_tir(cls.reshape17, (reshape1281,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv180 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_24_encoder_attn_out_proj_weight4, reshape1282, model_decoder_layers_24_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1146 = R.call_tir(cls.add5, (add1143, lv180), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm333 = R.call_tir(cls.layer_norm2, (add1146, model_decoder_layers_24_final_layer_norm_weight4, model_decoder_layers_24_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv24 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_24_fc1_weight4, layer_norm333, model_decoder_layers_24_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv181 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_24_fc2_weight4, lv24, model_decoder_layers_24_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1149 = R.call_tir(cls.add5, (add1146, lv181), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm334 = R.call_tir(cls.layer_norm2, (add1149, model_decoder_layers_25_self_attn_layer_norm_weight4, model_decoder_layers_25_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv182 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_q_proj_weight4, layer_norm334, model_decoder_layers_25_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1283 = R.call_tir(cls.reshape14, (lv182,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv57_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_25_self_attn_k_proj_weight4, layer_norm334), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1284 = R.call_tir(cls.reshape14, (lv57_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv183 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_v_proj_weight4, layer_norm334, model_decoder_layers_25_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1285 = R.call_tir(cls.reshape14, (lv183,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat89 = R.call_tir(cls.concatenate1, (reshape1283, reshape1284, reshape1285), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1286 = R.call_tir(cls.reshape15, (concat89,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv249 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape1286), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1287 = R.call_tir(cls.reshape16, (lv249,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1288 = R.call_tir(cls.reshape17, (reshape1287,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv184 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_self_attn_out_proj_weight4, reshape1288, model_decoder_layers_25_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1153 = R.call_tir(cls.add5, (add1149, lv184), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm335 = R.call_tir(cls.layer_norm2, (add1153, model_decoder_layers_25_encoder_attn_layer_norm_weight4, model_decoder_layers_25_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv185 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_q_proj_weight4, layer_norm335, model_decoder_layers_25_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1289 = R.call_tir(cls.reshape14, (lv185,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1290 = R.call_tir(cls.reshape18, (reshape1289,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv250 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1)), reshape1290), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1291 = R.call_tir(cls.reshape16, (lv250,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1292 = R.call_tir(cls.reshape17, (reshape1291,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv186 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_25_encoder_attn_out_proj_weight4, reshape1292, model_decoder_layers_25_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1156 = R.call_tir(cls.add5, (add1153, lv186), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm336 = R.call_tir(cls.layer_norm2, (add1156, model_decoder_layers_25_final_layer_norm_weight4, model_decoder_layers_25_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv25 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_25_fc1_weight4, layer_norm336, model_decoder_layers_25_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv187 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_25_fc2_weight4, lv25, model_decoder_layers_25_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1159 = R.call_tir(cls.add5, (add1156, lv187), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm337 = R.call_tir(cls.layer_norm2, (add1159, model_decoder_layers_26_self_attn_layer_norm_weight4, model_decoder_layers_26_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv188 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_q_proj_weight4, layer_norm337, model_decoder_layers_26_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1293 = R.call_tir(cls.reshape14, (lv188,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv58_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_26_self_attn_k_proj_weight4, layer_norm337), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1294 = R.call_tir(cls.reshape14, (lv58_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv189 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_v_proj_weight4, layer_norm337, model_decoder_layers_26_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1295 = R.call_tir(cls.reshape14, (lv189,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat90 = R.call_tir(cls.concatenate1, (reshape1293, reshape1294, reshape1295), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1296 = R.call_tir(cls.reshape15, (concat90,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv251 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape1296), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1297 = R.call_tir(cls.reshape16, (lv251,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1298 = R.call_tir(cls.reshape17, (reshape1297,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv190 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_self_attn_out_proj_weight4, reshape1298, model_decoder_layers_26_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1163 = R.call_tir(cls.add5, (add1159, lv190), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm338 = R.call_tir(cls.layer_norm2, (add1163, model_decoder_layers_26_encoder_attn_layer_norm_weight4, model_decoder_layers_26_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv191 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_q_proj_weight4, layer_norm338, model_decoder_layers_26_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1299 = R.call_tir(cls.reshape14, (lv191,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1300 = R.call_tir(cls.reshape18, (reshape1299,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv252 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1)), reshape1300), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1301 = R.call_tir(cls.reshape16, (lv252,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1302 = R.call_tir(cls.reshape17, (reshape1301,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv192 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_26_encoder_attn_out_proj_weight4, reshape1302, model_decoder_layers_26_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1166 = R.call_tir(cls.add5, (add1163, lv192), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm339 = R.call_tir(cls.layer_norm2, (add1166, model_decoder_layers_26_final_layer_norm_weight4, model_decoder_layers_26_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv26 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_26_fc1_weight4, layer_norm339, model_decoder_layers_26_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv193 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_26_fc2_weight4, lv26, model_decoder_layers_26_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1169 = R.call_tir(cls.add5, (add1166, lv193), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm340 = R.call_tir(cls.layer_norm2, (add1169, model_decoder_layers_27_self_attn_layer_norm_weight4, model_decoder_layers_27_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv194 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_q_proj_weight4, layer_norm340, model_decoder_layers_27_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1303 = R.call_tir(cls.reshape14, (lv194,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv59_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_27_self_attn_k_proj_weight4, layer_norm340), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1304 = R.call_tir(cls.reshape14, (lv59_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv195 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_v_proj_weight4, layer_norm340, model_decoder_layers_27_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1305 = R.call_tir(cls.reshape14, (lv195,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat91 = R.call_tir(cls.concatenate1, (reshape1303, reshape1304, reshape1305), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1306 = R.call_tir(cls.reshape15, (concat91,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv253 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape1306), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1307 = R.call_tir(cls.reshape16, (lv253,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1308 = R.call_tir(cls.reshape17, (reshape1307,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv196 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_self_attn_out_proj_weight4, reshape1308, model_decoder_layers_27_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1173 = R.call_tir(cls.add5, (add1169, lv196), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm341 = R.call_tir(cls.layer_norm2, (add1173, model_decoder_layers_27_encoder_attn_layer_norm_weight4, model_decoder_layers_27_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv197 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_q_proj_weight4, layer_norm341, model_decoder_layers_27_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1309 = R.call_tir(cls.reshape14, (lv197,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1310 = R.call_tir(cls.reshape18, (reshape1309,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv254 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1)), reshape1310), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1311 = R.call_tir(cls.reshape16, (lv254,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1312 = R.call_tir(cls.reshape17, (reshape1311,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv198_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_27_encoder_attn_out_proj_weight4, reshape1312, model_decoder_layers_27_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1176 = R.call_tir(cls.add5, (add1173, lv198_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm342 = R.call_tir(cls.layer_norm2, (add1176, model_decoder_layers_27_final_layer_norm_weight4, model_decoder_layers_27_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv27 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_27_fc1_weight4, layer_norm342, model_decoder_layers_27_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv199_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_27_fc2_weight4, lv27, model_decoder_layers_27_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1179 = R.call_tir(cls.add5, (add1176, lv199_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm343 = R.call_tir(cls.layer_norm2, (add1179, model_decoder_layers_28_self_attn_layer_norm_weight4, model_decoder_layers_28_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv200_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_q_proj_weight4, layer_norm343, model_decoder_layers_28_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1313 = R.call_tir(cls.reshape14, (lv200_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv60_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_28_self_attn_k_proj_weight4, layer_norm343), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1314 = R.call_tir(cls.reshape14, (lv60_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv201_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_v_proj_weight4, layer_norm343, model_decoder_layers_28_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1315 = R.call_tir(cls.reshape14, (lv201_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat92 = R.call_tir(cls.concatenate1, (reshape1313, reshape1314, reshape1315), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1316 = R.call_tir(cls.reshape15, (concat92,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv255 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape1316), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1317 = R.call_tir(cls.reshape16, (lv255,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1318 = R.call_tir(cls.reshape17, (reshape1317,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv202_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_self_attn_out_proj_weight4, reshape1318, model_decoder_layers_28_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1183 = R.call_tir(cls.add5, (add1179, lv202_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm344 = R.call_tir(cls.layer_norm2, (add1183, model_decoder_layers_28_encoder_attn_layer_norm_weight4, model_decoder_layers_28_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv203_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_q_proj_weight4, layer_norm344, model_decoder_layers_28_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1319 = R.call_tir(cls.reshape14, (lv203_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1320 = R.call_tir(cls.reshape18, (reshape1319,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv256 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1)), reshape1320), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1321 = R.call_tir(cls.reshape16, (lv256,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1322 = R.call_tir(cls.reshape17, (reshape1321,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv204_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_28_encoder_attn_out_proj_weight4, reshape1322, model_decoder_layers_28_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1186 = R.call_tir(cls.add5, (add1183, lv204_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm345 = R.call_tir(cls.layer_norm2, (add1186, model_decoder_layers_28_final_layer_norm_weight4, model_decoder_layers_28_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv28 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_28_fc1_weight4, layer_norm345, model_decoder_layers_28_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv205_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_28_fc2_weight4, lv28, model_decoder_layers_28_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1189 = R.call_tir(cls.add5, (add1186, lv205_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm346 = R.call_tir(cls.layer_norm2, (add1189, model_decoder_layers_29_self_attn_layer_norm_weight4, model_decoder_layers_29_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv206_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_q_proj_weight4, layer_norm346, model_decoder_layers_29_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1323 = R.call_tir(cls.reshape14, (lv206_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv61_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_29_self_attn_k_proj_weight4, layer_norm346), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1324 = R.call_tir(cls.reshape14, (lv61_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv207_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_v_proj_weight4, layer_norm346, model_decoder_layers_29_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1325 = R.call_tir(cls.reshape14, (lv207_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat93 = R.call_tir(cls.concatenate1, (reshape1323, reshape1324, reshape1325), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1326 = R.call_tir(cls.reshape15, (concat93,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv257 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1326), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1327 = R.call_tir(cls.reshape16, (lv257,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1328 = R.call_tir(cls.reshape17, (reshape1327,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv208_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_self_attn_out_proj_weight4, reshape1328, model_decoder_layers_29_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1193 = R.call_tir(cls.add5, (add1189, lv208_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm347 = R.call_tir(cls.layer_norm2, (add1193, model_decoder_layers_29_encoder_attn_layer_norm_weight4, model_decoder_layers_29_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv209_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_q_proj_weight4, layer_norm347, model_decoder_layers_29_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1329 = R.call_tir(cls.reshape14, (lv209_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1330 = R.call_tir(cls.reshape18, (reshape1329,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv258 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1)), reshape1330), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1331 = R.call_tir(cls.reshape16, (lv258,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1332 = R.call_tir(cls.reshape17, (reshape1331,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv210_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_29_encoder_attn_out_proj_weight4, reshape1332, model_decoder_layers_29_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1196 = R.call_tir(cls.add5, (add1193, lv210_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm348 = R.call_tir(cls.layer_norm2, (add1196, model_decoder_layers_29_final_layer_norm_weight4, model_decoder_layers_29_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv29 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_29_fc1_weight4, layer_norm348, model_decoder_layers_29_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv211_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_29_fc2_weight4, lv29, model_decoder_layers_29_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1199 = R.call_tir(cls.add5, (add1196, lv211_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm349 = R.call_tir(cls.layer_norm2, (add1199, model_decoder_layers_30_self_attn_layer_norm_weight4, model_decoder_layers_30_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv212_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_q_proj_weight4, layer_norm349, model_decoder_layers_30_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1333 = R.call_tir(cls.reshape14, (lv212_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv62_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_30_self_attn_k_proj_weight4, layer_norm349), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1334 = R.call_tir(cls.reshape14, (lv62_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv213_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_v_proj_weight4, layer_norm349, model_decoder_layers_30_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1335 = R.call_tir(cls.reshape14, (lv213_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat94 = R.call_tir(cls.concatenate1, (reshape1333, reshape1334, reshape1335), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1336 = R.call_tir(cls.reshape15, (concat94,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv259 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1336), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1337 = R.call_tir(cls.reshape16, (lv259,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1338 = R.call_tir(cls.reshape17, (reshape1337,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv214_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_self_attn_out_proj_weight4, reshape1338, model_decoder_layers_30_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1203 = R.call_tir(cls.add5, (add1199, lv214_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm350 = R.call_tir(cls.layer_norm2, (add1203, model_decoder_layers_30_encoder_attn_layer_norm_weight4, model_decoder_layers_30_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv215_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_q_proj_weight4, layer_norm350, model_decoder_layers_30_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1339 = R.call_tir(cls.reshape14, (lv215_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1340 = R.call_tir(cls.reshape18, (reshape1339,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv260 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1)), reshape1340), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1341 = R.call_tir(cls.reshape16, (lv260,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1342 = R.call_tir(cls.reshape17, (reshape1341,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv216_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_30_encoder_attn_out_proj_weight4, reshape1342, model_decoder_layers_30_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1206 = R.call_tir(cls.add5, (add1203, lv216_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm351 = R.call_tir(cls.layer_norm2, (add1206, model_decoder_layers_30_final_layer_norm_weight4, model_decoder_layers_30_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv30 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_30_fc1_weight4, layer_norm351, model_decoder_layers_30_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv217_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_30_fc2_weight4, lv30, model_decoder_layers_30_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1209 = R.call_tir(cls.add5, (add1206, lv217_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm352 = R.call_tir(cls.layer_norm2, (add1209, model_decoder_layers_31_self_attn_layer_norm_weight4, model_decoder_layers_31_self_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv218_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_q_proj_weight4, layer_norm352, model_decoder_layers_31_self_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1343 = R.call_tir(cls.reshape14, (lv218_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv63_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul1_cublas", (model_decoder_layers_31_self_attn_k_proj_weight4, layer_norm352), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1344 = R.call_tir(cls.reshape14, (lv63_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) lv219_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_v_proj_weight4, layer_norm352, model_decoder_layers_31_self_attn_v_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1345 = R.call_tir(cls.reshape14, (lv219_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) concat95 = R.call_tir(cls.concatenate1, (reshape1343, reshape1344, reshape1345), out_sinfo=R.Tensor((1, seq_len, 60, 64), dtype="float16")) reshape1346 = R.call_tir(cls.reshape15, (concat95,), out_sinfo=R.Tensor((seq_len, 60, 64), dtype="float16")) lv261 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1346), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1347 = R.call_tir(cls.reshape16, (lv261,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1348 = R.call_tir(cls.reshape17, (reshape1347,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv220_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_self_attn_out_proj_weight4, reshape1348, model_decoder_layers_31_self_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1213 = R.call_tir(cls.add5, (add1209, lv220_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm353 = R.call_tir(cls.layer_norm2, (add1213, model_decoder_layers_31_encoder_attn_layer_norm_weight4, model_decoder_layers_31_encoder_attn_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv221_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_q_proj_weight4, layer_norm353, model_decoder_layers_31_encoder_attn_q_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) reshape1349 = R.call_tir(cls.reshape14, (lv221_1,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1350 = R.call_tir(cls.reshape18, (reshape1349,), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) lv262 = R.call_dps_packed("vm.builtin.attention_kv_cache_cross_attention", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1)), reshape1350), out_sinfo=R.Tensor((seq_len, 20, 64), dtype="float16")) reshape1351 = R.call_tir(cls.reshape16, (lv262,), out_sinfo=R.Tensor((1, seq_len, 20, 64), dtype="float16")) reshape1352 = R.call_tir(cls.reshape17, (reshape1351,), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv222_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add1_cublas", (model_decoder_layers_31_encoder_attn_out_proj_weight4, reshape1352, model_decoder_layers_31_encoder_attn_out_proj_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1216 = R.call_tir(cls.add5, (add1213, lv222_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm354 = R.call_tir(cls.layer_norm2, (add1216, model_decoder_layers_31_final_layer_norm_weight4, model_decoder_layers_31_final_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv31 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_gelu_cublas", (model_decoder_layers_31_fc1_weight4, layer_norm354, model_decoder_layers_31_fc1_bias4), out_sinfo=R.Tensor((1, seq_len, 5120), dtype="float16")) lv223_1 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add2_cublas", (model_decoder_layers_31_fc2_weight4, lv31, model_decoder_layers_31_fc2_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) add1219 = R.call_tir(cls.add5, (add1216, lv223_1), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) layer_norm355 = R.call_tir(cls.layer_norm2, (add1219, model_decoder_layer_norm_weight4, model_decoder_layer_norm_bias4), out_sinfo=R.Tensor((1, seq_len, 1280), dtype="float16")) lv263 = R.call_tir(cls.index, (layer_norm355,), out_sinfo=R.Tensor((1, 1, 1280), dtype="float16")) gv4 = R.call_dps_packed("fused_relax_permute_dims_relax_matmul2_cublas", (model_decoder_embed_tokens_weight4, lv263), out_sinfo=R.Tensor((1, 1, 51866), dtype="float32")) R.output(gv4) return gv4 @R.function def renormalize_by_top_p(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), top_p: R.Tensor(("batch_size",), dtype="float32"), init_pivots: R.Tensor(("batch_size", 3), dtype="float32")) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"): batch_size = T.int64() vocab_size = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) cls = Module with R.dataflow(): lv6 = R.call_tir(cls.top_p_pivot_cutoff, (probs, top_p, init_pivots), out_sinfo=[R.Tensor((batch_size,), dtype="float32"), R.Tensor((batch_size,), dtype="float32")]) lv7: R.Tensor((batch_size,), dtype="float32") = lv6[0] lv8: R.Tensor((batch_size,), dtype="float32") = lv6[1] gv5 = R.call_tir(cls.top_p_renorm_after_cutoff, (probs, lv7, lv8), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) R.output(gv5) return gv5 @R.function def sample_with_top_p(sorted_probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), sorted_indices: R.Tensor(("batch_size", "vocab_size"), dtype="int32"), uniform_samples: R.Tensor(("num_samples",), dtype="float32"), sample_indices: R.Tensor(("num_samples",), dtype="int32"), top_p: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("num_samples",), dtype="int32"): num_samples = T.int64() batch_size = T.int64() vocab_size = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) cls = Module with R.dataflow(): uniform_samples1: R.Tensor((num_samples, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", uniform_samples, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="float32"),)) sample_indices1: R.Tensor((num_samples, 1), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", sample_indices, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="int32"),)) sample_indices2: R.Tensor((batch_size, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", top_p, R.shape([batch_size, 1]), sinfo_args=(R.Tensor((batch_size, 1), dtype="float32"),)) lv3 = R.call_tir(cls.full, R.tuple(), out_sinfo=R.Tensor((batch_size, 1), dtype="int32"), tir_vars=R.shape([vocab_size])) lv1: R.Tensor((8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12,), dtype="uint8") = R.builtin.alloc_tensor(R.shape([8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12]), R.dtype("uint8"), R.prim_value(0), R.str("global")) cumsum = R.call_tir(cls.cumsum, (sorted_probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) lv4 = R.call_tir(cls.get_renorm_prob, (cumsum, sample_indices2, lv3), out_sinfo=R.Tensor((batch_size, 1), dtype="float32")) lv5 = R.call_tir(cls.get_index_from_sorted, (cumsum, sorted_indices, lv4, uniform_samples1, sample_indices1), out_sinfo=R.Tensor((num_samples, 1), dtype="int32")) gv2: R.Tensor((num_samples,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv5, R.shape([num_samples]), sinfo_args=(R.Tensor((num_samples,), dtype="int32"),)) R.output(gv2) return gv2 @R.function def sampler_take_probs(unsorted_probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), sorted_indices: R.Tensor(("batch_size", "vocab_size"), dtype="int32"), sample_indices: R.Tensor(("num_samples",), dtype="int32"), sampling_result: R.Tensor(("num_samples",), dtype="int32"), lobprob_offsets: R.Tensor(("num_positions",), dtype="int32")) -> R.Tuple(R.Tensor(("num_samples",), dtype="float32"), R.Tensor(("num_positions",), dtype="float32"), R.Tensor(("num_positions",), dtype="int32")): num_samples = T.int64() num_positions = T.int64() batch_size = T.int64() vocab_size = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) cls = Module with R.dataflow(): gv3 = R.call_tir(cls.sampler_take_probs_tir, (unsorted_probs, sorted_indices, sample_indices, sampling_result, lobprob_offsets), out_sinfo=[R.Tensor((num_samples,), dtype="float32"), R.Tensor((num_positions,), dtype="float32"), R.Tensor((num_positions,), dtype="int32")]) R.output(gv3) return gv3 @R.function def sampler_verify_draft_tokens(draft_probs: R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), draft_tokens: R.Tensor(("num_nodes",), dtype="int32"), model_probs: R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), token_tree_first_child: R.Tensor(("num_nodes",), dtype="int32"), token_tree_next_sibling: R.Tensor(("num_nodes",), dtype="int32"), uniform_samples: R.Tensor(("num_nodes",), dtype="float32"), token_tree_parent_ptr: R.Tensor(("nbatch",), dtype="int32")) -> R.Tuple(R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), R.Tensor(("nbatch",), dtype="int32")): num_nodes = T.int64() vocab_size = T.int64() nbatch = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "num_positions": 48, "num_samples": 8}}) cls = Module with R.dataflow(): gv4: R.Tuple(R.Tensor((num_nodes, vocab_size), dtype="float32"), R.Tensor((nbatch,), dtype="int32")) = R.call_tir_inplace(cls.batch_verify_on_gpu_single_kernel, (draft_probs, draft_tokens, model_probs, token_tree_first_child, token_tree_next_sibling, uniform_samples, token_tree_parent_ptr), out_sinfo=[R.Tensor((num_nodes, vocab_size), dtype="float32"), R.Tensor((nbatch,), dtype="int32")], inplace_indices=[2, 6]) R.output(gv4) return gv4 @R.function def softmax_with_temperature(logits: R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), temperature: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"): batch_size = T.int64() vocab_size = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": 1, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 8, "seq_len": 15000, "total_seq_len": 1500}}) cls = Module with R.dataflow(): lv: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", logits, R.shape([batch_size, vocab_size]), sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),)) lv1 = R.call_tir(cls.chunk_lse, (lv, temperature), out_sinfo=[R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32"), R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32")]) lv2: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[0] lv3: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[1] lv4 = R.call_tir(cls.softmax_with_chunked_sum, (lv, temperature, lv2, lv3), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) gv: R.Tensor((batch_size, 1, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", lv4, R.shape([batch_size, 1, vocab_size]), sinfo_args=(R.Tensor((batch_size, 1, vocab_size), dtype="float32"),)) R.output(gv) return gv # Metadata omitted. Use show_meta=True in script() method to show it.