supersolar commited on
Commit
105baa4
1 Parent(s): c2469bd

Create florencegpu2.py

Browse files
Files changed (1) hide show
  1. utils/florencegpu2.py +57 -0
utils/florencegpu2.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, Any, Tuple, Dict
3
+ from unittest.mock import patch
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
11
+ FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
12
+ FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
13
+ FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
14
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
15
+ FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
16
+
17
+
18
+ def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
19
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
20
+ if not str(filename).endswith("/modeling_florence2.py"):
21
+ return get_imports(filename)
22
+ imports = get_imports(filename)
23
+ imports.remove("flash_attn")
24
+ return imports
25
+
26
+
27
+ def load_florence_model(
28
+ device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
29
+ ) -> Tuple[Any, Any]:
30
+ device = "cuda:1" if torch.cuda.is_available() else "cpu"
31
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
32
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
33
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
34
+ return model, processor
35
+
36
+
37
+ def run_florence_inference(
38
+ model: Any,
39
+ processor: Any,
40
+ device: torch.device,
41
+ image: Image,
42
+ task: str,
43
+ text: str = ""
44
+ ) -> Tuple[str, Dict]:
45
+ prompt = task + text
46
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
47
+ generated_ids = model.generate(
48
+ input_ids=inputs["input_ids"],
49
+ pixel_values=inputs["pixel_values"],
50
+ max_new_tokens=1024,
51
+ num_beams=3
52
+ )
53
+ generated_text = processor.batch_decode(
54
+ generated_ids, skip_special_tokens=False)[0]
55
+ response = processor.post_process_generation(
56
+ generated_text, task=task, image_size=image.size)
57
+ return generated_text, response