arthrod commited on
Commit
447768c
1 Parent(s): 5eb3754

Create filehandler.py

Browse files
Files changed (1) hide show
  1. filehandler.py +41 -0
filehandler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+
3
+ class FileHandler:
4
+ def __init__(self, model_path):
5
+ self.model_path = model_path
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
7
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
8
+ self.model.eval()
9
+
10
+ def generate_text(self, prompt, max_length=100, num_return_sequences=1, temperature=0.7):
11
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
12
+
13
+ generated_ids = self.model.generate(
14
+ input_ids,
15
+ max_length=max_length,
16
+ num_return_sequences=num_return_sequences,
17
+ temperature=temperature,
18
+ pad_token_id=self.tokenizer.eos_token_id,
19
+ )
20
+
21
+ generated_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in generated_ids]
22
+ return generated_texts
23
+
24
+ def __call__(self, request):
25
+ # Parse the request and extract the necessary information
26
+ prompt = request["prompt"]
27
+ max_length = request.get("max_length", 100)
28
+ num_return_sequences = request.get("num_return_sequences", 1)
29
+ temperature = request.get("temperature", 0.7)
30
+
31
+ # Generate text based on the prompt and parameters
32
+ generated_texts = self.generate_text(prompt, max_length, num_return_sequences, temperature)
33
+
34
+ # Prepare the response
35
+ response = {
36
+ "generated_texts": generated_texts
37
+ }
38
+
39
+ return response
40
+
41
+ handler = FileHandler(".")