Spaces:
Sleeping
Sleeping
File size: 1,287 Bytes
306b4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import transformers
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from lm_eval.api.model import LM
from lm_eval.models.huggingface import HFLM
from lm_eval.api.registry import register_model
from lm_eval.__main__ import cli_evaluate
@register_model("mamba")
class MambaEvalWrapper(HFLM):
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
dtype=torch.float16):
LM.__init__(self)
self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = int(batch_size) if batch_size is not None else 64
self._max_length = max_length
self._device = torch.device(device)
@property
def batch_size(self):
return self._batch_size
def _model_generate(self, context, max_length, stop, **generation_kwargs):
raise NotImplementedError()
if __name__ == "__main__":
cli_evaluate()
|