Spaces:
Sleeping
Sleeping
import os, json, time | |
import litellm | |
from litellm.utils import ModelResponse | |
import requests, threading | |
from typing import Optional, Union, Literal | |
class BudgetManager: | |
def __init__( | |
self, | |
project_name: str, | |
client_type: str = "local", | |
api_base: Optional[str] = None, | |
): | |
self.client_type = client_type | |
self.project_name = project_name | |
self.api_base = api_base or "https://api.litellm.ai" | |
## load the data or init the initial dictionaries | |
self.load_data() | |
def print_verbose(self, print_statement): | |
try: | |
if litellm.set_verbose: | |
import logging | |
logging.info(print_statement) | |
except: | |
pass | |
def load_data(self): | |
if self.client_type == "local": | |
# Check if user dict file exists | |
if os.path.isfile("user_cost.json"): | |
# Load the user dict | |
with open("user_cost.json", "r") as json_file: | |
self.user_dict = json.load(json_file) | |
else: | |
self.print_verbose("User Dictionary not found!") | |
self.user_dict = {} | |
self.print_verbose(f"user dict from local: {self.user_dict}") | |
elif self.client_type == "hosted": | |
# Load the user_dict from hosted db | |
url = self.api_base + "/get_budget" | |
headers = {"Content-Type": "application/json"} | |
data = {"project_name": self.project_name} | |
response = requests.post(url, headers=headers, json=data) | |
response = response.json() | |
if response["status"] == "error": | |
self.user_dict = ( | |
{} | |
) # assume this means the user dict hasn't been stored yet | |
else: | |
self.user_dict = response["data"] | |
def create_budget( | |
self, | |
total_budget: float, | |
user: str, | |
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, | |
created_at: float = time.time(), | |
): | |
self.user_dict[user] = {"total_budget": total_budget} | |
if duration is None: | |
return self.user_dict[user] | |
if duration == "daily": | |
duration_in_days = 1 | |
elif duration == "weekly": | |
duration_in_days = 7 | |
elif duration == "monthly": | |
duration_in_days = 28 | |
elif duration == "yearly": | |
duration_in_days = 365 | |
else: | |
raise ValueError( | |
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""" | |
) | |
self.user_dict[user] = { | |
"total_budget": total_budget, | |
"duration": duration_in_days, | |
"created_at": created_at, | |
"last_updated_at": created_at, | |
} | |
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution | |
return self.user_dict[user] | |
def projected_cost(self, model: str, messages: list, user: str): | |
text = "".join(message["content"] for message in messages) | |
prompt_tokens = litellm.token_counter(model=model, text=text) | |
prompt_cost, _ = litellm.cost_per_token( | |
model=model, prompt_tokens=prompt_tokens, completion_tokens=0 | |
) | |
current_cost = self.user_dict[user].get("current_cost", 0) | |
projected_cost = prompt_cost + current_cost | |
return projected_cost | |
def get_total_budget(self, user: str): | |
return self.user_dict[user]["total_budget"] | |
def update_cost( | |
self, | |
user: str, | |
completion_obj: Optional[ModelResponse] = None, | |
model: Optional[str] = None, | |
input_text: Optional[str] = None, | |
output_text: Optional[str] = None, | |
): | |
if model and input_text and output_text: | |
prompt_tokens = litellm.token_counter( | |
model=model, messages=[{"role": "user", "content": input_text}] | |
) | |
completion_tokens = litellm.token_counter( | |
model=model, messages=[{"role": "user", "content": output_text}] | |
) | |
( | |
prompt_tokens_cost_usd_dollar, | |
completion_tokens_cost_usd_dollar, | |
) = litellm.cost_per_token( | |
model=model, | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
) | |
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar | |
elif completion_obj: | |
cost = litellm.completion_cost(completion_response=completion_obj) | |
model = completion_obj[ | |
"model" | |
] # if this throws an error try, model = completion_obj['model'] | |
else: | |
raise ValueError( | |
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager" | |
) | |
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get( | |
"current_cost", 0 | |
) | |
if "model_cost" in self.user_dict[user]: | |
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][ | |
"model_cost" | |
].get(model, 0) | |
else: | |
self.user_dict[user]["model_cost"] = {model: cost} | |
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution | |
return {"user": self.user_dict[user]} | |
def get_current_cost(self, user): | |
return self.user_dict[user].get("current_cost", 0) | |
def get_model_cost(self, user): | |
return self.user_dict[user].get("model_cost", 0) | |
def is_valid_user(self, user: str) -> bool: | |
return user in self.user_dict | |
def get_users(self): | |
return list(self.user_dict.keys()) | |
def reset_cost(self, user): | |
self.user_dict[user]["current_cost"] = 0 | |
self.user_dict[user]["model_cost"] = {} | |
return {"user": self.user_dict[user]} | |
def reset_on_duration(self, user: str): | |
# Get current and creation time | |
last_updated_at = self.user_dict[user]["last_updated_at"] | |
current_time = time.time() | |
# Convert duration from days to seconds | |
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60 | |
# Check if duration has elapsed | |
if current_time - last_updated_at >= duration_in_seconds: | |
# Reset cost if duration has elapsed and update the creation time | |
self.reset_cost(user) | |
self.user_dict[user]["last_updated_at"] = current_time | |
self._save_data_thread() # Save the data | |
def update_budget_all_users(self): | |
for user in self.get_users(): | |
if "duration" in self.user_dict[user]: | |
self.reset_on_duration(user) | |
def _save_data_thread(self): | |
thread = threading.Thread( | |
target=self.save_data | |
) # [Non-Blocking]: saves data without blocking execution | |
thread.start() | |
def save_data(self): | |
if self.client_type == "local": | |
import json | |
# save the user dict | |
with open("user_cost.json", "w") as json_file: | |
json.dump( | |
self.user_dict, json_file, indent=4 | |
) # Indent for pretty formatting | |
return {"status": "success"} | |
elif self.client_type == "hosted": | |
url = self.api_base + "/set_budget" | |
headers = {"Content-Type": "application/json"} | |
data = {"project_name": self.project_name, "user_dict": self.user_dict} | |
response = requests.post(url, headers=headers, json=data) | |
response = response.json() | |
return response | |