TikZ-Assistant / infer.py
waleko's picture
fix generation
77f4da6
from collections import namedtuple
from functools import cache, cached_property
from io import BytesIO
from os import environ
from os.path import isfile, join
from re import MULTILINE, escape, search, sub
from subprocess import CalledProcessError, DEVNULL, TimeoutExpired
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Optional, Union
import warnings
from PIL import Image, ImageOps
import requests
import torch
from torch.cuda import current_device, is_available as has_cuda
from transformers import TextGenerationPipeline as TGP, TextStreamer, pipeline, ImageToTextPipeline as ITP
from transformers.utils import logging
from transformers.utils.hub import is_remote_url
from pdf2image.pdf2image import convert_from_bytes
from pdfCropMargins import crop
import fitz
logger = logging.get_logger("transformers")
from os import killpg, getpgid
from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE
from signal import SIGKILL
def run(*popenargs, input=None, timeout=None, check=False, **kwargs):
with Popen(*popenargs, start_new_session=True, **kwargs) as process:
try:
stdout, stderr = process.communicate(input, timeout=timeout)
except TimeoutExpired:
killpg(getpgid(process.pid), SIGKILL)
process.wait()
raise
except:
killpg(getpgid(process.pid), SIGKILL)
raise
retcode = process.poll()
if check and retcode:
raise CalledProcessError(retcode, process.args,
output=stdout, stderr=stderr)
return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore
def check_output(*popenargs, timeout=None, **kwargs):
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout
class PdfDocument:
def __init__(self, raw: bytes):
self.raw = raw
def save(self, filename):
with open(filename, "wb") as f:
f.write(self.raw)
class TikzDocument:
"""
Faciliate some operations with TikZ code. To compile the images a full
TeXLive installation is assumed to be on the PATH. Cropping additionally
requires Ghostscript, and rasterization needs poppler (apart from the 'pdf'
optional dependencies).
"""
# engines to try, could also try: https://tex.stackexchange.com/a/495999
engines = ["pdflatex", "lualatex", "xelatex"]
Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""])
def __init__(self, code: str, timeout=120):
self.code = code
self.timeout = timeout
@property
def status(self) -> int:
return self.compile().status
@property
def pdf(self) -> Optional[PdfDocument]:
return self.compile().pdf
@property
def log(self) -> str:
return self.compile().log
@property
def compiled_with_errors(self) -> bool:
return self.status != 0
@cached_property
def has_content(self) -> bool:
"""true if we have an image that isn't empty"""
return (img:=self.rasterize()) is not None and img.getcolors(1) is None
@classmethod
def set_engines(cls, engines: Union[str, list]):
cls.engines = [engines] if isinstance(engines, str) else engines
@cache
def compile(self) -> "Output":
output = dict()
with TemporaryDirectory() as tmpdirname:
with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile:
codelines = self.code.split("\n")
# make sure we don't have page numbers in compiled pdf (for cropping)
codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}"))
tmpfile.write("\n".join(codelines).encode())
try:
# compile
errorln, tmppdf, outpdf = 0, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf")
open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile
def try_save_last_page():
try:
doc = fitz.open(tmppdf) # type: ignore
doc.select([len(doc)-1])
doc.save(outpdf)
except:
pass
for engine in self.engines:
try:
check_output(
cwd=tmpdirname,
timeout=self.timeout,
stderr=DEVNULL,
env=environ | dict(max_print_line="1000"), # improve formatting of log
args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name]
)
except (CalledProcessError, TimeoutExpired) as proc:
log = getattr(proc, "output", b'').decode(errors="ignore")
error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE)
# only update status and log if first error occurs later than in previous engine
if (linenr:=int(error.group(1)) if error else 0) > errorln:
errorln = linenr
output.update(status=getattr(proc, 'returncode', -1), log=log)
try_save_last_page()
else:
output.update(status=0, log='')
try_save_last_page()
break
# crop
croppdf = f"{tmpfile.name}.crop"
crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True)
if isfile(croppdf):
with open(croppdf, "rb") as pdf:
output['pdf'] = PdfDocument(pdf.read())
except (FileNotFoundError, NameError) as e:
logger.error("Missing dependencies: " + (
"Install this project with the [pdf] feature name!" if isinstance(e, NameError)
else "Did you install TeX Live?"
))
except RuntimeError: # pdf error during cropping
pass
if output.get("status") == 0 and not output.get("pdf", None):
logger.warning("Could compile document but something seems to have gone wrong during cropping!")
return self.Output(**output)
def rasterize(self, size=336, expand_to_square=True) -> Optional[Image.Image]:
if self.pdf:
image = convert_from_bytes(self.pdf.raw, size=size, single_file=True)[0]
if expand_to_square:
image = ImageOps.pad(image, (size, size), color='white')
return image
def save(self, filename: str, *args, **kwargs):
match filename.split(".")[-1]:
case "tex": content = self.code.encode()
case "pdf": content = getattr(self.pdf, "raw", bytes())
case fmt if img := self.rasterize(*args, **kwargs):
img.save(imgByteArr:=BytesIO(), format=fmt)
content = imgByteArr.getvalue()
case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!")
with open(filename, "wb") as f:
f.write(content)
class TikzGenerator:
def __init__(
self,
pipe: ITP,
temperature: float = 0.8, # based on "a systematic evaluation of large language models of code"
top_p: float = 0.95,
top_k: int = 0,
stream: bool = False,
expand_to_square: bool = False,
clean_up_output: bool = True,
):
self.expand_to_square = expand_to_square
self.clean_up_output = clean_up_output
self.pipeline = pipe
self.default_kwargs = dict(
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
max_new_tokens=1024,
)
# if not stream:
# self.default_kwargs.pop("streamer")
def generate(self, image: Image.Image, **generate_kwargs):
prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
tokenizer = self.pipeline.tokenizer
text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
if self.clean_up_output:
for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
# remove leading characters because skip_special_tokens in pipeline
# adds unwanted prefix spaces if prompt ends with a special tokens
if text and text[0].isspace() and token in tokenizer.all_special_tokens: # type: ignore
text = text[1:]
else:
break
# occasionally observed artifacts
artifacts = {
r'\bamsop\b': 'amsopn'
}
for artifact, replacement in artifacts.items():
text = sub(artifact, replacement, text) # type: ignore
return TikzDocument(text)
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)