|
from .py_generate import PyGenerator |
|
from .rs_generate import RsGenerator |
|
from .go_generate import GoGenerator |
|
from .generator_types import Generator |
|
from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci, Samba, GPT4o, GroqBase |
|
|
|
|
|
def generator_factory(lang: str) -> Generator: |
|
if lang == "py" or lang == "python": |
|
return PyGenerator() |
|
elif lang == "rs" or lang == "rust": |
|
return RsGenerator() |
|
elif lang == "go" or lang == "golang": |
|
return GoGenerator() |
|
else: |
|
raise ValueError(f"Invalid language for generator: {lang}") |
|
|
|
|
|
def model_factory(model_name: str) -> ModelBase: |
|
if model_name == "gpt-4": |
|
return GPT4() |
|
elif model_name == "gpt-4o": |
|
return GPT4o() |
|
elif model_name == "samba": |
|
return Samba() |
|
elif model_name == "groq": |
|
return GroqBase() |
|
elif model_name == "gpt-3.5-turbo-0613": |
|
return GPT35() |
|
elif model_name == "starchat": |
|
return StarChat() |
|
elif model_name.startswith("codellama"): |
|
|
|
kwargs = {} |
|
if "-" in model_name: |
|
kwargs["version"] = model_name.split("-")[1] |
|
return CodeLlama(**kwargs) |
|
elif model_name.startswith("text-davinci"): |
|
return GPTDavinci(model_name) |
|
else: |
|
raise ValueError(f"Invalid model name: {model_name}") |
|
|