Spaces:
Sleeping
Sleeping
import { HfInference } from "@huggingface/inference"; | |
export const LLM_CONFIG = { | |
/* Hugginface config: */ | |
ollama: false, | |
huggingface: true, | |
url: "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct", | |
chatModel: "meta-llama/Meta-Llama-3-8B-Instruct", | |
embeddingModel: | |
"https://api-inference.huggingface.co/models/mixedbread-ai/mxbai-embed-large-v1", | |
embeddingDimension: 1024, | |
/* Ollama (local) config: | |
*/ | |
// ollama: true, | |
// url: 'http://127.0.0.1:11434', | |
// chatModel: 'llama3' as const, | |
// embeddingModel: 'mxbai-embed-large', | |
// embeddingDimension: 1024, | |
// embeddingModel: 'llama3', | |
// embeddingDimension: 4096, | |
/* Together.ai config: | |
ollama: false, | |
url: 'https://api.together.xyz', | |
chatModel: 'meta-llama/Llama-3-8b-chat-hf', | |
embeddingModel: 'togethercomputer/m2-bert-80M-8k-retrieval', | |
embeddingDimension: 768, | |
*/ | |
/* OpenAI config: | |
ollama: false, | |
url: 'https://api.openai.com', | |
chatModel: 'gpt-3.5-turbo-16k', | |
embeddingModel: 'text-embedding-ada-002', | |
embeddingDimension: 1536, | |
*/ | |
}; | |
function apiUrl(path: string) { | |
// OPENAI_API_BASE and OLLAMA_HOST are legacy | |
const host = | |
process.env.LLM_API_URL ?? | |
process.env.OLLAMA_HOST ?? | |
process.env.OPENAI_API_BASE ?? | |
LLM_CONFIG.url; | |
if (host.endsWith("/") && path.startsWith("/")) { | |
return host + path.slice(1); | |
} else if (!host.endsWith("/") && !path.startsWith("/")) { | |
return host + "/" + path; | |
} else { | |
return host + path; | |
} | |
} | |
function apiKey() { | |
return process.env.LLM_API_KEY ?? process.env.OPENAI_API_KEY; | |
} | |
const AuthHeaders = (): Record<string, string> => | |
apiKey() | |
? { | |
Authorization: "Bearer " + apiKey(), | |
} | |
: {}; | |
// Overload for non-streaming | |
export async function chatCompletion( | |
body: Omit<CreateChatCompletionRequest, "model"> & { | |
model?: CreateChatCompletionRequest["model"]; | |
} & { | |
stream?: false | null | undefined; | |
} | |
): Promise<{ content: string; retries: number; ms: number }>; | |
// Overload for streaming | |
export async function chatCompletion( | |
body: Omit<CreateChatCompletionRequest, "model"> & { | |
model?: CreateChatCompletionRequest["model"]; | |
} & { | |
stream?: true; | |
} | |
): Promise<{ content: ChatCompletionContent; retries: number; ms: number }>; | |
export async function chatCompletion( | |
body: Omit<CreateChatCompletionRequest, "model"> & { | |
model?: CreateChatCompletionRequest["model"]; | |
} | |
) { | |
assertApiKey(); | |
// OLLAMA_MODEL is legacy | |
body.model = | |
body.model ?? | |
process.env.LLM_MODEL ?? | |
process.env.OLLAMA_MODEL ?? | |
LLM_CONFIG.chatModel; | |
const stopWords = body.stop | |
? typeof body.stop === "string" | |
? [body.stop] | |
: body.stop | |
: []; | |
if (LLM_CONFIG.ollama || LLM_CONFIG.huggingface) stopWords.push("<|eot_id|>"); | |
const { | |
result: content, | |
retries, | |
ms, | |
} = await retryWithBackoff(async () => { | |
const hf = new HfInference(apiKey()); | |
const model = hf.endpoint(apiUrl("/v1/chat/completions")); | |
if (body.stream) { | |
const completion = model.chatCompletionStream({ | |
...body, | |
}); | |
return new ChatCompletionContent(completion, stopWords); | |
} else { | |
const completion = await model.chatCompletion({ | |
...body, | |
}); | |
const content = completion.choices[0].message?.content; | |
if (content === undefined) { | |
throw new Error( | |
"Unexpected result from OpenAI: " + JSON.stringify(completion) | |
); | |
} | |
return content; | |
} | |
}); | |
return { | |
content, | |
retries, | |
ms, | |
}; | |
} | |
export async function tryPullOllama(model: string, error: string) { | |
if (error.includes("try pulling")) { | |
console.error("Embedding model not found, pulling from Ollama"); | |
const pullResp = await fetch(apiUrl("/api/pull"), { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
}, | |
body: JSON.stringify({ name: model }), | |
}); | |
console.log("Pull response", await pullResp.text()); | |
throw { | |
retry: true, | |
error: `Dynamically pulled model. Original error: ${error}`, | |
}; | |
} | |
} | |
export async function fetchEmbeddingBatch(texts: string[]) { | |
if (LLM_CONFIG.ollama) { | |
return { | |
ollama: true as const, | |
embeddings: await Promise.all( | |
texts.map(async (t) => (await ollamaFetchEmbedding(t)).embedding) | |
), | |
}; | |
} | |
assertApiKey(); | |
if (LLM_CONFIG.huggingface) { | |
const result = await fetch(LLM_CONFIG.embeddingModel, { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
"X-Wait-For-Model": "true", | |
...AuthHeaders(), | |
}, | |
body: JSON.stringify({ | |
inputs: texts.map((text) => text.replace(/\n/g, " ")), | |
}), | |
}); | |
const embeddings = await result.json(); | |
return { | |
ollama: true as const, | |
embeddings: embeddings, | |
}; | |
} | |
const { | |
result: json, | |
retries, | |
ms, | |
} = await retryWithBackoff(async () => { | |
const result = await fetch(apiUrl("/v1/embeddings"), { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
...AuthHeaders(), | |
}, | |
body: JSON.stringify({ | |
model: LLM_CONFIG.embeddingModel, | |
input: texts.map((text) => text.replace(/\n/g, " ")), | |
}), | |
}); | |
if (!result.ok) { | |
throw { | |
retry: result.status === 429 || result.status >= 500, | |
error: new Error( | |
`Embedding failed with code ${result.status}: ${await result.text()}` | |
), | |
}; | |
} | |
return (await result.json()) as CreateEmbeddingResponse; | |
}); | |
if (json.data.length !== texts.length) { | |
console.error(json); | |
throw new Error("Unexpected number of embeddings"); | |
} | |
const allembeddings = json.data; | |
allembeddings.sort((a, b) => a.index - b.index); | |
return { | |
ollama: false as const, | |
embeddings: allembeddings.map(({ embedding }) => embedding), | |
usage: json.usage?.total_tokens, | |
retries, | |
ms, | |
}; | |
} | |
export async function fetchEmbedding(text: string) { | |
const { embeddings, ...stats } = await fetchEmbeddingBatch([text]); | |
return { embedding: embeddings[0], ...stats }; | |
} | |
export async function fetchModeration(content: string) { | |
assertApiKey(); | |
const { result: flagged } = await retryWithBackoff(async () => { | |
const result = await fetch(apiUrl("/v1/moderations"), { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
...AuthHeaders(), | |
}, | |
body: JSON.stringify({ | |
input: content, | |
}), | |
}); | |
if (!result.ok) { | |
throw { | |
retry: result.status === 429 || result.status >= 500, | |
error: new Error( | |
`Embedding failed with code ${result.status}: ${await result.text()}` | |
), | |
}; | |
} | |
return (await result.json()) as { results: { flagged: boolean }[] }; | |
}); | |
return flagged; | |
} | |
export function assertApiKey() { | |
if (!LLM_CONFIG.ollama && !apiKey()) { | |
throw new Error( | |
"\n Missing LLM_API_KEY in environment variables.\n\n" + | |
(LLM_CONFIG.ollama ? "just" : "npx") + | |
" convex env set LLM_API_KEY 'your-key'" | |
); | |
} | |
} | |
// Retry after this much time, based on the retry number. | |
const RETRY_BACKOFF = [1000, 10_000, 20_000]; // In ms | |
const RETRY_JITTER = 100; // In ms | |
type RetryError = { retry: boolean; error: any }; | |
export async function retryWithBackoff<T>( | |
fn: () => Promise<T> | |
): Promise<{ retries: number; result: T; ms: number }> { | |
let i = 0; | |
for (; i <= RETRY_BACKOFF.length; i++) { | |
try { | |
const start = Date.now(); | |
const result = await fn(); | |
const ms = Date.now() - start; | |
return { result, retries: i, ms }; | |
} catch (e) { | |
const retryError = e as RetryError; | |
if (i < RETRY_BACKOFF.length) { | |
if (retryError.retry) { | |
console.log( | |
`Attempt ${i + 1} failed, waiting ${ | |
RETRY_BACKOFF[i] | |
}ms to retry...`, | |
Date.now() | |
); | |
await new Promise((resolve) => | |
setTimeout(resolve, RETRY_BACKOFF[i] + RETRY_JITTER * Math.random()) | |
); | |
continue; | |
} | |
} | |
if (retryError.error) throw retryError.error; | |
else throw e; | |
} | |
} | |
throw new Error("Unreachable"); | |
} | |
// Lifted from openai's package | |
export interface LLMMessage { | |
/** | |
* The contents of the message. `content` is required for all messages, and may be | |
* null for assistant messages with function calls. | |
*/ | |
content: string | null; | |
/** | |
* The role of the messages author. One of `system`, `user`, `assistant`, or | |
* `function`. | |
*/ | |
role: "system" | "user" | "assistant" | "function"; | |
/** | |
* The name of the author of this message. `name` is required if role is | |
* `function`, and it should be the name of the function whose response is in the | |
* `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of | |
* 64 characters. | |
*/ | |
name?: string; | |
/** | |
* The name and arguments of a function that should be called, as generated by the model. | |
*/ | |
function_call?: { | |
// The name of the function to call. | |
name: string; | |
/** | |
* The arguments to call the function with, as generated by the model in | |
* JSON format. Note that the model does not always generate valid JSON, | |
* and may hallucinate parameters not defined by your function schema. | |
* Validate the arguments in your code before calling your function. | |
*/ | |
arguments: string; | |
}; | |
} | |
// Non-streaming chat completion response | |
interface CreateChatCompletionResponse { | |
id: string; | |
object: string; | |
created: number; | |
model: string; | |
choices: { | |
index?: number; | |
message?: { | |
role: "system" | "user" | "assistant"; | |
content: string; | |
}; | |
finish_reason?: string; | |
}[]; | |
usage?: { | |
completion_tokens: number; | |
prompt_tokens: number; | |
total_tokens: number; | |
}; | |
} | |
interface CreateEmbeddingResponse { | |
data: { | |
index: number; | |
object: string; | |
embedding: number[]; | |
}[]; | |
model: string; | |
object: string; | |
usage: { | |
prompt_tokens: number; | |
total_tokens: number; | |
}; | |
} | |
export interface CreateChatCompletionRequest { | |
/** | |
* ID of the model to use. | |
* @type {string} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
model: string; | |
// | 'gpt-4' | |
// | 'gpt-4-0613' | |
// | 'gpt-4-32k' | |
// | 'gpt-4-32k-0613' | |
// | 'gpt-3.5-turbo' | |
// | 'gpt-3.5-turbo-0613' | |
// | 'gpt-3.5-turbo-16k' // <- our default | |
// | 'gpt-3.5-turbo-16k-0613'; | |
/** | |
* The messages to generate chat completions for, in the chat format: | |
* https://platform.openai.com/docs/guides/chat/introduction | |
* @type {Array<ChatCompletionRequestMessage>} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
messages: LLMMessage[]; | |
/** | |
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both. | |
* @type {number} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
temperature?: number | null; | |
/** | |
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or `temperature` but not both. | |
* @type {number} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
top_p?: number | null; | |
/** | |
* How many chat completion choices to generate for each input message. | |
* @type {number} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
n?: number | null; | |
/** | |
* If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. | |
* @type {boolean} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
stream?: boolean | null; | |
/** | |
* | |
* @type {CreateChatCompletionRequestStop} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
stop?: Array<string> | string; | |
/** | |
* The maximum number of tokens allowed for the generated answer. By default, | |
* the number of tokens the model can return will be (4096 - prompt tokens). | |
* @type {number} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
max_tokens?: number; | |
/** | |
* Number between -2.0 and 2.0. Positive values penalize new tokens based on | |
* whether they appear in the text so far, increasing the model\'s likelihood | |
* to talk about new topics. See more information about frequency and | |
* presence penalties: | |
* https://platform.openai.com/docs/api-reference/parameter-details | |
* @type {number} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
presence_penalty?: number | null; | |
/** | |
* Number between -2.0 and 2.0. Positive values penalize new tokens based on | |
* their existing frequency in the text so far, decreasing the model\'s | |
* likelihood to repeat the same line verbatim. See more information about | |
* presence penalties: | |
* https://platform.openai.com/docs/api-reference/parameter-details | |
* @type {number} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
frequency_penalty?: number | null; | |
/** | |
* Modify the likelihood of specified tokens appearing in the completion. | |
* Accepts a json object that maps tokens (specified by their token ID in the | |
* tokenizer) to an associated bias value from -100 to 100. Mathematically, | |
* the bias is added to the logits generated by the model prior to sampling. | |
* The exact effect will vary per model, but values between -1 and 1 should | |
* decrease or increase likelihood of selection; values like -100 or 100 | |
* should result in a ban or exclusive selection of the relevant token. | |
* @type {object} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
logit_bias?: object | null; | |
/** | |
* A unique identifier representing your end-user, which can help OpenAI to | |
* monitor and detect abuse. Learn more: | |
* https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids | |
* @type {string} | |
* @memberof CreateChatCompletionRequest | |
*/ | |
user?: string; | |
tools?: { | |
// The type of the tool. Currently, only function is supported. | |
type: "function"; | |
function: { | |
/** | |
* The name of the function to be called. Must be a-z, A-Z, 0-9, or | |
* contain underscores and dashes, with a maximum length of 64. | |
*/ | |
name: string; | |
/** | |
* A description of what the function does, used by the model to choose | |
* when and how to call the function. | |
*/ | |
description?: string; | |
/** | |
* The parameters the functions accepts, described as a JSON Schema | |
* object. See the guide[1] for examples, and the JSON Schema reference[2] | |
* for documentation about the format. | |
* [1]: https://platform.openai.com/docs/guides/gpt/function-calling | |
* [2]: https://json-schema.org/understanding-json-schema/ | |
* To describe a function that accepts no parameters, provide the value | |
* {"type": "object", "properties": {}}. | |
*/ | |
parameters: object; | |
}; | |
}[]; | |
/** | |
* Controls which (if any) function is called by the model. `none` means the | |
* model will not call a function and instead generates a message. | |
* `auto` means the model can pick between generating a message or calling a | |
* function. Specifying a particular function via | |
* {"type: "function", "function": {"name": "my_function"}} forces the model | |
* to call that function. | |
* | |
* `none` is the default when no functions are present. | |
* `auto` is the default if functions are present. | |
*/ | |
tool_choice?: | |
| "none" // none means the model will not call a function and instead generates a message. | |
| "auto" // auto means the model can pick between generating a message or calling a function. | |
// Specifies a tool the model should use. Use to force the model to call | |
// a specific function. | |
| { | |
// The type of the tool. Currently, only function is supported. | |
type: "function"; | |
function: { name: string }; | |
}; | |
// Replaced by "tools" | |
// functions?: { | |
// /** | |
// * The name of the function to be called. Must be a-z, A-Z, 0-9, or | |
// * contain underscores and dashes, with a maximum length of 64. | |
// */ | |
// name: string; | |
// /** | |
// * A description of what the function does, used by the model to choose | |
// * when and how to call the function. | |
// */ | |
// description?: string; | |
// /** | |
// * The parameters the functions accepts, described as a JSON Schema | |
// * object. See the guide[1] for examples, and the JSON Schema reference[2] | |
// * for documentation about the format. | |
// * [1]: https://platform.openai.com/docs/guides/gpt/function-calling | |
// * [2]: https://json-schema.org/understanding-json-schema/ | |
// * To describe a function that accepts no parameters, provide the value | |
// * {"type": "object", "properties": {}}. | |
// */ | |
// parameters: object; | |
// }[]; | |
// /** | |
// * Controls how the model responds to function calls. "none" means the model | |
// * does not call a function, and responds to the end-user. "auto" means the | |
// * model can pick between an end-user or calling a function. Specifying a | |
// * particular function via {"name":\ "my_function"} forces the model to call | |
// * that function. | |
// * - "none" is the default when no functions are present. | |
// * - "auto" is the default if functions are present. | |
// */ | |
// function_call?: 'none' | 'auto' | { name: string }; | |
/** | |
* An object specifying the format that the model must output. | |
* | |
* Setting to { "type": "json_object" } enables JSON mode, which guarantees | |
* the message the model generates is valid JSON. | |
* *Important*: when using JSON mode, you must also instruct the model to | |
* produce JSON yourself via a system or user message. Without this, the model | |
* may generate an unending stream of whitespace until the generation reaches | |
* the token limit, resulting in a long-running and seemingly "stuck" request. | |
* Also note that the message content may be partially cut off if | |
* finish_reason="length", which indicates the generation exceeded max_tokens | |
* or the conversation exceeded the max context length. | |
*/ | |
response_format?: { type: "text" | "json_object" }; | |
} | |
// Checks whether a suffix of s1 is a prefix of s2. For example, | |
// ('Hello', 'Kira:') -> false | |
// ('Hello Kira', 'Kira:') -> true | |
const suffixOverlapsPrefix = (s1: string, s2: string) => { | |
for (let i = 1; i <= Math.min(s1.length, s2.length); i++) { | |
const suffix = s1.substring(s1.length - i); | |
const prefix = s2.substring(0, i); | |
if (suffix === prefix) { | |
return true; | |
} | |
} | |
return false; | |
}; | |
export class ChatCompletionContent { | |
private readonly completion: AsyncIterable<ChatCompletionChunk>; | |
private readonly stopWords: string[]; | |
constructor( | |
completion: AsyncIterable<ChatCompletionChunk>, | |
stopWords: string[] | |
) { | |
this.completion = completion; | |
this.stopWords = stopWords; | |
} | |
async *readInner() { | |
for await (const chunk of this.completion) { | |
yield chunk.choices[0].delta.content; | |
} | |
} | |
// stop words in OpenAI api don't always work. | |
// So we have to truncate on our side. | |
async *read() { | |
let lastFragment = ""; | |
for await (const data of this.readInner()) { | |
lastFragment += data; | |
let hasOverlap = false; | |
for (const stopWord of this.stopWords) { | |
const idx = lastFragment.indexOf(stopWord); | |
if (idx >= 0) { | |
yield lastFragment.substring(0, idx); | |
return; | |
} | |
if (suffixOverlapsPrefix(lastFragment, stopWord)) { | |
hasOverlap = true; | |
} | |
} | |
if (hasOverlap) continue; | |
yield lastFragment; | |
lastFragment = ""; | |
} | |
yield lastFragment; | |
} | |
async readAll() { | |
let allContent = ""; | |
for await (const chunk of this.read()) { | |
allContent += chunk; | |
} | |
return allContent; | |
} | |
} | |
export async function ollamaFetchEmbedding(text: string) { | |
const { result } = await retryWithBackoff(async () => { | |
const resp = await fetch(apiUrl("/api/embeddings"), { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
}, | |
body: JSON.stringify({ model: LLM_CONFIG.embeddingModel, prompt: text }), | |
}); | |
if (resp.status === 404) { | |
const error = await resp.text(); | |
await tryPullOllama(LLM_CONFIG.embeddingModel, error); | |
throw new Error(`Failed to fetch embeddings: ${resp.status}`); | |
} | |
return (await resp.json()).embedding as number[]; | |
}); | |
return { embedding: result }; | |
} | |