# pylint: disable=line-too-long """Generate an image using the Stability AI API Keyword arguments: prompt -- The prompt to generate the image from Return: An image saved in a .png file """ import os import io import warnings from stability_sdk import client import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation from PIL import Image from config import STABILITY_API_KEY def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, width=1024, height=1024, samples=1, api_key=STABILITY_API_KEY): """ Generates an image based on the given prompt using Stability API. :param prompt: The prompt to generate the image from. :param seed: Seed for deterministic generation. :param steps: Number of inference steps. :param cfg_scale: CFG scale for prompt guidance. :param width: Width of the generated image. :param height: Height of the generated image. :param samples: Number of images to generate. :return: A PIL.Image object of the generated image. """ os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443' os.environ['STABILITY_KEY'] = api_key # Set up our connection to the Stability API. stability_api = client.StabilityInference( key=os.environ['STABILITY_KEY'], verbose=True, engine="stable-diffusion-xl-1024-v1-0", ) print("Creating Stability Image...") answers = stability_api.generate( prompt=prompt, seed=seed, steps=steps, cfg_scale=cfg_scale, width=width, height=height, samples=samples, # sampler=generation.SAMPLER_K_DPMPP_2M # default: auto ) # Retrieve and process the generated image for resp in answers: for artifact in resp.artifacts: if artifact.finish_reason == generation.FILTER: warnings.warn( "Your request activated the API's safety filters and could not be processed." "Please modify the prompt and try again.") if artifact.type == generation.ARTIFACT_IMAGE: # saving img: img = Image.open(io.BytesIO(artifact.binary)) img.save("output_img/sd_generated_img.png") print("Image saved in output_img/sd_generated_img.png") return "output_img/sd_generated_img.png" raise ValueError("No image was generated.")