codelion commited on
Commit
904cc64
1 Parent(s): 518cdca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -28
app.py CHANGED
@@ -28,9 +28,10 @@ if torch.cuda.is_available():
28
  model_id = "patched-codes/patched-mix-4x7B"
29
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
- tokenizer.use_default_system_prompt = False
32
-
33
-
 
34
  @spaces.GPU(duration=60)
35
  def generate(
36
  message: str,
@@ -49,33 +50,38 @@ def generate(
49
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
50
  conversation.append({"role": "user", "content": message})
51
 
52
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
53
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
54
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
55
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
56
- input_ids = input_ids.to(model.device)
57
-
58
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
59
- generate_kwargs = dict(
60
- {"input_ids": input_ids},
61
- streamer=streamer,
62
- max_new_tokens=max_new_tokens,
63
- do_sample=True,
64
- top_p=top_p,
65
- #top_k=top_k,
66
- temperature=temperature,
67
- eos_token_id=tokenizer.eos_token_id,
68
- pad_token_id=tokenizer.pad_token_id,
 
 
 
 
 
69
  #num_beams=1,
70
  #repetition_penalty=1.2,
71
- )
72
- t = Thread(target=model.generate, kwargs=generate_kwargs)
73
- t.start()
74
-
75
- outputs = []
76
- for text in streamer:
77
- outputs.append(text)
78
- yield "".join(outputs)
79
 
80
  example1='''You are a senior software engineer who is best in the world at fixing vulnerabilities.
81
  Users will give you vulnerable code and you will generate a fix based on the provided INSTRUCTION.
 
28
  model_id = "patched-codes/patched-mix-4x7B"
29
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ tokenizer.padding_side = 'right'
32
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
33
+ # tokenizer.use_default_system_prompt = False
34
+
35
  @spaces.GPU(duration=60)
36
  def generate(
37
  message: str,
 
50
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
  conversation.append({"role": "user", "content": message})
52
 
53
+ prompt = pipe.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
54
+ outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p,
55
+ eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
56
+
57
+ return outputs[0]['generated_text'][len(prompt):].strip()
58
+ # input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
59
+ # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
60
+ # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
61
+ # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
62
+ # input_ids = input_ids.to(model.device)
63
+
64
+ # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
65
+ # generate_kwargs = dict(
66
+ # {"input_ids": input_ids},
67
+ # streamer=streamer,
68
+ # max_new_tokens=max_new_tokens,
69
+ # do_sample=True,
70
+ # top_p=top_p,
71
+ # #top_k=top_k,
72
+ # temperature=temperature,
73
+ # eos_token_id=tokenizer.eos_token_id,
74
+ # pad_token_id=tokenizer.pad_token_id,
75
  #num_beams=1,
76
  #repetition_penalty=1.2,
77
+ # )
78
+ # t = Thread(target=model.generate, kwargs=generate_kwargs)
79
+ # t.start()
80
+
81
+ # outputs = []
82
+ # for text in streamer:
83
+ # outputs.append(text)
84
+ # yield "".join(outputs)
85
 
86
  example1='''You are a senior software engineer who is best in the world at fixing vulnerabilities.
87
  Users will give you vulnerable code and you will generate a fix based on the provided INSTRUCTION.