Shape incompatibility when ```past_key_values``` are made persistent for context retention

#2
by bekatan - opened

I am running phi3-mini-4k-instruct-fp16 to create a chatbot. To provide the llm with context of previous conversation I used the pretrained_tokenizer.apply_chat_template(input) to create the input_ids, where input is an array of all previous messages, with past_key_values set to empty on every new llm.run(...). This enabled the model to retain information about the previous messages.

However, I was wondering if I could get the same effect by including only the most recent prompt from the user in the input_ids and past_key_values pairs from the previous response. I tried doing it here

 async run(input: any, callback?: (output: string) => void){
    const prompt = this.tokenizer.apply_chat_template([input[input.length - 1]], { tokenize: false});
    const tokens = await this.tokenizer.encode(prompt);
    const input_ids = new Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]);
    const output_tokens = [...input_ids.data];
    const last_token = 0n;
    this.feed['input_ids'] = input_ids;
    let seqlen = output_tokens.length;
    
    this.feed['position_ids'] = new Tensor('int64', BigInt64Array.from({ length: seqlen }, (_, i) => BigInt(this.feed['past_key_values.0.key'].dims[2] + i)), [1, seqlen]);
    const extra = this.feed['past_key_values.0.key'].dims[2];
    this.feed['attention_mask'] =  new Tensor(BigInt64Array.from({ length: seqlen + extra }, (_, i) => 1n), [1, seqlen + extra])
    while (last_token !== this.eos && last_token != 32007n && seqlen < this.max_seq_len){
      seqlen = output_tokens.length;
      const outputs = await runSession(this.feed, this.onnxSession);
      last_token = BigInt(argmax(outputs.logits));
      output_tokens.push(last_token);
      
      if (callback) {
        const endIndex = last_token === this.eos ? -1 : output_tokens.length;
        const newString = this.tokenizer.decode(output_tokens.slice(tokens.length, endIndex).map(t => Number(t)));
        callback(newString);
      }

      update_kv_cache(this.feed, outputs);
      this.feed['input_ids'] = new Tensor('int64', BigInt64Array.from([last_token]), [1, 1]);
      this.feed['position_ids'] = new Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]);
      this.feed['attention_mask'] =  new Tensor(BigInt64Array.from({ length: seqlen + 1 }, () => 1n), [1, seqlen + 1])
    }
    
    const output = this.tokenizer.decode(output_tokens.slice(tokens.length, -1).map(t => Number(t)));
    
    return output;
  }

The first call llm.run(...), where the past_key_values initially is empty, works fine.
On the second call, where the dimensions of tensors in the feed is:

input_ids dims: [1, seq_length]
position_ids dims: [1, seq_length]
attention_mask dims: [1, seq_length + past_sequence_length]
past_key_values.i.key dims: [1, 32, past_sequence_length, 96]
past_key_values.i.value dims: [1, 32, past_sequence_length, 96]

causes

Error: [WebGPU] Kernel "[Expand] /model/attn_mask_reformat/input_ids_subgraph/Expand" failed. Error: Expand requires shape to be broadcastable to input
    at Object._OrtRun (:8888/node_modules/onnxruntime-web/lib/wasm/binding/ort-wasm-simd.jsep.js:9:401)
    at zd (:8888/node_modules/onnxruntime-web/lib/wasm/wasm-core-impl.ts:562:19)
    at fi.run (:8888/node_modules/onnxruntime-web/lib/wasm/session-handler-inference.ts:109:21)
    at e.run (:8888/node_modules/common/lib/inference-session-impl.ts:110:21)

I believe this error implies shape incompatibility. However, I don't understand which tensor is at fault? Again, if past_key_values are initially empty, everything is fine. Can I use past_key_values for context retention?

Sign up or log in to comment