andrewqian123 commited on
Commit
1eff8ee
1 Parent(s): a0e05d2

Update modeling_minicpmv.py

Browse files
Files changed (1) hide show
  1. modeling_minicpmv.py +33 -29
modeling_minicpmv.py CHANGED
@@ -231,44 +231,48 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
231
 
232
  def generate(
233
  self,
234
- model_inputs,
235
  tokenizer=None,
236
  vision_hidden_states=None,
237
  stream=False,
238
  **kwargs
239
  ):
240
- bs = len(model_inputs["input_ids"])
241
- img_list = model_inputs["pixel_values"]
242
- tgt_sizes = model_inputs["tgt_sizes"]
243
- if img_list is None:
244
- img_list = [[] for i in range(bs)]
245
- assert bs == len(img_list)
246
- if vision_hidden_states is None:
247
- pixel_values = []
248
- for i in range(bs):
249
- img_inps = []
250
- for img in img_list[i]:
251
- img_inps.append(img.to(self.device))
252
- if img_inps:
253
- pixel_values.append(img_inps)
254
- else:
255
- pixel_values.append([])
256
- model_inputs["pixel_values"] = pixel_values
257
- model_inputs['tgt_sizes'] = tgt_sizes
258
- else:
259
- model_inputs["vision_hidden_states"] = vision_hidden_states
260
-
261
- (
262
- input_embeds,
263
- vision_hidden_states,
264
- ) = self.get_vllm_embedding(model_inputs)
 
 
 
 
265
 
266
  # output_ids = self._decode(input_embeds, tokenizer, **kwargs)
267
  if stream:
268
  kwargs.pop("decode_text")
269
- result = self._decode_stream(input_embeds, tokenizer, **kwargs)
270
  else:
271
- result = self._decode(input_embeds, tokenizer, **kwargs)
272
 
273
  return result
274
 
@@ -366,5 +370,5 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
366
  return stream_gen()
367
 
368
  else:
369
- answer = res[0]
370
  return answer
 
231
 
232
  def generate(
233
  self,
234
+ model_inputs_batch,
235
  tokenizer=None,
236
  vision_hidden_states=None,
237
  stream=False,
238
  **kwargs
239
  ):
240
+ batch = []
241
+ for model_inputs in model_inputs_batch:
242
+ bs = len(model_inputs["input_ids"])
243
+ img_list = model_inputs["pixel_values"]
244
+ tgt_sizes = model_inputs["tgt_sizes"]
245
+ if img_list is None:
246
+ img_list = [[] for i in range(bs)]
247
+ assert bs == len(img_list)
248
+ if vision_hidden_states is None:
249
+ pixel_values = []
250
+ for i in range(bs):
251
+ img_inps = []
252
+ for img in img_list[i]:
253
+ img_inps.append(img.to(self.device))
254
+ if img_inps:
255
+ pixel_values.append(img_inps)
256
+ else:
257
+ pixel_values.append([])
258
+ model_inputs["pixel_values"] = pixel_values
259
+ model_inputs['tgt_sizes'] = tgt_sizes
260
+ else:
261
+ model_inputs["vision_hidden_states"] = vision_hidden_states
262
+
263
+ (
264
+ input_embeds,
265
+ vision_hidden_states,
266
+ ) = self.get_vllm_embedding(model_inputs)
267
+ batch.append(input_embeds)
268
+
269
 
270
  # output_ids = self._decode(input_embeds, tokenizer, **kwargs)
271
  if stream:
272
  kwargs.pop("decode_text")
273
+ result = self._decode_stream(batch, tokenizer, **kwargs)
274
  else:
275
+ result = self._decode(batch, tokenizer, **kwargs)
276
 
277
  return result
278
 
 
370
  return stream_gen()
371
 
372
  else:
373
+ answer = res
374
  return answer