Spaces:
Running
Running
import os | |
import shutil | |
import asyncio | |
from urllib.parse import quote | |
from dotenv import load_dotenv | |
from io import BufferedIOBase | |
from typing import List, Optional, Union | |
from pathlib import Path | |
from botocore.exceptions import ClientError | |
from botocore.config import Config | |
from boto3.session import Session | |
from pydantic import PrivateAttr | |
from llama_index.core.async_utils import run_jobs | |
from llama_parse import LlamaParse | |
from llama_parse.utils import ( | |
nest_asyncio_err, | |
nest_asyncio_msg, | |
) | |
load_dotenv() | |
FileInput = Union[str, bytes, BufferedIOBase] | |
class S3ImageSaver: | |
def __init__(self, bucket_name, access_key=None, secret_key=None, region_name=None): | |
self.bucket_name = bucket_name | |
self.region_name = region_name | |
self.session = Session( | |
aws_access_key_id=access_key, | |
aws_secret_access_key=secret_key, | |
region_name=self.region_name, | |
) | |
self.s3_client = self.session.client( | |
"s3", config=Config(signature_version="s3v4", region_name=self.region_name) | |
) | |
def save_image(self, image_path, title): | |
"""Saves an image to the S3 bucket.""" | |
try: | |
print("---Saving Images---") | |
title_encoded = quote(title) | |
s3_key = f"images/{title}/{os.path.basename(image_path)}" | |
with open(image_path, "rb") as file: | |
self.s3_client.upload_fileobj(file, self.bucket_name, s3_key) | |
s3_url = f"https://{self.bucket_name}.s3.{self.region_name}.amazonaws.com/images/{title_encoded}/{os.path.basename(image_path)}" | |
print(f"Image saved to S3 bucket: {s3_url}") | |
return s3_url | |
except ClientError as e: | |
print(f"Error saving image to S3: {e}") | |
return None | |
class LlamaParseWithS3(LlamaParse): | |
_s3_image_saver: S3ImageSaver = PrivateAttr() | |
def __init__(self, *args, s3_image_saver=None, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._s3_image_saver = s3_image_saver or S3ImageSaver( | |
bucket_name=os.getenv("S3_BUCKET_NAME"), | |
access_key=os.getenv("AWS_ACCESS_KEY_ID"), | |
secret_key=os.getenv("AWS_SECRET_ACCESS_KEY"), | |
region_name="us-west-2", | |
) | |
async def aget_images( | |
self, json_result: List[dict], download_path: str | |
) -> List[dict]: | |
"""Download images from the parsed result.""" | |
headers = {"Authorization": f"Bearer {self.api_key}"} | |
# make the download path | |
if not os.path.exists(download_path): | |
os.makedirs(download_path) | |
try: | |
images = [] | |
for result in json_result: | |
job_id = result["job_id"] | |
for page in result["pages"]: | |
if self.verbose: | |
print(f"> Image for page {page['page']}: {page['images']}") | |
for image in page["images"]: | |
image_name = image["name"] | |
# get the full path | |
image_path = os.path.join(download_path, f"{image_name}") | |
# get a valid image path | |
if not image_path.endswith(".png"): | |
if not image_path.endswith(".jpg"): | |
image_path += ".png" | |
image["path"] = image_path | |
image["job_id"] = job_id | |
image["original_file_path"] = result.get("file_path", None) | |
image["page_number"] = page["page"] | |
with open(image_path, "wb") as f: | |
image_url = f"{self.base_url}/api/parsing/job/{job_id}/result/image/{image_name}" | |
async with self.client_context() as client: | |
res = await client.get( | |
image_url, headers=headers, timeout=self.max_timeout | |
) | |
res.raise_for_status() | |
f.write(res.content) | |
images.append(image) | |
return images | |
except Exception as e: | |
print("Error while downloading images from the parsed result:", e) | |
if self.ignore_errors: | |
return [] | |
else: | |
raise e | |
async def aget_images_s3(self, json_result: List[dict], title) -> List[dict]: | |
images = await self.aget_images( | |
json_result, download_path="tmp/" | |
) # Download to temporary location | |
# Process each image and upload to S3 | |
for image in images: | |
image_path = image["path"] | |
try: | |
s3_url = self._s3_image_saver.save_image(image_path, title) | |
if s3_url: | |
image["image_link"] = s3_url | |
except Exception as e: | |
print(f"Error saving image to S3: {image_path} - {e}") | |
# After processing all images, delete the tmp folder | |
folder_path = "tmp/" | |
try: | |
shutil.rmtree(folder_path) # Deletes the folder and all its contents | |
print(f"Folder {folder_path} and all its contents were deleted successfully.") | |
except Exception as e: | |
print(f"Error deleting folder {folder_path}: {e}") | |
return images | |
def get_images(self, json_result: List[dict], title) -> List[dict]: | |
"""Download images from the parsed result and save them to S3.""" | |
try: | |
return asyncio.run(self.aget_images_s3(json_result, title)) | |
except RuntimeError as e: | |
if nest_asyncio_err in str(e): | |
raise RuntimeError(nest_asyncio_msg) | |
else: | |
raise e | |