# ------------------------------------------------------------------------ # Copyright (c) 2023-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------ """Prompt encoder.""" import torch from torch import nn class PromptEncoder(nn.Module): """Module to encode geometric prompts.""" def __init__(self, embed_dim, image_size): super(PromptEncoder, self).__init__() self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad] self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64) self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2))) self.img_pos, self.img_size = None, [image_size] * 2 def as_tensor(self, input): """Convert input into a tensor.""" return torch.as_tensor(input, device=self.coord_matrix.device) def to_points(self, points=None, boxes=None): """Convert points or boxes to point prompts.""" if points is not None: if isinstance(points, (tuple, list)): coords, labels = points else: coords, labels = points[:, :, :2], points[:, :, 2] coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1]) coords = self.as_tensor(coords.clip(0, 1).astype("float32")) labels = self.as_tensor(labels.astype("int64")) return coords, labels if boxes is not None: coords = boxes.reshape((-1, 2, 2)) coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1]) coords = self.as_tensor(coords.clip(0, 1).astype("float32")) labels = self.as_tensor(self.corner_labels) return coords, labels return None def encode_coords(self, coords): """Return the embedding for given coords.""" pi4, pi2 = 4 * 3.1415926, 2 * 3.1415926 if self.coord_matrix.dtype != torch.float32: self.coord_matrix = self.coord_matrix.float() rad = coords.mul(pi4).sub_(pi2) @ self.coord_matrix dtype = self.point_embed.weight.dtype return torch.cat([rad.sin(), rad.cos()], dim=-1).to(dtype=dtype) def encode_points(self, coords, labels): """Return the embedding for given points.""" embed = self.encode_coords(coords) embed.mul_(labels.ne(4).unsqueeze_(-1).float().to(dtype=embed.dtype)) return embed.add_(self.point_embed(labels)) def encode_grid(self, grid_size): """Return the embedding for a grid of specified size.""" grid = torch.ones(*grid_size, dtype=torch.float32) y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0]) x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1]) coords = self.as_tensor(torch.stack([x, y], dim=-1)) return self.encode_coords(coords) def forward(self, inputs): sparse_embeds = [] if inputs.get("boxes", None) is not None: coords, labels = self.to_points(boxes=inputs["boxes"]) sparse_embeds.append(self.encode_points(coords, labels)) if inputs.get("points", None) is not None: coords, labels = self.to_points(points=inputs["points"]) sparse_embeds.append(self.encode_points(coords, labels)) if len(sparse_embeds) > 1: sparse_embeds = [torch.cat(sparse_embeds, dim=1)] elif len(sparse_embeds) == 0: raise ValueError("Excepted ``points`` or ``boxes`` prompts.") img_embed_size = torch.Size(inputs["img_embeds"].shape[2:-1]) if self.img_pos is None or self.img_pos.shape[0] != img_embed_size.numel(): self.img_pos = self.encode_grid(img_embed_size).flatten(0, 1) return {"sparse_embeds": sparse_embeds[0], "img_pos": self.img_pos}