|
|
|
"""Generate an image using the Stability AI API |
|
|
|
Keyword arguments: |
|
prompt -- The prompt to generate the image from |
|
Return: An image saved in a .png file |
|
""" |
|
|
|
import os |
|
import io |
|
import warnings |
|
from stability_sdk import client |
|
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation |
|
from PIL import Image |
|
from config import STABILITY_API_KEY |
|
|
|
|
|
def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, width=1024, height=1024, samples=1, api_key=STABILITY_API_KEY): |
|
""" |
|
Generates an image based on the given prompt using Stability API. |
|
|
|
:param prompt: The prompt to generate the image from. |
|
:param seed: Seed for deterministic generation. |
|
:param steps: Number of inference steps. |
|
:param cfg_scale: CFG scale for prompt guidance. |
|
:param width: Width of the generated image. |
|
:param height: Height of the generated image. |
|
:param samples: Number of images to generate. |
|
:return: A PIL.Image object of the generated image. |
|
""" |
|
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443' |
|
os.environ['STABILITY_KEY'] = api_key |
|
|
|
stability_api = client.StabilityInference( |
|
key=os.environ['STABILITY_KEY'], |
|
verbose=True, |
|
engine="stable-diffusion-xl-1024-v1-0", |
|
) |
|
|
|
print("Creating Stability Image...") |
|
answers = stability_api.generate( |
|
prompt=prompt, |
|
seed=seed, |
|
steps=steps, |
|
cfg_scale=cfg_scale, |
|
width=width, |
|
height=height, |
|
samples=samples, |
|
|
|
) |
|
|
|
|
|
for resp in answers: |
|
for artifact in resp.artifacts: |
|
if artifact.finish_reason == generation.FILTER: |
|
warnings.warn( |
|
"Your request activated the API's safety filters and could not be processed." |
|
"Please modify the prompt and try again.") |
|
if artifact.type == generation.ARTIFACT_IMAGE: |
|
|
|
img = Image.open(io.BytesIO(artifact.binary)) |
|
img.save("output_img/sd_generated_img.png") |
|
print("Image saved in output_img/sd_generated_img.png") |
|
return "output_img/sd_generated_img.png" |
|
|
|
raise ValueError("No image was generated.") |
|
|