Spaces:
Sleeping
Sleeping
import torch | |
import transformers | |
from transformers import AutoTokenizer | |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
from lm_eval.api.model import LM | |
from lm_eval.models.huggingface import HFLM | |
from lm_eval.api.registry import register_model | |
from lm_eval.__main__ import cli_evaluate | |
class MambaEvalWrapper(HFLM): | |
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM | |
def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", | |
dtype=torch.float16): | |
LM.__init__(self) | |
self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) | |
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
self.vocab_size = self.tokenizer.vocab_size | |
self._batch_size = int(batch_size) if batch_size is not None else 64 | |
self._max_length = max_length | |
self._device = torch.device(device) | |
def batch_size(self): | |
return self._batch_size | |
def _model_generate(self, context, max_length, stop, **generation_kwargs): | |
raise NotImplementedError() | |
if __name__ == "__main__": | |
cli_evaluate() | |