teowu commited on
Commit
0cac618
1 Parent(s): 01730ee

Update modeling_mplug_owl2.py

Browse files
Files changed (1) hide show
  1. modeling_mplug_owl2.py +7 -4
modeling_mplug_owl2.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -271,20 +271,23 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
271
  task_: str = "quality",
272
  input_: str = "image",
273
  return_dict=False,
 
274
  ):
275
  if not hasattr(self, "weight_tensor"):
276
  self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
277
  prompt = "USER: How would you rate the {} of this {}?\n<|image|>\nASSISTANT: The {} of the {} is".format(task_, input_, task_, input_)
278
  if input_ == "image":
279
- images = [expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in images]
280
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
281
- with torch.inference_mode():
282
  image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
 
283
  output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
284
  images=image_tensor)["logits"][:,-1, self.preferential_ids_]
285
  if return_dict:
286
  return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
287
  return torch.softmax(output_logits, -1) @ self.weight_tensor
 
288
  else:
289
  video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
290
  input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
 
1
+ # Copyright 2023 Haotian Liu & Qinghao Ye & Haoning Wu (Modified from LLaVA, and mPLUG-Owl2)
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
271
  task_: str = "quality",
272
  input_: str = "image",
273
  return_dict=False,
274
+ image_tensor = None,
275
  ):
276
  if not hasattr(self, "weight_tensor"):
277
  self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
278
  prompt = "USER: How would you rate the {} of this {}?\n<|image|>\nASSISTANT: The {} of the {} is".format(task_, input_, task_, input_)
279
  if input_ == "image":
280
+ if image_tensor is None:
281
+ images = [expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in images]
282
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
283
  image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
284
+ with torch.inference_mode():
285
  output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
286
  images=image_tensor)["logits"][:,-1, self.preferential_ids_]
287
  if return_dict:
288
  return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
289
  return torch.softmax(output_logits, -1) @ self.weight_tensor
290
+
291
  else:
292
  video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
293
  input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)