CodeLATS / generators /factory.py
Etash Guha
added samba
15d89f9
raw
history blame
1.39 kB
from .py_generate import PyGenerator
from .rs_generate import RsGenerator
from .go_generate import GoGenerator
from .generator_types import Generator
from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci, Samba, GPT4o, GroqBase
def generator_factory(lang: str) -> Generator:
if lang == "py" or lang == "python":
return PyGenerator()
elif lang == "rs" or lang == "rust":
return RsGenerator()
elif lang == "go" or lang == "golang":
return GoGenerator()
else:
raise ValueError(f"Invalid language for generator: {lang}")
def model_factory(model_name: str) -> ModelBase:
if model_name == "gpt-4":
return GPT4()
elif model_name == "gpt-4o":
return GPT4o()
elif model_name == "samba":
return Samba()
elif model_name == "groq":
return GroqBase()
elif model_name == "gpt-3.5-turbo-0613":
return GPT35()
elif model_name == "starchat":
return StarChat()
elif model_name.startswith("codellama"):
# if it has `-` in the name, version was specified
kwargs = {}
if "-" in model_name:
kwargs["version"] = model_name.split("-")[1]
return CodeLlama(**kwargs)
elif model_name.startswith("text-davinci"):
return GPTDavinci(model_name)
else:
raise ValueError(f"Invalid model name: {model_name}")