kiddobellamy commited on
Commit
2bace83
1 Parent(s): 8e761cc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +57 -44
handler.py CHANGED
@@ -1,53 +1,66 @@
1
  import requests
2
  import torch
3
- from PIL import Image
4
  from transformers import MllamaForConditionalGeneration, AutoProcessor
 
5
 
6
- # Define the model ID and load the model and processor
7
- model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
8
 
9
- def load_model():
10
- """Loads the Llama 3.2-90B Vision-Instruct model and processor."""
11
- model = MllamaForConditionalGeneration.from_pretrained(
12
- model_id,
13
- torch_dtype=torch.bfloat16,
14
- device_map="auto",
15
- )
16
- processor = AutoProcessor.from_pretrained(model_id)
17
- return model, processor
 
 
 
 
 
 
18
 
19
- def process_image(url):
20
- """Processes the image from the given URL."""
21
- image = Image.open(requests.get(url, stream=True).raw)
22
- return image
 
 
 
 
 
 
 
 
23
 
24
- def generate_response(model, processor, image, prompt):
25
- """Generates a text response based on the image and the prompt."""
26
- messages = [
27
- {"role": "user", "content": [
28
- {"type": "image"},
29
- {"type": "text", "text": prompt}
30
- ]}
31
- ]
32
- input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
33
- inputs = processor(image, input_text, return_tensors="pt").to(model.device)
34
- output = model.generate(**inputs, max_new_tokens=30)
35
- return processor.decode(output[0])
36
 
37
- def main():
38
- # Load model and processor
39
- model, processor = load_model()
40
-
41
- # Sample image URL
42
- url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
43
- image = process_image(url)
44
-
45
- # Define a sample prompt
46
- prompt = "If I had to write a haiku for this one, it would be:"
47
-
48
- # Generate response
49
- response = generate_response(model, processor, image, prompt)
50
- print(response)
 
 
 
51
 
52
- if __name__ == "__main__":
53
- main()
 
1
  import requests
2
  import torch
3
+ from PIL import Image, UnidentifiedImageError
4
  from transformers import MllamaForConditionalGeneration, AutoProcessor
5
+ import logging
6
 
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
 
10
+ class EndpointHandler:
11
+ def __init__(self, model_dir):
12
+ try:
13
+ # Initialize the model and processor from the directory
14
+ model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
15
+ self.model = MllamaForConditionalGeneration.from_pretrained(
16
+ model_id,
17
+ torch_dtype=torch.bfloat16,
18
+ device_map="auto"
19
+ )
20
+ self.processor = AutoProcessor.from_pretrained(model_id)
21
+ logging.info("Model and processor loaded successfully.")
22
+ except Exception as e:
23
+ logging.error(f"Error loading model or processor: {e}")
24
+ raise
25
 
26
+ def process(self, inputs):
27
+ """
28
+ Process the input data and return the output.
29
+ Expecting inputs in the form of a dictionary containing 'image_url' and 'prompt'.
30
+ """
31
+ try:
32
+ # Input validation
33
+ image_url = inputs.get("image_url")
34
+ if not image_url:
35
+ raise ValueError("No image URL provided in the input.")
36
+
37
+ prompt = inputs.get("prompt", "If I had to write a haiku for this one, it would be:")
38
 
39
+ # Process the image
40
+ try:
41
+ image = Image.open(requests.get(image_url, stream=True).raw)
42
+ except UnidentifiedImageError:
43
+ logging.error(f"Failed to identify the image from the URL: {image_url}")
44
+ raise
45
+ except Exception as e:
46
+ logging.error(f"Error downloading or processing the image: {e}")
47
+ raise
 
 
 
48
 
49
+ # Generate response
50
+ messages = [
51
+ {"role": "user", "content": [
52
+ {"type": "image"},
53
+ {"type": "text", "text": prompt}
54
+ ]}
55
+ ]
56
+ input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
57
+ model_inputs = self.processor(image, input_text, return_tensors="pt").to(self.model.device)
58
+ output = self.model.generate(**model_inputs, max_new_tokens=30)
59
+
60
+ # Return the output as a string
61
+ return self.processor.decode(output[0])
62
+
63
+ except Exception as e:
64
+ logging.error(f"Error during processing: {e}")
65
+ raise
66