RegBot4.0 / utils /customLLM.py
Zwea Htet
fixed some bugs
3b7cf58
raw
history blame
No virus
1.05 kB
from typing import Any, List, Mapping, Optional
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
model_name = "bigscience/bloom-560m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, config='T5Config')
pl = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
# device=0, # GPU device number
# max_length=512,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7
)
class CustomLLM(LLM):
pipeline = pl
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
prompt_length = len(prompt)
response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]
# only return newly generated tokens
return response[prompt_length:]
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"name_of_model": self.model_name}
@property
def _llm_type(self) -> str:
return "custom"