demo_generative_img / stability_generate_img.py
joelorellana's picture
first commit for the project
1c681f7
raw
history blame
2.42 kB
# pylint: disable=line-too-long
"""Generate an image using the Stability AI API
Keyword arguments:
prompt -- The prompt to generate the image from
Return: An image saved in a .png file
"""
import os
import io
import warnings
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
from PIL import Image
from config import STABILITY_API_KEY
def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, width=1024, height=1024, samples=1, api_key=STABILITY_API_KEY):
"""
Generates an image based on the given prompt using Stability API.
:param prompt: The prompt to generate the image from.
:param seed: Seed for deterministic generation.
:param steps: Number of inference steps.
:param cfg_scale: CFG scale for prompt guidance.
:param width: Width of the generated image.
:param height: Height of the generated image.
:param samples: Number of images to generate.
:return: A PIL.Image object of the generated image.
"""
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
os.environ['STABILITY_KEY'] = api_key
# Set up our connection to the Stability API.
stability_api = client.StabilityInference(
key=os.environ['STABILITY_KEY'],
verbose=True,
engine="stable-diffusion-xl-1024-v1-0",
)
print("Creating Stability Image...")
answers = stability_api.generate(
prompt=prompt,
seed=seed,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
samples=samples,
# sampler=generation.SAMPLER_K_DPMPP_2M # default: auto
)
# Retrieve and process the generated image
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
warnings.warn(
"Your request activated the API's safety filters and could not be processed."
"Please modify the prompt and try again.")
if artifact.type == generation.ARTIFACT_IMAGE:
# saving img:
img = Image.open(io.BytesIO(artifact.binary))
img.save("output_img/sd_generated_img.png")
print("Image saved in output_img/sd_generated_img.png")
return "output_img/sd_generated_img.png"
raise ValueError("No image was generated.")