import ast
import asyncio
import copy
import functools
import glob
import gzip
import inspect
import json
import os
import pathlib
import pickle
import shutil
import subprocess
import tempfile
import time
import traceback
import types
import typing
import urllib.error
import uuid
import zipfile
from collections import defaultdict
from datetime import datetime
from functools import reduce
from operator import concat
import filelock
import tabulate
import yaml
from joblib import delayed
from langchain.callbacks import streaming_stdout
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms.huggingface_pipeline import VALID_TASKS
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema import LLMResult, Generation
from langchain.tools import PythonREPLTool
from langchain.tools.json.tool import JsonSpec
from tqdm import tqdm
from src.db_utils import length_db1, set_dbid, set_userid, get_dbid, get_userid_direct, get_username_direct, \
set_userid_direct
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
get_list_or_str, have_pillow, only_selenium, only_playwright, only_unstructured_urls, get_sha, get_short_name, \
get_accordion, have_jq, get_doc, get_source, have_chromamigdb, get_token_count, reverse_ucurve_list
from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
LangChainAction, LangChainMode, DocumentChoice, LangChainTypes, font_size, head_acc, super_source_prefix, \
super_source_postfix, langchain_modes_intrinsic, get_langchain_prompts, LangChainAgent
from evaluate_params import gen_hyper, gen_hyper0
from gen import get_model, SEED, get_limited_prompt, get_docs_tokens
from prompter import non_hf_types, PromptType, Prompter
from src.serpapi import H2OSerpAPIWrapper
from utils_langchain import StreamingGradioCallbackHandler, _chunk_sources, _add_meta, add_parser, fix_json_meta
import_matplotlib()
import numpy as np
import pandas as pd
import requests
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
# , GCSDirectoryLoader, GCSFileLoader
# , OutlookMessageLoader # GPL3
# ImageCaptionLoader, # use our own wrapper
# ReadTheDocsLoader, # no special file, some path, so have to give as special option
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \
UnstructuredExcelLoader, JSONLoader
from langchain.text_splitter import Language
from langchain.chains.question_answering import load_qa_chain
from langchain.docstore.document import Document
from langchain import PromptTemplate, HuggingFaceTextGenInference, HuggingFacePipeline
from langchain.vectorstores import Chroma
from chromamig import ChromaMig
def split_list(input_list, split_size):
for i in range(0, len(input_list), split_size):
yield input_list[i:i + split_size]
def get_db(sources, use_openai_embedding=False, db_type='faiss',
persist_directory=None, load_db_if_exists=True,
langchain_mode='notset',
langchain_mode_paths={},
langchain_mode_types={},
collection_name=None,
hf_embedding_model=None,
migrate_embedding_model=False,
auto_migrate_db=False,
n_jobs=-1):
if not sources:
return None
user_path = langchain_mode_paths.get(langchain_mode)
if persist_directory is None:
langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value)
persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type)
langchain_mode_types[langchain_mode] = langchain_type
assert hf_embedding_model is not None
# get freshly-determined embedding model
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
assert collection_name is not None or langchain_mode != 'notset'
if collection_name is None:
collection_name = langchain_mode.replace(' ', '_')
# Create vector database
if db_type == 'faiss':
from langchain.vectorstores import FAISS
db = FAISS.from_documents(sources, embedding)
elif db_type == 'weaviate':
import weaviate
from weaviate.embedded import EmbeddedOptions
from langchain.vectorstores import Weaviate
if os.getenv('WEAVIATE_URL', None):
client = _create_local_weaviate_client()
else:
client = weaviate.Client(
embedded_options=EmbeddedOptions(persistence_data_path=persist_directory)
)
index_name = collection_name.capitalize()
db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
index_name=index_name)
elif db_type in ['chroma', 'chroma_old']:
assert persist_directory is not None
# use_base already handled when making persist_directory, unless was passed into get_db()
makedirs(persist_directory, exist_ok=True)
# see if already actually have persistent db, and deal with possible changes in embedding
db, use_openai_embedding, hf_embedding_model = \
get_existing_db(None, persist_directory, load_db_if_exists, db_type,
use_openai_embedding,
langchain_mode, langchain_mode_paths, langchain_mode_types,
hf_embedding_model, migrate_embedding_model, auto_migrate_db,
verbose=False,
n_jobs=n_jobs)
if db is None:
import logging
logging.getLogger("chromadb").setLevel(logging.ERROR)
if db_type == 'chroma':
from chromadb.config import Settings
settings_extra_kwargs = dict(is_persistent=True)
else:
from chromamigdb.config import Settings
settings_extra_kwargs = dict(chroma_db_impl="duckdb+parquet")
client_settings = Settings(anonymized_telemetry=False,
persist_directory=persist_directory,
**settings_extra_kwargs)
if n_jobs in [None, -1]:
n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2)))
num_threads = max(1, min(n_jobs, 8))
else:
num_threads = max(1, n_jobs)
collection_metadata = {"hnsw:num_threads": num_threads}
from_kwargs = dict(embedding=embedding,
persist_directory=persist_directory,
collection_name=collection_name,
client_settings=client_settings,
collection_metadata=collection_metadata)
if db_type == 'chroma':
import chromadb
api = chromadb.PersistentClient(path=persist_directory)
max_batch_size = api._producer.max_batch_size
sources_batches = split_list(sources, max_batch_size)
for sources_batch in sources_batches:
db = Chroma.from_documents(documents=sources_batch, **from_kwargs)
db.persist()
else:
db = ChromaMig.from_documents(documents=sources, **from_kwargs)
clear_embedding(db)
save_embed(db, use_openai_embedding, hf_embedding_model)
else:
# then just add
# doesn't check or change embedding, just saves it in case not saved yet, after persisting
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
use_openai_embedding=use_openai_embedding,
hf_embedding_model=hf_embedding_model)
else:
raise RuntimeError("No such db_type=%s" % db_type)
# once here, db is not changing and embedding choices in calling functions does not matter
return db
def _get_unique_sources_in_weaviate(db):
batch_size = 100
id_source_list = []
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)
while result['objects']:
id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
last_id = id_source_list[-1][0]
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)
unique_sources = {source for _, source in id_source_list}
return unique_sources
def del_from_db(db, sources, db_type=None):
if db_type in ['chroma', 'chroma_old'] and db is not None:
# sources should be list of x.metadata['source'] from document metadatas
if isinstance(sources, str):
sources = [sources]
else:
assert isinstance(sources, (list, tuple, types.GeneratorType))
metadatas = set(sources)
client_collection = db._client.get_collection(name=db._collection.name,
embedding_function=db._collection._embedding_function)
for source in metadatas:
meta = dict(source=source)
try:
client_collection.delete(where=meta)
except KeyError:
pass
def add_to_db(db, sources, db_type='faiss',
avoid_dup_by_file=False,
avoid_dup_by_content=True,
use_openai_embedding=False,
hf_embedding_model=None):
assert hf_embedding_model is not None
num_new_sources = len(sources)
if not sources:
return db, num_new_sources, []
if db_type == 'faiss':
db.add_documents(sources)
elif db_type == 'weaviate':
# FIXME: only control by file name, not hash yet
if avoid_dup_by_file or avoid_dup_by_content:
unique_sources = _get_unique_sources_in_weaviate(db)
sources = [x for x in sources if x.metadata['source'] not in unique_sources]
num_new_sources = len(sources)
if num_new_sources == 0:
return db, num_new_sources, []
db.add_documents(documents=sources)
elif db_type in ['chroma', 'chroma_old']:
collection = get_documents(db)
# files we already have:
metadata_files = set([x['source'] for x in collection['metadatas']])
if avoid_dup_by_file:
# Too weak in case file changed content, assume parent shouldn't pass true for this for now
raise RuntimeError("Not desired code path")
if avoid_dup_by_content:
# look at hash, instead of page_content
# migration: If no hash previously, avoid updating,
# since don't know if need to update and may be expensive to redo all unhashed files
metadata_hash_ids = set(
[x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
# avoid sources with same hash
sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
num_nohash = len([x for x in sources if not x.metadata.get('hashid')])
print("Found %s new sources (%d have no hash in original source,"
" so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True)
# get new file names that match existing file names. delete existing files we are overridding
dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
print("Removing %s duplicate files from db because ingesting those as new documents" % len(
dup_metadata_files), flush=True)
client_collection = db._client.get_collection(name=db._collection.name,
embedding_function=db._collection._embedding_function)
for dup_file in dup_metadata_files:
dup_file_meta = dict(source=dup_file)
try:
client_collection.delete(where=dup_file_meta)
except KeyError:
pass
num_new_sources = len(sources)
if num_new_sources == 0:
return db, num_new_sources, []
if hasattr(db, '_persist_directory'):
print("Existing db, adding to %s" % db._persist_directory, flush=True)
# chroma only
lock_file = get_db_lock_file(db)
context = filelock.FileLock
else:
lock_file = None
context = NullContext
with context(lock_file):
# this is place where add to db, but others maybe accessing db, so lock access.
# else see RuntimeError: Index seems to be corrupted or unsupported
import chromadb
api = chromadb.PersistentClient(path=db._persist_directory)
max_batch_size = api._producer.max_batch_size
sources_batches = split_list(sources, max_batch_size)
for sources_batch in sources_batches:
db.add_documents(documents=sources_batch)
db.persist()
clear_embedding(db)
# save here is for migration, in case old db directory without embedding saved
save_embed(db, use_openai_embedding, hf_embedding_model)
else:
raise RuntimeError("No such db_type=%s" % db_type)
new_sources_metadata = [x.metadata for x in sources]
return db, num_new_sources, new_sources_metadata
def create_or_update_db(db_type, persist_directory, collection_name,
user_path, langchain_type,
sources, use_openai_embedding, add_if_exists, verbose,
hf_embedding_model, migrate_embedding_model, auto_migrate_db,
n_jobs=-1):
if not os.path.isdir(persist_directory) or not add_if_exists:
if os.path.isdir(persist_directory):
if verbose:
print("Removing %s" % persist_directory, flush=True)
remove(persist_directory)
if verbose:
print("Generating db", flush=True)
if db_type == 'weaviate':
import weaviate
from weaviate.embedded import EmbeddedOptions
if os.getenv('WEAVIATE_URL', None):
client = _create_local_weaviate_client()
else:
client = weaviate.Client(
embedded_options=EmbeddedOptions(persistence_data_path=persist_directory)
)
index_name = collection_name.replace(' ', '_').capitalize()
if client.schema.exists(index_name) and not add_if_exists:
client.schema.delete_class(index_name)
if verbose:
print("Removing %s" % index_name, flush=True)
elif db_type in ['chroma', 'chroma_old']:
pass
if not add_if_exists:
if verbose:
print("Generating db", flush=True)
else:
if verbose:
print("Loading and updating db", flush=True)
db = get_db(sources,
use_openai_embedding=use_openai_embedding,
db_type=db_type,
persist_directory=persist_directory,
langchain_mode=collection_name,
langchain_mode_paths={collection_name: user_path},
langchain_mode_types={collection_name: langchain_type},
hf_embedding_model=hf_embedding_model,
migrate_embedding_model=migrate_embedding_model,
auto_migrate_db=auto_migrate_db,
n_jobs=n_jobs)
return db
from langchain.embeddings import FakeEmbeddings
class H2OFakeEmbeddings(FakeEmbeddings):
"""Fake embedding model, but constant instead of random"""
size: int
"""The size of the embedding vector."""
def _get_embedding(self) -> typing.List[float]:
return [1] * self.size
def embed_documents(self, texts: typing.List[str]) -> typing.List[typing.List[float]]:
return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> typing.List[float]:
return self._get_embedding()
def get_embedding(use_openai_embedding, hf_embedding_model=None, preload=False):
assert hf_embedding_model is not None
# Get embedding model
if use_openai_embedding:
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
from langchain.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings(disallowed_special=())
elif hf_embedding_model == 'fake':
embedding = H2OFakeEmbeddings(size=1)
else:
if isinstance(hf_embedding_model, str):
pass
elif isinstance(hf_embedding_model, dict):
# embedding itself preloaded globally
return hf_embedding_model['model']
else:
# object
return hf_embedding_model
# to ensure can fork without deadlock
from langchain.embeddings import HuggingFaceEmbeddings
device, torch_dtype, context_class = get_device_dtype()
model_kwargs = dict(device=device)
if 'instructor' in hf_embedding_model:
encode_kwargs = {'normalize_embeddings': True}
embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
else:
embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
embedding.client.preload = preload
return embedding
def get_answer_from_sources(chain, sources, question):
return chain(
{
"input_documents": sources,
"question": question,
},
return_only_outputs=True,
)["output_text"]
"""Wrapper around Huggingface text generation inference API."""
from functools import partial
from typing import Any, Dict, List, Optional, Set, Iterable
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain.llms.base import LLM
class GradioInference(LLM):
"""
Gradio generation inference API.
"""
inference_server_url: str = ""
temperature: float = 0.8
top_p: Optional[float] = 0.95
top_k: Optional[int] = None
num_beams: Optional[int] = 1
max_new_tokens: int = 512
min_new_tokens: int = 1
early_stopping: bool = False
max_time: int = 180
repetition_penalty: Optional[float] = None
num_return_sequences: Optional[int] = 1
do_sample: bool = False
chat_client: bool = False
return_full_text: bool = False
stream_output: bool = False
sanitize_bot_response: bool = False
prompter: Any = None
context: Any = ''
iinput: Any = ''
client: Any = None
tokenizer: Any = None
system_prompt: Any = None
visible_models: Any = None
h2ogpt_key: Any = None
count_input_tokens: Any = 0
count_output_tokens: Any = 0
min_max_new_tokens: Any = 256
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
try:
if values['client'] is None:
import gradio_client
values["client"] = gradio_client.Client(
values["inference_server_url"]
)
except ImportError:
raise ImportError(
"Could not import gradio_client python package. "
"Please install it with `pip install gradio_client`."
)
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gradio_inference"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
# NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection,
# so server should get prompt_type or '', not plain
# This is good, so gradio server can also handle stopping.py conditions
# this is different than TGI server that uses prompter to inject prompt_type prompting
stream_output = self.stream_output
gr_client = self.client
client_langchain_mode = 'Disabled'
client_add_chat_history_to_context = True
client_add_search_to_context = False
client_chat_conversation = []
client_langchain_action = LangChainAction.QUERY.value
client_langchain_agents = []
top_k_docs = 1
chunk = True
chunk_size = 512
client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
iinput=self.iinput if self.chat_client else '', # only for chat=True
context=self.context,
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
prompt_type=self.prompter.prompt_type,
prompt_dict='',
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
num_beams=self.num_beams,
max_new_tokens=self.max_new_tokens,
min_new_tokens=self.min_new_tokens,
early_stopping=self.early_stopping,
max_time=self.max_time,
repetition_penalty=self.repetition_penalty,
num_return_sequences=self.num_return_sequences,
do_sample=self.do_sample,
chat=self.chat_client,
instruction_nochat=prompt if not self.chat_client else '',
iinput_nochat=self.iinput if not self.chat_client else '',
langchain_mode=client_langchain_mode,
add_chat_history_to_context=client_add_chat_history_to_context,
langchain_action=client_langchain_action,
langchain_agents=client_langchain_agents,
top_k_docs=top_k_docs,
chunk=chunk,
chunk_size=chunk_size,
document_subset=DocumentSubset.Relevant.name,
document_choice=[DocumentChoice.ALL.value],
pre_prompt_query=None,
prompt_query=None,
pre_prompt_summary=None,
prompt_summary=None,
system_prompt=self.system_prompt,
image_loaders=None, # don't need to further do doc specific things
pdf_loaders=None, # don't need to further do doc specific things
url_loaders=None, # don't need to further do doc specific things
jq_schema=None, # don't need to further do doc specific things
visible_models=self.visible_models,
h2ogpt_key=self.h2ogpt_key,
add_search_to_context=client_add_search_to_context,
chat_conversation=client_chat_conversation,
text_context_list=None,
docs_ordering_type=None,
min_max_new_tokens=self.min_max_new_tokens,
)
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
self.count_input_tokens += self.get_num_tokens(prompt)
if not stream_output:
res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
res_dict = ast.literal_eval(res)
text = res_dict['response']
ret = self.prompter.get_response(prompt + text, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
self.count_output_tokens += self.get_num_tokens(ret)
return ret
else:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
text0 = ''
while not job.done():
if job.communicator.job.latest_status.code.name == 'FINISHED':
break
e = job.future._exception
if e is not None:
break
outputs_list = job.communicator.job.outputs
if outputs_list:
res = job.communicator.job.outputs[-1]
res_dict = ast.literal_eval(res)
text = res_dict['response']
text = self.prompter.get_response(prompt + text, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
# FIXME: derive chunk from full for now
text_chunk = text[len(text0):]
if not text_chunk:
continue
# save old
text0 = text
if text_callback:
text_callback(text_chunk)
time.sleep(0.01)
# ensure get last output to avoid race
res_all = job.outputs()
if len(res_all) > 0:
res = res_all[-1]
res_dict = ast.literal_eval(res)
text = res_dict['response']
# FIXME: derive chunk from full for now
else:
# go with old if failure
text = text0
text_chunk = text[len(text0):]
if text_callback:
text_callback(text_chunk)
ret = self.prompter.get_response(prompt + text, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
self.count_output_tokens += self.get_num_tokens(ret)
return ret
def get_token_ids(self, text: str) -> List[int]:
return self.tokenizer.encode(text)
# avoid base method that is not aware of how to properly tokenize (uses GPT2)
# return _get_token_ids_default_method(text)
class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
max_new_tokens: int = 512
do_sample: bool = False
top_k: Optional[int] = None
top_p: Optional[float] = 0.95
typical_p: Optional[float] = 0.95
temperature: float = 0.8
repetition_penalty: Optional[float] = None
return_full_text: bool = False
stop_sequences: List[str] = Field(default_factory=list)
seed: Optional[int] = None
inference_server_url: str = ""
timeout: int = 300
headers: dict = None
stream_output: bool = False
sanitize_bot_response: bool = False
prompter: Any = None
context: Any = ''
iinput: Any = ''
tokenizer: Any = None
async_sem: Any = None
count_input_tokens: Any = 0
count_output_tokens: Any = 0
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is None:
stop = self.stop_sequences.copy()
else:
stop += self.stop_sequences.copy()
stop_tmp = stop.copy()
stop = []
[stop.append(x) for x in stop_tmp if x not in stop]
# HF inference server needs control over input tokens
assert self.tokenizer is not None
from h2oai_pipeline import H2OTextGenerationPipeline
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
# NOTE: TGI server does not add prompting, so must do here
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
prompt = self.prompter.generate_prompt(data_point)
self.count_input_tokens += self.get_num_tokens(prompt)
gen_server_kwargs = dict(do_sample=self.do_sample,
stop_sequences=stop,
max_new_tokens=self.max_new_tokens,
top_k=self.top_k,
top_p=self.top_p,
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
return_full_text=self.return_full_text,
seed=self.seed,
)
gen_server_kwargs.update(kwargs)
# lower bound because client is re-used if multi-threading
self.client.timeout = max(300, self.timeout)
if not self.stream_output:
res = self.client.generate(
prompt,
**gen_server_kwargs,
)
if self.return_full_text:
gen_text = res.generated_text[len(prompt):]
else:
gen_text = res.generated_text
# remove stop sequences from the end of the generated text
for stop_seq in stop:
if stop_seq in gen_text:
gen_text = gen_text[:gen_text.index(stop_seq)]
text = prompt + gen_text
text = self.prompter.get_response(text, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
else:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
if text_callback:
text_callback(prompt)
text = ""
# Note: Streaming ignores return_full_text=True
for response in self.client.generate_stream(prompt, **gen_server_kwargs):
text_chunk = response.token.text
text += text_chunk
text = self.prompter.get_response(prompt + text, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
# stream part
is_stop = False
for stop_seq in stop:
if stop_seq in text_chunk:
is_stop = True
break
if is_stop:
break
if not response.token.special:
if text_callback:
text_callback(text_chunk)
self.count_output_tokens += self.get_num_tokens(text)
return text
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
# print("acall", flush=True)
if stop is None:
stop = self.stop_sequences.copy()
else:
stop += self.stop_sequences.copy()
stop_tmp = stop.copy()
stop = []
[stop.append(x) for x in stop_tmp if x not in stop]
# HF inference server needs control over input tokens
assert self.tokenizer is not None
from h2oai_pipeline import H2OTextGenerationPipeline
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
# NOTE: TGI server does not add prompting, so must do here
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
prompt = self.prompter.generate_prompt(data_point)
gen_text = await super()._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
# remove stop sequences from the end of the generated text
for stop_seq in stop:
if stop_seq in gen_text:
gen_text = gen_text[:gen_text.index(stop_seq)]
text = prompt + gen_text
text = self.prompter.get_response(text, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
# print("acall done", flush=True)
return text
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
self.count_input_tokens += sum([self.get_num_tokens(prompt) for prompt in prompts])
tasks = [
asyncio.ensure_future(self._agenerate_one(prompt, stop=stop, run_manager=run_manager,
new_arg_supported=new_arg_supported, **kwargs))
for prompt in prompts
]
texts = await asyncio.gather(*tasks)
self.count_output_tokens += sum([self.get_num_tokens(text) for text in texts])
[generations.append([Generation(text=text)]) for text in texts]
return LLMResult(generations=generations)
async def _agenerate_one(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
new_arg_supported=None,
**kwargs: Any,
) -> str:
async with self.async_sem: # semaphore limits num of simultaneous downloads
return await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) \
if new_arg_supported else \
await self._acall(prompt, stop=stop, **kwargs)
def get_token_ids(self, text: str) -> List[int]:
return self.tokenizer.encode(text)
# avoid base method that is not aware of how to properly tokenize (uses GPT2)
# return _get_token_ids_default_method(text)
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Replicate
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
update_token_usage
class H2OOpenAI(OpenAI):
"""
New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here
Handles prompting that OpenAI doesn't need, stopping as well
"""
stop_sequences: Any = None
sanitize_bot_response: bool = False
prompter: Any = None
context: Any = ''
iinput: Any = ''
tokenizer: Any = None
@classmethod
def _all_required_field_names(cls) -> Set:
_all_required_field_names = super(OpenAI, cls)._all_required_field_names()
_all_required_field_names.update(
{'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter',
'tokenizer', 'logit_bias'})
return _all_required_field_names
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop
stop = []
[stop.append(x) for x in stop_tmp if x not in stop]
# HF inference server needs control over input tokens
assert self.tokenizer is not None
from h2oai_pipeline import H2OTextGenerationPipeline
for prompti, prompt in enumerate(prompts):
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
# NOTE: OpenAI/vLLM server does not add prompting, so must do here
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
prompt = self.prompter.generate_prompt(data_point)
prompts[prompti] = prompt
params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
text = ''
for _prompts in sub_prompts:
if self.streaming:
text_with_prompt = ""
prompt = _prompts[0]
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template()
first = True
for stream_resp in completion_with_retry(
self, prompt=_prompts, **params
):
if first:
stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"]
first = False
text_chunk = stream_resp["choices"][0]["text"]
text_with_prompt += text_chunk
text = self.prompter.get_response(text_with_prompt, prompt=prompt,
sanitize_bot_response=self.sanitize_bot_response)
if run_manager:
run_manager.on_llm_new_token(
text_chunk,
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
)
_update_response(response, stream_resp)
choices.extend(response["choices"])
else:
response = completion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage)
if self.streaming:
choices[0]['text'] = text
return self.create_llm_result(choices, prompts, token_usage)
def get_token_ids(self, text: str) -> List[int]:
if self.tokenizer is not None:
return self.tokenizer.encode(text)
else:
# OpenAI uses tiktoken
return super().get_token_ids(text)
class H2OReplicate(Replicate):
stop_sequences: Any = None
sanitize_bot_response: bool = False
prompter: Any = None
context: Any = ''
iinput: Any = ''
tokenizer: Any = None
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to replicate endpoint."""
stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop
stop = []
[stop.append(x) for x in stop_tmp if x not in stop]
# HF inference server needs control over input tokens
assert self.tokenizer is not None
from h2oai_pipeline import H2OTextGenerationPipeline
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
# Note Replicate handles the prompting of the specific model
return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
def get_token_ids(self, text: str) -> List[int]:
return self.tokenizer.encode(text)
# avoid base method that is not aware of how to properly tokenize (uses GPT2)
# return _get_token_ids_default_method(text)
class H2OChatOpenAI(ChatOpenAI):
@classmethod
def _all_required_field_names(cls) -> Set:
_all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
return _all_required_field_names
class H2OAzureChatOpenAI(AzureChatOpenAI):
@classmethod
def _all_required_field_names(cls) -> Set:
_all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
return _all_required_field_names
class H2OAzureOpenAI(AzureOpenAI):
@classmethod
def _all_required_field_names(cls) -> Set:
_all_required_field_names = super(AzureOpenAI, cls)._all_required_field_names()
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
return _all_required_field_names
class H2OHuggingFacePipeline(HuggingFacePipeline):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
response = self.pipeline(prompt, stop=stop)
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt):]
elif self.pipeline.task == "text2text-generation":
text = response[0]["generated_text"]
elif self.pipeline.task == "summarization":
text = response[0]["summary_text"]
else:
raise ValueError(
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
def get_llm(use_openai_model=False,
model_name=None,
model=None,
tokenizer=None,
inference_server=None,
langchain_only_model=None,
stream_output=False,
async_output=True,
num_async=3,
do_sample=False,
temperature=0.1,
top_k=40,
top_p=0.7,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
early_stopping=False,
max_time=180,
repetition_penalty=1.0,
num_return_sequences=1,
prompt_type=None,
prompt_dict=None,
prompter=None,
context=None,
iinput=None,
sanitize_bot_response=False,
system_prompt='',
visible_models=0,
h2ogpt_key=None,
min_max_new_tokens=None,
n_jobs=None,
cli=False,
llamacpp_dict=None,
verbose=False,
):
# currently all but h2oai_pipeline case return prompt + new text, but could change
only_new_text = False
if n_jobs in [None, -1]:
n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2)))
if inference_server is None:
inference_server = ''
if inference_server.startswith('replicate'):
model_string = ':'.join(inference_server.split(':')[1:])
if 'meta/llama' in model_string:
temperature = max(0.01, temperature if do_sample else 0)
else:
temperature =temperature if do_sample else 0
gen_kwargs = dict(temperature=temperature,
seed=1234,
max_length=max_new_tokens, # langchain
max_new_tokens=max_new_tokens, # replicate docs
top_p=top_p if do_sample else 1,
top_k=top_k, # not always supported
repetition_penalty=repetition_penalty)
if system_prompt in [None, 'None', 'auto']:
if prompter.system_prompt:
system_prompt = prompter.system_prompt
else:
system_prompt = ''
if system_prompt:
gen_kwargs.update(dict(system_prompt=system_prompt))
# replicate handles prompting, so avoid get_response() filter
prompter.prompt_type = 'plain'
if stream_output:
callbacks = [StreamingGradioCallbackHandler()]
streamer = callbacks[0] if stream_output else None
llm = H2OReplicate(
streaming=True,
callbacks=callbacks,
model=model_string,
input=gen_kwargs,
stop=prompter.stop_sequences,
stop_sequences=prompter.stop_sequences,
sanitize_bot_response=sanitize_bot_response,
prompter=prompter,
context=context,
iinput=iinput,
tokenizer=tokenizer,
)
else:
streamer = None
llm = H2OReplicate(
model=model_string,
input=gen_kwargs,
stop=prompter.stop_sequences,
stop_sequences=prompter.stop_sequences,
sanitize_bot_response=sanitize_bot_response,
prompter=prompter,
context=context,
iinput=iinput,
tokenizer=tokenizer,
)
elif use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
if use_openai_model and model_name is None:
model_name = "gpt-3.5-turbo"
# FIXME: Will later import be ignored? I think so, so should be fine
openai, inf_type, deployment_name, base_url, api_version = set_openai(inference_server)
kwargs_extra = {}
if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
cls = H2OChatOpenAI
# FIXME: Support context, iinput
# if inf_type == 'vllm_chat':
# kwargs_extra.update(dict(tokenizer=tokenizer))
openai_api_key = openai.api_key
elif inf_type == 'openai_azure_chat':
cls = H2OAzureChatOpenAI
kwargs_extra.update(dict(openai_api_type='azure'))
# FIXME: Support context, iinput
if os.getenv('OPENAI_AZURE_KEY') is not None:
openai_api_key = os.getenv('OPENAI_AZURE_KEY')
else:
openai_api_key = openai.api_key
elif inf_type == 'openai_azure':
cls = H2OAzureOpenAI
kwargs_extra.update(dict(openai_api_type='azure'))
# FIXME: Support context, iinput
if os.getenv('OPENAI_AZURE_KEY') is not None:
openai_api_key = os.getenv('OPENAI_AZURE_KEY')
else:
openai_api_key = openai.api_key
else:
cls = H2OOpenAI
if inf_type == 'vllm':
kwargs_extra.update(dict(stop_sequences=prompter.stop_sequences,
sanitize_bot_response=sanitize_bot_response,
prompter=prompter,
context=context,
iinput=iinput,
tokenizer=tokenizer,
openai_api_base=openai.api_base,
client=None))
else:
assert inf_type == 'openai' or use_openai_model
openai_api_key = openai.api_key
if deployment_name:
kwargs_extra.update(dict(deployment_name=deployment_name))
if api_version:
kwargs_extra.update(dict(openai_api_version=api_version))
elif openai.api_version:
kwargs_extra.update(dict(openai_api_version=openai.api_version))
elif inf_type in ['openai_azure', 'openai_azure_chat']:
kwargs_extra.update(dict(openai_api_version="2023-05-15"))
if base_url:
kwargs_extra.update(dict(openai_api_base=base_url))
else:
kwargs_extra.update(dict(openai_api_base=openai.api_base))
callbacks = [StreamingGradioCallbackHandler()]
llm = cls(model_name=model_name,
temperature=temperature if do_sample else 0,
# FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py
max_tokens=max_new_tokens,
top_p=top_p if do_sample else 1,
frequency_penalty=0,
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
callbacks=callbacks if stream_output else None,
openai_api_key=openai_api_key,
logit_bias=None if inf_type == 'vllm' else {},
max_retries=6,
streaming=stream_output,
**kwargs_extra
)
streamer = callbacks[0] if stream_output else None
if inf_type in ['openai', 'openai_chat', 'openai_azure', 'openai_azure_chat']:
prompt_type = inference_server
else:
# vllm goes here
prompt_type = prompt_type or 'plain'
elif inference_server and inference_server.startswith('sagemaker'):
callbacks = [StreamingGradioCallbackHandler()] # FIXME
streamer = None
endpoint_name = ':'.join(inference_server.split(':')[1:2])
region_name = ':'.join(inference_server.split(':')[2:])
from sagemaker import H2OSagemakerEndpoint, ChatContentHandler, BaseContentHandler
if inference_server.startswith('sagemaker_chat'):
content_handler = ChatContentHandler()
else:
content_handler = BaseContentHandler()
model_kwargs = dict(temperature=temperature if do_sample else 1E-10,
return_full_text=False, top_p=top_p, max_new_tokens=max_new_tokens)
llm = H2OSagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region_name,
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
model_kwargs=model_kwargs,
content_handler=content_handler,
endpoint_kwargs={'CustomAttributes': 'accept_eula=true'},
)
elif inference_server:
assert inference_server.startswith(
'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server
from gradio_utils.grclient import GradioClient
from text_generation import Client as HFClient
if isinstance(model, GradioClient):
gr_client = model
hf_client = None
else:
gr_client = None
hf_client = model
assert isinstance(hf_client, HFClient)
inference_server, headers = get_hf_server(inference_server)
# quick sanity check to avoid long timeouts, just see if can reach server
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
callbacks = [StreamingGradioCallbackHandler()]
if gr_client:
async_output = False # FIXME: not implemented yet
chat_client = False
llm = GradioInference(
inference_server_url=inference_server,
return_full_text=False,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
chat_client=chat_client,
callbacks=callbacks if stream_output else None,
stream_output=stream_output,
prompter=prompter,
context=context,
iinput=iinput,
client=gr_client,
sanitize_bot_response=sanitize_bot_response,
tokenizer=tokenizer,
system_prompt=system_prompt,
visible_models=visible_models,
h2ogpt_key=h2ogpt_key,
min_max_new_tokens=min_max_new_tokens,
)
elif hf_client:
# no need to pass original client, no state and fast, so can use same validate_environment from base class
async_sem = asyncio.Semaphore(num_async) if async_output else NullContext()
llm = H2OHuggingFaceTextGenInference(
inference_server_url=inference_server,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=False, # this only controls internal behavior, still returns processed text
seed=SEED,
stop_sequences=prompter.stop_sequences,
temperature=temperature,
top_k=top_k,
top_p=top_p,
# typical_p=top_p,
callbacks=callbacks if stream_output else None,
stream_output=stream_output,
prompter=prompter,
context=context,
iinput=iinput,
tokenizer=tokenizer,
timeout=max_time,
sanitize_bot_response=sanitize_bot_response,
async_sem=async_sem,
)
else:
raise RuntimeError("No defined client")
streamer = callbacks[0] if stream_output else None
elif model_name in non_hf_types:
async_output = False # FIXME: not implemented yet
assert langchain_only_model
if model_name == 'llama':
callbacks = [StreamingGradioCallbackHandler()]
streamer = callbacks[0] if stream_output else None
else:
# stream_output = False
# doesn't stream properly as generator, but at least
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
streamer = None
if prompter:
prompt_type = prompter.prompt_type
else:
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output)
pass # assume inputted prompt_type is correct
from gpt4all_llm import get_llm_gpt4all
max_max_tokens = tokenizer.model_max_length
llm = get_llm_gpt4all(model_name,
model=model,
max_new_tokens=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
callbacks=callbacks,
n_jobs=n_jobs,
verbose=verbose,
streaming=stream_output,
prompter=prompter,
context=context,
iinput=iinput,
max_seq_len=max_max_tokens,
llamacpp_dict=llamacpp_dict,
)
elif hasattr(model, 'is_exlama') and model.is_exlama():
async_output = False # FIXME: not implemented yet
assert langchain_only_model
callbacks = [StreamingGradioCallbackHandler()]
streamer = callbacks[0] if stream_output else None
max_max_tokens = tokenizer.model_max_length
from src.llm_exllama import Exllama
llm = Exllama(streaming=stream_output,
model_path=None,
model=model,
lora_path=None,
temperature=temperature,
top_k=top_k,
top_p=top_p,
typical=.7,
beams=1,
# beam_length = 40,
stop_sequences=prompter.stop_sequences,
callbacks=callbacks,
verbose=verbose,
max_seq_len=max_max_tokens,
fused_attn=False,
# alpha_value = 1.0, #For use with any models
# compress_pos_emb = 4.0, #For use with superhot
# set_auto_map = "3, 2" #Gpu split, this will split 3gigs/2gigs
prompter=prompter,
context=context,
iinput=iinput,
)
else:
async_output = False # FIXME: not implemented yet
if model is None:
# only used if didn't pass model in
assert tokenizer is None
prompt_type = 'human_bot'
if model_name is None:
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
inference_server = ''
model, tokenizer, device = get_model(load_8bit=True, base_model=model_name,
inference_server=inference_server, gpu_id=0)
max_max_tokens = tokenizer.model_max_length
only_new_text = True
gen_kwargs = dict(do_sample=do_sample,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
return_full_text=not only_new_text,
handle_long_generation=None)
if do_sample:
gen_kwargs.update(dict(temperature=temperature,
top_k=top_k,
top_p=top_p))
assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
else:
assert len(set(gen_hyper0).difference(gen_kwargs.keys())) == 0
if stream_output:
skip_prompt = only_new_text
from gen import H2OTextIteratorStreamer
decoder_kwargs = {}
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
gen_kwargs.update(dict(streamer=streamer))
else:
streamer = None
from h2oai_pipeline import H2OTextGenerationPipeline
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
prompter=prompter,
context=context,
iinput=iinput,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
sanitize_bot_response=sanitize_bot_response,
chat=False, stream_output=stream_output,
tokenizer=tokenizer,
# leave some room for 1 paragraph, even if min_new_tokens=0
max_input_tokens=max_max_tokens - max(min_new_tokens, 256),
base_model=model_name,
**gen_kwargs)
# pipe.task = "text-generation"
# below makes it listen only to our prompt removal,
# not built in prompt removal that is less general and not specific for our model
pipe.task = "text2text-generation"
llm = H2OHuggingFacePipeline(pipeline=pipe)
return llm, model_name, streamer, prompt_type, async_output, only_new_text
def get_device_dtype():
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
import torch
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
device = 'cpu' if n_gpus == 0 else 'cuda'
# from utils import NullContext
# context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class
context_class = torch.device
torch_dtype = torch.float16 if device == 'cuda' else torch.float32
return device, torch_dtype, context_class
def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
"""
Get wikipedia data from online
:param title:
:param first_paragraph_only:
:param text_limit:
:param take_head:
:return:
"""
filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head)
url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}"
if first_paragraph_only:
url += "&exintro=1"
import json
if not os.path.isfile(filename):
data = requests.get(url).json()
json.dump(data, open(filename, 'wt'))
else:
data = json.load(open(filename, "rt"))
page_content = list(data["query"]["pages"].values())[0]["extract"]
if take_head is not None and text_limit is not None:
page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
title_url = str(title).replace(' ', '_')
return Document(
page_content=str(page_content),
metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"},
)
def get_wiki_sources(first_para=True, text_limit=None):
"""
Get specific named sources from wikipedia
:param first_para:
:param text_limit:
:return:
"""
default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux']
wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources))
return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources]
def get_github_docs(repo_owner, repo_name):
"""
Access github from specific repo
:param repo_owner:
:param repo_name:
:return:
"""
with tempfile.TemporaryDirectory() as d:
subprocess.check_call(
f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
cwd=d,
shell=True,
)
git_sha = (
subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
.decode("utf-8")
.strip()
)
repo_path = pathlib.Path(d)
markdown_files = list(repo_path.glob("*/*.md")) + list(
repo_path.glob("*/*.mdx")
)
for markdown_file in markdown_files:
with open(markdown_file, "r") as f:
relative_path = markdown_file.relative_to(repo_path)
github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
yield Document(page_content=str(f.read()), metadata={"source": github_url})
def get_dai_pickle(dest="."):
from huggingface_hub import hf_hub_download
# True for case when locally already logged in with correct token, so don't have to set key
token = os.getenv('HUGGING_FACE_HUB_TOKEN', True)
path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset')
shutil.copy(path_to_zip_file, dest)
def get_dai_docs(from_hf=False, get_pickle=True):
"""
Consume DAI documentation, or consume from public pickle
:param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain
:param get_pickle: Avoid raw DAI docs, just get pickle directly from HF
:return:
"""
import pickle
if get_pickle:
get_dai_pickle()
dai_store = 'dai_docs.pickle'
dst = "working_dir_docs"
if not os.path.isfile(dai_store):
from create_data import setup_dai_docs
dst = setup_dai_docs(dst=dst, from_hf=from_hf)
import glob
files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
basedir = os.path.abspath(os.getcwd())
from create_data import rst_to_outputs
new_outputs = rst_to_outputs(files)
os.chdir(basedir)
pickle.dump(new_outputs, open(dai_store, 'wb'))
else:
new_outputs = pickle.load(open(dai_store, 'rb'))
sources = []
for line, file in new_outputs:
# gradio requires any linked file to be with app.py
sym_src = os.path.abspath(os.path.join(dst, file))
sym_dst = os.path.abspath(os.path.join(os.getcwd(), file))
if os.path.lexists(sym_dst):
os.remove(sym_dst)
os.symlink(sym_src, sym_dst)
itm = Document(page_content=str(line), metadata={"source": file})
# NOTE: yield has issues when going into db, loses metadata
# yield itm
sources.append(itm)
return sources
def get_supported_types():
non_image_types0 = ["pdf", "txt", "csv", "toml", "py", "rst", "xml", "rtf",
"md",
"html", "mhtml", "htm",
"enex", "eml", "epub", "odt", "pptx", "ppt",
"zip",
"gz",
"gzip",
"urls",
]
# "msg", GPL3
video_types0 = ['WEBM',
'MPG', 'MP2', 'MPEG', 'MPE', '.PV',
'OGG',
'MP4', 'M4P', 'M4V',
'AVI', 'WMV',
'MOV', 'QT',
'FLV', 'SWF',
'AVCHD']
video_types0 = [x.lower() for x in video_types0]
if have_pillow:
from PIL import Image
exts = Image.registered_extensions()
image_types0 = {ex for ex, f in exts.items() if f in Image.OPEN if ex not in video_types0 + non_image_types0}
image_types0 = sorted(image_types0)
image_types0 = [x[1:] if x.startswith('.') else x for x in image_types0]
else:
image_types0 = []
return non_image_types0, image_types0, video_types0
non_image_types, image_types, video_types = get_supported_types()
set_image_types = set(image_types)
if have_libreoffice or True:
# or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that
non_image_types.extend(["docx", "doc", "xls", "xlsx"])
if have_jq:
non_image_types.extend(["json", "jsonl"])
file_types = non_image_types + image_types
def try_as_html(file):
# try treating as html as occurs when scraping websites
from bs4 import BeautifulSoup
with open(file, "rt") as f:
try:
is_html = bool(BeautifulSoup(f.read(), "html.parser").find())
except: # FIXME
is_html = False
if is_html:
file_url = 'file://' + file
doc1 = UnstructuredURLLoader(urls=[file_url]).load()
doc1 = [x for x in doc1 if x.page_content]
else:
doc1 = []
return doc1
def json_metadata_func(record: dict, metadata: dict) -> dict:
# Define the metadata extraction function.
if isinstance(record, dict):
metadata["sender_name"] = record.get("sender_name")
metadata["timestamp_ms"] = record.get("timestamp_ms")
if "source" in metadata:
metadata["source_json"] = metadata['source']
if "seq_num" in metadata:
metadata["seq_num_json"] = metadata['seq_num']
return metadata
def file_to_doc(file,
filei=0,
base_path=None, verbose=False, fail_any_exception=False,
chunk=True, chunk_size=512, n_jobs=-1,
is_url=False, is_txt=False,
# urls
use_unstructured=True,
use_playwright=False,
use_selenium=False,
# pdfs
use_pymupdf='auto',
use_unstructured_pdf='auto',
use_pypdf='auto',
enable_pdf_ocr='auto',
try_pdf_as_html='auto',
enable_pdf_doctr='auto',
# images
enable_ocr=False,
enable_doctr=False,
enable_pix2struct=False,
enable_captions=True,
captions_model=None,
model_loaders=None,
# json
jq_schema='.[]',
headsize=50, # see also H2OSerpAPIWrapper
db_type=None,
selected_file_types=None):
assert isinstance(model_loaders, dict)
if selected_file_types is not None:
set_image_types1 = set_image_types.intersection(set(selected_file_types))
else:
set_image_types1 = set_image_types
assert db_type is not None
chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type)
add_meta = functools.partial(_add_meta, headsize=headsize, filei=filei)
# FIXME: if zip, file index order will not be correct if other files involved
path_to_docs_func = functools.partial(path_to_docs,
verbose=verbose,
fail_any_exception=fail_any_exception,
n_jobs=n_jobs,
chunk=chunk, chunk_size=chunk_size,
# url=file if is_url else None,
# text=file if is_txt else None,
# urls
use_unstructured=use_unstructured,
use_playwright=use_playwright,
use_selenium=use_selenium,
# pdfs
use_pymupdf=use_pymupdf,
use_unstructured_pdf=use_unstructured_pdf,
use_pypdf=use_pypdf,
enable_pdf_ocr=enable_pdf_ocr,
enable_pdf_doctr=enable_pdf_doctr,
try_pdf_as_html=try_pdf_as_html,
# images
enable_ocr=enable_ocr,
enable_doctr=enable_doctr,
enable_pix2struct=enable_pix2struct,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=model_loaders['caption'],
doctr_loader=model_loaders['doctr'],
pix2struct_loader=model_loaders['pix2struct'],
# json
jq_schema=jq_schema,
db_type=db_type,
)
if file is None:
if fail_any_exception:
raise RuntimeError("Unexpected None file")
else:
return []
doc1 = [] # in case no support, or disabled support
if base_path is None and not is_txt and not is_url:
# then assume want to persist but don't care which path used
# can't be in base_path
dir_name = os.path.dirname(file)
base_name = os.path.basename(file)
# if from gradio, will have its own temp uuid too, but that's ok
base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
base_path = os.path.join(dir_name, base_name)
if is_url:
file = file.strip() # in case accidental spaces in front or at end
file_lower = file.lower()
case1 = file_lower.startswith('arxiv:') and len(file_lower.split('arxiv:')) == 2
case2 = file_lower.startswith('https://arxiv.org/abs') and len(file_lower.split('https://arxiv.org/abs')) == 2
case3 = file_lower.startswith('http://arxiv.org/abs') and len(file_lower.split('http://arxiv.org/abs')) == 2
case4 = file_lower.startswith('arxiv.org/abs/') and len(file_lower.split('arxiv.org/abs/')) == 2
if case1 or case2 or case3 or case4:
if case1:
query = file.lower().split('arxiv:')[1].strip()
elif case2:
query = file.lower().split('https://arxiv.org/abs/')[1].strip()
elif case2:
query = file.lower().split('http://arxiv.org/abs/')[1].strip()
elif case3:
query = file.lower().split('arxiv.org/abs/')[1].strip()
else:
raise RuntimeError("Unexpected arxiv error for %s" % file)
if have_arxiv:
trials = 3
docs1 = []
for trial in range(trials):
try:
docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load()
break
except urllib.error.URLError:
pass
if not docs1:
print("Failed to get arxiv %s" % query, flush=True)
# ensure string, sometimes None
[[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1]
query_url = f"https://arxiv.org/abs/{query}"
[x.metadata.update(
dict(source=x.metadata.get('entry_id', query_url), query=query_url,
input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in
docs1]
else:
docs1 = []
else:
if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")):
file = 'http://' + file
docs1 = []
do_unstructured = only_unstructured_urls or use_unstructured
if only_selenium or only_playwright:
do_unstructured = False
do_playwright = have_playwright and (use_playwright or only_playwright)
if only_unstructured_urls or only_selenium:
do_playwright = False
do_selenium = have_selenium and (use_selenium or only_selenium)
if only_unstructured_urls or only_playwright:
do_selenium = False
if do_unstructured or use_unstructured:
docs1a = UnstructuredURLLoader(urls=[file]).load()
docs1a = [x for x in docs1a if x.page_content]
add_parser(docs1a, 'UnstructuredURLLoader')
docs1.extend(docs1a)
if len(docs1) == 0 and have_playwright or do_playwright:
# then something went wrong, try another loader:
from langchain.document_loaders import PlaywrightURLLoader
docs1a = asyncio.run(PlaywrightURLLoader(urls=[file]).aload())
# docs1 = PlaywrightURLLoader(urls=[file]).load()
docs1a = [x for x in docs1a if x.page_content]
add_parser(docs1a, 'PlaywrightURLLoader')
docs1.extend(docs1a)
if len(docs1) == 0 and have_selenium or do_selenium:
# then something went wrong, try another loader:
# but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException:
# Message: unknown error: cannot find Chrome binary
from langchain.document_loaders import SeleniumURLLoader
from selenium.common.exceptions import WebDriverException
try:
docs1a = SeleniumURLLoader(urls=[file]).load()
docs1a = [x for x in docs1a if x.page_content]
add_parser(docs1a, 'SeleniumURLLoader')
docs1.extend(docs1a)
except WebDriverException as e:
print("No web driver: %s" % str(e), flush=True)
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
add_meta(docs1, file, parser="is_url")
docs1 = clean_doc(docs1)
doc1 = chunk_sources(docs1)
elif is_txt:
base_path = "user_paste"
base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
with open(source_file, "wt") as f:
f.write(file)
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
doc1 = Document(page_content=str(file), metadata=metadata)
add_meta(doc1, file, parser="f.write")
# Bit odd to change if was original text
# doc1 = clean_doc(doc1)
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml') or file.lower().endswith('.htm'):
docs1 = UnstructuredHTMLLoader(file_path=file).load()
add_meta(docs1, file, parser='UnstructuredHTMLLoader')
docs1 = clean_doc(docs1)
doc1 = chunk_sources(docs1, language=Language.HTML)
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True):
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
add_meta(docs1, file, parser='UnstructuredWordDocumentLoader')
doc1 = chunk_sources(docs1)
elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True):
docs1 = UnstructuredExcelLoader(file_path=file).load()
add_meta(docs1, file, parser='UnstructuredExcelLoader')
doc1 = chunk_sources(docs1)
elif file.lower().endswith('.odt'):
docs1 = UnstructuredODTLoader(file_path=file).load()
add_meta(docs1, file, parser='UnstructuredODTLoader')
doc1 = chunk_sources(docs1)
elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
add_meta(docs1, file, parser='UnstructuredPowerPointLoader')
docs1 = clean_doc(docs1)
doc1 = chunk_sources(docs1)
elif file.lower().endswith('.txt'):
# use UnstructuredFileLoader ?
docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
# makes just one, but big one
doc1 = chunk_sources(docs1)
# Bit odd to change if was original text
# doc1 = clean_doc(doc1)
add_meta(doc1, file, parser='TextLoader')
elif file.lower().endswith('.rtf'):
docs1 = UnstructuredRTFLoader(file).load()
add_meta(docs1, file, parser='UnstructuredRTFLoader')
doc1 = chunk_sources(docs1)
elif file.lower().endswith('.md'):
docs1 = UnstructuredMarkdownLoader(file).load()
add_meta(docs1, file, parser='UnstructuredMarkdownLoader')
docs1 = clean_doc(docs1)
doc1 = chunk_sources(docs1, language=Language.MARKDOWN)
elif file.lower().endswith('.enex'):
docs1 = EverNoteLoader(file).load()
add_meta(doc1, file, parser='EverNoteLoader')
doc1 = chunk_sources(docs1)
elif file.lower().endswith('.epub'):
docs1 = UnstructuredEPubLoader(file).load()
add_meta(docs1, file, parser='UnstructuredEPubLoader')
doc1 = chunk_sources(docs1)
elif any(file.lower().endswith(x) for x in set_image_types1):
docs1 = []
if verbose:
print("BEGIN: Tesseract", flush=True)
if have_tesseract and enable_ocr:
# OCR, somewhat works, but not great
docs1a = UnstructuredImageLoader(file, strategy='ocr_only').load()
# docs1a = UnstructuredImageLoader(file, strategy='hi_res').load()
docs1a = [x for x in docs1a if x.page_content]
add_meta(docs1a, file, parser='UnstructuredImageLoader')
docs1.extend(docs1a)
if verbose:
print("END: Tesseract", flush=True)
if have_doctr and enable_doctr:
if verbose:
print("BEGIN: DocTR", flush=True)
if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)):
if verbose:
print("Reuse DocTR", flush=True)
model_loaders['doctr'].load_model()
else:
if verbose:
print("Fresh DocTR", flush=True)
from image_doctr import H2OOCRLoader
model_loaders['doctr'] = H2OOCRLoader()
model_loaders['doctr'].set_document_paths([file])
docs1c = model_loaders['doctr'].load()
docs1c = [x for x in docs1c if x.page_content]
add_meta(docs1c, file, parser='H2OOCRLoader: %s' % 'DocTR')
# caption didn't set source, so fix-up meta
for doci in docs1c:
doci.metadata['source'] = doci.metadata.get('document_path', file)
doci.metadata['hashid'] = hash_file(doci.metadata['source'])
docs1.extend(docs1c)
if verbose:
print("END: DocTR", flush=True)
if enable_captions:
# BLIP
if verbose:
print("BEGIN: BLIP", flush=True)
if model_loaders['caption'] is not None and not isinstance(model_loaders['caption'], (str, bool)):
# assumes didn't fork into this process with joblib, else can deadlock
if verbose:
print("Reuse BLIP", flush=True)
model_loaders['caption'].load_model()
else:
if verbose:
print("Fresh BLIP", flush=True)
from image_captions import H2OImageCaptionLoader
model_loaders['caption'] = H2OImageCaptionLoader(caption_gpu=model_loaders['caption'] == 'gpu',
blip_model=captions_model,
blip_processor=captions_model)
model_loaders['caption'].set_image_paths([file])
docs1c = model_loaders['caption'].load()
docs1c = [x for x in docs1c if x.page_content]
add_meta(docs1c, file, parser='H2OImageCaptionLoader: %s' % captions_model)
# caption didn't set source, so fix-up meta
for doci in docs1c:
doci.metadata['source'] = doci.metadata.get('image_path', file)
doci.metadata['hashid'] = hash_file(doci.metadata['source'])
docs1.extend(docs1c)
if verbose:
print("END: BLIP", flush=True)
if enable_pix2struct:
# BLIP
if verbose:
print("BEGIN: Pix2Struct", flush=True)
if model_loaders['pix2struct'] is not None and not isinstance(model_loaders['pix2struct'], (str, bool)):
if verbose:
print("Reuse pix2struct", flush=True)
model_loaders['pix2struct'].load_model()
else:
if verbose:
print("Fresh pix2struct", flush=True)
from image_pix2struct import H2OPix2StructLoader
model_loaders['pix2struct'] = H2OPix2StructLoader()
model_loaders['pix2struct'].set_image_paths([file])
docs1c = model_loaders['pix2struct'].load()
docs1c = [x for x in docs1c if x.page_content]
add_meta(docs1c, file, parser='H2OPix2StructLoader: %s' % model_loaders['pix2struct'])
# caption didn't set source, so fix-up meta
for doci in docs1c:
doci.metadata['source'] = doci.metadata.get('image_path', file)
doci.metadata['hashid'] = hash_file(doci.metadata['source'])
docs1.extend(docs1c)
if verbose:
print("END: Pix2Struct", flush=True)
doc1 = chunk_sources(docs1)
elif file.lower().endswith('.msg'):
raise RuntimeError("Not supported, GPL3 license")
# docs1 = OutlookMessageLoader(file).load()
# docs1[0].metadata['source'] = file
elif file.lower().endswith('.eml'):
try:
docs1 = UnstructuredEmailLoader(file).load()
add_meta(docs1, file, parser='UnstructuredEmailLoader')
doc1 = chunk_sources(docs1)
except ValueError as e:
if 'text/html content not found in email' in str(e):
pass
else:
raise
doc1 = [x for x in doc1 if x.page_content]
if len(doc1) == 0:
# e.g. plain/text dict key exists, but not
# doc1 = TextLoader(file, encoding="utf8").load()
docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
docs1 = [x for x in docs1 if x.page_content]
add_meta(docs1, file, parser='UnstructuredEmailLoader text/plain')
doc1 = chunk_sources(docs1)
# elif file.lower().endswith('.gcsdir'):
# doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
# elif file.lower().endswith('.gcsfile'):
# doc1 = GCSFileLoader(project_name, bucket, blob).load()
elif file.lower().endswith('.rst'):
with open(file, "r") as f:
doc1 = Document(page_content=str(f.read()), metadata={"source": file})
add_meta(doc1, file, parser='f.read()')
doc1 = chunk_sources(doc1, language=Language.RST)
elif file.lower().endswith('.json'):
# 10k rows, 100 columns-like parts 4 bytes each
JSON_SIZE_LIMIT = int(os.getenv('JSON_SIZE_LIMIT', str(10 * 10 * 1024 * 10 * 4)))
if os.path.getsize(file) > JSON_SIZE_LIMIT:
raise ValueError(
"JSON file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % JSON_SIZE_LIMIT)
loader = JSONLoader(
file_path=file,
# jq_schema='.messages[].content',
jq_schema=jq_schema,
text_content=False,
metadata_func=json_metadata_func)
doc1 = loader.load()
add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema)
fix_json_meta(doc1)
elif file.lower().endswith('.jsonl'):
loader = JSONLoader(
file_path=file,
# jq_schema='.messages[].content',
jq_schema=jq_schema,
json_lines=True,
text_content=False,
metadata_func=json_metadata_func)
doc1 = loader.load()
add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema)
fix_json_meta(doc1)
elif file.lower().endswith('.pdf'):
# migration
if isinstance(use_pymupdf, bool):
if use_pymupdf == False:
use_pymupdf = 'off'
if use_pymupdf == True:
use_pymupdf = 'on'
if isinstance(use_unstructured_pdf, bool):
if use_unstructured_pdf == False:
use_unstructured_pdf = 'off'
if use_unstructured_pdf == True:
use_unstructured_pdf = 'on'
if isinstance(use_pypdf, bool):
if use_pypdf == False:
use_pypdf = 'off'
if use_pypdf == True:
use_pypdf = 'on'
if isinstance(enable_pdf_ocr, bool):
if enable_pdf_ocr == False:
enable_pdf_ocr = 'off'
if enable_pdf_ocr == True:
enable_pdf_ocr = 'on'
if isinstance(try_pdf_as_html, bool):
if try_pdf_as_html == False:
try_pdf_as_html = 'off'
if try_pdf_as_html == True:
try_pdf_as_html = 'on'
doc1 = []
tried_others = False
handled = False
did_pymupdf = False
did_unstructured = False
e = None
if have_pymupdf and (len(doc1) == 0 and use_pymupdf == 'auto' or use_pymupdf == 'on'):
# GPL, only use if installed
from langchain.document_loaders import PyMuPDFLoader
# load() still chunks by pages, but every page has title at start to help
try:
doc1a = PyMuPDFLoader(file).load()
did_pymupdf = True
except BaseException as e0:
doc1a = []
print("PyMuPDFLoader: %s" % str(e0), flush=True)
e = e0
# remove empty documents
handled |= len(doc1a) > 0
doc1a = [x for x in doc1a if x.page_content]
doc1a = clean_doc(doc1a)
add_parser(doc1a, 'PyMuPDFLoader')
doc1.extend(doc1a)
if len(doc1) == 0 and use_unstructured_pdf == 'auto' or use_unstructured_pdf == 'on':
tried_others = True
try:
doc1a = UnstructuredPDFLoader(file).load()
did_unstructured = True
except BaseException as e0:
doc1a = []
print("UnstructuredPDFLoader: %s" % str(e0), flush=True)
e = e0
handled |= len(doc1a) > 0
# remove empty documents
doc1a = [x for x in doc1a if x.page_content]
add_parser(doc1a, 'UnstructuredPDFLoader')
# seems to not need cleaning in most cases
doc1.extend(doc1a)
if len(doc1) == 0 and use_pypdf == 'auto' or use_pypdf == 'on':
tried_others = True
# open-source fallback
# load() still chunks by pages, but every page has title at start to help
try:
doc1a = PyPDFLoader(file).load()
except BaseException as e0:
doc1a = []
print("PyPDFLoader: %s" % str(e0), flush=True)
e = e0
handled |= len(doc1a) > 0
# remove empty documents
doc1a = [x for x in doc1a if x.page_content]
doc1a = clean_doc(doc1a)
add_parser(doc1a, 'PyPDFLoader')
doc1.extend(doc1a)
if not did_pymupdf and ((have_pymupdf and len(doc1) == 0) and tried_others):
# try again in case only others used, but only if didn't already try (2nd part of and)
# GPL, only use if installed
from langchain.document_loaders import PyMuPDFLoader
# load() still chunks by pages, but every page has title at start to help
try:
doc1a = PyMuPDFLoader(file).load()
except BaseException as e0:
doc1a = []
print("PyMuPDFLoader: %s" % str(e0), flush=True)
e = e0
handled |= len(doc1a) > 0
# remove empty documents
doc1a = [x for x in doc1a if x.page_content]
doc1a = clean_doc(doc1a)
add_parser(doc1a, 'PyMuPDFLoader2')
doc1.extend(doc1a)
did_pdf_ocr = False
if len(doc1) == 0 and (enable_pdf_ocr == 'auto' and enable_pdf_doctr != 'on') or enable_pdf_ocr == 'on':
did_pdf_ocr = True
# no did_unstructured condition here because here we do OCR, and before we did not
# try OCR in end since slowest, but works on pure image pages well
doc1a = UnstructuredPDFLoader(file, strategy='ocr_only').load()
handled |= len(doc1a) > 0
# remove empty documents
doc1a = [x for x in doc1a if x.page_content]
add_parser(doc1a, 'UnstructuredPDFLoader ocr_only')
# seems to not need cleaning in most cases
doc1.extend(doc1a)
# Some PDFs return nothing or junk from PDFMinerLoader
if len(doc1) == 0 and enable_pdf_doctr == 'auto' or enable_pdf_doctr == 'on':
if verbose:
print("BEGIN: DocTR", flush=True)
if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)):
model_loaders['doctr'].load_model()
else:
from image_doctr import H2OOCRLoader
model_loaders['doctr'] = H2OOCRLoader()
model_loaders['doctr'].set_document_paths([file])
doc1a = model_loaders['doctr'].load()
doc1a = [x for x in doc1a if x.page_content]
add_meta(doc1a, file, parser='H2OOCRLoader: %s' % 'DocTR')
handled |= len(doc1a) > 0
# caption didn't set source, so fix-up meta
for doci in doc1a:
doci.metadata['source'] = doci.metadata.get('document_path', file)
doci.metadata['hashid'] = hash_file(doci.metadata['source'])
doc1.extend(doc1a)
if verbose:
print("END: DocTR", flush=True)
if try_pdf_as_html in ['auto', 'on']:
doc1a = try_as_html(file)
add_parser(doc1a, 'try_as_html')
doc1.extend(doc1a)
if len(doc1) == 0:
# if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all.
if handled:
raise ValueError("%s had no valid text, but meta data was parsed" % file)
else:
raise ValueError("%s had no valid text and no meta data was parsed: %s" % (file, str(e)))
add_meta(doc1, file, parser='pdf')
doc1 = chunk_sources(doc1)
elif file.lower().endswith('.csv'):
CSV_SIZE_LIMIT = int(os.getenv('CSV_SIZE_LIMIT', str(10 * 1024 * 10 * 4)))
if os.path.getsize(file) > CSV_SIZE_LIMIT:
raise ValueError(
"CSV file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % CSV_SIZE_LIMIT)
doc1 = CSVLoader(file).load()
add_meta(doc1, file, parser='CSVLoader')
if isinstance(doc1, list):
# each row is a Document, identify
[x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(doc1)]
if db_type in ['chroma', 'chroma_old']:
# then separate summarize list
sdoc1 = clone_documents(doc1)
[x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sdoc1)]
doc1 = sdoc1 + doc1
elif file.lower().endswith('.py'):
doc1 = PythonLoader(file).load()
add_meta(doc1, file, parser='PythonLoader')
doc1 = chunk_sources(doc1, language=Language.PYTHON)
elif file.lower().endswith('.toml'):
doc1 = TomlLoader(file).load()
add_meta(doc1, file, parser='TomlLoader')
doc1 = chunk_sources(doc1)
elif file.lower().endswith('.xml'):
from langchain.document_loaders import UnstructuredXMLLoader
loader = UnstructuredXMLLoader(file_path=file)
doc1 = loader.load()
add_meta(doc1, file, parser='UnstructuredXMLLoader')
elif file.lower().endswith('.urls'):
with open(file, "r") as f:
urls = f.readlines()
# recurse
doc1 = path_to_docs_func(None, url=urls)
elif file.lower().endswith('.zip'):
with zipfile.ZipFile(file, 'r') as zip_ref:
# don't put into temporary path, since want to keep references to docs inside zip
# so just extract in path where
zip_ref.extractall(base_path)
# recurse
doc1 = path_to_docs_func(base_path)
elif file.lower().endswith('.gz') or file.lower().endswith('.gzip'):
if file.lower().endswith('.gz'):
de_file = file.lower().replace('.gz', '')
else:
de_file = file.lower().replace('.gzip', '')
with gzip.open(file, 'rb') as f_in:
with open(de_file, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
# recurse
doc1 = file_to_doc(de_file,
filei=filei, # single file, same file index as outside caller
base_path=base_path, verbose=verbose, fail_any_exception=fail_any_exception,
chunk=chunk, chunk_size=chunk_size, n_jobs=n_jobs,
is_url=is_url, is_txt=is_txt,
# urls
use_unstructured=use_unstructured,
use_playwright=use_playwright,
use_selenium=use_selenium,
# pdfs
use_pymupdf=use_pymupdf,
use_unstructured_pdf=use_unstructured_pdf,
use_pypdf=use_pypdf,
enable_pdf_ocr=enable_pdf_ocr,
enable_pdf_doctr=enable_pdf_doctr,
try_pdf_as_html=try_pdf_as_html,
# images
enable_ocr=enable_ocr,
enable_doctr=enable_doctr,
enable_pix2struct=enable_pix2struct,
enable_captions=enable_captions,
captions_model=captions_model,
model_loaders=model_loaders,
# json
jq_schema=jq_schema,
headsize=headsize,
db_type=db_type,
selected_file_types=selected_file_types)
else:
raise RuntimeError("No file handler for %s" % os.path.basename(file))
# allow doc1 to be list or not.
if not isinstance(doc1, list):
# If not list, did not chunk yet, so chunk now
docs = chunk_sources([doc1])
elif isinstance(doc1, list) and len(doc1) == 1:
# if list of length one, don't trust and chunk it, chunk_id's will still be correct if repeat
docs = chunk_sources(doc1)
else:
docs = doc1
assert isinstance(docs, list)
return docs
def path_to_doc1(file,
filei=0,
verbose=False, fail_any_exception=False, return_file=True,
chunk=True, chunk_size=512,
n_jobs=-1,
is_url=False, is_txt=False,
# urls
use_unstructured=True,
use_playwright=False,
use_selenium=False,
# pdfs
use_pymupdf='auto',
use_unstructured_pdf='auto',
use_pypdf='auto',
enable_pdf_ocr='auto',
enable_pdf_doctr='auto',
try_pdf_as_html='auto',
# images
enable_ocr=False,
enable_doctr=False,
enable_pix2struct=False,
enable_captions=True,
captions_model=None,
model_loaders=None,
# json
jq_schema='.[]',
db_type=None,
selected_file_types=None):
assert db_type is not None
if verbose:
if is_url:
print("Ingesting URL: %s" % file, flush=True)
elif is_txt:
print("Ingesting Text: %s" % file, flush=True)
else:
print("Ingesting file: %s" % file, flush=True)
res = None
try:
# don't pass base_path=path, would infinitely recurse
res = file_to_doc(file,
filei=filei,
base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
chunk=chunk, chunk_size=chunk_size,
n_jobs=n_jobs,
is_url=is_url, is_txt=is_txt,
# urls
use_unstructured=use_unstructured,
use_playwright=use_playwright,
use_selenium=use_selenium,
# pdfs
use_pymupdf=use_pymupdf,
use_unstructured_pdf=use_unstructured_pdf,
use_pypdf=use_pypdf,
enable_pdf_ocr=enable_pdf_ocr,
enable_pdf_doctr=enable_pdf_doctr,
try_pdf_as_html=try_pdf_as_html,
# images
enable_ocr=enable_ocr,
enable_doctr=enable_doctr,
enable_pix2struct=enable_pix2struct,
enable_captions=enable_captions,
captions_model=captions_model,
model_loaders=model_loaders,
# json
jq_schema=jq_schema,
db_type=db_type,
selected_file_types=selected_file_types)
except BaseException as e:
print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
if fail_any_exception:
raise
else:
exception_doc = Document(
page_content='',
metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)),
"traceback": traceback.format_exc()})
res = [exception_doc]
if verbose:
if is_url:
print("DONE Ingesting URL: %s" % file, flush=True)
elif is_txt:
print("DONE Ingesting Text: %s" % file, flush=True)
else:
print("DONE Ingesting file: %s" % file, flush=True)
if return_file:
base_tmp = "temp_path_to_doc1"
if not os.path.isdir(base_tmp):
base_tmp = makedirs(base_tmp, exist_ok=True, tmp_ok=True, use_base=True)
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
with open(filename, 'wb') as f:
pickle.dump(res, f)
return filename
return res
def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1,
chunk=True, chunk_size=512,
url=None, text=None,
# urls
use_unstructured=True,
use_playwright=False,
use_selenium=False,
# pdfs
use_pymupdf='auto',
use_unstructured_pdf='auto',
use_pypdf='auto',
enable_pdf_ocr='auto',
enable_pdf_doctr='auto',
try_pdf_as_html='auto',
# images
enable_ocr=False,
enable_doctr=False,
enable_pix2struct=False,
enable_captions=True,
captions_model=None,
caption_loader=None,
doctr_loader=None,
pix2struct_loader=None,
# json
jq_schema='.[]',
existing_files=[],
existing_hash_ids={},
db_type=None,
selected_file_types=None,
):
if verbose:
print("BEGIN Consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True)
if selected_file_types is not None:
non_image_types1 = [x for x in non_image_types if x in selected_file_types]
image_types1 = [x for x in image_types if x in selected_file_types]
else:
non_image_types1 = non_image_types.copy()
image_types1 = image_types.copy()
assert db_type is not None
# path_or_paths could be str, list, tuple, generator
globs_image_types = []
globs_non_image_types = []
if not path_or_paths and not url and not text:
return []
elif url:
url = get_list_or_str(url)
globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url]
elif text:
globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text]
elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths):
# single path, only consume allowed files
path = path_or_paths
# Below globs should match patterns in file_to_doc()
[globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
for ftype in image_types1]
globs_image_types = [os.path.normpath(x) for x in globs_image_types]
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
for ftype in non_image_types1]
globs_non_image_types = [os.path.normpath(x) for x in globs_non_image_types]
else:
if isinstance(path_or_paths, str):
if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths):
path_or_paths = [path_or_paths]
else:
# path was deleted etc.
return []
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \
"Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths))
# reform out of allowed types
globs_image_types.extend(
flatten_list([[os.path.normpath(x) for x in path_or_paths if x.endswith(y)] for y in image_types1]))
# could do below:
# globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types1])
# But instead, allow fail so can collect unsupported too
set_globs_image_types = set(globs_image_types)
globs_non_image_types.extend([os.path.normpath(x) for x in path_or_paths if x not in set_globs_image_types])
# filter out any files to skip (e.g. if already processed them)
# this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[]
assert not existing_files, "DEV: assume not using this approach"
if existing_files:
set_skip_files = set(existing_files)
globs_image_types = [x for x in globs_image_types if x not in set_skip_files]
globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files]
if existing_hash_ids:
# assume consistent with add_meta() use of hash_file(file)
# also assume consistent with get_existing_hash_ids for dict creation
# assume hashable values
existing_hash_ids_set = set(existing_hash_ids.items())
hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items())
hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items())
# don't use symmetric diff. If file is gone, ignore and don't remove or something
# just consider existing files (key) having new hash or not (value)
new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys())
new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys())
globs_image_types = [x for x in globs_image_types if x in new_files_image]
globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image]
# could use generator, but messes up metadata handling in recursive case
if caption_loader and not isinstance(caption_loader, (bool, str)) and caption_loader.device != 'cpu' or \
get_device() == 'cuda':
# to avoid deadlocks, presume was preloaded and so can't fork due to cuda context
# get_device() == 'cuda' because presume faster to process image from (temporarily) preloaded model
n_jobs_image = 1
else:
n_jobs_image = n_jobs
if enable_doctr or enable_pdf_doctr in [True, 'auto', 'on']:
if doctr_loader and not isinstance(doctr_loader, (bool, str)) and doctr_loader.device != 'cpu':
# can't fork cuda context
n_jobs = 1
return_file = True # local choice
is_url = url is not None
is_txt = text is not None
model_loaders = dict(caption=caption_loader,
doctr=doctr_loader,
pix2struct=pix2struct_loader)
model_loaders0 = model_loaders.copy()
kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
return_file=return_file,
chunk=chunk, chunk_size=chunk_size,
n_jobs=n_jobs,
is_url=is_url,
is_txt=is_txt,
# urls
use_unstructured=use_unstructured,
use_playwright=use_playwright,
use_selenium=use_selenium,
# pdfs
use_pymupdf=use_pymupdf,
use_unstructured_pdf=use_unstructured_pdf,
use_pypdf=use_pypdf,
enable_pdf_ocr=enable_pdf_ocr,
enable_pdf_doctr=enable_pdf_doctr,
try_pdf_as_html=try_pdf_as_html,
# images
enable_ocr=enable_ocr,
enable_doctr=enable_doctr,
enable_pix2struct=enable_pix2struct,
enable_captions=enable_captions,
captions_model=captions_model,
model_loaders=model_loaders,
# json
jq_schema=jq_schema,
db_type=db_type,
selected_file_types=selected_file_types,
)
if n_jobs != 1 and len(globs_non_image_types) > 1:
# avoid nesting, e.g. upload 1 zip and then inside many files
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_non_image_types)
)
else:
documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in
enumerate(tqdm(globs_non_image_types))]
# do images separately since can't fork after cuda in parent, so can't be parallel
if n_jobs_image != 1 and len(globs_image_types) > 1:
# avoid nesting, e.g. upload 1 zip and then inside many files
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_image_types)
)
else:
image_documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in
enumerate(tqdm(globs_image_types))]
# unload loaders (image loaders, includes enable_pdf_doctr that uses same loader)
for name, loader in model_loaders.items():
loader0 = model_loaders0[name]
real_model_initial = loader0 is not None and not isinstance(loader0, (str, bool))
real_model_final = model_loaders[name] is not None and not isinstance(model_loaders[name], (str, bool))
if not real_model_initial and real_model_final:
# clear off GPU newly added model
model_loaders[name].unload_model()
# add image docs in
documents += image_documents
if return_file:
# then documents really are files
files = documents.copy()
documents = []
for fil in files:
with open(fil, 'rb') as f:
documents.extend(pickle.load(f))
# remove temp pickle
remove(fil)
else:
documents = reduce(concat, documents)
if verbose:
print("END consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True)
return documents
def prep_langchain(persist_directory,
load_db_if_exists,
db_type, use_openai_embedding,
langchain_mode, langchain_mode_paths, langchain_mode_types,
hf_embedding_model,
migrate_embedding_model,
auto_migrate_db,
n_jobs=-1, kwargs_make_db={},
verbose=False):
"""
do prep first time, involving downloads
# FIXME: Add github caching then add here
:return:
"""
if os.getenv("HARD_ASSERTS"):
assert langchain_mode not in ['MyData'], "Should not prep scratch/personal data"
if langchain_mode in langchain_modes_intrinsic:
return None
db_dir_exists = os.path.isdir(persist_directory)
user_path = langchain_mode_paths.get(langchain_mode)
if db_dir_exists and user_path is None:
if verbose:
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
db, use_openai_embedding, hf_embedding_model = \
get_existing_db(None, persist_directory, load_db_if_exists,
db_type, use_openai_embedding,
langchain_mode, langchain_mode_paths, langchain_mode_types,
hf_embedding_model, migrate_embedding_model, auto_migrate_db,
n_jobs=n_jobs)
else:
if db_dir_exists and user_path is not None:
if verbose:
print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
persist_directory, user_path), flush=True)
elif not db_dir_exists:
if verbose:
print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
db = None
if langchain_mode in ['DriverlessAI docs']:
# FIXME: Could also just use dai_docs.pickle directly and upload that
get_dai_docs(from_hf=True)
if langchain_mode in ['wiki']:
get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit'])
langchain_kwargs = kwargs_make_db.copy()
langchain_kwargs.update(locals())
db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs)
return db
import posthog
posthog.disabled = True
class FakeConsumer(object):
def __init__(self, *args, **kwargs):
pass
def run(self):
pass
def pause(self):
pass
def upload(self):
pass
def next(self):
pass
def request(self, batch):
pass
posthog.Consumer = FakeConsumer
def check_update_chroma_embedding(db,
db_type,
use_openai_embedding,
hf_embedding_model, migrate_embedding_model, auto_migrate_db,
langchain_mode, langchain_mode_paths, langchain_mode_types,
n_jobs=-1):
changed_db = False
embed_tuple = load_embed(db=db)
if embed_tuple not in [(True, use_openai_embedding, hf_embedding_model),
(False, use_openai_embedding, hf_embedding_model)]:
print("Detected new embedding %s vs. %s %s, updating db: %s" % (
use_openai_embedding, hf_embedding_model, embed_tuple, langchain_mode), flush=True)
# handle embedding changes
db_get = get_documents(db)
sources = [Document(page_content=result[0], metadata=result[1] or {})
for result in zip(db_get['documents'], db_get['metadatas'])]
# delete index, has to be redone
persist_directory = db._persist_directory
shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak")
assert db_type in ['chroma', 'chroma_old']
load_db_if_exists = False
db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
persist_directory=persist_directory, load_db_if_exists=load_db_if_exists,
langchain_mode=langchain_mode,
langchain_mode_paths=langchain_mode_paths,
langchain_mode_types=langchain_mode_types,
collection_name=None,
hf_embedding_model=hf_embedding_model,
migrate_embedding_model=migrate_embedding_model,
auto_migrate_db=auto_migrate_db,
n_jobs=n_jobs,
)
changed_db = True
print("Done updating db for new embedding: %s" % langchain_mode, flush=True)
return db, changed_db
def migrate_meta_func(db, langchain_mode):
changed_db = False
db_get = get_documents(db)
# just check one doc
if len(db_get['metadatas']) > 0 and 'chunk_id' not in db_get['metadatas'][0]:
print("Detected old metadata, adding additional information", flush=True)
t0 = time.time()
# handle meta changes
[x.update(dict(chunk_id=x.get('chunk_id', 0))) for x in db_get['metadatas']]
client_collection = db._client.get_collection(name=db._collection.name,
embedding_function=db._collection._embedding_function)
client_collection.update(ids=db_get['ids'], metadatas=db_get['metadatas'])
# check
db_get = get_documents(db)
assert 'chunk_id' in db_get['metadatas'][0], "Failed to add meta"
changed_db = True
print("Done updating db for new meta: %s in %s seconds" % (langchain_mode, time.time() - t0), flush=True)
return db, changed_db
def get_existing_db(db, persist_directory,
load_db_if_exists, db_type, use_openai_embedding,
langchain_mode, langchain_mode_paths, langchain_mode_types,
hf_embedding_model,
migrate_embedding_model,
auto_migrate_db=False,
verbose=False, check_embedding=True, migrate_meta=True,
n_jobs=-1):
if load_db_if_exists and db_type in ['chroma', 'chroma_old'] and os.path.isdir(persist_directory):
if os.path.isfile(os.path.join(persist_directory, 'chroma.sqlite3')):
must_migrate = False
elif os.path.isdir(os.path.join(persist_directory, 'index')):
must_migrate = True
else:
return db, use_openai_embedding, hf_embedding_model
chroma_settings = dict(is_persistent=True)
use_chromamigdb = False
if must_migrate:
if auto_migrate_db:
print("Detected chromadb<0.4 database, require migration, doing now....", flush=True)
from chroma_migrate.import_duckdb import migrate_from_duckdb
import chromadb
api = chromadb.PersistentClient(path=persist_directory)
did_migration = migrate_from_duckdb(api, persist_directory)
assert did_migration, "Failed to migrate chroma collection at %s, see https://docs.trychroma.com/migration for CLI tool" % persist_directory
elif have_chromamigdb:
print(
"Detected chroma<0.4 database but --auto_migrate_db=False, but detected chromamigdb package, so using old database that still requires duckdb",
flush=True)
chroma_settings = dict(chroma_db_impl="duckdb+parquet")
use_chromamigdb = True
else:
raise ValueError(
"Detected chromadb<0.4 database, require migration, but did not detect chromamigdb package or did not choose auto_migrate_db=False (see FAQ.md)")
if db is None:
if verbose:
print("DO Loading db: %s" % langchain_mode, flush=True)
got_embedding, use_openai_embedding0, hf_embedding_model0 = load_embed(persist_directory=persist_directory)
if got_embedding:
use_openai_embedding, hf_embedding_model = use_openai_embedding0, hf_embedding_model0
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
import logging
logging.getLogger("chromadb").setLevel(logging.ERROR)
if use_chromamigdb:
from chromamigdb.config import Settings
chroma_class = ChromaMig
else:
from chromadb.config import Settings
chroma_class = Chroma
client_settings = Settings(anonymized_telemetry=False,
**chroma_settings,
persist_directory=persist_directory)
db = chroma_class(persist_directory=persist_directory, embedding_function=embedding,
collection_name=langchain_mode.replace(' ', '_'),
client_settings=client_settings)
try:
db.similarity_search('')
except BaseException as e:
# migration when no embed_info
if 'Dimensionality of (768) does not match index dimensionality (384)' in str(e) or \
'Embedding dimension 768 does not match collection dimensionality 384' in str(e):
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
db = chroma_class(persist_directory=persist_directory, embedding_function=embedding,
collection_name=langchain_mode.replace(' ', '_'),
client_settings=client_settings)
# should work now, let fail if not
db.similarity_search('')
save_embed(db, use_openai_embedding, hf_embedding_model)
else:
raise
if verbose:
print("DONE Loading db: %s" % langchain_mode, flush=True)
else:
if not migrate_embedding_model:
# OVERRIDE embedding choices if could load embedding info when not migrating
got_embedding, use_openai_embedding, hf_embedding_model = load_embed(db=db)
if verbose:
print("USING already-loaded db: %s" % langchain_mode, flush=True)
if check_embedding:
db_trial, changed_db = check_update_chroma_embedding(db,
db_type,
use_openai_embedding,
hf_embedding_model,
migrate_embedding_model,
auto_migrate_db,
langchain_mode,
langchain_mode_paths,
langchain_mode_types,
n_jobs=n_jobs)
if changed_db:
db = db_trial
# only call persist if really changed db, else takes too long for large db
if db is not None:
db.persist()
clear_embedding(db)
save_embed(db, use_openai_embedding, hf_embedding_model)
if migrate_meta and db is not None:
db_trial, changed_db = migrate_meta_func(db, langchain_mode)
if changed_db:
db = db_trial
return db, use_openai_embedding, hf_embedding_model
return db, use_openai_embedding, hf_embedding_model
def clear_embedding(db):
if db is None:
return
# don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
try:
if hasattr(db._embedding_function, 'client') and hasattr(db._embedding_function.client, 'cpu'):
# only push back to CPU if each db/user has own embedding model, else if shared share on GPU
if hasattr(db._embedding_function.client, 'preload') and not db._embedding_function.client.preload:
db._embedding_function.client.cpu()
clear_torch_cache()
except RuntimeError as e:
print("clear_embedding error: %s" % ''.join(traceback.format_tb(e.__traceback__)), flush=True)
def make_db(**langchain_kwargs):
func_names = list(inspect.signature(_make_db).parameters)
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()}
for k in missing_kwargs:
if k in defaults_db:
langchain_kwargs[k] = defaults_db[k]
# final check for missing
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs
# only keep actual used
langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
return _make_db(**langchain_kwargs)
embed_lock_name = 'embed.lock'
def get_embed_lock_file(db, persist_directory=None):
if hasattr(db, '_persist_directory') or persist_directory:
if persist_directory is None:
persist_directory = db._persist_directory
check_persist_directory(persist_directory)
base_path = os.path.join('locks', persist_directory)
base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
lock_file = os.path.join(base_path, embed_lock_name)
makedirs(os.path.dirname(lock_file))
return lock_file
return None
def save_embed(db, use_openai_embedding, hf_embedding_model):
if hasattr(db, '_persist_directory'):
persist_directory = db._persist_directory
lock_file = get_embed_lock_file(db)
with filelock.FileLock(lock_file):
embed_info_file = os.path.join(persist_directory, 'embed_info')
with open(embed_info_file, 'wb') as f:
if isinstance(hf_embedding_model, str):
hf_embedding_model_save = hf_embedding_model
elif hasattr(hf_embedding_model, 'model_name'):
hf_embedding_model_save = hf_embedding_model.model_name
elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model:
hf_embedding_model_save = hf_embedding_model['name']
elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model:
if os.getenv('HARD_ASSERTS'):
# unexpected in testing or normally
raise RuntimeError("HERE")
hf_embedding_model_save = 'hkunlp/instructor-large'
pickle.dump((use_openai_embedding, hf_embedding_model_save), f)
return use_openai_embedding, hf_embedding_model
def load_embed(db=None, persist_directory=None):
if hasattr(db, 'embeddings') and hasattr(db.embeddings, 'model_name'):
hf_embedding_model = db.embeddings.model_name if 'openai' not in db.embeddings.model_name.lower() else None
use_openai_embedding = hf_embedding_model is None
save_embed(db, use_openai_embedding, hf_embedding_model)
return True, use_openai_embedding, hf_embedding_model
if persist_directory is None:
persist_directory = db._persist_directory
embed_info_file = os.path.join(persist_directory, 'embed_info')
if os.path.isfile(embed_info_file):
lock_file = get_embed_lock_file(db, persist_directory=persist_directory)
with filelock.FileLock(lock_file):
with open(embed_info_file, 'rb') as f:
try:
use_openai_embedding, hf_embedding_model = pickle.load(f)
if not isinstance(hf_embedding_model, str):
# work-around bug introduced here: https://github.com/h2oai/h2ogpt/commit/54c4414f1ce3b5b7c938def651c0f6af081c66de
hf_embedding_model = 'hkunlp/instructor-large'
# fix file
save_embed(db, use_openai_embedding, hf_embedding_model)
got_embedding = True
except EOFError:
use_openai_embedding, hf_embedding_model = False, 'hkunlp/instructor-large'
got_embedding = False
if os.getenv('HARD_ASSERTS'):
# unexpected in testing or normally
raise
else:
# migration, assume defaults
use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2"
got_embedding = False
assert isinstance(hf_embedding_model, str)
return got_embedding, use_openai_embedding, hf_embedding_model
def get_persist_directory(langchain_mode, langchain_type=None, db1s=None, dbs=None):
if langchain_mode in [LangChainMode.DISABLED.value, LangChainMode.LLM.value]:
# not None so join works but will fail to find db
return '', langchain_type
userid = get_userid_direct(db1s)
username = get_username_direct(db1s)
# sanity for bad code
assert userid != 'None'
assert username != 'None'
dirid = username or userid
if langchain_type == LangChainTypes.SHARED.value and not dirid:
dirid = './' # just to avoid error
if langchain_type == LangChainTypes.PERSONAL.value and not dirid:
# e.g. from client when doing transient calls with MyData
if db1s is None:
# just trick to get filled locally
db1s = {LangChainMode.MY_DATA.value: [None, None, None]}
set_userid_direct(db1s, str(uuid.uuid4()), str(uuid.uuid4()))
userid = get_userid_direct(db1s)
username = get_username_direct(db1s)
dirid = username or userid
langchain_type = LangChainTypes.PERSONAL.value
# deal with existing locations
user_base_dir = os.getenv('USERS_BASE_DIR', 'users')
persist_directory = os.path.join(user_base_dir, dirid, 'db_dir_%s' % langchain_mode)
if userid and \
(os.path.isdir(persist_directory) or
db1s is not None and langchain_mode in db1s or
langchain_type == LangChainTypes.PERSONAL.value):
langchain_type = LangChainTypes.PERSONAL.value
persist_directory = makedirs(persist_directory, use_base=True)
check_persist_directory(persist_directory)
return persist_directory, langchain_type
persist_directory = 'db_dir_%s' % langchain_mode
if (os.path.isdir(persist_directory) or
dbs is not None and langchain_mode in dbs or
langchain_type == LangChainTypes.SHARED.value):
# ensure consistent
langchain_type = LangChainTypes.SHARED.value
persist_directory = makedirs(persist_directory, use_base=True)
check_persist_directory(persist_directory)
return persist_directory, langchain_type
# dummy return for prep_langchain() or full personal space
base_others = 'db_nonusers'
persist_directory = os.path.join(base_others, 'db_dir_%s' % str(uuid.uuid4()))
persist_directory = makedirs(persist_directory, use_base=True)
langchain_type = LangChainTypes.PERSONAL.value
check_persist_directory(persist_directory)
return persist_directory, langchain_type
def check_persist_directory(persist_directory):
# deal with some cases when see intrinsic names being used as shared
for langchain_mode in langchain_modes_intrinsic:
if persist_directory == 'db_dir_%s' % langchain_mode:
raise RuntimeError("Illegal access to %s" % persist_directory)
def _make_db(use_openai_embedding=False,
hf_embedding_model=None,
migrate_embedding_model=False,
auto_migrate_db=False,
first_para=False, text_limit=None,
chunk=True, chunk_size=512,
# urls
use_unstructured=True,
use_playwright=False,
use_selenium=False,
# pdfs
use_pymupdf='auto',
use_unstructured_pdf='auto',
use_pypdf='auto',
enable_pdf_ocr='auto',
enable_pdf_doctr='auto',
try_pdf_as_html='auto',
# images
enable_ocr=False,
enable_doctr=False,
enable_pix2struct=False,
enable_captions=True,
captions_model=None,
caption_loader=None,
doctr_loader=None,
pix2struct_loader=None,
# json
jq_schema='.[]',
langchain_mode=None,
langchain_mode_paths=None,
langchain_mode_types=None,
db_type='faiss',
load_db_if_exists=True,
db=None,
n_jobs=-1,
verbose=False):
assert hf_embedding_model is not None
user_path = langchain_mode_paths.get(langchain_mode)
langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value)
persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type)
langchain_mode_types[langchain_mode] = langchain_type
# see if can get persistent chroma db
db_trial, use_openai_embedding, hf_embedding_model = \
get_existing_db(db, persist_directory, load_db_if_exists, db_type,
use_openai_embedding,
langchain_mode, langchain_mode_paths, langchain_mode_types,
hf_embedding_model, migrate_embedding_model, auto_migrate_db, verbose=verbose,
n_jobs=n_jobs)
if db_trial is not None:
db = db_trial
sources = []
if not db:
chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type)
if langchain_mode in ['wiki_full']:
from read_wiki_full import get_all_documents
small_test = None
print("Generating new wiki", flush=True)
sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
print("Got new wiki", flush=True)
sources1 = chunk_sources(sources1, chunk=chunk)
print("Chunked new wiki", flush=True)
sources.extend(sources1)
elif langchain_mode in ['wiki']:
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
sources1 = chunk_sources(sources1, chunk=chunk)
sources.extend(sources1)
elif langchain_mode in ['github h2oGPT']:
# sources = get_github_docs("dagster-io", "dagster")
sources1 = get_github_docs("h2oai", "h2ogpt")
# FIXME: always chunk for now
sources1 = chunk_sources(sources1)
sources.extend(sources1)
elif langchain_mode in ['DriverlessAI docs']:
sources1 = get_dai_docs(from_hf=True)
# FIXME: DAI docs are already chunked well, should only chunk more if over limit
sources1 = chunk_sources(sources1, chunk=False)
sources.extend(sources1)
if user_path:
# UserData or custom, which has to be from user's disk
if db is not None:
# NOTE: Ignore file names for now, only go by hash ids
# existing_files = get_existing_files(db)
existing_files = []
existing_hash_ids = get_existing_hash_ids(db)
else:
# pretend no existing files so won't filter
existing_files = []
existing_hash_ids = []
# chunk internally for speed over multiple docs
# FIXME: If first had old Hash=None and switch embeddings,
# then re-embed, and then hit here and reload so have hash, and then re-embed.
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
# urls
use_unstructured=use_unstructured,
use_playwright=use_playwright,
use_selenium=use_selenium,
# pdfs
use_pymupdf=use_pymupdf,
use_unstructured_pdf=use_unstructured_pdf,
use_pypdf=use_pypdf,
enable_pdf_ocr=enable_pdf_ocr,
enable_pdf_doctr=enable_pdf_doctr,
try_pdf_as_html=try_pdf_as_html,
# images
enable_ocr=enable_ocr,
enable_doctr=enable_doctr,
enable_pix2struct=enable_pix2struct,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
doctr_loader=doctr_loader,
pix2struct_loader=pix2struct_loader,
# json
jq_schema=jq_schema,
existing_files=existing_files, existing_hash_ids=existing_hash_ids,
db_type=db_type)
new_metadata_sources = set([x.metadata['source'] for x in sources1])
if new_metadata_sources:
if os.getenv('NO_NEW_FILES') is not None:
raise RuntimeError("Expected no new files! %s" % new_metadata_sources)
print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode),
flush=True)
if verbose:
print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
sources.extend(sources1)
if len(sources) > 0 and os.getenv('NO_NEW_FILES') is not None:
raise RuntimeError("Expected no new files! %s" % langchain_mode)
if len(sources) == 0 and os.getenv('SHOULD_NEW_FILES') is not None:
raise RuntimeError("Expected new files! %s" % langchain_mode)
print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True)
# see if got sources
if not sources:
if verbose:
if db is not None:
print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True)
else:
print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True)
return db, 0, []
if verbose:
if db is not None:
print("Generating db", flush=True)
else:
print("Adding to db", flush=True)
if not db:
if sources:
db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
persist_directory=persist_directory,
langchain_mode=langchain_mode,
langchain_mode_paths=langchain_mode_paths,
langchain_mode_types=langchain_mode_types,
hf_embedding_model=hf_embedding_model,
migrate_embedding_model=migrate_embedding_model,
auto_migrate_db=auto_migrate_db,
n_jobs=n_jobs)
if verbose:
print("Generated db", flush=True)
elif langchain_mode not in langchain_modes_intrinsic:
print("Did not generate db for %s since no sources" % langchain_mode, flush=True)
new_sources_metadata = [x.metadata for x in sources]
elif user_path is not None:
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
use_openai_embedding=use_openai_embedding,
hf_embedding_model=hf_embedding_model)
print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
else:
new_sources_metadata = [x.metadata for x in sources]
return db, len(new_sources_metadata), new_sources_metadata
def get_metadatas(db):
metadatas = []
from langchain.vectorstores import FAISS
if isinstance(db, FAISS):
metadatas = [v.metadata for k, v in db.docstore._dict.items()]
elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db):
metadatas = get_documents(db)['metadatas']
elif db is not None:
# FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
# seems no way to get all metadata, so need to avoid this approach for weaviate
metadatas = [x.metadata for x in db.similarity_search("", k=10000)]
return metadatas
def get_db_lock_file(db, lock_type='getdb'):
if hasattr(db, '_persist_directory'):
persist_directory = db._persist_directory
check_persist_directory(persist_directory)
base_path = os.path.join('locks', persist_directory)
base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
lock_file = os.path.join(base_path, "%s.lock" % lock_type)
makedirs(os.path.dirname(lock_file)) # ensure made
return lock_file
return None
def get_documents(db):
if hasattr(db, '_persist_directory'):
lock_file = get_db_lock_file(db)
with filelock.FileLock(lock_file):
# get segfaults and other errors when multiple threads access this
return _get_documents(db)
else:
return _get_documents(db)
def _get_documents(db):
from langchain.vectorstores import FAISS
if isinstance(db, FAISS):
documents = [v for k, v in db.docstore._dict.items()]
documents = dict(documents=documents)
elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db):
documents = db.get()
else:
# FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
# seems no way to get all metadata, so need to avoid this approach for weaviate
documents = [x for x in db.similarity_search("", k=10000)]
documents = dict(documents=documents)
return documents
def get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None):
if hasattr(db, '_persist_directory'):
lock_file = get_db_lock_file(db)
with filelock.FileLock(lock_file):
return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list)
else:
return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list)
def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None):
db_documents = []
db_metadatas = []
if text_context_list:
db_documents += [x.page_content if hasattr(x, 'page_content') else x for x in text_context_list]
db_metadatas += [x.metadata if hasattr(x, 'metadata') else {} for x in text_context_list]
from langchain.vectorstores import FAISS
if isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db):
db_get = db._collection.get(where=filter_kwargs.get('filter'))
db_metadatas += db_get['metadatas']
db_documents += db_get['documents']
elif isinstance(db, FAISS):
import itertools
db_metadatas += get_metadatas(db)
# FIXME: FAISS has no filter
if top_k_docs == -1:
db_documents += list(db.docstore._dict.values())
else:
# slice dict first
db_documents += list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values())
elif db is not None:
db_metadatas += get_metadatas(db)
db_documents += get_documents(db)['documents']
return db_documents, db_metadatas
def get_existing_files(db):
metadatas = get_metadatas(db)
metadata_sources = set([x['source'] for x in metadatas])
return metadata_sources
def get_existing_hash_ids(db):
metadatas = get_metadatas(db)
# assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
metadata_hash_ids = {os.path.normpath(x['source']): x.get('hashid') for x in metadatas}
return metadata_hash_ids
def run_qa_db(**kwargs):
func_names = list(inspect.signature(_run_qa_db).parameters)
# hard-coded defaults
kwargs['answer_with_sources'] = kwargs.get('answer_with_sources', True)
kwargs['show_rank'] = kwargs.get('show_rank', False)
kwargs['show_accordions'] = kwargs.get('show_accordions', True)
kwargs['show_link_in_sources'] = kwargs.get('show_link_in_sources', True)
kwargs['top_k_docs_max_show'] = kwargs.get('top_k_docs_max_show', 10)
kwargs['llamacpp_dict'] = {} # shouldn't be required unless from test using _run_qa_db
missing_kwargs = [x for x in func_names if x not in kwargs]
assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs
# only keep actual used
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
try:
return _run_qa_db(**kwargs)
finally:
clear_torch_cache()
def _run_qa_db(query=None,
iinput=None,
context=None,
use_openai_model=False, use_openai_embedding=False,
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
# urls
use_unstructured=True,
use_playwright=False,
use_selenium=False,
# pdfs
use_pymupdf='auto',
use_unstructured_pdf='auto',
use_pypdf='auto',
enable_pdf_ocr='auto',
enable_pdf_doctr='auto',
try_pdf_as_html='auto',
# images
enable_ocr=False,
enable_doctr=False,
enable_pix2struct=False,
enable_captions=True,
captions_model=None,
caption_loader=None,
doctr_loader=None,
pix2struct_loader=None,
# json
jq_schema='.[]',
langchain_mode_paths={},
langchain_mode_types={},
detect_user_path_changes_every_query=False,
db_type=None,
model_name=None, model=None, tokenizer=None, inference_server=None,
langchain_only_model=False,
hf_embedding_model=None,
migrate_embedding_model=False,
auto_migrate_db=False,
stream_output=False,
async_output=True,
num_async=3,
prompter=None,
prompt_type=None,
prompt_dict=None,
answer_with_sources=True,
append_sources_to_answer=True,
cut_distance=1.64,
add_chat_history_to_context=True,
add_search_to_context=False,
keep_sources_in_context=False,
memory_restriction_level=0,
system_prompt='',
sanitize_bot_response=False,
show_rank=False,
show_accordions=True,
show_link_in_sources=True,
top_k_docs_max_show=10,
use_llm_if_no_docs=True,
load_db_if_exists=False,
db=None,
do_sample=False,
temperature=0.1,
top_k=40,
top_p=0.7,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
early_stopping=False,
max_time=180,
repetition_penalty=1.0,
num_return_sequences=1,
langchain_mode=None,
langchain_action=None,
langchain_agents=None,
document_subset=DocumentSubset.Relevant.name,
document_choice=[DocumentChoice.ALL.value],
pre_prompt_query=None,
prompt_query=None,
pre_prompt_summary=None,
prompt_summary=None,
text_context_list=None,
chat_conversation=None,
visible_models=None,
h2ogpt_key=None,
docs_ordering_type='reverse_ucurve_sort',
min_max_new_tokens=256,
n_jobs=-1,
llamacpp_dict=None,
verbose=False,
cli=False,
lora_weights='',
auto_reduce_chunks=True,
max_chunks=100,
total_tokens_for_docs=None,
headsize=50,
):
"""
:param query:
:param use_openai_model:
:param use_openai_embedding:
:param first_para:
:param text_limit:
:param top_k_docs:
:param chunk:
:param chunk_size:
:param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from
:param db_type: 'faiss' for in-memory
'chroma' (for chroma >= 0.4)
'chroma_old' (for chroma < 0.4)
'weaviate' for persisted on disk
:param model_name: model name, used to switch behaviors
:param model: pre-initialized model, else will make new one
:param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
:param answer_with_sources
:return:
"""
t_run = time.time()
if stream_output:
# threads and asyncio don't mix
async_output = False
if langchain_action in [LangChainAction.QUERY.value]:
# only summarization supported
async_output = False
# in case None, e.g. lazy client, then set based upon actual model
pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary = \
get_langchain_prompts(pre_prompt_query, prompt_query,
pre_prompt_summary, prompt_summary,
model_name, inference_server,
llamacpp_dict.get('model_path_llama'))
assert db_type is not None
assert hf_embedding_model is not None
assert langchain_mode_paths is not None
assert langchain_mode_types is not None
if model is not None:
assert model_name is not None # require so can make decisions
assert query is not None
assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
if prompter is not None:
prompt_type = prompter.prompt_type
prompt_dict = prompter.prompt_dict
if model is not None:
assert prompt_type is not None
if prompt_type == PromptType.custom.name:
assert prompt_dict is not None # should at least be {} or ''
else:
prompt_dict = ''
if LangChainAgent.SEARCH.value in langchain_agents and 'llama' in model_name.lower():
system_prompt = """You are a zero shot react agent.
Consider to prompt of Question that was original query from the user.
Respond to prompt of Thought with a thought that may lead to a reasonable new action choice.
Respond to prompt of Action with an action to take out of the tools given, giving exactly single word for the tool name.
Respond to prompt of Action Input with an input to give the tool.
Consider to prompt of Observation that was response from the tool.
Repeat this Thought, Action, Action Input, Observation, Thought sequence several times with new and different thoughts and actions each time, do not repeat.
Once satisfied that the thoughts, responses are sufficient to answer the question, then respond to prompt of Thought with: I now know the final answer
Respond to prompt of Final Answer with your final high-quality bullet list answer to the original query.
"""
prompter.system_prompt = system_prompt
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
# pass in context to LLM directly, since already has prompt_type structure
# can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638
llm, model_name, streamer, prompt_type_out, async_output, only_new_text = \
get_llm(use_openai_model=use_openai_model, model_name=model_name,
model=model,
tokenizer=tokenizer,
inference_server=inference_server,
langchain_only_model=langchain_only_model,
stream_output=stream_output,
async_output=async_output,
num_async=num_async,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
prompter=prompter,
context=context,
iinput=iinput,
sanitize_bot_response=sanitize_bot_response,
system_prompt=system_prompt,
visible_models=visible_models,
h2ogpt_key=h2ogpt_key,
min_max_new_tokens=min_max_new_tokens,
n_jobs=n_jobs,
llamacpp_dict=llamacpp_dict,
cli=cli,
verbose=verbose,
)
# in case change, override original prompter
if hasattr(llm, 'prompter'):
prompter = llm.prompter
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'prompter'):
prompter = llm.pipeline.prompter
if prompter is None:
if prompt_type is None:
prompt_type = prompt_type_out
# get prompter
chat = True # FIXME?
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
system_prompt=system_prompt)
use_docs_planned = False
scores = []
chain = None
# basic version of prompt without docs etc.
data_point = dict(context=context, instruction=query, input=iinput)
prompt_basic = prompter.generate_prompt(data_point)
if isinstance(document_choice, str):
# support string as well
document_choice = [document_choice]
func_names = list(inspect.signature(get_chain).parameters)
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
assert not missing_kwargs, "Missing: %s" % missing_kwargs
docs, chain, scores, \
use_docs_planned, num_docs_before_cut, \
use_llm_if_no_docs, llm_mode, top_k_docs_max_show = \
get_chain(**sim_kwargs)
if document_subset in non_query_commands:
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
if not formatted_doc_chunks and not use_llm_if_no_docs:
yield dict(prompt=prompt_basic, response="No sources", sources='', num_prompt_tokens=0)
return
# if no souces, outside gpt_langchain, LLM will be used with '' input
scores = [1] * len(docs)
get_answer_args = tuple([query, docs, formatted_doc_chunks, scores, show_rank,
answer_with_sources,
append_sources_to_answer])
get_answer_kwargs = dict(show_accordions=show_accordions,
show_link_in_sources=show_link_in_sources,
top_k_docs_max_show=top_k_docs_max_show,
docs_ordering_type=docs_ordering_type,
num_docs_before_cut=num_docs_before_cut,
verbose=verbose)
ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
return
if not use_llm_if_no_docs:
if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_ALL.value,
LangChainAction.SUMMARIZE_REFINE.value]:
ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.'
extra = ''
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
return
if not docs and not llm_mode:
ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).'
extra = ''
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
return
if chain is None and not langchain_only_model:
# here if no docs at all and not HF type
# can only return if HF type
return
# context stuff similar to used in evaluate()
import torch
device, torch_dtype, context_class = get_device_dtype()
conditional_type = hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'model') and hasattr(llm.pipeline.model,
'conditional_type') and llm.pipeline.model.conditional_type
with torch.no_grad():
have_lora_weights = lora_weights not in [no_lora_str, '', None]
context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
if conditional_type:
# issues when casting to float16, can mess up t5 model, e.g. only when not streaming, or other odd behaviors
context_class_cast = NullContext
with context_class_cast(device):
if stream_output and streamer:
answer = None
import queue
bucket = queue.Queue()
thread = EThread(target=chain, streamer=streamer, bucket=bucket)
thread.start()
outputs = ""
try:
for new_text in streamer:
# print("new_text: %s" % new_text, flush=True)
if bucket.qsize() > 0 or thread.exc:
thread.join()
outputs += new_text
if prompter: # and False: # FIXME: pipeline can already use prompter
if conditional_type:
if prompter.botstr:
prompt = prompter.botstr
output_with_prompt = prompt + outputs
only_new_text = False
else:
prompt = None
output_with_prompt = outputs
only_new_text = True
else:
prompt = None # FIXME
output_with_prompt = outputs
# don't specify only_new_text here, use get_llm() value
output1 = prompter.get_response(output_with_prompt, prompt=prompt,
only_new_text=only_new_text,
sanitize_bot_response=sanitize_bot_response)
yield dict(prompt=prompt, response=output1, sources='', num_prompt_tokens=0)
else:
yield dict(prompt=prompt, response=outputs, sources='', num_prompt_tokens=0)
except BaseException:
# if any exception, raise that exception if was from thread, first
if thread.exc:
raise thread.exc
raise
finally:
# in case no exception and didn't join with thread yet, then join
if not thread.exc:
answer = thread.join()
if isinstance(answer, dict):
if 'output_text' in answer:
answer = answer['output_text']
elif 'output' in answer:
answer = answer['output']
# in case raise StopIteration or broke queue loop in streamer, but still have exception
if thread.exc:
raise thread.exc
else:
if async_output:
import asyncio
answer = asyncio.run(chain())
else:
answer = chain()
if isinstance(answer, dict):
if 'output_text' in answer:
answer = answer['output_text']
elif 'output' in answer:
answer = answer['output']
get_answer_args = tuple([query, docs, answer, scores, show_rank,
answer_with_sources,
append_sources_to_answer])
get_answer_kwargs = dict(show_accordions=show_accordions,
show_link_in_sources=show_link_in_sources,
top_k_docs_max_show=top_k_docs_max_show,
docs_ordering_type=docs_ordering_type,
num_docs_before_cut=num_docs_before_cut,
verbose=verbose,
t_run=t_run,
count_input_tokens=llm.count_input_tokens
if hasattr(llm, 'count_input_tokens') else None,
count_output_tokens=llm.count_output_tokens
if hasattr(llm, 'count_output_tokens') else None)
t_run = time.time() - t_run
# for final yield, get real prompt used
if hasattr(llm, 'prompter') and llm.prompter.prompt is not None:
prompt = llm.prompter.prompt
else:
prompt = prompt_basic
num_prompt_tokens = get_token_count(prompt, tokenizer)
if not use_docs_planned:
ret = answer
extra = ''
yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
elif answer is not None:
ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
return
def get_docs_with_score(query, k_db, filter_kwargs, db, db_type, text_context_list=None, verbose=False):
docs_with_score = []
got_db_docs = False
if text_context_list:
docs_with_score += [(x, x.metadata.get('score', 1.0)) for x in text_context_list]
# deal with bug in chroma where if (say) 234 doc chunks and ask for 233+ then fails due to reduction misbehavior
if hasattr(db, '_embedding_function') and isinstance(db._embedding_function, FakeEmbeddings):
top_k_docs = -1
# don't add text_context_list twice
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
text_context_list=None)
# sort by order given to parser (file_id) and any chunk_id if chunked
doc_file_ids = [x.get('file_id', 0) for x in db_metadatas]
doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
docs_with_score_fake = [(Document(page_content=result[0], metadata=result[1] or {}), 1.0)
for result in zip(db_documents, db_metadatas)]
docs_with_score_fake = [x for fx, cx, x in
sorted(zip(doc_file_ids, doc_chunk_ids, docs_with_score_fake),
key=lambda x: (x[0], x[1]))
]
got_db_docs |= len(docs_with_score_fake) > 0
docs_with_score += docs_with_score_fake
elif db is not None and db_type in ['chroma', 'chroma_old']:
while True:
try:
docs_with_score_chroma = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)
break
except (RuntimeError, AttributeError) as e:
# AttributeError is for people with wrong version of langchain
if verbose:
print("chroma bug: %s" % str(e), flush=True)
if k_db == 1:
raise
if k_db > 500:
k_db -= 200
elif k_db > 100:
k_db -= 50
elif k_db > 10:
k_db -= 5
else:
k_db -= 1
k_db = max(1, k_db)
got_db_docs |= len(docs_with_score_chroma) > 0
docs_with_score += docs_with_score_chroma
elif db is not None:
docs_with_score_other = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)
got_db_docs |= len(docs_with_score_other) > 0
docs_with_score += docs_with_score_other
# set in metadata original order of docs
[x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
return docs_with_score, got_db_docs
def get_chain(query=None,
iinput=None,
context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638
use_openai_model=False, use_openai_embedding=False,
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
# urls
use_unstructured=True,
use_playwright=False,
use_selenium=False,
# pdfs
use_pymupdf='auto',
use_unstructured_pdf='auto',
use_pypdf='auto',
enable_pdf_ocr='auto',
enable_pdf_doctr='auto',
try_pdf_as_html='auto',
# images
enable_ocr=False,
enable_doctr=False,
enable_pix2struct=False,
enable_captions=True,
captions_model=None,
caption_loader=None,
doctr_loader=None,
pix2struct_loader=None,
# json
jq_schema='.[]',
langchain_mode_paths=None,
langchain_mode_types=None,
detect_user_path_changes_every_query=False,
db_type='faiss',
model_name=None,
inference_server='',
max_new_tokens=None,
langchain_only_model=False,
hf_embedding_model=None,
migrate_embedding_model=False,
auto_migrate_db=False,
prompter=None,
prompt_type=None,
prompt_dict=None,
system_prompt=None,
cut_distance=1.1,
add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638
add_search_to_context=False,
keep_sources_in_context=False,
memory_restriction_level=0,
top_k_docs_max_show=10,
load_db_if_exists=False,
db=None,
langchain_mode=None,
langchain_action=None,
langchain_agents=None,
document_subset=DocumentSubset.Relevant.name,
document_choice=[DocumentChoice.ALL.value],
pre_prompt_query=None,
prompt_query=None,
pre_prompt_summary=None,
prompt_summary=None,
text_context_list=None,
chat_conversation=None,
n_jobs=-1,
# beyond run_db_query:
llm=None,
tokenizer=None,
verbose=False,
docs_ordering_type='reverse_ucurve_sort',
min_max_new_tokens=256,
stream_output=True,
async_output=True,
# local
auto_reduce_chunks=True,
max_chunks=100,
total_tokens_for_docs=None,
use_llm_if_no_docs=None,
headsize=50,
):
if inference_server is None:
inference_server = ''
assert hf_embedding_model is not None
assert langchain_agents is not None # should be at least []
if text_context_list is None:
text_context_list = []
# default value:
llm_mode = langchain_mode in ['Disabled', 'LLM'] and len(text_context_list) == 0
query_action = langchain_action == LangChainAction.QUERY.value
summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_ALL.value,
LangChainAction.SUMMARIZE_REFINE.value]
if len(text_context_list) > 0:
# turn into documents to make easy to manage and add meta
# try to account for summarization vs. query
chunk_id = 0 if query_action else -1
text_context_list = [
Document(page_content=x, metadata=dict(source='text_context_list', score=1.0, chunk_id=chunk_id)) for x
in text_context_list]
if add_search_to_context:
params = {
"engine": "duckduckgo",
"gl": "us",
"hl": "en",
}
search = H2OSerpAPIWrapper(params=params)
# if doing search, allow more docs
docs_search, top_k_docs = search.get_search_documents(query,
query_action=query_action,
chunk=chunk, chunk_size=chunk_size,
db_type=db_type,
headsize=headsize,
top_k_docs=top_k_docs)
text_context_list = docs_search + text_context_list
add_search_to_context &= len(docs_search) > 0
top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
if len(text_context_list) > 0:
llm_mode = False
use_llm_if_no_docs = True
from src.output_parser import H2OMRKLOutputParser
from langchain.agents import AgentType, load_tools, initialize_agent, create_vectorstore_agent, \
create_pandas_dataframe_agent, create_json_agent, create_csv_agent
from langchain.agents.agent_toolkits import VectorStoreInfo, VectorStoreToolkit, create_python_agent, JsonToolkit
if LangChainAgent.SEARCH.value in langchain_agents:
output_parser = H2OMRKLOutputParser()
tools = load_tools(["serpapi"], llm=llm, serpapi_api_key=os.environ.get('SERPAPI_API_KEY'))
if inference_server.startswith('openai'):
agent_type = AgentType.OPENAI_FUNCTIONS
agent_executor_kwargs = {"handle_parsing_errors": True, 'output_parser': output_parser}
else:
agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION
agent_executor_kwargs = {'output_parser': output_parser}
chain = initialize_agent(tools, llm, agent=agent_type,
agent_executor_kwargs=agent_executor_kwargs,
agent_kwargs=dict(output_parser=output_parser,
format_instructions=output_parser.get_format_instructions()),
output_parser=output_parser,
max_iterations=10,
verbose=True)
chain_kwargs = dict(input=query)
target = wrapped_partial(chain, chain_kwargs)
docs = []
scores = []
use_docs_planned = False
num_docs_before_cut = 0
use_llm_if_no_docs = True
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
if LangChainAgent.COLLECTION.value in langchain_agents:
output_parser = H2OMRKLOutputParser()
vectorstore_info = VectorStoreInfo(
name=langchain_mode,
description="DataBase of text from PDFs, Image Captions, or web URL content",
vectorstore=db,
)
toolkit = VectorStoreToolkit(vectorstore_info=vectorstore_info)
chain = create_vectorstore_agent(llm=llm, toolkit=toolkit,
agent_executor_kwargs=dict(output_parser=output_parser),
verbose=True)
chain_kwargs = dict(input=query)
target = wrapped_partial(chain, chain_kwargs)
docs = []
scores = []
use_docs_planned = False
num_docs_before_cut = 0
use_llm_if_no_docs = True
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
chain = create_python_agent(
llm=llm,
tool=PythonREPLTool(),
verbose=True,
agent_type=AgentType.OPENAI_FUNCTIONS,
agent_executor_kwargs={"handle_parsing_errors": True},
)
chain_kwargs = dict(input=query)
target = wrapped_partial(chain, chain_kwargs)
docs = []
scores = []
use_docs_planned = False
num_docs_before_cut = 0
use_llm_if_no_docs = True
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
# FIXME: DATA
df = pd.DataFrame(None)
chain = create_pandas_dataframe_agent(
llm,
df,
verbose=True,
agent_type=AgentType.OPENAI_FUNCTIONS,
)
chain_kwargs = dict(input=query)
target = wrapped_partial(chain, chain_kwargs)
docs = []
scores = []
use_docs_planned = False
num_docs_before_cut = 0
use_llm_if_no_docs = True
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
if isinstance(document_choice, str):
document_choice = [document_choice]
if document_choice and document_choice[0] == DocumentChoice.ALL.value:
document_choice_agent = document_choice[1:]
else:
document_choice_agent = document_choice
document_choice_agent = [x for x in document_choice_agent if x.endswith('.json')]
if LangChainAgent.JSON.value in \
langchain_agents and \
inference_server.startswith('openai_chat') and \
len(document_choice_agent) == 1 and \
document_choice_agent[0].endswith('.json'):
# with open('src/openai.yaml') as f:
# data = yaml.load(f, Loader=yaml.FullLoader)
with open(document_choice[0], 'rt') as f:
data = json.loads(f.read())
json_spec = JsonSpec(dict_=data, max_value_length=4000)
json_toolkit = JsonToolkit(spec=json_spec)
chain = create_json_agent(
llm=llm, toolkit=json_toolkit, verbose=True
)
chain_kwargs = dict(input=query)
target = wrapped_partial(chain, chain_kwargs)
docs = []
scores = []
use_docs_planned = False
num_docs_before_cut = 0
use_llm_if_no_docs = True
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
if isinstance(document_choice, str):
document_choice = [document_choice]
if document_choice and document_choice[0] == DocumentChoice.ALL.value:
document_choice_agent = document_choice[1:]
else:
document_choice_agent = document_choice
document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
0].endswith(
'.csv'):
data_file = document_choice[0]
if inference_server.startswith('openai_chat'):
chain = create_csv_agent(
llm,
data_file,
verbose=True,
agent_type=AgentType.OPENAI_FUNCTIONS,
)
else:
chain = create_csv_agent(
llm,
data_file,
verbose=True,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
chain_kwargs = dict(input=query)
target = wrapped_partial(chain, chain_kwargs)
docs = []
scores = []
use_docs_planned = False
num_docs_before_cut = 0
use_llm_if_no_docs = True
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
# determine whether use of context out of docs is planned
if not use_openai_model and prompt_type not in ['plain'] or langchain_only_model:
if llm_mode:
use_docs_planned = False
else:
use_docs_planned = True
else:
use_docs_planned = True
# https://github.com/hwchase17/langchain/issues/1946
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
# Chroma collection MyData contains fewer than 4 elements.
# type logger error
if top_k_docs == -1:
k_db = 1000 if db_type in ['chroma', 'chroma_old'] else 100
else:
# top_k_docs=100 works ok too
k_db = 1000 if db_type in ['chroma', 'chroma_old'] else top_k_docs
# FIXME: For All just go over all dbs instead of a separate db for All
if not detect_user_path_changes_every_query and db is not None:
# avoid looking at user_path during similarity search db handling,
# if already have db and not updating from user_path every query
# but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
if langchain_mode_paths is None:
langchain_mode_paths = {}
langchain_mode_paths = langchain_mode_paths.copy()
langchain_mode_paths[langchain_mode] = None
# once use_openai_embedding, hf_embedding_model passed in, possibly changed,
# but that's ok as not used below or in calling functions
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
hf_embedding_model=hf_embedding_model,
migrate_embedding_model=migrate_embedding_model,
auto_migrate_db=auto_migrate_db,
first_para=first_para, text_limit=text_limit,
chunk=chunk, chunk_size=chunk_size,
# urls
use_unstructured=use_unstructured,
use_playwright=use_playwright,
use_selenium=use_selenium,
# pdfs
use_pymupdf=use_pymupdf,
use_unstructured_pdf=use_unstructured_pdf,
use_pypdf=use_pypdf,
enable_pdf_ocr=enable_pdf_ocr,
enable_pdf_doctr=enable_pdf_doctr,
try_pdf_as_html=try_pdf_as_html,
# images
enable_ocr=enable_ocr,
enable_doctr=enable_doctr,
enable_pix2struct=enable_pix2struct,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
doctr_loader=doctr_loader,
pix2struct_loader=pix2struct_loader,
# json
jq_schema=jq_schema,
langchain_mode=langchain_mode,
langchain_mode_paths=langchain_mode_paths,
langchain_mode_types=langchain_mode_types,
db_type=db_type,
load_db_if_exists=load_db_if_exists,
db=db,
n_jobs=n_jobs,
verbose=verbose)
num_docs_before_cut = 0
use_template = not use_openai_model and prompt_type not in ['plain'] or langchain_only_model
got_db_docs = False # not yet at least
template, template_if_no_docs, auto_reduce_chunks, query = \
get_template(query, iinput,
pre_prompt_query, prompt_query,
pre_prompt_summary, prompt_summary,
langchain_action,
llm_mode,
use_docs_planned,
auto_reduce_chunks,
got_db_docs,
add_search_to_context)
max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
model_name=model_name, max_new_tokens=max_new_tokens)
if (db or text_context_list) and use_docs_planned:
if hasattr(db, '_persist_directory'):
lock_file = get_db_lock_file(db, lock_type='sim')
else:
base_path = 'locks'
base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
name_path = "sim.lock"
lock_file = os.path.join(base_path, name_path)
if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)):
# only chroma supports filtering
filter_kwargs = {}
filter_kwargs_backup = {}
else:
import logging
logging.getLogger("chromadb").setLevel(logging.ERROR)
assert document_choice is not None, "Document choice was None"
if isinstance(db, Chroma):
filter_kwargs_backup = {} # shouldn't ever need backup
# chroma >= 0.4
if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
0] == DocumentChoice.ALL.value:
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
{"filter": {"chunk_id": {"$eq": -1}}}
else:
if document_choice[0] == DocumentChoice.ALL.value:
document_choice = document_choice[1:]
if len(document_choice) == 0:
filter_kwargs = {}
elif len(document_choice) > 1:
or_filter = [
{"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else {
"$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]}
for x in document_choice]
filter_kwargs = dict(filter={"$or": or_filter})
else:
# still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter
one_filter = \
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
"source": {"$eq": x},
"chunk_id": {
"$eq": -1}}
for x in document_choice][0]
filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']),
dict(chunk_id=one_filter['chunk_id'])]})
else:
# migration for chroma < 0.4
if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
0] == DocumentChoice.ALL.value:
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
{"filter": {"chunk_id": {"$eq": -1}}}
filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
elif len(document_choice) >= 2:
if document_choice[0] == DocumentChoice.ALL.value:
document_choice = document_choice[1:]
or_filter = [
{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
"chunk_id": {
"$eq": -1}}
for x in document_choice]
filter_kwargs = dict(filter={"$or": or_filter})
or_filter_backup = [
{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
for x in document_choice]
filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
elif len(document_choice) == 1:
# degenerate UX bug in chroma
one_filter = \
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
"chunk_id": {
"$eq": -1}}
for x in document_choice][0]
filter_kwargs = dict(filter=one_filter)
one_filter_backup = \
[{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
for x in document_choice][0]
filter_kwargs_backup = dict(filter=one_filter_backup)
else:
# shouldn't reach
filter_kwargs = {}
filter_kwargs_backup = {}
if llm_mode:
docs = []
scores = []
elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
text_context_list=text_context_list)
if len(db_documents) == 0 and filter_kwargs_backup:
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
text_context_list=text_context_list)
if top_k_docs == -1:
top_k_docs = len(db_documents)
# similar to langchain's chroma's _results_to_docs_and_scores
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
for result in zip(db_documents, db_metadatas)]
# set in metadata original order of docs
[x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
# order documents
doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
if query_action:
doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
docs_with_score2 = [x for hx, cx, x in
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
if cx >= 0]
else:
assert summarize_action
doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas]
docs_with_score2 = [x for hx, cx, x in
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
if cx == -1
]
if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
# old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
# just do again and relax filter, let summarize operate on actual chunks if nothing else
docs_with_score2 = [x for hx, cx, x in
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
key=lambda x: (x[0], x[1]))
]
docs_with_score = docs_with_score2
docs_with_score = docs_with_score[:top_k_docs]
docs = [x[0] for x in docs_with_score]
scores = [x[1] for x in docs_with_score]
num_docs_before_cut = len(docs)
else:
with filelock.FileLock(lock_file):
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
text_context_list=text_context_list,
verbose=verbose)
if len(docs_with_score) == 0 and filter_kwargs_backup:
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
db_type,
text_context_list=text_context_list,
verbose=verbose)
tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server,
use_openai_model=use_openai_model,
db_type=db_type)
# NOTE: if map_reduce, then no need to auto reduce chunks
if query_action and (top_k_docs == -1 or auto_reduce_chunks):
top_k_docs_tokenize = 100
docs_with_score = docs_with_score[:top_k_docs_tokenize]
prompt_no_docs = template.format(context='', question=query)
model_max_length = tokenizer.model_max_length
chat = True # FIXME?
# first docs_with_score are most important with highest score
full_prompt, \
instruction, iinput, context, \
num_prompt_tokens, max_new_tokens, \
num_prompt_tokens0, num_prompt_tokens_actual, \
chat_index, top_k_docs_trial, one_doc_size = \
get_limited_prompt(prompt_no_docs,
iinput,
tokenizer,
prompter=prompter,
inference_server=inference_server,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
chat=chat,
max_new_tokens=max_new_tokens,
system_prompt=system_prompt,
context=context,
chat_conversation=chat_conversation,
text_context_list=[x[0].page_content for x in docs_with_score],
keep_sources_in_context=keep_sources_in_context,
model_max_length=model_max_length,
memory_restriction_level=memory_restriction_level,
langchain_mode=langchain_mode,
add_chat_history_to_context=add_chat_history_to_context,
min_max_new_tokens=min_max_new_tokens,
)
# avoid craziness
if 0 < top_k_docs_trial < max_chunks:
# avoid craziness
if top_k_docs == -1:
top_k_docs = top_k_docs_trial
else:
top_k_docs = min(top_k_docs, top_k_docs_trial)
elif top_k_docs_trial >= max_chunks:
top_k_docs = max_chunks
if top_k_docs > 0:
docs_with_score = docs_with_score[:top_k_docs]
elif one_doc_size is not None:
docs_with_score = [docs_with_score[0][:one_doc_size]]
else:
docs_with_score = []
else:
if total_tokens_for_docs is not None:
# used to limit tokens for summarization, e.g. public instance
top_k_docs, one_doc_size, num_doc_tokens = \
get_docs_tokens(tokenizer,
text_context_list=[x[0].page_content for x in docs_with_score],
max_input_tokens=total_tokens_for_docs)
docs_with_score = docs_with_score[:top_k_docs]
# put most relevant chunks closest to question,
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
if docs_ordering_type in ['best_first']:
pass
elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
docs_with_score.reverse()
elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
docs_with_score = reverse_ucurve_list(docs_with_score)
else:
raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
# cut off so no high distance docs/sources considered
num_docs_before_cut = len(docs_with_score)
docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
if len(scores) > 0 and verbose:
print("Distance: min: %s max: %s mean: %s median: %s" %
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
else:
docs = []
scores = []
if not docs and use_docs_planned and not langchain_only_model:
# if HF type and have no docs, can bail out
return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
if document_subset in non_query_commands:
# no LLM use
return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
# FIXME: WIP
common_words_file = "data/NGSL_1.2_stats.csv.zip"
if False and os.path.isfile(common_words_file) and langchain_action == LangChainAction.QUERY.value:
df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
import string
reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
reduced_query_words = reduced_query.split(' ')
set_common = set(df['Lemma'].values.tolist())
num_common = len([x.lower() in set_common for x in reduced_query_words])
frac_common = num_common / len(reduced_query) if reduced_query else 0
# FIXME: report to user bad query that uses too many common words
if verbose:
print("frac_common: %s" % frac_common, flush=True)
if len(docs) == 0:
# avoid context == in prompt then
use_docs_planned = False
template = template_if_no_docs
got_db_docs = got_db_docs and len(text_context_list) < len(docs)
# update template in case situation changed or did get docs
# then no new documents from database or not used, redo template
# got template earlier as estimate of template token size, here is final used version
template, template_if_no_docs, auto_reduce_chunks, query = \
get_template(query, iinput,
pre_prompt_query, prompt_query,
pre_prompt_summary, prompt_summary,
langchain_action,
llm_mode,
use_docs_planned,
auto_reduce_chunks,
got_db_docs,
add_search_to_context)
if langchain_action == LangChainAction.QUERY.value:
if use_template:
# instruct-like, rather than few-shot prompt_type='plain' as default
# but then sources confuse the model with how inserted among rest of text, so avoid
prompt = PromptTemplate(
# input_variables=["summaries", "question"],
input_variables=["context", "question"],
template=template,
)
chain = load_qa_chain(llm, prompt=prompt, verbose=verbose)
else:
# only if use_openai_model = True, unused normally except in testing
chain = load_qa_with_sources_chain(llm)
if not use_docs_planned:
chain_kwargs = dict(input_documents=[], question=query)
else:
chain_kwargs = dict(input_documents=docs, question=query)
target = wrapped_partial(chain, chain_kwargs)
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_REFINE,
LangChainAction.SUMMARIZE_ALL.value]:
if async_output:
return_intermediate_steps = False
else:
return_intermediate_steps = True
from langchain.chains.summarize import load_summarize_chain
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
prompt = PromptTemplate(input_variables=["text"], template=template)
chain = load_summarize_chain(llm, chain_type="map_reduce",
map_prompt=prompt, combine_prompt=prompt,
return_intermediate_steps=return_intermediate_steps,
token_max=max_input_tokens, verbose=verbose)
if async_output:
chain_func = chain.arun
else:
chain_func = chain
target = wrapped_partial(chain_func, {"input_documents": docs}) # , return_only_outputs=True)
elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
assert use_template
prompt = PromptTemplate(input_variables=["text"], template=template)
chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt,
return_intermediate_steps=return_intermediate_steps, verbose=verbose)
if async_output:
chain_func = chain.arun
else:
chain_func = chain
target = wrapped_partial(chain_func)
elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
chain = load_summarize_chain(llm, chain_type="refine",
return_intermediate_steps=return_intermediate_steps, verbose=verbose)
if async_output:
chain_func = chain.arun
else:
chain_func = chain
target = wrapped_partial(chain_func)
else:
raise RuntimeError("No such langchain_action=%s" % langchain_action)
else:
raise RuntimeError("No such langchain_action=%s" % langchain_action)
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
if hasattr(tokenizer, 'model_max_length'):
return tokenizer.model_max_length
elif inference_server in ['openai', 'openai_azure']:
return llm.modelname_to_contextsize(model_name)
elif inference_server in ['openai_chat', 'openai_azure_chat']:
return model_token_mapping[model_name]
elif isinstance(tokenizer, FakeTokenizer):
# GGML
return tokenizer.model_max_length
else:
return 2048
def get_max_input_tokens(llm=None, tokenizer=None, inference_server=None, model_name=None, max_new_tokens=None):
model_max_length = get_max_model_length(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
model_name=model_name)
if any([inference_server.startswith(x) for x in
['openai', 'openai_azure', 'openai_chat', 'openai_azure_chat', 'vllm']]):
# openai can't handle tokens + max_new_tokens > max_tokens even if never generate those tokens
# and vllm uses OpenAI API with same limits
max_input_tokens = model_max_length - max_new_tokens
elif isinstance(tokenizer, FakeTokenizer):
# don't trust that fake tokenizer (e.g. GGML) will make lots of tokens normally, allow more input
max_input_tokens = model_max_length - min(256, max_new_tokens)
else:
if 'falcon' in model_name or inference_server.startswith('http'):
# allow for more input for falcon, assume won't make as long outputs as default max_new_tokens
# Also allow if TGI or Gradio, because we tell it input may be same as output, even if model can't actually handle
max_input_tokens = model_max_length - min(256, max_new_tokens)
else:
# trust that maybe model will make so many tokens, so limit input
max_input_tokens = model_max_length - max_new_tokens
return max_input_tokens
def get_tokenizer(db=None, llm=None, tokenizer=None, inference_server=None, use_openai_model=False,
db_type='chroma'):
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
# more accurate
return llm.pipeline.tokenizer
elif hasattr(llm, 'tokenizer'):
# e.g. TGI client mode etc.
return llm.tokenizer
elif inference_server in ['openai', 'openai_chat', 'openai_azure',
'openai_azure_chat']:
return tokenizer
elif isinstance(tokenizer, FakeTokenizer):
return tokenizer
elif use_openai_model:
return FakeTokenizer()
elif (hasattr(db, '_embedding_function') and
hasattr(db._embedding_function, 'client') and
hasattr(db._embedding_function.client, 'tokenize')):
# in case model is not our pipeline with HF tokenizer
return db._embedding_function.client.tokenize
else:
# backup method
if os.getenv('HARD_ASSERTS'):
assert db_type in ['faiss', 'weaviate']
# use tiktoken for faiss since embedding called differently
return FakeTokenizer()
def get_template(query, iinput,
pre_prompt_query, prompt_query,
pre_prompt_summary, prompt_summary,
langchain_action,
llm_mode,
use_docs_planned,
auto_reduce_chunks,
got_db_docs,
add_search_to_context):
if got_db_docs and add_search_to_context:
# modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that.
prompt_query = prompt_query.replace('information in the document sources',
'information in the document and web search sources (and their source dates and website source)')
prompt_summary = prompt_summary.replace('information in the document sources',
'information in the document and web search sources (and their source dates and website source)')
elif got_db_docs and not add_search_to_context:
pass
elif not got_db_docs and add_search_to_context:
# modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that.
prompt_query = prompt_query.replace('information in the document sources',
'information in the web search sources (and their source dates and website source)')
prompt_summary = prompt_summary.replace('information in the document sources',
'information in the web search sources (and their source dates and website source)')
if langchain_action == LangChainAction.QUERY.value:
if iinput:
query = "%s\n%s" % (query, iinput)
if llm_mode or not use_docs_planned:
template_if_no_docs = template = """{context}{question}"""
else:
template = """%s
\"\"\"
{context}
\"\"\"
%s{question}""" % (pre_prompt_query, prompt_query)
template_if_no_docs = """{context}{question}"""
elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]:
none = ['', '\n', None]
# modify prompt_summary if user passes query or iinput
if query not in none and iinput not in none:
prompt_summary = "Focusing on %s, %s, %s" % (query, iinput, prompt_summary)
elif query not in none:
prompt_summary = "Focusing on %s, %s" % (query, prompt_summary)
# don't auto reduce
auto_reduce_chunks = False
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
fstring = '{text}'
else:
fstring = '{input_documents}'
template = """%s:
\"\"\"
%s
\"\"\"\n%s""" % (pre_prompt_summary, fstring, prompt_summary)
template_if_no_docs = "Exactly only say: There are no documents to summarize."
elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]:
template = '' # unused
template_if_no_docs = '' # unused
else:
raise RuntimeError("No such langchain_action=%s" % langchain_action)
return template, template_if_no_docs, auto_reduce_chunks, query
def get_sources_answer(query, docs, answer, scores, show_rank,
answer_with_sources, append_sources_to_answer,
show_accordions=True,
show_link_in_sources=True,
top_k_docs_max_show=10,
docs_ordering_type='reverse_ucurve_sort',
num_docs_before_cut=0,
verbose=False,
t_run=None,
count_input_tokens=None, count_output_tokens=None):
if verbose:
print("query: %s" % query, flush=True)
print("answer: %s" % answer, flush=True)
if len(docs) == 0:
extra = ''
ret = answer + extra
return ret, extra
if answer_with_sources == -1:
extra = [dict(score=score, content=get_doc(x), source=get_source(x), orig_index=x.metadata.get('orig_index', 0))
for score, x in zip(scores, docs)][
:top_k_docs_max_show]
if append_sources_to_answer:
extra_str = [str(x) for x in extra]
ret = answer + '\n\n' + '\n'.join(extra_str)
else:
ret = answer
return ret, extra
# link
answer_sources = [(max(0.0, 1.5 - score) / 1.5,
get_url(doc, font_size=font_size),
get_accordion(doc, font_size=font_size, head_acc=head_acc)) for score, doc in
zip(scores, docs)]
if not show_accordions:
answer_sources_dict = defaultdict(list)
[answer_sources_dict[url].append(score) for score, url in answer_sources]
answers_dict = {}
for url, scores_url in answer_sources_dict.items():
answers_dict[url] = np.max(scores_url)
answer_sources = [(score, url) for url, score in answers_dict.items()]
answer_sources.sort(key=lambda x: x[0], reverse=True)
if show_rank:
# answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
# sorted_sources_urls = "Sources [Rank | Link]:
" + "
".join(answer_sources)
answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
answer_sources = answer_sources[:top_k_docs_max_show]
sorted_sources_urls = "Ranked Sources:
" + "
".join(answer_sources)
else:
if show_accordions:
if show_link_in_sources:
answer_sources = ['" + "".join(answer_sources)
else:
sorted_sources_urls = f"{source_prefix}
" + "
".join( answer_sources) if verbose: if int(t_run): sorted_sources_urls += 'Total Time: %d [s]
' % t_run if count_input_tokens and count_output_tokens: sorted_sources_urls += 'Input Tokens: %s | Output Tokens: %d
' % ( count_input_tokens, count_output_tokens) sorted_sources_urls += f"
{source_postfix}" title_overall = "Sources" sorted_sources_urls = f"""
Sources:
{0}
{0}
Exceptions: