|
""" DalleBart processor """ |
|
|
|
import jax.numpy as jnp |
|
|
|
from .configuration import DalleBartConfig |
|
from .text import TextNormalizer |
|
from .tokenizer import DalleBartTokenizer |
|
from .utils import PretrainedFromWandbMixin |
|
|
|
|
|
class DalleBartProcessorBase: |
|
def __init__( |
|
self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int |
|
): |
|
self.tokenizer = tokenizer |
|
self.normalize_text = normalize_text |
|
self.max_text_length = max_text_length |
|
if normalize_text: |
|
self.text_processor = TextNormalizer() |
|
|
|
uncond = self.tokenizer( |
|
"", |
|
return_tensors="jax", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.max_text_length, |
|
).data |
|
self.input_ids_uncond = uncond["input_ids"] |
|
self.attention_mask_uncond = uncond["attention_mask"] |
|
|
|
def __call__(self, text: str = None): |
|
|
|
assert not isinstance(text, str), "text must be a list of strings" |
|
|
|
if self.normalize_text: |
|
text = [self.text_processor(t) for t in text] |
|
res = self.tokenizer( |
|
text, |
|
return_tensors="jax", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.max_text_length, |
|
).data |
|
|
|
n = len(text) |
|
res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0) |
|
res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0) |
|
return res |
|
|
|
@classmethod |
|
def from_pretrained(cls, *args, **kwargs): |
|
tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs) |
|
config = DalleBartConfig.from_pretrained(*args, **kwargs) |
|
return cls(tokenizer, config.normalize_text, config.max_text_length) |
|
|
|
|
|
class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase): |
|
pass |
|
|