mtc commited on
Commit
cadb0dd
1 Parent(s): e34bc0e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -4
README.md CHANGED
@@ -58,13 +58,13 @@ def predict_with_vllm(prompts: List[str], model_name: str, max_context_length: i
58
  return predictions
59
 
60
 
61
- def predict_with_hf_generation_pipeline(prompts: List[str], model_name: str, max_new_tokens: int = 256,
62
- batch_size: int = 2):
63
  text_generation_pipeline = pipeline("text-generation", model=model_name,
64
  model_kwargs={"torch_dtype": torch.float16}, device_map="auto",
65
  batch_size=batch_size)
66
 
67
- batch_output = text_generation_pipeline(prompts, truncation=True, max_new_tokens=max_new_tokens,
68
  return_full_text=False)
69
  predictions = [result[0]['generated_text'] for result in batch_output]
70
  return predictions
@@ -89,7 +89,8 @@ Satz: {sentence}
89
  ### Erklärung und Label:"""
90
 
91
  prompts = generate_prompts_for_generation(prompt_template=prompt_template, article=article, summary_sentences=summary_sentences)
92
- predictions = predict_with_hf_generation_pipeline(prompts=prompts, model_name=model_name, max_context_length=max_context_length)
 
93
  print(predictions)
94
 
95
  # Uncomment the following lines to use vllm for prediction
 
58
  return predictions
59
 
60
 
61
+ def predict_with_hf_generation_pipeline(prompts: List[str], model_name: str, max_context_length: int = 4096,
62
+ batch_size: int = 2) -> List[str]:
63
  text_generation_pipeline = pipeline("text-generation", model=model_name,
64
  model_kwargs={"torch_dtype": torch.float16}, device_map="auto",
65
  batch_size=batch_size)
66
 
67
+ batch_output = text_generation_pipeline(prompts, truncation=True, max_length=max_context_length,
68
  return_full_text=False)
69
  predictions = [result[0]['generated_text'] for result in batch_output]
70
  return predictions
 
89
  ### Erklärung und Label:"""
90
 
91
  prompts = generate_prompts_for_generation(prompt_template=prompt_template, article=article, summary_sentences=summary_sentences)
92
+ predictions = predict_with_hf_generation_pipeline(prompts=prompts, model_name=model_name,
93
+ max_context_length=max_context_length, batch_size=batch_size)
94
  print(predictions)
95
 
96
  # Uncomment the following lines to use vllm for prediction