Spaces:
Runtime error
Runtime error
File size: 6,847 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
#!/usr/bin/env python
import os
import time
import functools
import argparse
import logging
import warnings
from dataclasses import dataclass
logging.getLogger("DeepSpeed").disabled = True
warnings.filterwarnings(action="ignore", category=FutureWarning)
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
import torch
import diffusers
n_warmup = 5
n_traces = 10
n_runs = 100
args = {}
pipe = None
log = logging.getLogger("sd")
def setup_logging():
from rich.theme import Theme
from rich.logging import RichHandler
from rich.console import Console
from rich.traceback import install
log.setLevel(logging.DEBUG)
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ "traceback.border": "black", "traceback.border.syntax_error": "black", "inspect.value.border": "black" }))
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG, console=console)
rh.setLevel(logging.DEBUG)
log.addHandler(rh)
logging.getLogger("diffusers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning)
install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
def generate_inputs():
if args.type == 'sd15':
sample = torch.randn(2, 4, 64, 64).half().cuda()
timestep = torch.rand(1).half().cuda() * 999
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
return sample, timestep, encoder_hidden_states
if args.type == 'sdxl':
sample = torch.randn(2, 4, 64, 64).half().cuda()
timestep = torch.rand(1).half().cuda() * 999
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
text_embeds = torch.randn(1, 77, 2048).half().cuda()
return sample, timestep, encoder_hidden_states, text_embeds
def load_model():
log.info(f'versions: torch={torch.__version__} diffusers={diffusers.__version__}')
diffusers_load_config = {
"low_cpu_mem_usage": True,
"torch_dtype": torch.float16,
"safety_checker": None,
"requires_safety_checker": False,
"load_safety_checker": False,
"load_connected_pipeline": True,
"use_safetensors": True,
}
pipeline = diffusers.StableDiffusionPipeline if args.type == 'sd15' else diffusers.StableDiffusionXLPipeline
global pipe # pylint: disable=global-statement
t0 = time.time()
pipe = pipeline.from_single_file(args.model, **diffusers_load_config).to('cuda')
size = os.path.getsize(args.model)
log.info(f'load: model={args.model} type={args.type} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
def load_trace(fn: str):
@dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
t0 = time.time()
unet_traced = torch.jit.load(fn)
pipe.unet = TracedUNet()
size = os.path.getsize(fn)
log.info(f'load: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
def trace_model():
log.info(f'tracing model: {args.model}')
torch.set_grad_enabled(False)
unet = pipe.unet
unet.eval()
# unet.to(memory_format=torch.channels_last) # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
# warmup
t0 = time.time()
for _ in range(n_warmup):
with torch.inference_mode():
inputs = generate_inputs()
_output = unet(*inputs)
log.info(f'warmup: time={time.time() - t0:.3f}s passes={n_warmup}')
# trace
t0 = time.time()
unet_traced = torch.jit.trace(unet, inputs, check_trace=True)
unet_traced.eval()
log.info(f'trace: time={time.time() - t0:.3f}s')
# optimize graph
t0 = time.time()
for _ in range(n_traces):
with torch.inference_mode():
inputs = generate_inputs()
_output = unet_traced(*inputs)
log.info(f'optimize: time={time.time() - t0:.3f}s passes={n_traces}')
# save the model
if args.save:
t0 = time.time()
basename, _ext = os.path.splitext(args.model)
fn = f"{basename}.pt"
unet_traced.save(fn)
size = os.path.getsize(fn)
log.info(f'save: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
return fn
pipe.unet = unet_traced
return None
def benchmark_model(msg: str):
with torch.inference_mode():
inputs = generate_inputs()
torch.cuda.synchronize()
for n in range(n_runs):
if n > n_runs / 10:
t0 = time.time()
_output = pipe.unet(*inputs)
torch.cuda.synchronize()
t1 = time.time()
log.info(f"benchmark unet: {t1 - t0:.3f}s passes={n_runs} type={msg}")
return t1 - t0
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = 'SD.Next')
parser.add_argument('--model', type=str, default='', required=True, help='model path')
parser.add_argument('--type', type=str, default='sd15', choices=['sd15', 'sdxl'], required=False, help='model type, default: %(default)s')
parser.add_argument('--benchmark', default = False, action='store_true', help = "run benchmarks, default: %(default)s")
parser.add_argument('--trace', default = True, action='store_true', help = "run jit tracing, default: %(default)s")
parser.add_argument('--save', default = False, action='store_true', help = "save optimized unet, default: %(default)s")
args = parser.parse_args()
setup_logging()
log.info('sdnext model jit tracing')
if not os.path.isfile(args.model):
log.error(f"invalid model path: {args.model}")
exit(1)
load_model()
if args.benchmark:
time0 = benchmark_model('original')
unet_saved = trace_model()
if unet_saved is not None:
load_trace(unet_saved)
if args.benchmark:
time1 = benchmark_model('traced')
log.info(f'benchmark speedup: {100 * (time0 - time1) / time0:.3f}%')
|