Update modeling_mplug_owl2.py
Browse files- modeling_mplug_owl2.py +5 -0
modeling_mplug_owl2.py
CHANGED
@@ -270,6 +270,7 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
270 |
def score(self, images,
|
271 |
task_: str = "quality",
|
272 |
input_: str = "image",
|
|
|
273 |
):
|
274 |
if not hasattr(self, "weight_tensor"):
|
275 |
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
|
@@ -281,6 +282,8 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
281 |
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
282 |
output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
|
283 |
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
|
|
|
|
284 |
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
285 |
else:
|
286 |
video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
|
@@ -289,6 +292,8 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
289 |
video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
|
290 |
output_logits = self(input_ids.repeat(len(video_tensors), 1),
|
291 |
images=video_tensors)["logits"][:,-1, self.preferential_ids_]
|
|
|
|
|
292 |
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
293 |
|
294 |
def forward(
|
|
|
270 |
def score(self, images,
|
271 |
task_: str = "quality",
|
272 |
input_: str = "image",
|
273 |
+
return_dict=False,
|
274 |
):
|
275 |
if not hasattr(self, "weight_tensor"):
|
276 |
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
|
|
|
282 |
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
283 |
output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
|
284 |
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
285 |
+
if return_dict:
|
286 |
+
return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
|
287 |
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
288 |
else:
|
289 |
video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
|
|
|
292 |
video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
|
293 |
output_logits = self(input_ids.repeat(len(video_tensors), 1),
|
294 |
images=video_tensors)["logits"][:,-1, self.preferential_ids_]
|
295 |
+
if return_dict:
|
296 |
+
return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
|
297 |
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
298 |
|
299 |
def forward(
|