codelion commited on
Commit
f87f20f
1 Parent(s): c8fce50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -29,7 +29,7 @@ if torch.cuda.is_available():
29
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
  tokenizer.padding_side = 'right'
32
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
33
  # tokenizer.use_default_system_prompt = False
34
 
35
  @spaces.GPU(duration=60)
@@ -50,38 +50,39 @@ def generate(
50
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
  conversation.append({"role": "user", "content": message})
52
 
53
- prompt = pipe.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
54
- outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p,
55
- eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
56
-
57
- return outputs[0]['generated_text'][len(prompt):].strip()
58
- # input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
59
- # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
60
- # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
61
- # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
62
- # input_ids = input_ids.to(model.device)
63
-
64
- # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
65
- # generate_kwargs = dict(
66
- # {"input_ids": input_ids},
67
- # streamer=streamer,
68
- # max_new_tokens=max_new_tokens,
69
- # do_sample=True,
70
- # top_p=top_p,
71
- # #top_k=top_k,
72
- # temperature=temperature,
73
- # eos_token_id=tokenizer.eos_token_id,
74
- # pad_token_id=tokenizer.pad_token_id,
 
75
  #num_beams=1,
76
  #repetition_penalty=1.2,
77
- # )
78
- # t = Thread(target=model.generate, kwargs=generate_kwargs)
79
- # t.start()
80
-
81
- # outputs = []
82
- # for text in streamer:
83
- # outputs.append(text)
84
- # yield "".join(outputs)
85
 
86
  example1='''You are a senior software engineer who is best in the world at fixing vulnerabilities.
87
  Users will give you vulnerable code and you will generate a fix based on the provided INSTRUCTION.
 
29
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
  tokenizer.padding_side = 'right'
32
+ # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
33
  # tokenizer.use_default_system_prompt = False
34
 
35
  @spaces.GPU(duration=60)
 
50
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
  conversation.append({"role": "user", "content": message})
52
 
53
+ # prompt = pipe.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
54
+ # outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p,
55
+ # eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
56
+
57
+ # return outputs[0]['generated_text'][len(prompt):].strip()
58
+
59
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
60
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
63
+ input_ids = input_ids.to(model.device)
64
+
65
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
66
+ generate_kwargs = dict(
67
+ {"input_ids": input_ids},
68
+ streamer=streamer,
69
+ max_new_tokens=max_new_tokens,
70
+ do_sample=True,
71
+ top_p=top_p,
72
+ #top_k=top_k,
73
+ temperature=temperature,
74
+ eos_token_id=tokenizer.eos_token_id,
75
+ pad_token_id=tokenizer.pad_token_id,
76
  #num_beams=1,
77
  #repetition_penalty=1.2,
78
+ )
79
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
80
+ t.start()
81
+
82
+ outputs = []
83
+ for text in streamer:
84
+ outputs.append(text)
85
+ yield "".join(outputs)
86
 
87
  example1='''You are a senior software engineer who is best in the world at fixing vulnerabilities.
88
  Users will give you vulnerable code and you will generate a fix based on the provided INSTRUCTION.