pcuenq HF staff commited on
Commit
7fbd1fa
1 Parent(s): beaba43

Custom device map to reduce memory consumption

Browse files

1920x1080 images now cause the demo to OOM, and we can't downscale more
because the text location features work best at that resolution.

Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -7,7 +7,15 @@ from transformers import FuyuForCausalLM, FuyuProcessor
7
  model_id = "adept/fuyu-8b"
8
  dtype = torch.bfloat16
9
 
10
- model = FuyuForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=dtype)
 
 
 
 
 
 
 
 
11
  processor = FuyuProcessor.from_pretrained(model_id)
12
 
13
  CAPTION_PROMPT = "Generate a coco-style caption.\n"
@@ -36,7 +44,7 @@ def pad_to_size(image, canvas_width=1920, canvas_height=1080):
36
 
37
  def predict(image, prompt):
38
  # image = image.convert('RGB')
39
- model_inputs = processor(text=prompt, images=[image]).to(device=model.device)
40
 
41
  generation_output = model.generate(**model_inputs, max_new_tokens=50)
42
  prompt_len = model_inputs["input_ids"].shape[-1]
@@ -71,7 +79,7 @@ def localize(image, query):
71
  padded = resize_to_max(image)
72
  padded = pad_to_size(padded)
73
 
74
- model_inputs = processor(text=prompt, images=[padded]).to(device=model.device)
75
 
76
  outputs = model.generate(**model_inputs, max_new_tokens=40)
77
  post_processed_bbox_tokens = processor.post_process_box_coordinates(outputs)[0]
 
7
  model_id = "adept/fuyu-8b"
8
  dtype = torch.bfloat16
9
 
10
+ device_map = {
11
+ "language_model.model.embed_tokens": "cpu",
12
+ "language_model.model.layers": 0,
13
+ "language_model.model.final_layernorm": 0,
14
+ "language_model.lm_head": "cpu",
15
+ "vision_embed_tokens": "cpu",
16
+ }
17
+
18
+ model = FuyuForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=dtype)
19
  processor = FuyuProcessor.from_pretrained(model_id)
20
 
21
  CAPTION_PROMPT = "Generate a coco-style caption.\n"
 
44
 
45
  def predict(image, prompt):
46
  # image = image.convert('RGB')
47
+ model_inputs = processor(text=prompt, images=[image])
48
 
49
  generation_output = model.generate(**model_inputs, max_new_tokens=50)
50
  prompt_len = model_inputs["input_ids"].shape[-1]
 
79
  padded = resize_to_max(image)
80
  padded = pad_to_size(padded)
81
 
82
+ model_inputs = processor(text=prompt, images=[padded])
83
 
84
  outputs = model.generate(**model_inputs, max_new_tokens=40)
85
  post_processed_bbox_tokens = processor.post_process_box_coordinates(outputs)[0]