Spaces:
Sleeping
Sleeping
import os, types | |
import json | |
from enum import Enum | |
import requests | |
import time | |
from typing import Callable, Optional | |
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper | |
import litellm, uuid | |
import httpx | |
class VertexAIError(Exception): | |
def __init__(self, status_code, message): | |
self.status_code = status_code | |
self.message = message | |
self.request = httpx.Request( | |
method="POST", url=" https://cloud.google.com/vertex-ai/" | |
) | |
self.response = httpx.Response(status_code=status_code, request=self.request) | |
super().__init__( | |
self.message | |
) # Call the base class constructor with the parameters it needs | |
class VertexAIConfig: | |
""" | |
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts | |
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters: | |
- `temperature` (float): This controls the degree of randomness in token selection. | |
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256. | |
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95. | |
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40. | |
Note: Please make sure to modify the default parameters as required for your use case. | |
""" | |
temperature: Optional[float] = None | |
max_output_tokens: Optional[int] = None | |
top_p: Optional[float] = None | |
top_k: Optional[int] = None | |
def __init__( | |
self, | |
temperature: Optional[float] = None, | |
max_output_tokens: Optional[int] = None, | |
top_p: Optional[float] = None, | |
top_k: Optional[int] = None, | |
) -> None: | |
locals_ = locals() | |
for key, value in locals_.items(): | |
if key != "self" and value is not None: | |
setattr(self.__class__, key, value) | |
def get_config(cls): | |
return { | |
k: v | |
for k, v in cls.__dict__.items() | |
if not k.startswith("__") | |
and not isinstance( | |
v, | |
( | |
types.FunctionType, | |
types.BuiltinFunctionType, | |
classmethod, | |
staticmethod, | |
), | |
) | |
and v is not None | |
} | |
def _get_image_bytes_from_url(image_url: str) -> bytes: | |
try: | |
response = requests.get(image_url) | |
response.raise_for_status() # Raise an error for bad responses (4xx and 5xx) | |
image_bytes = response.content | |
return image_bytes | |
except requests.exceptions.RequestException as e: | |
# Handle any request exceptions (e.g., connection error, timeout) | |
return b"" # Return an empty bytes object or handle the error as needed | |
def _load_image_from_url(image_url: str): | |
""" | |
Loads an image from a URL. | |
Args: | |
image_url (str): The URL of the image. | |
Returns: | |
Image: The loaded image. | |
""" | |
from vertexai.preview.generative_models import ( | |
GenerativeModel, | |
Part, | |
GenerationConfig, | |
Image, | |
) | |
image_bytes = _get_image_bytes_from_url(image_url) | |
return Image.from_bytes(image_bytes) | |
def _gemini_vision_convert_messages(messages: list): | |
""" | |
Converts given messages for GPT-4 Vision to Gemini format. | |
Args: | |
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type: | |
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt. | |
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images. | |
Returns: | |
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images). | |
Raises: | |
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed. | |
Exception: If any other exception occurs during the execution of the function. | |
Note: | |
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub. | |
The supported MIME types for images include 'image/png' and 'image/jpeg'. | |
Examples: | |
>>> messages = [ | |
... {"content": "Hello, world!"}, | |
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]}, | |
... ] | |
>>> _gemini_vision_convert_messages(messages) | |
('Hello, world!This is a text message.', [<Part object>, <Part object>]) | |
""" | |
try: | |
import vertexai | |
except: | |
raise VertexAIError( | |
status_code=400, | |
message="vertexai import failed please run `pip install google-cloud-aiplatform`", | |
) | |
try: | |
from vertexai.preview.language_models import ( | |
ChatModel, | |
CodeChatModel, | |
InputOutputTextPair, | |
) | |
from vertexai.language_models import TextGenerationModel, CodeGenerationModel | |
from vertexai.preview.generative_models import ( | |
GenerativeModel, | |
Part, | |
GenerationConfig, | |
Image, | |
) | |
# given messages for gpt-4 vision, convert them for gemini | |
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb | |
prompt = "" | |
images = [] | |
for message in messages: | |
if isinstance(message["content"], str): | |
prompt += message["content"] | |
elif isinstance(message["content"], list): | |
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models | |
for element in message["content"]: | |
if isinstance(element, dict): | |
if element["type"] == "text": | |
prompt += element["text"] | |
elif element["type"] == "image_url": | |
image_url = element["image_url"]["url"] | |
images.append(image_url) | |
# processing images passed to gemini | |
processed_images = [] | |
for img in images: | |
if "gs://" in img: | |
# Case 1: Images with Cloud Storage URIs | |
# The supported MIME types for images include image/png and image/jpeg. | |
part_mime = "image/png" if "png" in img else "image/jpeg" | |
google_clooud_part = Part.from_uri(img, mime_type=part_mime) | |
processed_images.append(google_clooud_part) | |
elif "https:/" in img: | |
# Case 2: Images with direct links | |
image = _load_image_from_url(img) | |
processed_images.append(image) | |
elif ".mp4" in img and "gs://" in img: | |
# Case 3: Videos with Cloud Storage URIs | |
part_mime = "video/mp4" | |
google_clooud_part = Part.from_uri(img, mime_type=part_mime) | |
processed_images.append(google_clooud_part) | |
return prompt, processed_images | |
except Exception as e: | |
raise e | |
def completion( | |
model: str, | |
messages: list, | |
model_response: ModelResponse, | |
print_verbose: Callable, | |
encoding, | |
logging_obj, | |
vertex_project=None, | |
vertex_location=None, | |
optional_params=None, | |
litellm_params=None, | |
logger_fn=None, | |
acompletion: bool = False, | |
): | |
try: | |
import vertexai | |
except: | |
raise VertexAIError( | |
status_code=400, | |
message="vertexai import failed please run `pip install google-cloud-aiplatform`", | |
) | |
try: | |
from vertexai.preview.language_models import ( | |
ChatModel, | |
CodeChatModel, | |
InputOutputTextPair, | |
) | |
from vertexai.language_models import TextGenerationModel, CodeGenerationModel | |
from vertexai.preview.generative_models import ( | |
GenerativeModel, | |
Part, | |
GenerationConfig, | |
) | |
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types | |
vertexai.init(project=vertex_project, location=vertex_location) | |
## Load Config | |
config = litellm.VertexAIConfig.get_config() | |
for k, v in config.items(): | |
if k not in optional_params: | |
optional_params[k] = v | |
## Process safety settings into format expected by vertex AI | |
safety_settings = None | |
if "safety_settings" in optional_params: | |
safety_settings = optional_params.pop("safety_settings") | |
if not isinstance(safety_settings, list): | |
raise ValueError("safety_settings must be a list") | |
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): | |
raise ValueError("safety_settings must be a list of dicts") | |
safety_settings = [ | |
gapic_content_types.SafetySetting(x) for x in safety_settings | |
] | |
# vertexai does not use an API key, it looks for credentials.json in the environment | |
prompt = " ".join( | |
[ | |
message["content"] | |
for message in messages | |
if isinstance(message["content"], str) | |
] | |
) | |
mode = "" | |
request_str = "" | |
response_obj = None | |
if ( | |
model in litellm.vertex_language_models | |
or model in litellm.vertex_vision_models | |
): | |
llm_model = GenerativeModel(model) | |
mode = "vision" | |
request_str += f"llm_model = GenerativeModel({model})\n" | |
elif model in litellm.vertex_chat_models: | |
llm_model = ChatModel.from_pretrained(model) | |
mode = "chat" | |
request_str += f"llm_model = ChatModel.from_pretrained({model})\n" | |
elif model in litellm.vertex_text_models: | |
llm_model = TextGenerationModel.from_pretrained(model) | |
mode = "text" | |
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" | |
elif model in litellm.vertex_code_text_models: | |
llm_model = CodeGenerationModel.from_pretrained(model) | |
mode = "text" | |
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" | |
else: # vertex_code_llm_models | |
llm_model = CodeChatModel.from_pretrained(model) | |
mode = "chat" | |
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" | |
if acompletion == True: # [TODO] expand support to vertex ai chat + text models | |
if optional_params.get("stream", False) is True: | |
# async streaming | |
return async_streaming( | |
llm_model=llm_model, | |
mode=mode, | |
prompt=prompt, | |
logging_obj=logging_obj, | |
request_str=request_str, | |
model=model, | |
model_response=model_response, | |
messages=messages, | |
print_verbose=print_verbose, | |
**optional_params, | |
) | |
return async_completion( | |
llm_model=llm_model, | |
mode=mode, | |
prompt=prompt, | |
logging_obj=logging_obj, | |
request_str=request_str, | |
model=model, | |
model_response=model_response, | |
encoding=encoding, | |
messages=messages, | |
print_verbose=print_verbose, | |
**optional_params, | |
) | |
if mode == "vision": | |
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") | |
print_verbose(f"\nProcessing input messages = {messages}") | |
tools = optional_params.pop("tools", None) | |
prompt, images = _gemini_vision_convert_messages(messages=messages) | |
content = [prompt] + images | |
if "stream" in optional_params and optional_params["stream"] == True: | |
stream = optional_params.pop("stream") | |
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
model_response = llm_model.generate_content( | |
contents=content, | |
generation_config=GenerationConfig(**optional_params), | |
safety_settings=safety_settings, | |
stream=True, | |
tools=tools, | |
) | |
optional_params["stream"] = True | |
return model_response | |
request_str += f"response = llm_model.generate_content({content})\n" | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
## LLM Call | |
response = llm_model.generate_content( | |
contents=content, | |
generation_config=GenerationConfig(**optional_params), | |
safety_settings=safety_settings, | |
tools=tools, | |
) | |
if tools is not None and hasattr( | |
response.candidates[0].content.parts[0], "function_call" | |
): | |
function_call = response.candidates[0].content.parts[0].function_call | |
args_dict = {} | |
for k, v in function_call.args.items(): | |
args_dict[k] = v | |
args_str = json.dumps(args_dict) | |
message = litellm.Message( | |
content=None, | |
tool_calls=[ | |
{ | |
"id": f"call_{str(uuid.uuid4())}", | |
"function": { | |
"arguments": args_str, | |
"name": function_call.name, | |
}, | |
"type": "function", | |
} | |
], | |
) | |
completion_response = message | |
else: | |
completion_response = response.text | |
response_obj = response._raw_response | |
optional_params["tools"] = tools | |
elif mode == "chat": | |
chat = llm_model.start_chat() | |
request_str += f"chat = llm_model.start_chat()\n" | |
if "stream" in optional_params and optional_params["stream"] == True: | |
# NOTE: VertexAI does not accept stream=True as a param and raises an error, | |
# we handle this by removing 'stream' from optional params and sending the request | |
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format | |
optional_params.pop( | |
"stream", None | |
) # vertex ai raises an error when passing stream in optional params | |
request_str += ( | |
f"chat.send_message_streaming({prompt}, **{optional_params})\n" | |
) | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
model_response = chat.send_message_streaming(prompt, **optional_params) | |
optional_params["stream"] = True | |
return model_response | |
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
completion_response = chat.send_message(prompt, **optional_params).text | |
elif mode == "text": | |
if "stream" in optional_params and optional_params["stream"] == True: | |
optional_params.pop( | |
"stream", None | |
) # See note above on handling streaming for vertex ai | |
request_str += ( | |
f"llm_model.predict_streaming({prompt}, **{optional_params})\n" | |
) | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
model_response = llm_model.predict_streaming(prompt, **optional_params) | |
optional_params["stream"] = True | |
return model_response | |
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
completion_response = llm_model.predict(prompt, **optional_params).text | |
## LOGGING | |
logging_obj.post_call( | |
input=prompt, api_key=None, original_response=completion_response | |
) | |
## RESPONSE OBJECT | |
if isinstance(completion_response, litellm.Message): | |
model_response["choices"][0]["message"] = completion_response | |
elif len(str(completion_response)) > 0: | |
model_response["choices"][0]["message"]["content"] = str( | |
completion_response | |
) | |
model_response["choices"][0]["message"]["content"] = str(completion_response) | |
model_response["created"] = int(time.time()) | |
model_response["model"] = model | |
## CALCULATING USAGE | |
if model in litellm.vertex_language_models and response_obj is not None: | |
model_response["choices"][0].finish_reason = response_obj.candidates[ | |
0 | |
].finish_reason.name | |
usage = Usage( | |
prompt_tokens=response_obj.usage_metadata.prompt_token_count, | |
completion_tokens=response_obj.usage_metadata.candidates_token_count, | |
total_tokens=response_obj.usage_metadata.total_token_count, | |
) | |
else: | |
# init prompt tokens | |
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter | |
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 | |
if response_obj is not None: | |
if hasattr(response_obj, "usage_metadata") and hasattr( | |
response_obj.usage_metadata, "prompt_token_count" | |
): | |
prompt_tokens = response_obj.usage_metadata.prompt_token_count | |
completion_tokens = ( | |
response_obj.usage_metadata.candidates_token_count | |
) | |
else: | |
prompt_tokens = len(encoding.encode(prompt)) | |
completion_tokens = len( | |
encoding.encode( | |
model_response["choices"][0]["message"].get("content", "") | |
) | |
) | |
usage = Usage( | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
total_tokens=prompt_tokens + completion_tokens, | |
) | |
model_response.usage = usage | |
return model_response | |
except Exception as e: | |
raise VertexAIError(status_code=500, message=str(e)) | |
async def async_completion( | |
llm_model, | |
mode: str, | |
prompt: str, | |
model: str, | |
model_response: ModelResponse, | |
logging_obj=None, | |
request_str=None, | |
encoding=None, | |
messages=None, | |
print_verbose=None, | |
**optional_params, | |
): | |
""" | |
Add support for acompletion calls for gemini-pro | |
""" | |
try: | |
from vertexai.preview.generative_models import GenerationConfig | |
if mode == "vision": | |
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") | |
print_verbose(f"\nProcessing input messages = {messages}") | |
tools = optional_params.pop("tools", None) | |
prompt, images = _gemini_vision_convert_messages(messages=messages) | |
content = [prompt] + images | |
request_str += f"response = llm_model.generate_content({content})\n" | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
## LLM Call | |
response = await llm_model._generate_content_async( | |
contents=content, | |
generation_config=GenerationConfig(**optional_params), | |
tools=tools, | |
) | |
if tools is not None and hasattr( | |
response.candidates[0].content.parts[0], "function_call" | |
): | |
function_call = response.candidates[0].content.parts[0].function_call | |
args_dict = {} | |
for k, v in function_call.args.items(): | |
args_dict[k] = v | |
args_str = json.dumps(args_dict) | |
message = litellm.Message( | |
content=None, | |
tool_calls=[ | |
{ | |
"id": f"call_{str(uuid.uuid4())}", | |
"function": { | |
"arguments": args_str, | |
"name": function_call.name, | |
}, | |
"type": "function", | |
} | |
], | |
) | |
completion_response = message | |
else: | |
completion_response = response.text | |
response_obj = response._raw_response | |
optional_params["tools"] = tools | |
elif mode == "chat": | |
# chat-bison etc. | |
chat = llm_model.start_chat() | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
response_obj = await chat.send_message_async(prompt, **optional_params) | |
completion_response = response_obj.text | |
elif mode == "text": | |
# gecko etc. | |
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
response_obj = await llm_model.predict_async(prompt, **optional_params) | |
completion_response = response_obj.text | |
## LOGGING | |
logging_obj.post_call( | |
input=prompt, api_key=None, original_response=completion_response | |
) | |
## RESPONSE OBJECT | |
if isinstance(completion_response, litellm.Message): | |
model_response["choices"][0]["message"] = completion_response | |
elif len(str(completion_response)) > 0: | |
model_response["choices"][0]["message"]["content"] = str( | |
completion_response | |
) | |
model_response["choices"][0]["message"]["content"] = str(completion_response) | |
model_response["created"] = int(time.time()) | |
model_response["model"] = model | |
## CALCULATING USAGE | |
if model in litellm.vertex_language_models and response_obj is not None: | |
model_response["choices"][0].finish_reason = response_obj.candidates[ | |
0 | |
].finish_reason.name | |
usage = Usage( | |
prompt_tokens=response_obj.usage_metadata.prompt_token_count, | |
completion_tokens=response_obj.usage_metadata.candidates_token_count, | |
total_tokens=response_obj.usage_metadata.total_token_count, | |
) | |
else: | |
# init prompt tokens | |
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter | |
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 | |
if response_obj is not None: | |
if hasattr(response_obj, "usage_metadata") and hasattr( | |
response_obj.usage_metadata, "prompt_token_count" | |
): | |
prompt_tokens = response_obj.usage_metadata.prompt_token_count | |
completion_tokens = ( | |
response_obj.usage_metadata.candidates_token_count | |
) | |
else: | |
prompt_tokens = len(encoding.encode(prompt)) | |
completion_tokens = len( | |
encoding.encode( | |
model_response["choices"][0]["message"].get("content", "") | |
) | |
) | |
# set usage | |
usage = Usage( | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
total_tokens=prompt_tokens + completion_tokens, | |
) | |
model_response.usage = usage | |
return model_response | |
except Exception as e: | |
raise VertexAIError(status_code=500, message=str(e)) | |
async def async_streaming( | |
llm_model, | |
mode: str, | |
prompt: str, | |
model: str, | |
model_response: ModelResponse, | |
logging_obj=None, | |
request_str=None, | |
messages=None, | |
print_verbose=None, | |
**optional_params, | |
): | |
""" | |
Add support for async streaming calls for gemini-pro | |
""" | |
from vertexai.preview.generative_models import GenerationConfig | |
if mode == "vision": | |
stream = optional_params.pop("stream") | |
tools = optional_params.pop("tools", None) | |
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") | |
print_verbose(f"\nProcessing input messages = {messages}") | |
prompt, images = _gemini_vision_convert_messages(messages=messages) | |
content = [prompt] + images | |
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
response = await llm_model._generate_content_streaming_async( | |
contents=content, | |
generation_config=GenerationConfig(**optional_params), | |
tools=tools, | |
) | |
optional_params["stream"] = True | |
optional_params["tools"] = tools | |
elif mode == "chat": | |
chat = llm_model.start_chat() | |
optional_params.pop( | |
"stream", None | |
) # vertex ai raises an error when passing stream in optional params | |
request_str += ( | |
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n" | |
) | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
response = chat.send_message_streaming_async(prompt, **optional_params) | |
optional_params["stream"] = True | |
elif mode == "text": | |
optional_params.pop( | |
"stream", None | |
) # See note above on handling streaming for vertex ai | |
request_str += ( | |
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n" | |
) | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=None, | |
additional_args={ | |
"complete_input_dict": optional_params, | |
"request_str": request_str, | |
}, | |
) | |
response = llm_model.predict_streaming_async(prompt, **optional_params) | |
streamwrapper = CustomStreamWrapper( | |
completion_stream=response, | |
model=model, | |
custom_llm_provider="vertex_ai", | |
logging_obj=logging_obj, | |
) | |
async for transformed_chunk in streamwrapper: | |
yield transformed_chunk | |
def embedding(): | |
# logic for parsing in - calling - parsing out model embedding calls | |
pass | |