|
from .py_generate import PyGenerator |
|
from .generator_types import Generator |
|
from .model import ModelBase, GPT4, GPT35, GPTDavinci, Samba |
|
|
|
def generator_factory(lang: str) -> Generator: |
|
if lang == "py" or lang == "python": |
|
return PyGenerator() |
|
else: |
|
raise ValueError(f"Invalid language for generator: {lang}") |
|
|
|
|
|
def model_factory(model_name: str) -> ModelBase: |
|
print(model_name) |
|
if model_name == "gpt-4": |
|
return GPT4() |
|
elif model_name == "samba": |
|
return Samba() |
|
elif model_name == "gpt-3.5-turbo-0613": |
|
return GPT35() |
|
elif model_name.startswith("text-davinci"): |
|
return GPTDavinci(model_name) |
|
else: |
|
raise ValueError(f"Invalid model name: {model_name}") |
|
|