kenken999's picture
df
e91d22b
raw
history blame
8.38 kB
from langchain.chains.openai_functions import (
create_structured_output_runnable)
from utils.config import get_llm, load_prompt
from langchain_community.callbacks import get_openai_callback
import asyncio
from langchain.chains import LLMChain
import importlib
from pathlib import Path
from tqdm import trange, tqdm
import concurrent.futures
import logging
class DummyCallback:
"""
A dummy callback for the LLM.
This is a trick to handle an empty callback.
"""
def __enter__(self):
self.total_cost = 0
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def get_dummy_callback():
return DummyCallback()
class ChainWrapper:
"""
A wrapper for a LLM chain
"""
def __init__(self, llm_config, prompt_path: str, json_schema: dict = None, parser_func=None):
"""
Initialize a new instance of the ChainWrapper class.
:param llm_config: The config for the LLM
:param prompt_path: A path to the prompt file (text file)
:param json_schema: A dict for the json schema, to get a structured output for the LLM
:param parser_func: A function to parse the output of the LLM
"""
self.llm_config = llm_config
self.llm = get_llm(llm_config)
self.json_schema = json_schema
self.parser_func = parser_func
self.prompt = load_prompt(prompt_path)
self.build_chain()
self.accumulate_usage = 0
if self.llm_config.type == 'OpenAI':
self.callback = get_openai_callback
else:
self.callback = get_dummy_callback
def invoke(self, chain_input: dict) -> dict:
"""
Invoke the chain on a single input
:param chain_input: The input for the chain
:return: A dict with the defined json schema
"""
with self.callback() as cb:
try:
result = self.chain.invoke(chain_input)
if self.parser_func is not None:
result = self.parser_func(result)
except Exception as e:
#raise e
#if e.http_status == 401:
# raise e
#else:
#logging.error('Error in chain invoke: {}'.format(e.user_message))
result = None
self.accumulate_usage += cb.total_cost
return result
async def retry_operation(self, tasks):
"""
Retry an async operation
:param tasks:
:return:
"""
delay = self.llm_config.async_params.retry_interval
timeout = delay * self.llm_config.async_params.max_retries
start_time = asyncio.get_event_loop().time()
end_time = start_time + timeout
results = []
while True:
remaining_time = end_time - asyncio.get_event_loop().time()
if remaining_time <= 0:
print("Timeout reached. Operation incomplete.")
break
done, pending = await asyncio.wait(tasks, timeout=delay)
results += list(done)
if len(done) == len(tasks):
print("All tasks completed successfully.")
break
if not pending:
print("No pending tasks. Operation incomplete.")
break
tasks = list(pending) # Retry with the pending tasks
return results
async def async_batch_invoke(self, inputs: list[dict]) -> list[dict]:
"""
Invoke the chain on a batch of inputs in async mode
:param inputs: A batch of inputs
:return: A list of dicts with the defined json schema
"""
with self.callback() as cb:
tasks = [self.chain.ainvoke(chain_input) for chain_input in inputs]
all_res = await self.retry_operation(tasks)
self.accumulate_usage += cb.total_cost
if self.parser_func is not None:
return [self.parser_func(t.result()) for t in list(all_res)]
return [t.result() for t in list(all_res)]
def batch_invoke(self, inputs: list[dict], num_workers: int):
"""
Invoke the chain on a batch of inputs either async or not
:param inputs: The list of all inputs
:param num_workers: The number of workers
:return: A list of results
"""
def sample_generator():
for sample in inputs:
yield sample
def process_sample_with_progress(sample):
result = self.invoke(sample)
pbar.update(1) # Update the progress bar
return result
if not ('async_params' in self.llm_config.keys()): # non async mode, use regular workers
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
with tqdm(total=len(inputs), desc="Processing samples") as pbar:
all_results = list(executor.map(process_sample_with_progress, sample_generator()))
else:
all_results = []
for i in trange(0, len(inputs), num_workers, desc='Predicting'):
results = asyncio.run(self.async_batch_invoke(inputs[i:i + num_workers]))
all_results += results
all_results = [res for res in all_results if res is not None]
return all_results
def build_chain(self):
"""
Build the chain according to the LLM type
"""
if (self.llm_config.type == 'OpenAI' or self.llm_config.type == 'Azure') and self.json_schema is not None:
#self.chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.chain = create_structured_output_runnable(self.json_schema, self.llm, self.prompt)
else:
self.chain = LLMChain(llm=self.llm, prompt=self.prompt)
def get_chain_metadata(prompt_fn: Path, retrieve_module: bool = False) -> dict:
"""
Get the metadata of the chain
:param prompt_fn: The path to the prompt file
:param retrieve_module: If True, retrieve the module
:return: A dict with the metadata
"""
prompt_directory = str(prompt_fn.parent)
prompt_name = str(prompt_fn.stem)
try:
spec = importlib.util.spec_from_file_location('output_schemes', prompt_directory + '/output_schemes.py')
schema_parser = importlib.util.module_from_spec(spec)
spec.loader.exec_module(schema_parser)
except ImportError as e:
print(f"Error loading module {prompt_directory + '/output_schemes'}: {e}")
if hasattr(schema_parser, '{}_schema'.format(prompt_name)):
json_schema = getattr(schema_parser, '{}_schema'.format(prompt_name))
else:
json_schema = None
if hasattr(schema_parser, '{}_parser'.format(prompt_name)):
parser_func = getattr(schema_parser, '{}_parser'.format(prompt_name))
else:
parser_func = None
result = {'json_schema': json_schema, 'parser_func': parser_func}
if retrieve_module:
result['module'] = schema_parser
return result
class MetaChain:
"""
A wrapper for the meta-prompts chain
"""
def __init__(self, config):
"""
Initialize a new instance of the MetaChain class. Loading all the meta-prompts
:param config: An EasyDict configuration
"""
self.config = config
self.initial_chain = self.load_chain('initial')
self.step_prompt_chain = self.load_chain('step_prompt')
self.step_samples = self.load_chain('step_samples')
self.error_analysis = self.load_chain('error_analysis')
def load_chain(self, chain_name: str) -> ChainWrapper:
"""
Load a chain according to the chain name
:param chain_name: The name of the chain
"""
metadata = get_chain_metadata(self.config.meta_prompts.folder / '{}.prompt'.format(chain_name))
return ChainWrapper(self.config.llm, self.config.meta_prompts.folder / '{}.prompt'.format(chain_name),
metadata['json_schema'], metadata['parser_func'])
def calc_usage(self) -> float:
"""
Calculate the usage of all the meta-prompts
:return: The total usage value
"""
return self.initial_chain.accumulate_usage + self.step_prompt_chain.accumulate_usage \
+ self.step_samples.accumulate_usage