Spaces:
Running
on
Zero
Running
on
Zero
ahmed-masry
commited on
Commit
•
9f9c2cc
1
Parent(s):
e8ad0b4
Create processing_utils.py
Browse files- processing_utils.py +121 -0
processing_utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import BatchEncoding, BatchFeature
|
7 |
+
|
8 |
+
def get_torch_device(device: str = "auto") -> str:
|
9 |
+
"""
|
10 |
+
Returns the device (string) to be used by PyTorch.
|
11 |
+
|
12 |
+
`device` arg defaults to "auto" which will use:
|
13 |
+
- "cuda:0" if available
|
14 |
+
- else "mps" if available
|
15 |
+
- else "cpu".
|
16 |
+
"""
|
17 |
+
|
18 |
+
if device == "auto":
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
device = "cuda:0"
|
21 |
+
elif torch.backends.mps.is_available(): # for Apple Silicon
|
22 |
+
device = "mps"
|
23 |
+
else:
|
24 |
+
device = "cpu"
|
25 |
+
logger.info(f"Using device: {device}")
|
26 |
+
|
27 |
+
return device
|
28 |
+
|
29 |
+
class BaseVisualRetrieverProcessor(ABC):
|
30 |
+
"""
|
31 |
+
Base class for visual retriever processors.
|
32 |
+
"""
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def process_images(
|
36 |
+
self,
|
37 |
+
images: List[Image.Image],
|
38 |
+
) -> Union[BatchFeature, BatchEncoding]:
|
39 |
+
pass
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
def process_queries(
|
43 |
+
self,
|
44 |
+
queries: List[str],
|
45 |
+
max_length: int = 50,
|
46 |
+
suffix: Optional[str] = None,
|
47 |
+
) -> Union[BatchFeature, BatchEncoding]:
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def score(
|
52 |
+
self,
|
53 |
+
qs: List[torch.Tensor],
|
54 |
+
ps: List[torch.Tensor],
|
55 |
+
device: Optional[Union[str, torch.device]] = None,
|
56 |
+
**kwargs,
|
57 |
+
) -> torch.Tensor:
|
58 |
+
pass
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def score_single_vector(
|
62 |
+
qs: List[torch.Tensor],
|
63 |
+
ps: List[torch.Tensor],
|
64 |
+
device: Optional[Union[str, torch.device]] = None,
|
65 |
+
) -> torch.Tensor:
|
66 |
+
"""
|
67 |
+
Compute the dot product score for the given single-vector query and passage embeddings.
|
68 |
+
"""
|
69 |
+
device = device or get_torch_device("auto")
|
70 |
+
|
71 |
+
if len(qs) == 0:
|
72 |
+
raise ValueError("No queries provided")
|
73 |
+
if len(ps) == 0:
|
74 |
+
raise ValueError("No passages provided")
|
75 |
+
|
76 |
+
qs_stacked = torch.stack(qs).to(device)
|
77 |
+
ps_stacked = torch.stack(ps).to(device)
|
78 |
+
|
79 |
+
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
|
80 |
+
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
|
81 |
+
|
82 |
+
scores = scores.to(torch.float32)
|
83 |
+
return scores
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def score_multi_vector(
|
87 |
+
qs: List[torch.Tensor],
|
88 |
+
ps: List[torch.Tensor],
|
89 |
+
batch_size: int = 128,
|
90 |
+
device: Optional[Union[str, torch.device]] = None,
|
91 |
+
) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
|
94 |
+
"""
|
95 |
+
device = device or get_torch_device("auto")
|
96 |
+
|
97 |
+
if len(qs) == 0:
|
98 |
+
raise ValueError("No queries provided")
|
99 |
+
if len(ps) == 0:
|
100 |
+
raise ValueError("No passages provided")
|
101 |
+
|
102 |
+
scores_list: List[torch.Tensor] = []
|
103 |
+
|
104 |
+
for i in range(0, len(qs), batch_size):
|
105 |
+
scores_batch = []
|
106 |
+
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
|
107 |
+
device
|
108 |
+
)
|
109 |
+
for j in range(0, len(ps), batch_size):
|
110 |
+
ps_batch = torch.nn.utils.rnn.pad_sequence(
|
111 |
+
ps[j : j + batch_size], batch_first=True, padding_value=0
|
112 |
+
).to(device)
|
113 |
+
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
|
114 |
+
scores_batch = torch.cat(scores_batch, dim=1).cpu()
|
115 |
+
scores_list.append(scores_batch)
|
116 |
+
|
117 |
+
scores = torch.cat(scores_list, dim=0)
|
118 |
+
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
|
119 |
+
|
120 |
+
scores = scores.to(torch.float32)
|
121 |
+
return scores
|