Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import logging | |
import time | |
from typing import Annotated, Iterator | |
import ijson | |
import outlines | |
import torch | |
from pydantic import BaseModel, StringConstraints, conlist, conset | |
from outlines import generate, models | |
from outlines.generate.api import SequenceGenerator | |
from transformers import AutoTokenizer | |
from fsm import replace_fields | |
from samplers import PenalizedMultinomialSampler | |
from utils import StringIteratorIO | |
logger = logging.getLogger(__name__) | |
logger.warning("Loading model...") | |
model_id = "google/gemma-2b-it" | |
# model_id = "Qwen/Qwen1.5-0.5B-Chat" | |
if torch.backends.mps.is_available(): | |
model = models.transformers(model_id, device="mps") | |
else: | |
model = models.transformers(model_id, device="cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
sampler = PenalizedMultinomialSampler() | |
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()] | |
sampler.set_max_repeats(empty_tokens, 1) | |
# This Sample & Dataset models ztr just templated with placeholder fields | |
class Sample(BaseModel): | |
# We use get_samples_generator() to replace the placeholder with the requested fields | |
ABCDabcd: str | |
EFGHefgh: str | |
IJKLijkl: str | |
MNOPmnop: str | |
QRSTqrst: str | |
# PS: don't use StringConstraints with max_length here since it creates a fsm that is too big | |
class Dataset(BaseModel): | |
# We use get_samples_generator() to set the length to infinity | |
data: conlist(Sample, min_length=2, max_length=3) # type: ignore | |
samples_generator_template = generate.json(model, Dataset, sampler=sampler) | |
class Columns(BaseModel): | |
columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore | |
columns_generator = generate.json(model, Columns, sampler=sampler) | |
def get_samples_generator(new_fields: list[str]) -> SequenceGenerator: | |
fsm=samples_generator_template.fsm | |
fsm = replace_fields( # replace the placeholder fields by the real fields | |
fsm=samples_generator_template.fsm, | |
model=Sample, | |
new_fields=new_fields, | |
tokenizer=tokenizer, | |
make_infinite_loop=True # to generate as many samples as we want | |
) | |
return SequenceGenerator( | |
fsm=fsm, | |
model=samples_generator_template.model, | |
sampler=samples_generator_template.sampler, | |
device=samples_generator_template.device | |
) | |
def columns_prompt(filename: str): | |
"""I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data. | |
Give an example of column names / columns for this dataset to populate a SQL schema. | |
Please reply in JSON format and place the columns in a field named "columns". | |
""" | |
def samples_prommpt(filename: str, prompt: str, columns: str): | |
"""I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data. | |
Give an example of content using a JSON field named "data" with samples with columns {{ columns }}. | |
{{ prompt }} | |
""" | |
def stream_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]: | |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})") | |
_start = time.time() | |
rng = torch.Generator(device=model.device) | |
rng.manual_seed(seed) | |
if not columns: | |
messages = [ | |
{"role": "user", "content": columns_prompt(filename=filename)} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns...") | |
columns_generator_tokens = columns_generator.stream(text, rng=rng) | |
for column in ijson.items(StringIteratorIO(columns_generator_tokens), "columns.item", buf_size=16): | |
columns.append(column) | |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)") | |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...") | |
samples_generator = get_samples_generator(new_fields=columns) | |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)") | |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples...") | |
messages = [ | |
{"role": "user", "content": samples_prommpt(filename=filename, prompt=prompt, columns="'" + "', '".join(columns) + "'")} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
samples_generator_tokens = samples_generator.stream(text, rng=rng) | |
for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)): | |
yield json.dumps(sample, ensure_ascii=False) + "\n" | |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)") | |