lucidmorto commited on
Commit
3f7af4c
1 Parent(s): b7588d3

feat: Improve text generation with advanced parameters

Browse files

Enhanced the text generation function to preprocess input text, refine input preparation, and optimize output generation with advanced parameters like top-k sampling, top-p sampling, and temperature. This increases the quality and variability of generated text while ensuring robustness with early stopping and stricter no-repeat n-gram constraints.

Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -6,9 +6,29 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
6
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
 
8
  def generate_text(input_text):
9
- input_ids = tokenizer("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True).input_ids
10
- outputs = model.generate(input_ids, max_length=300, num_return_sequences=1, no_repeat_ngram_size=2)
11
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  iface = gr.Interface(
14
  fn=generate_text,
 
6
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
 
8
  def generate_text(input_text):
9
+ # Preprocess input text
10
+ input_text = input_text.strip()
11
+
12
+ # Prepare input for the model
13
+ input_ids = tokenizer.encode("humanize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
14
+
15
+ # Generate text with improved parameters
16
+ outputs = model.generate(
17
+ input_ids,
18
+ max_length=300,
19
+ min_length=30,
20
+ num_return_sequences=1,
21
+ no_repeat_ngram_size=3,
22
+ top_k=50,
23
+ top_p=0.95,
24
+ temperature=0.8,
25
+ do_sample=True,
26
+ early_stopping=True
27
+ )
28
+
29
+ # Decode and clean up the generated text
30
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+ return generated_text.strip()
32
 
33
  iface = gr.Interface(
34
  fn=generate_text,