Spaces:
Running
on
Zero
Running
on
Zero
Quentin Lhoest
commited on
Commit
•
4f83ec0
1
Parent(s):
7238d75
initial commit
Browse files- README.md +1 -1
- __init__.py +0 -0
- api.py +57 -0
- fsm.py +92 -0
- generate.py +129 -0
- gradio_app.py +77 -0
- requirements.txt +6 -0
- samplers.py +72 -0
- utils.py +60 -0
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.25.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.25.0
|
8 |
+
app_file: gradio_app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
__init__.py
ADDED
File without changes
|
api.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Annotated
|
3 |
+
|
4 |
+
from fastapi import FastAPI, Request
|
5 |
+
from fastapi.responses import StreamingResponse
|
6 |
+
from pydantic import BaseModel, StringConstraints
|
7 |
+
from outlines import generate
|
8 |
+
|
9 |
+
from generate import model, sampler, stream_file
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class Status(BaseModel):
|
15 |
+
status: Annotated[str, StringConstraints(pattern="ok")]
|
16 |
+
|
17 |
+
status_generator = generate.json(model, Status, sampler=sampler)
|
18 |
+
|
19 |
+
status_stream = status_generator.stream("status:")
|
20 |
+
status = "".join(char.strip() for char in status_stream if char.strip())
|
21 |
+
logger.warning("Model status: " + status)
|
22 |
+
|
23 |
+
|
24 |
+
async def stream_response(filename: str, prompt: str, columns: list[str], seed: int, size: int):
|
25 |
+
for chunk in stream_file(
|
26 |
+
filename=filename,
|
27 |
+
prompt=prompt,
|
28 |
+
columns=columns,
|
29 |
+
seed=seed,
|
30 |
+
size=size,
|
31 |
+
):
|
32 |
+
yield chunk
|
33 |
+
|
34 |
+
|
35 |
+
async def dummy_stream():
|
36 |
+
yield ""
|
37 |
+
|
38 |
+
|
39 |
+
app = FastAPI()
|
40 |
+
|
41 |
+
@app.head("/{filename}.jsonl")
|
42 |
+
@app.get("/{filename}.jsonl")
|
43 |
+
async def read_item(request: Request, filename: str, prompt: str = "", columns: str = "", seed: int = 42, size: int = 3):
|
44 |
+
if request.method == 'GET':
|
45 |
+
columns = [field.strip() for field in columns.split(",") if field.strip()]
|
46 |
+
content = stream_response(
|
47 |
+
filename,
|
48 |
+
prompt=prompt,
|
49 |
+
columns=columns,
|
50 |
+
seed=seed,
|
51 |
+
size=size
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
content = dummy_stream()
|
55 |
+
response = StreamingResponse(content, media_type="text/jsonlines")
|
56 |
+
response.headers["Content-Disposition"] = f"attachment; filename={filename}.jsonl"
|
57 |
+
return response
|
fsm.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import copy
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
from outlines.fsm.guide import RegexGuide
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from transformers import PreTrainedTokenizerBase
|
7 |
+
|
8 |
+
|
9 |
+
def merge_successive_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
|
10 |
+
states_to_token_maps = dict(states_to_token_maps)
|
11 |
+
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
|
12 |
+
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
|
13 |
+
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
|
14 |
+
for s1, s2 in transitions_i.items():
|
15 |
+
while s2 in transitions_j:
|
16 |
+
s2 = transitions_j[s2]
|
17 |
+
if s2 != transitions_i[s1]:
|
18 |
+
states_to_token_maps[s1] = dict(states_to_token_maps[s1])
|
19 |
+
states_to_token_maps[s1][i] = s2
|
20 |
+
return states_to_token_maps
|
21 |
+
|
22 |
+
|
23 |
+
def replace_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
|
24 |
+
states_to_token_maps = dict(states_to_token_maps)
|
25 |
+
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
|
26 |
+
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
|
27 |
+
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
|
28 |
+
for s1, s2 in transitions_i.items():
|
29 |
+
if s2 != transitions_j.get(s1):
|
30 |
+
states_to_token_maps[s1] = dict(states_to_token_maps[s1])
|
31 |
+
if s1 in transitions_j:
|
32 |
+
states_to_token_maps[s1][i] = transitions_j[s1]
|
33 |
+
else:
|
34 |
+
states_to_token_maps[s1].pop(i)
|
35 |
+
states_to_token_maps[s1][j] = s2
|
36 |
+
return states_to_token_maps
|
37 |
+
|
38 |
+
|
39 |
+
def find_paths_with_transitions(states_to_token_maps: dict[int, dict[int, int]], transitions: list[int]) -> list[list[int]]:
|
40 |
+
possible_s0 = {s0 for s0 in states_to_token_maps if transitions[0] in states_to_token_maps[s0]}
|
41 |
+
possible_s1 = {s1 for s1 in states_to_token_maps if transitions[1] in states_to_token_maps[s1]} - possible_s0
|
42 |
+
starts = sorted(
|
43 |
+
s0 for s0 in possible_s0
|
44 |
+
if states_to_token_maps[s0][transitions[0]] in possible_s1
|
45 |
+
)
|
46 |
+
paths = [[start] for start in starts]
|
47 |
+
for path in paths:
|
48 |
+
for i in transitions:
|
49 |
+
if i in states_to_token_maps[path[-1]]:
|
50 |
+
path.append(states_to_token_maps[path[-1]][i])
|
51 |
+
else:
|
52 |
+
break
|
53 |
+
return [path for path in paths if len(path) == len(transitions) + 1]
|
54 |
+
|
55 |
+
|
56 |
+
def replace_fields(fsm: RegexGuide, model: BaseModel, new_fields: list[str], tokenizer: PreTrainedTokenizerBase, make_infinite_loop: bool = False) -> RegexGuide:
|
57 |
+
assert len(new_fields) <= len(model.model_fields)
|
58 |
+
sttm = dict(fsm.states_to_token_maps)
|
59 |
+
encode = partial(tokenizer.encode, add_special_tokens=False)
|
60 |
+
quote = encode('"')[0]
|
61 |
+
# Let's replace the placeholder fields from the model in the finite state model by the new fields
|
62 |
+
for orig_field, new_field in zip(model.model_fields, new_fields):
|
63 |
+
orig_field_tokens = [encode(orig_field_char)[0] for orig_field_char in orig_field]
|
64 |
+
new_field_tokens = encode(new_field)
|
65 |
+
assert len(new_field_tokens) <= len(orig_field_tokens)
|
66 |
+
# Merge transitions until we have number of transitions = number of tokens in the field name
|
67 |
+
for k in reversed(range(len(new_field_tokens), len(orig_field_tokens))):
|
68 |
+
sttm = merge_successive_transitions(sttm, orig_field_tokens[k - 1], orig_field_tokens[k])
|
69 |
+
# Replace the token ids in the transitions with the ones of the new field name
|
70 |
+
for k in range(len(new_field_tokens)):
|
71 |
+
sttm = replace_transitions(sttm, orig_field_tokens[k], new_field_tokens[k])
|
72 |
+
if len(new_fields) < len(model.model_fields) or make_infinite_loop:
|
73 |
+
# Set the last field last state to generate less than the number of fields in the model
|
74 |
+
# We need to do this for every possible path
|
75 |
+
# e.g. multiple paths are used to count items when setting a min/max length
|
76 |
+
orig_last_field = list(model.model_fields)[-1]
|
77 |
+
new_last_field = new_fields[-1]
|
78 |
+
orig_last_field_paths = find_paths_with_transitions(sttm, [quote] + [encode(c)[0] for c in orig_last_field])
|
79 |
+
new_last_field_paths = find_paths_with_transitions(sttm, [quote] + encode(new_last_field))
|
80 |
+
if make_infinite_loop: # this is a hack to loop on the same states over and over again
|
81 |
+
orig_last_field_paths = [orig_last_field_paths[0]] * len(orig_last_field_paths)
|
82 |
+
for orig_last_field_path, new_last_field_path in zip(
|
83 |
+
orig_last_field_paths,
|
84 |
+
new_last_field_paths
|
85 |
+
):
|
86 |
+
orig_last_field_last_state = orig_last_field_path[-1]
|
87 |
+
new_last_field_second_last_state = new_last_field_path[-2]
|
88 |
+
sttm[new_last_field_second_last_state] = dict(sttm[new_last_field_second_last_state])
|
89 |
+
sttm[new_last_field_second_last_state][encode(new_last_field)[-1]] = orig_last_field_last_state
|
90 |
+
fsm = copy(fsm)
|
91 |
+
fsm.states_to_token_maps = sttm
|
92 |
+
return fsm
|
generate.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
from typing import Annotated, Iterator
|
6 |
+
|
7 |
+
import ijson
|
8 |
+
import outlines
|
9 |
+
import torch
|
10 |
+
from pydantic import BaseModel, StringConstraints, conlist, conset
|
11 |
+
from outlines import generate, models
|
12 |
+
from outlines.generate.api import SequenceGenerator
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
|
15 |
+
from fsm import replace_fields
|
16 |
+
from samplers import PenalizedMultinomialSampler
|
17 |
+
from utils import StringIteratorIO
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
device = "cuda"
|
24 |
+
elif torch.backends.mps.is_available():
|
25 |
+
device = "mps"
|
26 |
+
else:
|
27 |
+
raise RuntimeError("couldn't find cuda or mps")
|
28 |
+
|
29 |
+
logger.warning("Loading model...")
|
30 |
+
model_id = "google/gemma-2b-it"
|
31 |
+
# model_id = "Qwen/Qwen1.5-0.5B-Chat"
|
32 |
+
model = models.transformers(model_id, device=device)
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
34 |
+
sampler = PenalizedMultinomialSampler()
|
35 |
+
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
|
36 |
+
sampler.set_max_repeats(empty_tokens, 1)
|
37 |
+
|
38 |
+
# This Sample & Dataset models ztr just templated with placeholder fields
|
39 |
+
|
40 |
+
class Sample(BaseModel):
|
41 |
+
# We use get_samples_generator() to replace the placeholder with the requested fields
|
42 |
+
ABCDabcd: str
|
43 |
+
EFGHefgh: str
|
44 |
+
IJKLijkl: str
|
45 |
+
MNOPmnop: str
|
46 |
+
QRSTqrst: str
|
47 |
+
# PS: don't use StringConstraints with max_length here since it creates a fsm that is too big
|
48 |
+
|
49 |
+
|
50 |
+
class Dataset(BaseModel):
|
51 |
+
# We use get_samples_generator() to set the length to infinity
|
52 |
+
data: conlist(Sample, min_length=2, max_length=3) # type: ignore
|
53 |
+
|
54 |
+
|
55 |
+
samples_generator_template = generate.json(model, Dataset, sampler=sampler)
|
56 |
+
|
57 |
+
class Columns(BaseModel):
|
58 |
+
columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields)) # type: ignore
|
59 |
+
|
60 |
+
columns_generator = generate.json(model, Columns, sampler=sampler)
|
61 |
+
|
62 |
+
def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
|
63 |
+
fsm=samples_generator_template.fsm
|
64 |
+
fsm = replace_fields( # replace the placeholder fields by the real fields
|
65 |
+
fsm=samples_generator_template.fsm,
|
66 |
+
model=Sample,
|
67 |
+
new_fields=new_fields,
|
68 |
+
tokenizer=tokenizer,
|
69 |
+
make_infinite_loop=True # to generate as many samples as we want
|
70 |
+
)
|
71 |
+
return SequenceGenerator(
|
72 |
+
fsm=fsm,
|
73 |
+
model=samples_generator_template.model,
|
74 |
+
sampler=samples_generator_template.sampler,
|
75 |
+
device=samples_generator_template.device
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
@outlines.prompt
|
80 |
+
def columns_prompt(filename: str):
|
81 |
+
"""I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
|
82 |
+
Give an example of column names / columns for this dataset to populate a SQL schema.
|
83 |
+
Please reply in JSON format and place the columns in a field named "columns".
|
84 |
+
"""
|
85 |
+
|
86 |
+
@outlines.prompt
|
87 |
+
def samples_prommpt(filename: str, prompt: str, columns: str):
|
88 |
+
"""I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
|
89 |
+
Give an example of content using a JSON field named "data" with samples with columns {{ columns }}.
|
90 |
+
{{ prompt }}
|
91 |
+
"""
|
92 |
+
|
93 |
+
def stream_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
94 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
95 |
+
_start = time.time()
|
96 |
+
rng = torch.Generator(device=model.device)
|
97 |
+
rng.manual_seed(seed)
|
98 |
+
if not columns:
|
99 |
+
|
100 |
+
messages = [
|
101 |
+
{"role": "user", "content": columns_prompt(filename=filename)}
|
102 |
+
]
|
103 |
+
text = tokenizer.apply_chat_template(
|
104 |
+
messages,
|
105 |
+
tokenize=False,
|
106 |
+
add_generation_prompt=True
|
107 |
+
)
|
108 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns...")
|
109 |
+
columns_generator_tokens = columns_generator.stream(text, rng=rng)
|
110 |
+
for column in ijson.items(StringIteratorIO(columns_generator_tokens), "columns.item", buf_size=16):
|
111 |
+
columns.append(column)
|
112 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)")
|
113 |
+
|
114 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...")
|
115 |
+
samples_generator = get_samples_generator(new_fields=columns)
|
116 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)")
|
117 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples...")
|
118 |
+
messages = [
|
119 |
+
{"role": "user", "content": samples_prommpt(filename=filename, prompt=prompt, columns="'" + "', '".join(columns) + "'")}
|
120 |
+
]
|
121 |
+
text = tokenizer.apply_chat_template(
|
122 |
+
messages,
|
123 |
+
tokenize=False,
|
124 |
+
add_generation_prompt=True
|
125 |
+
)
|
126 |
+
samples_generator_tokens = samples_generator.stream(text, rng=rng)
|
127 |
+
for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
|
128 |
+
yield json.dumps(sample, ensure_ascii=False) + "\n"
|
129 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
|
gradio_app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import io
|
5 |
+
import pandas as pd
|
6 |
+
import spaces
|
7 |
+
|
8 |
+
@spaces.GPU(duration=120)
|
9 |
+
def stream_output(filename: str):
|
10 |
+
if filename.endswith(".jsonl"):
|
11 |
+
filename = filename[:-len(".jsonl")]
|
12 |
+
from generate import stream_file
|
13 |
+
content = ""
|
14 |
+
size=3
|
15 |
+
start_time = time.time()
|
16 |
+
for i, chunk in enumerate(stream_file(
|
17 |
+
filename=filename,
|
18 |
+
prompt="",
|
19 |
+
columns=[],
|
20 |
+
seed=42,
|
21 |
+
size=size,
|
22 |
+
)):
|
23 |
+
content += chunk
|
24 |
+
df = pd.read_json(io.StringIO(content), lines=True)
|
25 |
+
state_msg = (
|
26 |
+
f"✅ Done generating {size} samples in {time.time() - start_time:.2f}s"
|
27 |
+
if i + 1 == size else
|
28 |
+
f"⚙️ Generating... [{i}/{size}]"
|
29 |
+
)
|
30 |
+
yield df, "```json\n" + content + "\n```", state_msg
|
31 |
+
|
32 |
+
def test(filename: str):
|
33 |
+
if not filename.endswith(".jsonl"):
|
34 |
+
yield "❌ 404: File name must end with .jsonl", None, ""
|
35 |
+
return
|
36 |
+
|
37 |
+
content = ""
|
38 |
+
size = 10
|
39 |
+
start_time = time.time()
|
40 |
+
for i in range(size):
|
41 |
+
content += f'{{"i": {i}, "filename": "{filename}"}}\n'
|
42 |
+
df = pd.read_json(io.StringIO(content), lines=True)
|
43 |
+
state_msg = (
|
44 |
+
f"✅ Done generating {size} samples in {time.time() - start_time:.2f}s"
|
45 |
+
if i + 1 == size else
|
46 |
+
f"⚙️ Generating... [{i}/{size}]"
|
47 |
+
)
|
48 |
+
yield df, "```json\n" + content + "\n```", state_msg
|
49 |
+
time.sleep(0.1)
|
50 |
+
|
51 |
+
title = "LLM DataGen"
|
52 |
+
description = "Generate and stream synthetic dataset files in JSON Lines format"
|
53 |
+
examples = [
|
54 |
+
"movies_data.jsonl",
|
55 |
+
"common_first_names.jsonl",
|
56 |
+
"bad_amazon_reviews_on_defunct_products_that_people_hate.jsonl",
|
57 |
+
"dungeon_and_dragon_characters.jsonl"
|
58 |
+
]
|
59 |
+
|
60 |
+
with gr.Blocks() as demo:
|
61 |
+
gr.Markdown(f"# {title}")
|
62 |
+
gr.Markdown(description)
|
63 |
+
filename_comp = gr.Textbox(examples[0], placeholder=examples[0])
|
64 |
+
gr.Examples(examples, filename_comp)
|
65 |
+
generate_button = gr.Button("Generate dataset")
|
66 |
+
state_msg_comp = gr.Markdown("🔥 Ready to generate")
|
67 |
+
with gr.Tab("Dataset"):
|
68 |
+
dataframe_comp = gr.DataFrame()
|
69 |
+
with gr.Tab("File content"):
|
70 |
+
with gr.Blocks(fill_height=True):
|
71 |
+
with gr.Row():
|
72 |
+
file_content_comp = gr.Markdown()
|
73 |
+
|
74 |
+
generate_button.click(test, filename_comp, [dataframe_comp, file_content_comp, state_msg_comp])
|
75 |
+
|
76 |
+
|
77 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spaces
|
2 |
+
outlines==0.0.37
|
3 |
+
transformers
|
4 |
+
torch
|
5 |
+
ijson
|
6 |
+
pydantic
|
samplers.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from outlines.samplers import MultinomialSampler
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
class PenalizedMultinomialSampler(MultinomialSampler):
|
10 |
+
|
11 |
+
def __init__(self, **kwargs):
|
12 |
+
super().__init__(**kwargs)
|
13 |
+
self.penalized_tokens_group: list[torch.IntTensor] = []
|
14 |
+
self.max_repeats_per_token_group: list[int] = []
|
15 |
+
self.repeats_per_token_group: list[int] = []
|
16 |
+
self.token_id_to_tokens_groups: list[list[int]] = []
|
17 |
+
|
18 |
+
def set_max_repeats(self, token_ids: list[int], max_repeats: int) -> None:
|
19 |
+
max_token_ids = max(token_ids)
|
20 |
+
if max_token_ids >= len(self.token_id_to_tokens_groups):
|
21 |
+
self.token_id_to_tokens_groups += [[] for _ in range(len(self.token_id_to_tokens_groups), max_token_ids + 1)]
|
22 |
+
for token_id in token_ids:
|
23 |
+
self.token_id_to_tokens_groups[token_id].append(len(self.penalized_tokens_group))
|
24 |
+
self.penalized_tokens_group.append(torch.tensor(token_ids, dtype=torch.int32))
|
25 |
+
self.max_repeats_per_token_group.append(max_repeats)
|
26 |
+
self.repeats_per_token_group.append(0)
|
27 |
+
|
28 |
+
def __call__(
|
29 |
+
self,
|
30 |
+
next_token_logits: torch.DoubleTensor,
|
31 |
+
sequence_weights: torch.DoubleTensor,
|
32 |
+
rng: torch.Generator,
|
33 |
+
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
|
34 |
+
"""Call the multinomial sampler.
|
35 |
+
|
36 |
+
Parameters
|
37 |
+
----------
|
38 |
+
next_token_logits
|
39 |
+
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
|
40 |
+
probability distribution of the next token over the vocabulary.
|
41 |
+
sequence_weights
|
42 |
+
A tensor of shape ``(n_seqs,)`` that represents the cumulative
|
43 |
+
weight of each sequence.
|
44 |
+
rng
|
45 |
+
A random number generator.
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
A tuple with an array that contains the ids of the sampled tokens of
|
50 |
+
shape ``(n_seqs, 1)``, an array that contains the ancestors of each
|
51 |
+
sampled id of shape ``(n_seqs,)`` and an array that contains the updated
|
52 |
+
cumulative weights of each sequence of shape ``(n_seqs,)``.
|
53 |
+
|
54 |
+
"""
|
55 |
+
if sequence_weights.min() == sequence_weights.max() == 0:
|
56 |
+
self.repeats_per_token_group = [0] * len(self.repeats_per_token_group)
|
57 |
+
else:
|
58 |
+
for penalized_tokens_group, max_repeats_per_token_group, repeats_per_token_group in zip(self.penalized_tokens_group, self.max_repeats_per_token_group, self.repeats_per_token_group):
|
59 |
+
if repeats_per_token_group >= max_repeats_per_token_group:
|
60 |
+
penalty = torch.zeros_like(next_token_logits)
|
61 |
+
penalty[:, penalized_tokens_group] = - torch.inf
|
62 |
+
next_token_logits = next_token_logits + penalty
|
63 |
+
next_token_ids, ancestors, weights = super().__call__(
|
64 |
+
next_token_logits=next_token_logits,
|
65 |
+
sequence_weights=sequence_weights,
|
66 |
+
rng=rng
|
67 |
+
)
|
68 |
+
for next_token_id in next_token_ids.cpu():
|
69 |
+
if next_token_id < len(self.token_id_to_tokens_groups):
|
70 |
+
for token_group in self.token_id_to_tokens_groups[next_token_id]:
|
71 |
+
self.repeats_per_token_group[token_group] += 1
|
72 |
+
return next_token_ids, ancestors, weights
|
utils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import logging
|
3 |
+
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
class StringIteratorIO(io.TextIOBase):
|
9 |
+
"""From: https://stackoverflow.com/a/12604375"""
|
10 |
+
|
11 |
+
def __init__(self, iter):
|
12 |
+
self._iter = iter
|
13 |
+
self._left = ''
|
14 |
+
|
15 |
+
def readable(self):
|
16 |
+
return True
|
17 |
+
|
18 |
+
def _read1(self, n=None):
|
19 |
+
while not self._left:
|
20 |
+
try:
|
21 |
+
self._left = next(self._iter)
|
22 |
+
except StopIteration:
|
23 |
+
break
|
24 |
+
ret = self._left[:n]
|
25 |
+
self._left = self._left[len(ret):]
|
26 |
+
return ret
|
27 |
+
|
28 |
+
def read(self, n=None):
|
29 |
+
buf = []
|
30 |
+
if n is None or n < 0:
|
31 |
+
while True:
|
32 |
+
m = self._read1()
|
33 |
+
if not m:
|
34 |
+
break
|
35 |
+
buf.append(m)
|
36 |
+
else:
|
37 |
+
while n > 0:
|
38 |
+
m = self._read1(n)
|
39 |
+
if not m:
|
40 |
+
break
|
41 |
+
n -= len(m)
|
42 |
+
buf.append(m)
|
43 |
+
return ''.join(buf)
|
44 |
+
|
45 |
+
def readline(self):
|
46 |
+
buf = []
|
47 |
+
while True:
|
48 |
+
i = self._left.find('\n')
|
49 |
+
if i == -1:
|
50 |
+
buf.append(self._left)
|
51 |
+
try:
|
52 |
+
self._left = next(self._iter)
|
53 |
+
except StopIteration:
|
54 |
+
self._left = ''
|
55 |
+
break
|
56 |
+
else:
|
57 |
+
buf.append(self._left[:i+1])
|
58 |
+
self._left = self._left[i+1:]
|
59 |
+
break
|
60 |
+
return ''.join(buf)
|