test / cli /generate.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
#!/usr/bin/env python
# pylint: disable=no-member
"""generate batches of images from prompts and upscale them
params: run with `--help`
default workflow runs infinite loop and prints stats when interrupted:
1. choose random scheduler lookup all available and pick one
2. generate dynamic prompt based on styles, embeddings, places, artists, suffixes
3. beautify prompt
4. generate 3x3 images
5. create image grid
6. upscale images with face restoration
"""
import argparse
import asyncio
import base64
import io
import json
import logging
import math
import os
import pathlib
import secrets
import time
import sys
import importlib
from random import randrange
from PIL import Image
from PIL.ExifTags import TAGS
from PIL.TiffImagePlugin import ImageFileDirectory_v2
from sdapi import close, get, interrupt, post, session
from util import Map, log, safestring
sd = {}
random = {}
stats = Map({ 'images': 0, 'wall': 0, 'generate': 0, 'upscale': 0 })
avg = {}
def grid(data):
if len(data.image) > 1:
w, h = data.image[0].size
rows = round(math.sqrt(len(data.image)))
cols = math.ceil(len(data.image) / rows)
image = Image.new('RGB', size = (cols * w, rows * h), color = 'black')
for i, img in enumerate(data.image):
image.paste(img, box=(i % cols * w, i // cols * h))
short = data.info.prompt[:min(len(data.info.prompt), 96)] # limit prompt part of filename to 96 chars
name = '{seed:0>9} {short}'.format(short = short, seed = data.info.all_seeds[0]) # pylint: disable=consider-using-f-string
name = safestring(name) + '.jpg'
f = os.path.join(sd.paths.root, sd.paths.grid, name)
log.info({ 'grid': { 'name': f, 'size': image.size, 'images': len(data.image) } })
image.save(f, 'JPEG', exif = exif(data.info, None, 'grid'), optimize = True, quality = 70)
return image
return data.image
def exif(info, i = None, op = 'generate'):
seed = [info.all_seeds[i]] if len(info.all_seeds) > 0 and i is not None else info.all_seeds # always returns list
seed = ', '.join([str(x) for x in seed]) # int list to str list to single str
template = '{prompt} | negative {negative_prompt} | seed {s} | steps {steps} | cfgscale {cfg_scale} | sampler {sampler_name} | batch {batch_size} | timestamp {job_timestamp} | model {model} | vae {vae}'.format(s = seed, model = sd.options['sd_model_checkpoint'], vae = sd.options['sd_vae'], **info) # pylint: disable=consider-using-f-string
if op == 'upscale':
template += ' | faces gfpgan' if sd.upscale.gfpgan_visibility > 0 else ''
template += ' | faces codeformer' if sd.upscale.codeformer_visibility > 0 else ''
template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_1) if sd.upscale.upscaler_1 != 'None' else '' # pylint: disable=consider-using-f-string
template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_2) if sd.upscale.upscaler_2 != 'None' else '' # pylint: disable=consider-using-f-string
if op == 'grid':
template += ' | grid {num}'.format(num = sd.generate.batch_size * sd.generate.n_iter) # pylint: disable=consider-using-f-string
ifd = ImageFileDirectory_v2()
exif_stream = io.BytesIO()
_TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags
ifd[_TAGS['ImageDescription']] = template
ifd.save(exif_stream)
val = b'Exif\x00\x00' + exif_stream.getvalue()
return val
def randomize(lst):
if len(lst) > 0:
return secrets.choice(lst)
else:
return ''
def prompt(params): # generate dynamic prompt or use one if provided
sd.generate.prompt = params.prompt if params.prompt != 'dynamic' else randomize(random.prompts)
sd.generate.negative_prompt = params.negative if params.negative != 'dynamic' else randomize(random.negative)
embedding = params.embedding if params.embedding != 'random' else randomize(random.embeddings)
sd.generate.prompt = sd.generate.prompt.replace('<embedding>', embedding)
artist = params.artist if params.artist != 'random' else randomize(random.artists)
sd.generate.prompt = sd.generate.prompt.replace('<artist>', artist)
style = params.style if params.style != 'random' else randomize(random.styles)
sd.generate.prompt = sd.generate.prompt.replace('<style>', style)
suffix = params.suffix if params.suffix != 'random' else randomize(random.suffixes)
sd.generate.prompt = sd.generate.prompt.replace('<suffix>', suffix)
place = params.suffix if params.suffix != 'random' else randomize(random.places)
sd.generate.prompt = sd.generate.prompt.replace('<place>', place)
if params.prompts or params.debug:
log.info({ 'random initializers': random })
if params.prompt == 'dynamic':
log.info({ 'dynamic prompt': sd.generate.prompt })
return sd.generate.prompt
def sampler(params, options): # find sampler
if params.sampler == 'random':
sd.generate.sampler_name = randomize(options.samplers)
log.info({ 'random sampler': sd.generate.sampler_name })
else:
found = [i for i in options.samplers if i.startswith(params.sampler)]
if len(found) == 0:
log.error({ 'sampler error': sd.generate.sampler_name, 'available': options.samplers})
exit()
sd.generate.sampler_name = found[0]
return sd.generate.sampler_name
async def generate(prompt = None, options = None, quiet = False): # pylint: disable=redefined-outer-name
global sd # pylint: disable=global-statement
if options:
sd = Map(options)
if prompt is not None:
sd.generate.prompt = prompt
if not quiet:
log.info({ 'generate': sd.generate })
if sd.get('options', None) is None:
sd['options'] = await get('/sdapi/v1/options')
names = []
b64s = []
images = []
info = Map({})
data = await post('/sdapi/v1/txt2img', sd.generate)
if 'error' in data:
log.error({ 'generate': data['error'], 'reason': data['reason'] })
return Map({})
info = Map(json.loads(data['info']))
log.debug({ 'info': info })
images = data['images']
short = info.prompt[:min(len(info.prompt), 96)] # limit prompt part of filename to 64 chars
for i in range(len(images)):
b64s.append(images[i])
images[i] = Image.open(io.BytesIO(base64.b64decode(images[i].split(',',1)[0])))
name = '{seed:0>9} {short}'.format(short = short, seed = info.all_seeds[i]) # pylint: disable=consider-using-f-string
name = safestring(name) + '.jpg'
f = os.path.join(sd.paths.root, sd.paths.generate, name)
names.append(f)
if not quiet:
log.info({ 'image': { 'name': f, 'size': images[i].size } })
images[i].save(f, 'JPEG', exif = exif(info, i), optimize = True, quality = 70)
return Map({ 'name': names, 'image': images, 'b64': b64s, 'info': info })
async def upscale(data):
data.upscaled = []
if sd.upscale.upscaling_resize <=1:
return data
sd.upscale.image = ''
log.info({ 'upscale': sd.upscale })
for i in range(len(data.image)):
f = data.name[i].replace(sd.paths.generate, sd.paths.upscale)
sd.upscale.image = data.b64[i]
res = await post('/sdapi/v1/extra-single-image', sd.upscale)
image = Image.open(io.BytesIO(base64.b64decode(res['image'].split(',',1)[0])))
data.upscaled.append(image)
log.info({ 'image': { 'name': f, 'size': image.size } })
image.save(f, 'JPEG', exif = exif(data.info, i, 'upscale'), optimize = True, quality = 70)
return data
async def init():
'''
import torch
log.info({ 'torch': torch.__version__, 'available': torch.cuda.is_available() })
current_device = torch.cuda.current_device()
mem_free, mem_total = torch.cuda.mem_get_info()
log.info({ 'cuda': torch.version.cuda, 'available': torch.cuda.is_available(), 'arch': torch.cuda.get_arch_list(), 'device': torch.cuda.get_device_name(current_device), 'memory': { 'free': round(mem_free / 1024 / 1024), 'total': (mem_total / 1024 / 1024) } })
'''
options = Map({})
options.flags = await get('/sdapi/v1/cmd-flags')
log.debug({ 'flags': options.flags })
data = await get('/sdapi/v1/sd-models')
options.models = [obj['title'] for obj in data]
log.debug({ 'registered models': options.models })
found = sd.options.sd_model_checkpoint if sd.options.sd_model_checkpoint in options.models else None
if found is None:
found = [i for i in options.models if i.startswith(sd.options.sd_model_checkpoint)]
if len(found) == 0:
log.error({ 'model error': sd.generate.sd_model_checkpoint, 'available': options.models})
exit()
sd.options.sd_model_checkpoint = found[0]
data = await get('/sdapi/v1/samplers')
options.samplers = [obj['name'] for obj in data]
log.debug({ 'registered samplers': options.samplers })
data = await get('/sdapi/v1/upscalers')
options.upscalers = [obj['name'] for obj in data]
log.debug({ 'registered upscalers': options.upscalers })
data = await get('/sdapi/v1/face-restorers')
options.restorers = [obj['name'] for obj in data]
log.debug({ 'registered face restorers': options.restorers })
await interrupt()
await post('/sdapi/v1/options', sd.options)
options.options = await get('/sdapi/v1/options')
log.info({ 'target models': { 'diffuser': options.options['sd_model_checkpoint'], 'vae': options.options['sd_vae'] } })
log.info({ 'paths': sd.paths })
options.queue = await get('/queue/status')
log.info({ 'queue': options.queue })
pathlib.Path(sd.paths.root).mkdir(parents = True, exist_ok = True)
pathlib.Path(os.path.join(sd.paths.root, sd.paths.generate)).mkdir(parents = True, exist_ok = True)
pathlib.Path(os.path.join(sd.paths.root, sd.paths.upscale)).mkdir(parents = True, exist_ok = True)
pathlib.Path(os.path.join(sd.paths.root, sd.paths.grid)).mkdir(parents = True, exist_ok = True)
return options
def args(): # parse cmd arguments
global sd # pylint: disable=global-statement
global random # pylint: disable=global-statement
parser = argparse.ArgumentParser(description = 'sd pipeline')
parser.add_argument('--config', type = str, default = 'generate.json', required = False, help = 'configuration file')
parser.add_argument('--random', type = str, default = 'random.json', required = False, help = 'prompt file with randomized sections')
parser.add_argument('--max', type = int, default = 1, required = False, help = 'maximum number of generated images')
parser.add_argument('--prompt', type = str, default = 'dynamic', required = False, help = 'prompt')
parser.add_argument('--negative', type = str, default = 'dynamic', required = False, help = 'negative prompt')
parser.add_argument('--artist', type = str, default = 'random', required = False, help = 'artist style, used to guide dynamic prompt when prompt is not provided')
parser.add_argument('--embedding', type = str, default = 'random', required = False, help = 'use embedding, used to guide dynamic prompt when prompt is not provided')
parser.add_argument('--style', type = str, default = 'random', required = False, help = 'image style, used to guide dynamic prompt when prompt is not provided')
parser.add_argument('--suffix', type = str, default = 'random', required = False, help = 'style suffix, used to guide dynamic prompt when prompt is not provided')
parser.add_argument('--place', type = str, default = 'random', required = False, help = 'place locator, used to guide dynamic prompt when prompt is not provided')
parser.add_argument('--faces', default = False, action='store_true', help = 'restore faces during upscaling')
parser.add_argument('--steps', type = int, default = 0, required = False, help = 'number of steps')
parser.add_argument('--batch', type = int, default = 0, required = False, help = 'batch size, limited by gpu vram')
parser.add_argument('--n', type = int, default = 0, required = False, help = 'number of iterations')
parser.add_argument('--cfg', type = int, default = 0, required = False, help = 'classifier free guidance scale')
parser.add_argument('--sampler', type = str, default = 'random', required = False, help = 'sampler')
parser.add_argument('--seed', type = int, default = 0, required = False, help = 'seed, default is random')
parser.add_argument('--upscale', type = int, default = 0, required = False, help = 'upscale factor, disabled if 0')
parser.add_argument('--model', type = str, default = '', required = False, help = 'diffusion model')
parser.add_argument('--vae', type = str, default = '', required = False, help = 'vae model')
parser.add_argument('--path', type = str, default = '', required = False, help = 'output path')
parser.add_argument('--width', type = int, default = 0, required = False, help = 'width')
parser.add_argument('--height', type = int, default = 0, required = False, help = 'height')
parser.add_argument('--beautify', default = False, action='store_true', help = 'beautify prompt')
parser.add_argument('--prompts', default = False, action='store_true', help = 'print dynamic prompt templates')
parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information')
params = parser.parse_args()
if params.debug:
log.setLevel(logging.DEBUG)
log.debug({ 'debug': True })
log.debug({ 'args': params.__dict__ })
home = pathlib.Path(sys.argv[0]).parent
if os.path.isfile(params.config):
try:
with open(params.config, 'r', encoding='utf-8') as f:
data = json.load(f)
sd = Map(data)
log.debug({ 'config': sd })
except Exception as e:
log.error({ 'config error': params.config, 'exception': e })
exit()
elif os.path.isfile(os.path.join(home, params.config)):
try:
with open(os.path.join(home, params.config), 'r', encoding='utf-8') as f:
data = json.load(f)
sd = Map(data)
log.debug({ 'config': sd })
except Exception as e:
log.error({ 'config error': params.config, 'exception': e })
exit()
else:
log.error({ 'config file not found': params.config})
exit()
if params.prompt == 'dynamic':
log.info({ 'prompt template': params.random })
if os.path.isfile(params.random):
try:
with open(params.random, 'r', encoding='utf-8') as f:
data = json.load(f)
random = Map(data)
log.debug({ 'random template': sd })
except Exception:
log.error({ 'random template error': params.random})
exit()
elif os.path.isfile(os.path.join(home, params.random)):
try:
with open(os.path.join(home, params.random), 'r', encoding='utf-8') as f:
data = json.load(f)
random = Map(data)
log.debug({ 'random template': sd })
except Exception:
log.error({ 'random template error': params.random})
exit()
else:
log.error({ 'random template file not found': params.random})
exit()
_dynamic = prompt(params)
sd.paths.root = params.path if params.path != '' else sd.paths.root
sd.generate.restore_faces = params.faces if params.faces is not None else sd.generate.restore_faces
sd.generate.seed = params.seed if params.seed > 0 else sd.generate.seed
sd.generate.sampler_name = params.sampler if params.sampler != 'random' else sd.generate.sampler_name
sd.generate.batch_size = params.batch if params.batch > 0 else sd.generate.batch_size
sd.generate.cfg_scale = params.cfg if params.cfg > 0 else sd.generate.cfg_scale
sd.generate.n_iter = params.n if params.n > 0 else sd.generate.n_iter
sd.generate.width = params.width if params.width > 0 else sd.generate.width
sd.generate.height = params.height if params.height > 0 else sd.generate.height
sd.generate.steps = params.steps if params.steps > 0 else sd.generate.steps
sd.upscale.upscaling_resize = params.upscale if params.upscale > 0 else sd.upscale.upscaling_resize
sd.upscale.codeformer_visibility = 1 if params.faces else sd.upscale.codeformer_visibility
sd.options.sd_vae = params.vae if params.vae != '' else sd.options.sd_vae
sd.options.sd_model_checkpoint = params.model if params.model != '' else sd.options.sd_model_checkpoint
sd.upscale.upscaler_1 = 'SwinIR_4x' if params.upscale > 1 else sd.upscale.upscaler_1
if sd.generate.cfg_scale == 0:
sd.generate.cfg_scale = randrange(5, 10)
return params
async def main():
params = args()
sess = await session()
if sess is None:
await close()
exit()
options = await init()
iteration = 0
while True:
iteration += 1
log.info('')
log.info({ 'iteration': iteration, 'batch': sd.generate.batch_size, 'n': sd.generate.n_iter, 'total': sd.generate.n_iter * sd.generate.batch_size })
dynamic = prompt(params)
if params.beautify:
try:
promptist = importlib.import_module('modules.promptist')
sd.generate.prompt = promptist.beautify(dynamic)
except Exception as e:
log.error({ 'beautify': e })
scheduler = sampler(params, options)
t0 = time.perf_counter()
data = await generate() # generate returns list of images
if 'image' not in data:
break
stats.images += len(data.image)
t1 = time.perf_counter()
if len(data.image) > 0:
avg[scheduler] = (t1 - t0) / len(data.image)
stats.generate += t1 - t0
_image = grid(data)
data = await upscale(data)
t2 = time.perf_counter()
stats.upscale += t2 - t1
stats.wall += t2 - t0
its = sd.generate.steps / ((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
avg_time = round((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
log.info({ 'time' : { 'wall': round(t1 - t0), 'average': avg_time, 'upscale': round(t2 - t1), 'its': round(its, 2) } })
log.info({ 'generated': stats.images, 'max': params.max, 'progress': round(100 * stats.images / params.max, 1) })
if params.max != 0 and stats.images >= params.max:
break
if __name__ == '__main__':
try:
asyncio.run(main())
except KeyboardInterrupt:
asyncio.run(interrupt())
asyncio.run(close())
log.info({ 'interrupt': True })
finally:
log.info({ 'sampler performance': avg })
log.info({ 'stats' : stats })
asyncio.run(close())