ryanzhangfan commited on
Commit
66ecdd5
1 Parent(s): 0f8e8b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -21
app.py CHANGED
@@ -39,15 +39,6 @@ gen_model = AutoModelForCausalLM.from_pretrained(
39
  trust_remote_code=True,
40
  )
41
 
42
- gen_tokenizer = AutoTokenizer.from_pretrained(EMU_GEN_HUB, trust_remote_code=True)
43
- gen_image_processor = AutoImageProcessor.from_pretrained(
44
- VQ_HUB, trust_remote_code=True
45
- )
46
- gen_image_tokenizer = AutoModel.from_pretrained(
47
- VQ_HUB, device_map="cuda:0", trust_remote_code=True
48
- ).eval()
49
- gen_processor = Emu3Processor(gen_image_processor, gen_image_tokenizer, gen_tokenizer)
50
-
51
  # Emu3-Chat model and processor
52
  chat_model = AutoModelForCausalLM.from_pretrained(
53
  EMU_CHAT_HUB,
@@ -57,18 +48,18 @@ chat_model = AutoModelForCausalLM.from_pretrained(
57
  trust_remote_code=True,
58
  )
59
 
60
- chat_tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
61
- chat_image_processor = AutoImageProcessor.from_pretrained(
62
  VQ_HUB, trust_remote_code=True
63
  )
64
- chat_image_tokenizer = AutoModel.from_pretrained(
65
  VQ_HUB, device_map="cuda:0", trust_remote_code=True
66
  ).eval()
67
- chat_processor = Emu3Processor(
68
- chat_image_processor, chat_image_tokenizer, chat_tokenizer
69
  )
70
 
71
- @spaces.GPU(duration=120)
72
  def generate_image(prompt):
73
  POSITIVE_PROMPT = " masterpiece, film grained, best quality."
74
  NEGATIVE_PROMPT = (
@@ -86,8 +77,8 @@ def generate_image(prompt):
86
  image_area=gen_model.config.image_area,
87
  return_tensors="pt",
88
  )
89
- pos_inputs = gen_processor(text=full_prompt, **kwargs)
90
- neg_inputs = gen_processor(text=NEGATIVE_PROMPT, **kwargs)
91
 
92
  # Prepare hyperparameters
93
  GENERATION_CONFIG = GenerationConfig(
@@ -100,7 +91,7 @@ def generate_image(prompt):
100
  )
101
 
102
  h, w = pos_inputs.image_size[0]
103
- constrained_fn = gen_processor.build_prefix_constrained_fn(h, w)
104
  logits_processor = LogitsProcessorList(
105
  [
106
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
@@ -122,14 +113,14 @@ def generate_image(prompt):
122
  logits_processor=logits_processor,
123
  )
124
 
125
- mm_list = gen_processor.decode(outputs[0])
126
  for idx, im in enumerate(mm_list):
127
  if isinstance(im, Image.Image):
128
  return im
129
  return None
130
 
131
  def vision_language_understanding(image, text):
132
- inputs = chat_processor(
133
  text=text,
134
  image=image,
135
  mode="U",
@@ -154,7 +145,7 @@ def vision_language_understanding(image, text):
154
  )
155
 
156
  outputs = outputs[:, inputs.input_ids.shape[-1] :]
157
- response = chat_processor.batch_decode(outputs, skip_special_tokens=True)[0]
158
  return response
159
 
160
  def chat(history, user_input, user_image):
 
39
  trust_remote_code=True,
40
  )
41
 
 
 
 
 
 
 
 
 
 
42
  # Emu3-Chat model and processor
43
  chat_model = AutoModelForCausalLM.from_pretrained(
44
  EMU_CHAT_HUB,
 
48
  trust_remote_code=True,
49
  )
50
 
51
+ tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
52
+ image_processor = AutoImageProcessor.from_pretrained(
53
  VQ_HUB, trust_remote_code=True
54
  )
55
+ image_tokenizer = AutoModel.from_pretrained(
56
  VQ_HUB, device_map="cuda:0", trust_remote_code=True
57
  ).eval()
58
+ processor = Emu3Processor(
59
+ image_processor, image_tokenizer, tokenizer
60
  )
61
 
62
+ @spaces.GPU(duration=300)
63
  def generate_image(prompt):
64
  POSITIVE_PROMPT = " masterpiece, film grained, best quality."
65
  NEGATIVE_PROMPT = (
 
77
  image_area=gen_model.config.image_area,
78
  return_tensors="pt",
79
  )
80
+ pos_inputs = processor(text=full_prompt, **kwargs)
81
+ neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
82
 
83
  # Prepare hyperparameters
84
  GENERATION_CONFIG = GenerationConfig(
 
91
  )
92
 
93
  h, w = pos_inputs.image_size[0]
94
+ constrained_fn = processor.build_prefix_constrained_fn(h, w)
95
  logits_processor = LogitsProcessorList(
96
  [
97
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
 
113
  logits_processor=logits_processor,
114
  )
115
 
116
+ mm_list = processor.decode(outputs[0])
117
  for idx, im in enumerate(mm_list):
118
  if isinstance(im, Image.Image):
119
  return im
120
  return None
121
 
122
  def vision_language_understanding(image, text):
123
+ inputs = processor(
124
  text=text,
125
  image=image,
126
  mode="U",
 
145
  )
146
 
147
  outputs = outputs[:, inputs.input_ids.shape[-1] :]
148
+ response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
149
  return response
150
 
151
  def chat(history, user_input, user_image):