Spaces:
Running
Running
File size: 2,062 Bytes
7c1eee1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import json
import base64
import requests
import numpy as np
import matplotlib.pyplot as plt
# MAX_LEN = 40
# STEP = 2
# x = np.arange(0, MAX_LEN, STEP)
# token_counts = [0] * (MAX_LEN//STEP)
# with open("prompts.json", 'r') as f:
# prompts = json.load(f)
# for prompt in prompts:
# tokens = len(prompt.strip().split(' '))
# token_counts[min(tokens//STEP, MAX_LEN//STEP-1)] += 1
# plt.xticks(x, x+1)
# plt.xlabel("token counts")
# plt.bar(x, token_counts, width=1.3)
# # plt.show()
# plt.savefig("token_counts.png")
## Generate image prompts
with open("prompts.json") as f:
text_prompts = json.load(f)
engine_id = "stable-diffusion-v1-6"
api_host = os.getenv('API_HOST', 'https://api.stability.ai')
api_key = os.getenv("STABILITY_API_KEY", "sk-ZvoFiXEbln6yh0hvSlm1K60WYcWFY5rmyW8a9FgoVBrKKP9N")
if api_key is None:
raise Exception("Missing Stability API key.")
for idx, text in enumerate(text_prompts):
if idx<=20: continue
print(f"Start generate prompt[{idx}]: {text}")
response = requests.post(
f"{api_host}/v1/generation/{engine_id}/text-to-image",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key}"
},
json={
"text_prompts": [
{
"text": text.strip()
}
],
"cfg_scale": 7,
"height": 1024,
"width": 1024,
"samples": 3,
"steps": 30,
},
)
if response.status_code != 200:
# raise Exception("Non-200 response: " + str(response.text))
print(f"{idx} Failed!!! {str(response.text)}")
continue
print("Finished!")
data = response.json()
for i, image in enumerate(data["artifacts"]):
img_path = f"./images/{idx}/v1_txt2img_{i}.png"
os.makedirs(os.path.dirname(img_path), exist_ok=True)
with open(img_path, "wb") as f:
f.write(base64.b64decode(image["base64"])) |