File size: 4,790 Bytes
300a419 93bef3b 300a419 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
####################################################################
## Built on top of Ultravox: https://github.com/fixie-ai/ultravox ##
####################################################################
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import transformers
# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .shuka_model import ShukaModel
from .shuka_processing import ShukaProcessor
class ShukaPipeline(transformers.Pipeline):
def __init__(
self,
model: ShukaModel,
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
audio_processor: Optional[transformers.ProcessorMixin] = None,
**kwargs
):
if tokenizer is None:
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config._name_or_path
)
except:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config.text_model_id or model.config.text_config._name_or_path
)
if audio_processor is None:
audio_processor = transformers.AutoProcessor.from_pretrained(
model.config.audio_model_id or model.config.audio_config._name_or_path
)
self.processor = ShukaProcessor(
audio_processor=audio_processor,
tokenizer=tokenizer,
stack_factor=model.config.stack_factor,
)
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
def _sanitize_parameters(self, **kwargs):
generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
return {}, generation_kwargs, {}
def preprocess(self, inputs: Dict[str, Any]):
turns: list = inputs.get("turns", [])
audio = inputs.get("audio", None)
# Convert to float32 if needed.
if isinstance(audio, np.ndarray):
if audio.dtype == np.float64:
audio = audio.astype(np.float32)
elif audio.dtype == np.int16:
audio = audio.astype(np.float32) / np.float32(32768.0)
elif audio.dtype == np.int32:
audio = audio.astype(np.float32) / np.float32(2147483648.0)
if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
prompt = inputs.get("prompt", "<|audio|>")
if "<|audio|>" not in prompt:
logging.warning(
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
)
prompt += " <|audio|>"
turns.append({"role": "user", "content": prompt})
text = self.processor.tokenizer.apply_chat_template(
turns, add_generation_prompt=True, tokenize=False
)
if "sampling_rate" not in inputs and audio is not None:
logging.warning(
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
)
output = self.processor(
text=text,
audio=audio,
sampling_rate=inputs.get("sampling_rate", 16000),
)
if "audio_values" in output:
output["audio_values"] = output["audio_values"].to(self.model.dtype)
return output
def _forward(
self,
model_inputs: Dict[str, Any],
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: float = 1.1,
) -> List[int]:
temperature = temperature or None
do_sample = temperature is not None
terminators = [self.tokenizer.eos_token_id]
if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
input_len = model_inputs["input_ids"].shape[1]
outputs = self.model.generate(
**model_inputs,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
eos_token_id=terminators
)
return outputs[0][input_len:]
def postprocess(self, model_outputs) -> str:
output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
return output_text
transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
"shuka-pipeline",
pipeline_class=ShukaPipeline,
pt_model=transformers.AutoModel,
type="multimodal",
)
|