import os import time from dotenv import load_dotenv from distutils.util import strtobool from llama2_wrapper import LLAMA2_WRAPPER def main(): load_dotenv() DEFAULT_SYSTEM_PROMPT = ( os.getenv("DEFAULT_SYSTEM_PROMPT") if os.getenv("DEFAULT_SYSTEM_PROMPT") is not None else "" ) MAX_MAX_NEW_TOKENS = ( int(os.getenv("MAX_MAX_NEW_TOKENS")) if os.getenv("DEFAULT_MAX_NEW_TOKENS") is not None else 2048 ) DEFAULT_MAX_NEW_TOKENS = ( int(os.getenv("DEFAULT_MAX_NEW_TOKENS")) if os.getenv("DEFAULT_MAX_NEW_TOKENS") is not None else 1024 ) MAX_INPUT_TOKEN_LENGTH = ( int(os.getenv("MAX_INPUT_TOKEN_LENGTH")) if os.getenv("MAX_INPUT_TOKEN_LENGTH") is not None else 4000 ) MODEL_PATH = os.getenv("MODEL_PATH") assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}" LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True"))) LOAD_IN_4BIT = bool(strtobool(os.getenv("LOAD_IN_4BIT", "True"))) LLAMA_CPP = bool(strtobool(os.getenv("LLAMA_CPP", "True"))) if LLAMA_CPP: print("Running on CPU with llama.cpp.") else: import torch if torch.cuda.is_available(): print("Running on GPU with torch transformers.") else: print("CUDA not found.") config = { "model_name": MODEL_PATH, "load_in_8bit": LOAD_IN_8BIT, "load_in_4bit": LOAD_IN_4BIT, "llama_cpp": LLAMA_CPP, "MAX_INPUT_TOKEN_LENGTH": MAX_INPUT_TOKEN_LENGTH, } tic = time.perf_counter() llama2_wrapper = LLAMA2_WRAPPER(config) llama2_wrapper.init_tokenizer() llama2_wrapper.init_model() toc = time.perf_counter() print(f"Initialize the model in {toc - tic:0.4f} seconds.") example = "Can you explain briefly to me what is the Python programming language?" generator = llama2_wrapper.run( example, [], DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, 1, 0.95, 50 ) tic = time.perf_counter() try: first_response = next(generator) # history += [(example, first_response)] # print(first_response) except StopIteration: pass # history += [(example, "")] # print(history) for response in generator: # history += [(example, response)] # print(response) pass print(response) toc = time.perf_counter() output_token_length = llama2_wrapper.get_token_length(response) print(f"Generating the out in {toc - tic:0.4f} seconds.") print(f"Speed: {output_token_length / (toc - tic):0.4f} tokens/sec.") if __name__ == "__main__": main()