kiddobellamy commited on
Commit
ebfa455
1 Parent(s): 03ffc4b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +77 -29
handler.py CHANGED
@@ -1,40 +1,88 @@
1
- import requests
 
2
  import torch
3
- from PIL import Image
4
  from transformers import MllamaForConditionalGeneration, AutoProcessor
 
 
 
 
 
 
 
 
5
 
6
- class EndpointHandler:
7
- def __init__(self, model_dir):
8
- # Initialize the model and processor from the directory
9
  model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
 
10
  self.model = MllamaForConditionalGeneration.from_pretrained(
11
  model_id,
12
- torch_dtype=torch.bfloat16,
13
- device_map="auto"
14
  )
15
  self.processor = AutoProcessor.from_pretrained(model_id)
16
-
17
- def process(self, inputs):
18
- """
19
- Process the input data and return the output.
20
- Expecting inputs in the form of a dictionary containing 'image_url' and 'prompt'.
21
- """
22
- image_url = inputs.get("image_url")
23
- prompt = inputs.get("prompt", "If I had to write a haiku for this one, it would be:")
24
-
25
- # Process the image
26
- image = Image.open(requests.get(image_url, stream=True).raw)
27
-
28
- # Generate response
 
 
 
 
 
 
 
 
 
 
 
29
  messages = [
30
- {"role": "user", "content": [
31
- {"type": "image"},
32
- {"type": "text", "text": prompt}
33
- ]}
 
 
 
34
  ]
 
 
35
  input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
36
- model_inputs = self.processor(image, input_text, return_tensors="pt").to(self.model.device)
37
- output = self.model.generate(**model_inputs, max_new_tokens=30)
38
-
39
- # Return the output as a string
40
- return self.processor.decode(output[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+
3
  import torch
 
4
  from transformers import MllamaForConditionalGeneration, AutoProcessor
5
+ from PIL import Image
6
+ import base64
7
+ import io
8
+
9
+ class Llama32VisionHandler:
10
+ def __init__(self):
11
+ self.model = None
12
+ self.processor = None
13
 
14
+ def initialize(self):
15
+ # Cargar el modelo y el procesador
 
16
  model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
17
+
18
  self.model = MllamaForConditionalGeneration.from_pretrained(
19
  model_id,
20
+ torch_dtype=torch.bfloat16, # Usar bfloat16 para eficiencia de memoria
21
+ device_map="auto", # Mapear automáticamente el modelo a los dispositivos disponibles
22
  )
23
  self.processor = AutoProcessor.from_pretrained(model_id)
24
+ self.model.eval()
25
+
26
+ def handle(self, request):
27
+ # Asegurarse de que el modelo esté cargado
28
+ if self.model is None:
29
+ self.initialize()
30
+
31
+ # Extraer imagen y texto de la solicitud
32
+ image_data = request.get('image', None)
33
+ text_input = request.get('text', '')
34
+
35
+ # Procesar la imagen
36
+ if image_data:
37
+ # Si los datos de imagen están en formato base64
38
+ if isinstance(image_data, str):
39
+ image_bytes = base64.b64decode(image_data)
40
+ image = Image.open(io.BytesIO(image_bytes))
41
+ else:
42
+ # Si los datos de imagen son bytes crudos
43
+ image = Image.open(io.BytesIO(image_data))
44
+ else:
45
+ image = None # Manejar casos donde no se proporciona imagen
46
+
47
+ # Preparar mensajes para el procesador
48
  messages = [
49
+ {
50
+ "role": "user",
51
+ "content": [
52
+ {"type": "image"},
53
+ {"type": "text", "text": text_input}
54
+ ]
55
+ }
56
  ]
57
+
58
+ # Aplicar la plantilla de chat a los mensajes
59
  input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
60
+ # Procesar las entradas
61
+ inputs = self.processor(image, input_text, return_tensors="pt").to(self.model.device)
62
+
63
+ # Generar salida
64
+ with torch.no_grad():
65
+ outputs = self.model.generate(**inputs, max_new_tokens=50)
66
+
67
+ # Decodificar la salida
68
+ response = self.processor.decode(outputs[0], skip_special_tokens=True)
69
+ return response
70
+
71
+ # Ejemplo de uso
72
+ if __name__ == '__main__':
73
+ handler = Llama32VisionHandler()
74
+ # Cargar una imagen de ejemplo y codificarla en base64
75
+ with open('ruta_a_tu_imagen.jpg', 'rb') as f:
76
+ image_bytes = f.read()
77
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
78
+
79
+ # Crear una solicitud de ejemplo
80
+ request = {
81
+ 'image': image_base64,
82
+ 'text': 'Por favor, describe esta imagen en detalle.'
83
+ }
84
+
85
+ # Obtener la respuesta del handler
86
+ response = handler.handle(request)
87
+ print(response)
88
+ #000