Spaces:
Runtime error
Runtime error
from collections import namedtuple | |
from functools import cache, cached_property | |
from io import BytesIO | |
from os import environ | |
from os.path import isfile, join | |
from re import MULTILINE, escape, search, sub | |
from subprocess import CalledProcessError, DEVNULL, TimeoutExpired | |
from tempfile import NamedTemporaryFile, TemporaryDirectory | |
from typing import Optional, Union | |
import warnings | |
from PIL import Image, ImageOps | |
import requests | |
import torch | |
from torch.cuda import current_device, is_available as has_cuda | |
from transformers import TextGenerationPipeline as TGP, TextStreamer, pipeline, ImageToTextPipeline as ITP | |
from transformers.utils import logging | |
from transformers.utils.hub import is_remote_url | |
from pdf2image.pdf2image import convert_from_bytes | |
from pdfCropMargins import crop | |
import fitz | |
logger = logging.get_logger("transformers") | |
from os import killpg, getpgid | |
from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE | |
from signal import SIGKILL | |
def run(*popenargs, input=None, timeout=None, check=False, **kwargs): | |
with Popen(*popenargs, start_new_session=True, **kwargs) as process: | |
try: | |
stdout, stderr = process.communicate(input, timeout=timeout) | |
except TimeoutExpired: | |
killpg(getpgid(process.pid), SIGKILL) | |
process.wait() | |
raise | |
except: | |
killpg(getpgid(process.pid), SIGKILL) | |
raise | |
retcode = process.poll() | |
if check and retcode: | |
raise CalledProcessError(retcode, process.args, | |
output=stdout, stderr=stderr) | |
return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore | |
def check_output(*popenargs, timeout=None, **kwargs): | |
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout | |
class PdfDocument: | |
def __init__(self, raw: bytes): | |
self.raw = raw | |
def save(self, filename): | |
with open(filename, "wb") as f: | |
f.write(self.raw) | |
class TikzDocument: | |
""" | |
Faciliate some operations with TikZ code. To compile the images a full | |
TeXLive installation is assumed to be on the PATH. Cropping additionally | |
requires Ghostscript, and rasterization needs poppler (apart from the 'pdf' | |
optional dependencies). | |
""" | |
# engines to try, could also try: https://tex.stackexchange.com/a/495999 | |
engines = ["pdflatex", "lualatex", "xelatex"] | |
Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""]) | |
def __init__(self, code: str, timeout=120): | |
self.code = code | |
self.timeout = timeout | |
def status(self) -> int: | |
return self.compile().status | |
def pdf(self) -> Optional[PdfDocument]: | |
return self.compile().pdf | |
def log(self) -> str: | |
return self.compile().log | |
def compiled_with_errors(self) -> bool: | |
return self.status != 0 | |
def has_content(self) -> bool: | |
"""true if we have an image that isn't empty""" | |
return (img:=self.rasterize()) is not None and img.getcolors(1) is None | |
def set_engines(cls, engines: Union[str, list]): | |
cls.engines = [engines] if isinstance(engines, str) else engines | |
def compile(self) -> "Output": | |
output = dict() | |
with TemporaryDirectory() as tmpdirname: | |
with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile: | |
codelines = self.code.split("\n") | |
# make sure we don't have page numbers in compiled pdf (for cropping) | |
codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}")) | |
tmpfile.write("\n".join(codelines).encode()) | |
try: | |
# compile | |
errorln, tmppdf, outpdf = 0, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf") | |
open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile | |
def try_save_last_page(): | |
try: | |
doc = fitz.open(tmppdf) # type: ignore | |
doc.select([len(doc)-1]) | |
doc.save(outpdf) | |
except: | |
pass | |
for engine in self.engines: | |
try: | |
check_output( | |
cwd=tmpdirname, | |
timeout=self.timeout, | |
stderr=DEVNULL, | |
env=environ | dict(max_print_line="1000"), # improve formatting of log | |
args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name] | |
) | |
except (CalledProcessError, TimeoutExpired) as proc: | |
log = getattr(proc, "output", b'').decode(errors="ignore") | |
error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE) | |
# only update status and log if first error occurs later than in previous engine | |
if (linenr:=int(error.group(1)) if error else 0) > errorln: | |
errorln = linenr | |
output.update(status=getattr(proc, 'returncode', -1), log=log) | |
try_save_last_page() | |
else: | |
output.update(status=0, log='') | |
try_save_last_page() | |
break | |
# crop | |
croppdf = f"{tmpfile.name}.crop" | |
crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True) | |
if isfile(croppdf): | |
with open(croppdf, "rb") as pdf: | |
output['pdf'] = PdfDocument(pdf.read()) | |
except (FileNotFoundError, NameError) as e: | |
logger.error("Missing dependencies: " + ( | |
"Install this project with the [pdf] feature name!" if isinstance(e, NameError) | |
else "Did you install TeX Live?" | |
)) | |
except RuntimeError: # pdf error during cropping | |
pass | |
if output.get("status") == 0 and not output.get("pdf", None): | |
logger.warning("Could compile document but something seems to have gone wrong during cropping!") | |
return self.Output(**output) | |
def rasterize(self, size=336, expand_to_square=True) -> Optional[Image.Image]: | |
if self.pdf: | |
image = convert_from_bytes(self.pdf.raw, size=size, single_file=True)[0] | |
if expand_to_square: | |
image = ImageOps.pad(image, (size, size), color='white') | |
return image | |
def save(self, filename: str, *args, **kwargs): | |
match filename.split(".")[-1]: | |
case "tex": content = self.code.encode() | |
case "pdf": content = getattr(self.pdf, "raw", bytes()) | |
case fmt if img := self.rasterize(*args, **kwargs): | |
img.save(imgByteArr:=BytesIO(), format=fmt) | |
content = imgByteArr.getvalue() | |
case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!") | |
with open(filename, "wb") as f: | |
f.write(content) | |
class TikzGenerator: | |
def __init__( | |
self, | |
pipe: ITP, | |
temperature: float = 0.8, # based on "a systematic evaluation of large language models of code" | |
top_p: float = 0.95, | |
top_k: int = 0, | |
stream: bool = False, | |
expand_to_square: bool = False, | |
clean_up_output: bool = True, | |
): | |
self.expand_to_square = expand_to_square | |
self.clean_up_output = clean_up_output | |
self.pipeline = pipe | |
# self.pipeline.model = torch.compile(self.pipeline.model) | |
self.default_kwargs = dict( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_return_sequences=1, | |
max_length=self.pipeline.tokenizer.model_max_length, # type: ignore | |
do_sample=True, | |
return_full_text=False, | |
streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore | |
skip_prompt=True, | |
skip_special_tokens=True | |
), | |
) | |
if not stream: | |
self.default_kwargs.pop("streamer") | |
def generate(self, image: Image.Image, **generate_kwargs): | |
prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:" | |
tokenizer = self.pipeline.tokenizer | |
text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore | |
if self.clean_up_output: | |
for token in reversed(tokenizer.tokenize(prompt)): # type: ignore | |
# remove leading characters because skip_special_tokens in pipeline | |
# adds unwanted prefix spaces if prompt ends with a special tokens | |
if text and text[0].isspace() and token in tokenizer.all_special_tokens: # type: ignore | |
text = text[1:] | |
else: | |
break | |
# occasionally observed artifacts | |
artifacts = { | |
r'\bamsop\b': 'amsopn' | |
} | |
for artifact, replacement in artifacts.items(): | |
text = sub(artifact, replacement, text) # type: ignore | |
return text | |
def __call__(self, *args, **kwargs): | |
return self.generate(*args, **kwargs) | |