from collections import namedtuple from functools import cache, cached_property from io import BytesIO from os import environ from os.path import isfile, join from re import MULTILINE, escape, search, sub from subprocess import CalledProcessError, DEVNULL, TimeoutExpired from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import Optional, Union import warnings from PIL import Image, ImageOps import requests import torch from torch.cuda import current_device, is_available as has_cuda from transformers import TextGenerationPipeline as TGP, TextStreamer, pipeline, ImageToTextPipeline as ITP from transformers.utils import logging from transformers.utils.hub import is_remote_url from pdf2image.pdf2image import convert_from_bytes from pdfCropMargins import crop import fitz logger = logging.get_logger("transformers") from os import killpg, getpgid from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE from signal import SIGKILL def run(*popenargs, input=None, timeout=None, check=False, **kwargs): with Popen(*popenargs, start_new_session=True, **kwargs) as process: try: stdout, stderr = process.communicate(input, timeout=timeout) except TimeoutExpired: killpg(getpgid(process.pid), SIGKILL) process.wait() raise except: killpg(getpgid(process.pid), SIGKILL) raise retcode = process.poll() if check and retcode: raise CalledProcessError(retcode, process.args, output=stdout, stderr=stderr) return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore def check_output(*popenargs, timeout=None, **kwargs): return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout class PdfDocument: def __init__(self, raw: bytes): self.raw = raw def save(self, filename): with open(filename, "wb") as f: f.write(self.raw) class TikzDocument: """ Faciliate some operations with TikZ code. To compile the images a full TeXLive installation is assumed to be on the PATH. Cropping additionally requires Ghostscript, and rasterization needs poppler (apart from the 'pdf' optional dependencies). """ # engines to try, could also try: https://tex.stackexchange.com/a/495999 engines = ["pdflatex", "lualatex", "xelatex"] Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""]) def __init__(self, code: str, timeout=120): self.code = code self.timeout = timeout @property def status(self) -> int: return self.compile().status @property def pdf(self) -> Optional[PdfDocument]: return self.compile().pdf @property def log(self) -> str: return self.compile().log @property def compiled_with_errors(self) -> bool: return self.status != 0 @cached_property def has_content(self) -> bool: """true if we have an image that isn't empty""" return (img:=self.rasterize()) is not None and img.getcolors(1) is None @classmethod def set_engines(cls, engines: Union[str, list]): cls.engines = [engines] if isinstance(engines, str) else engines @cache def compile(self) -> "Output": output = dict() with TemporaryDirectory() as tmpdirname: with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile: codelines = self.code.split("\n") # make sure we don't have page numbers in compiled pdf (for cropping) codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}")) tmpfile.write("\n".join(codelines).encode()) try: # compile errorln, tmppdf, outpdf = 0, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf") open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile def try_save_last_page(): try: doc = fitz.open(tmppdf) # type: ignore doc.select([len(doc)-1]) doc.save(outpdf) except: pass for engine in self.engines: try: check_output( cwd=tmpdirname, timeout=self.timeout, stderr=DEVNULL, env=environ | dict(max_print_line="1000"), # improve formatting of log args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name] ) except (CalledProcessError, TimeoutExpired) as proc: log = getattr(proc, "output", b'').decode(errors="ignore") error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE) # only update status and log if first error occurs later than in previous engine if (linenr:=int(error.group(1)) if error else 0) > errorln: errorln = linenr output.update(status=getattr(proc, 'returncode', -1), log=log) try_save_last_page() else: output.update(status=0, log='') try_save_last_page() break # crop croppdf = f"{tmpfile.name}.crop" crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True) if isfile(croppdf): with open(croppdf, "rb") as pdf: output['pdf'] = PdfDocument(pdf.read()) except (FileNotFoundError, NameError) as e: logger.error("Missing dependencies: " + ( "Install this project with the [pdf] feature name!" if isinstance(e, NameError) else "Did you install TeX Live?" )) except RuntimeError: # pdf error during cropping pass if output.get("status") == 0 and not output.get("pdf", None): logger.warning("Could compile document but something seems to have gone wrong during cropping!") return self.Output(**output) def rasterize(self, size=336, expand_to_square=True) -> Optional[Image.Image]: if self.pdf: image = convert_from_bytes(self.pdf.raw, size=size, single_file=True)[0] if expand_to_square: image = ImageOps.pad(image, (size, size), color='white') return image def save(self, filename: str, *args, **kwargs): match filename.split(".")[-1]: case "tex": content = self.code.encode() case "pdf": content = getattr(self.pdf, "raw", bytes()) case fmt if img := self.rasterize(*args, **kwargs): img.save(imgByteArr:=BytesIO(), format=fmt) content = imgByteArr.getvalue() case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!") with open(filename, "wb") as f: f.write(content) class TikzGenerator: def __init__( self, pipe: ITP, temperature: float = 0.8, # based on "a systematic evaluation of large language models of code" top_p: float = 0.95, top_k: int = 0, stream: bool = False, expand_to_square: bool = False, clean_up_output: bool = True, ): self.expand_to_square = expand_to_square self.clean_up_output = clean_up_output self.pipeline = pipe # self.pipeline.model = torch.compile(self.pipeline.model) self.default_kwargs = dict( temperature=temperature, top_p=top_p, top_k=top_k, num_return_sequences=1, max_length=self.pipeline.tokenizer.model_max_length, # type: ignore do_sample=True, return_full_text=False, streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore skip_prompt=True, skip_special_tokens=True ), ) if not stream: self.default_kwargs.pop("streamer") def generate(self, image: Image.Image, **generate_kwargs): prompt = "Assistant helps to write down the TikZ code for the user's image. USER: \nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:" tokenizer = self.pipeline.tokenizer text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore if self.clean_up_output: for token in reversed(tokenizer.tokenize(prompt)): # type: ignore # remove leading characters because skip_special_tokens in pipeline # adds unwanted prefix spaces if prompt ends with a special tokens if text and text[0].isspace() and token in tokenizer.all_special_tokens: # type: ignore text = text[1:] else: break # occasionally observed artifacts artifacts = { r'\bamsop\b': 'amsopn' } for artifact, replacement in artifacts.items(): text = sub(artifact, replacement, text) # type: ignore return text def __call__(self, *args, **kwargs): return self.generate(*args, **kwargs)