import torch import torch.nn.functional as F from einops import rearrange import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel device = "cuda" tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b-slimpj", device=device, dtype=torch.float16) genlen = 500 def pred(text_in,): tokens = tokenizer(text_in, return_tensors="pt") input_ids = tokens.input_ids.to(device=device) attn_mask = tokens.attention_mask.to(device=device) max_length = input_ids.shape[1] + genlen fn = lambda: model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=False, temperature=0.9, top_p=0.7, ) out = fn() text_out = tokenizer.batch_decode(out.sequences.tolist()) return text_out[0] demo = gr.Interface( title="Mamba: Selective State Space Model", description="A demo for [Mamba](https://github.com/state-spaces/mamba) by Albert & Tri.", fn=pred, inputs="text", outputs="text") if __name__ == "__main__": demo.launch()