import io import os import time from pathlib import Path import requests from PIL import Image API_ENDPOINT = "https://api.bfl.ml" class ApiException(Exception): def __init__(self, status_code: int, detail: str = None): super().__init__() self.detail = detail self.status_code = status_code def __str__(self) -> str: return self.__repr__() def __repr__(self) -> str: if self.detail is None: message = None elif isinstance(self.detail, str): message = self.detail else: message = "[" + ",".join(d["msg"] for d in self.detail) + "]" return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" class ImageRequest: def __init__( self, prompt: str, width: int = 1024, height: int = 1024, name: str = "flux.1-pro", num_steps: int = 50, prompt_upsampling: bool = False, seed: int = None, validate: bool = True, launch: bool = True, api_key: str = None, ): """ Manages an image generation request to the API. Args: prompt: Prompt to sample width: Width of the image in pixel height: Height of the image in pixel name: Name of the model num_steps: Number of network evaluations prompt_upsampling: Use prompt upsampling seed: Fix the generation seed validate: Run input validation launch: Directly launches request api_key: Your API key if not provided by the environment Raises: ValueError: For invalid input ApiException: For errors raised from the API """ if validate: if name not in ["flux.1-pro"]: raise ValueError(f"Invalid model {name}") elif width % 32 != 0: raise ValueError(f"width must be divisible by 32, got {width}") elif not (256 <= width <= 1440): raise ValueError(f"width must be between 256 and 1440, got {width}") elif height % 32 != 0: raise ValueError(f"height must be divisible by 32, got {height}") elif not (256 <= height <= 1440): raise ValueError(f"height must be between 256 and 1440, got {height}") elif not (1 <= num_steps <= 50): raise ValueError(f"steps must be between 1 and 50, got {num_steps}") self.request_json = { "prompt": prompt, "width": width, "height": height, "variant": name, "steps": num_steps, "prompt_upsampling": prompt_upsampling, } if seed is not None: self.request_json["seed"] = seed self.request_id: str = None self.result: dict = None self._image_bytes: bytes = None self._url: str = None if api_key is None: self.api_key = os.environ.get("BFL_API_KEY") else: self.api_key = api_key if launch: self.request() def request(self): """ Request to generate the image. """ if self.request_id is not None: return response = requests.post( f"{API_ENDPOINT}/v1/image", headers={ "accept": "application/json", "x-key": self.api_key, "Content-Type": "application/json", }, json=self.request_json, ) result = response.json() if response.status_code != 200: raise ApiException(status_code=response.status_code, detail=result.get("detail")) self.request_id = response.json()["id"] def retrieve(self) -> dict: """ Wait for the generation to finish and retrieve response. """ if self.request_id is None: self.request() while self.result is None: response = requests.get( f"{API_ENDPOINT}/v1/get_result", headers={ "accept": "application/json", "x-key": self.api_key, }, params={ "id": self.request_id, }, ) result = response.json() if "status" not in result: raise ApiException(status_code=response.status_code, detail=result.get("detail")) elif result["status"] == "Ready": self.result = result["result"] elif result["status"] == "Pending": time.sleep(0.5) else: raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") return self.result @property def bytes(self) -> bytes: """ Generated image as bytes. """ if self._image_bytes is None: response = requests.get(self.url) if response.status_code == 200: self._image_bytes = response.content else: raise ApiException(status_code=response.status_code) return self._image_bytes @property def url(self) -> str: """ Public url to retrieve the image from """ if self._url is None: result = self.retrieve() self._url = result["sample"] return self._url @property def image(self) -> Image.Image: """ Load the image as a PIL Image """ return Image.open(io.BytesIO(self.bytes)) def save(self, path: str): """ Save the generated image to a local path """ suffix = Path(self.url).suffix if not path.endswith(suffix): path = path + suffix Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) with open(path, "wb") as file: file.write(self.bytes) if __name__ == "__main__": from fire import Fire Fire(ImageRequest)