kiddobellamy commited on
Commit
03ffc4b
1 Parent(s): 2bace83

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -55
handler.py CHANGED
@@ -1,66 +1,40 @@
1
  import requests
2
  import torch
3
- from PIL import Image, UnidentifiedImageError
4
  from transformers import MllamaForConditionalGeneration, AutoProcessor
5
- import logging
6
-
7
- # Configure logging
8
- logging.basicConfig(level=logging.INFO)
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir):
12
- try:
13
- # Initialize the model and processor from the directory
14
- model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
15
- self.model = MllamaForConditionalGeneration.from_pretrained(
16
- model_id,
17
- torch_dtype=torch.bfloat16,
18
- device_map="auto"
19
- )
20
- self.processor = AutoProcessor.from_pretrained(model_id)
21
- logging.info("Model and processor loaded successfully.")
22
- except Exception as e:
23
- logging.error(f"Error loading model or processor: {e}")
24
- raise
25
-
26
  def process(self, inputs):
27
  """
28
  Process the input data and return the output.
29
  Expecting inputs in the form of a dictionary containing 'image_url' and 'prompt'.
30
  """
31
- try:
32
- # Input validation
33
- image_url = inputs.get("image_url")
34
- if not image_url:
35
- raise ValueError("No image URL provided in the input.")
36
-
37
- prompt = inputs.get("prompt", "If I had to write a haiku for this one, it would be:")
38
-
39
- # Process the image
40
- try:
41
- image = Image.open(requests.get(image_url, stream=True).raw)
42
- except UnidentifiedImageError:
43
- logging.error(f"Failed to identify the image from the URL: {image_url}")
44
- raise
45
- except Exception as e:
46
- logging.error(f"Error downloading or processing the image: {e}")
47
- raise
48
-
49
- # Generate response
50
- messages = [
51
- {"role": "user", "content": [
52
- {"type": "image"},
53
- {"type": "text", "text": prompt}
54
- ]}
55
- ]
56
- input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
57
- model_inputs = self.processor(image, input_text, return_tensors="pt").to(self.model.device)
58
- output = self.model.generate(**model_inputs, max_new_tokens=30)
59
-
60
- # Return the output as a string
61
- return self.processor.decode(output[0])
62
-
63
- except Exception as e:
64
- logging.error(f"Error during processing: {e}")
65
- raise
66
-
 
1
  import requests
2
  import torch
3
+ from PIL import Image
4
  from transformers import MllamaForConditionalGeneration, AutoProcessor
 
 
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
+ # Initialize the model and processor from the directory
9
+ model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
10
+ self.model = MllamaForConditionalGeneration.from_pretrained(
11
+ model_id,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto"
14
+ )
15
+ self.processor = AutoProcessor.from_pretrained(model_id)
16
+
 
 
 
 
 
17
  def process(self, inputs):
18
  """
19
  Process the input data and return the output.
20
  Expecting inputs in the form of a dictionary containing 'image_url' and 'prompt'.
21
  """
22
+ image_url = inputs.get("image_url")
23
+ prompt = inputs.get("prompt", "If I had to write a haiku for this one, it would be:")
24
+
25
+ # Process the image
26
+ image = Image.open(requests.get(image_url, stream=True).raw)
27
+
28
+ # Generate response
29
+ messages = [
30
+ {"role": "user", "content": [
31
+ {"type": "image"},
32
+ {"type": "text", "text": prompt}
33
+ ]}
34
+ ]
35
+ input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
36
+ model_inputs = self.processor(image, input_text, return_tensors="pt").to(self.model.device)
37
+ output = self.model.generate(**model_inputs, max_new_tokens=30)
38
+
39
+ # Return the output as a string
40
+ return self.processor.decode(output[0])