lucianotonet commited on
Commit
cda11d3
1 Parent(s): 2eea391

Atualiza modelo de geração e ajusta configuração de dispositivo

Browse files

Substitui o modelo de linguagem para uma versão mais eficiente e apropriada. A mudança para o `Qwen2VLForConditionalGeneration` melhora a performance da geração de texto. Além disso, a configuração do dispositivo foi alterada para "cpu", permitindo um comportamento mais consistente em ambientes sem GPU, garantindo que o processamento continue funcional em diversas plataformas.

Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -1,11 +1,11 @@
1
  from fastapi import FastAPI
2
- from transformers import AutoModelForCausalLM, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float16, device_map="auto")
9
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
10
 
11
  @app.post("/predict")
@@ -20,7 +20,7 @@ async def predict(messages: list):
20
  padding=True,
21
  return_tensors="pt"
22
  )
23
- inputs = inputs.to(model.device)
24
 
25
  generated_ids = model.generate(**inputs, max_new_tokens=128)
26
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
 
1
  from fastapi import FastAPI
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
9
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
10
 
11
  @app.post("/predict")
 
20
  padding=True,
21
  return_tensors="pt"
22
  )
23
+ inputs = inputs.to("cpu") # Altere para "cuda" se tiver GPU disponível
24
 
25
  generated_ids = model.generate(**inputs, max_new_tokens=128)
26
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]