Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import csv | |
import gc | |
import os | |
from dataclasses import dataclass | |
from typing import Dict, List, Union | |
import torch | |
import torch.utils.benchmark as benchmark | |
GITHUB_SHA = os.getenv("GITHUB_SHA", None) | |
BENCHMARK_FIELDS = [ | |
"pipeline_cls", | |
"ckpt_id", | |
"batch_size", | |
"num_inference_steps", | |
"model_cpu_offload", | |
"run_compile", | |
"time (secs)", | |
"memory (gbs)", | |
"actual_gpu_memory (gbs)", | |
"github_sha", | |
] | |
PROMPT = "ghibli style, a fantasy landscape with castles" | |
BASE_PATH = os.getenv("BASE_PATH", ".") | |
TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3))) | |
REPO_ID = "diffusers/benchmarks" | |
FINAL_CSV_FILE = "collated_results.csv" | |
class BenchmarkInfo: | |
time: float | |
memory: float | |
def flush(): | |
"""Wipes off memory.""" | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.reset_max_memory_allocated() | |
torch.cuda.reset_peak_memory_stats() | |
def bytes_to_giga_bytes(bytes): | |
return f"{(bytes / 1024 / 1024 / 1024):.3f}" | |
def benchmark_fn(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", | |
globals={"args": args, "kwargs": kwargs, "f": f}, | |
num_threads=torch.get_num_threads(), | |
) | |
return f"{(t0.blocked_autorange().mean):.3f}" | |
def generate_csv_dict( | |
pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo | |
) -> Dict[str, Union[str, bool, float]]: | |
"""Packs benchmarking data into a dictionary for latter serialization.""" | |
data_dict = { | |
"pipeline_cls": pipeline_cls, | |
"ckpt_id": ckpt, | |
"batch_size": args.batch_size, | |
"num_inference_steps": args.num_inference_steps, | |
"model_cpu_offload": args.model_cpu_offload, | |
"run_compile": args.run_compile, | |
"time (secs)": benchmark_info.time, | |
"memory (gbs)": benchmark_info.memory, | |
"actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}", | |
"github_sha": GITHUB_SHA, | |
} | |
return data_dict | |
def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]): | |
"""Serializes a dictionary into a CSV file.""" | |
with open(file_name, mode="w", newline="") as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS) | |
writer.writeheader() | |
writer.writerow(data_dict) | |
def collate_csv(input_files: List[str], output_file: str): | |
"""Collates multiple identically structured CSVs into a single CSV file.""" | |
with open(output_file, mode="w", newline="") as outfile: | |
writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS) | |
writer.writeheader() | |
for file in input_files: | |
with open(file, mode="r") as infile: | |
reader = csv.DictReader(infile) | |
for row in reader: | |
writer.writerow(row) | |