Spaces:
Sleeping
Sleeping
from copy import copy | |
from functools import partial | |
from outlines.fsm.guide import RegexGuide | |
from pydantic import BaseModel | |
from transformers import PreTrainedTokenizerBase | |
def merge_successive_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j): | |
states_to_token_maps = dict(states_to_token_maps) | |
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]} | |
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]} | |
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i) | |
for s1, s2 in transitions_i.items(): | |
while s2 in transitions_j: | |
s2 = transitions_j[s2] | |
if s2 != transitions_i[s1]: | |
states_to_token_maps[s1] = dict(states_to_token_maps[s1]) | |
states_to_token_maps[s1][i] = s2 | |
return states_to_token_maps | |
def replace_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j): | |
states_to_token_maps = dict(states_to_token_maps) | |
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]} | |
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]} | |
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i) | |
for s1, s2 in transitions_i.items(): | |
if s2 != transitions_j.get(s1): | |
states_to_token_maps[s1] = dict(states_to_token_maps[s1]) | |
if s1 in transitions_j: | |
states_to_token_maps[s1][i] = transitions_j[s1] | |
else: | |
states_to_token_maps[s1].pop(i) | |
states_to_token_maps[s1][j] = s2 | |
return states_to_token_maps | |
def find_paths_with_transitions(states_to_token_maps: dict[int, dict[int, int]], transitions: list[int]) -> list[list[int]]: | |
possible_s0 = {s0 for s0 in states_to_token_maps if transitions[0] in states_to_token_maps[s0]} | |
possible_s1 = {s1 for s1 in states_to_token_maps if transitions[1] in states_to_token_maps[s1]} - possible_s0 | |
starts = sorted( | |
s0 for s0 in possible_s0 | |
if states_to_token_maps[s0][transitions[0]] in possible_s1 | |
) | |
paths = [[start] for start in starts] | |
for path in paths: | |
for i in transitions: | |
if i in states_to_token_maps[path[-1]]: | |
path.append(states_to_token_maps[path[-1]][i]) | |
else: | |
break | |
return [path for path in paths if len(path) == len(transitions) + 1] | |
def replace_fields(fsm: RegexGuide, model: BaseModel, new_fields: list[str], tokenizer: PreTrainedTokenizerBase, make_infinite_loop: bool = False) -> RegexGuide: | |
assert len(new_fields) <= len(model.model_fields) | |
sttm = dict(fsm.states_to_token_maps) | |
encode = partial(tokenizer.encode, add_special_tokens=False) | |
quote = encode('"')[0] | |
# Let's replace the placeholder fields from the model in the finite state model by the new fields | |
for orig_field, new_field in zip(model.model_fields, new_fields): | |
orig_field_tokens = [encode(orig_field_char)[0] for orig_field_char in orig_field] | |
new_field_tokens = encode(new_field) | |
assert len(new_field_tokens) <= len(orig_field_tokens) | |
# Merge transitions until we have number of transitions = number of tokens in the field name | |
for k in reversed(range(len(new_field_tokens), len(orig_field_tokens))): | |
sttm = merge_successive_transitions(sttm, orig_field_tokens[k - 1], orig_field_tokens[k]) | |
# Replace the token ids in the transitions with the ones of the new field name | |
for k in range(len(new_field_tokens)): | |
sttm = replace_transitions(sttm, orig_field_tokens[k], new_field_tokens[k]) | |
if len(new_fields) < len(model.model_fields) or make_infinite_loop: | |
# Set the last field last state to generate less than the number of fields in the model | |
# We need to do this for every possible path | |
# e.g. multiple paths are used to count items when setting a min/max length | |
orig_last_field = list(model.model_fields)[-1] | |
new_last_field = new_fields[-1] | |
orig_last_field_paths = find_paths_with_transitions(sttm, [quote] + [encode(c)[0] for c in orig_last_field]) | |
new_last_field_paths = find_paths_with_transitions(sttm, [quote] + encode(new_last_field)) | |
if make_infinite_loop: # this is a hack to loop on the same states over and over again | |
orig_last_field_paths = [orig_last_field_paths[0]] * len(orig_last_field_paths) | |
for orig_last_field_path, new_last_field_path in zip( | |
orig_last_field_paths, | |
new_last_field_paths | |
): | |
orig_last_field_last_state = orig_last_field_path[-1] | |
new_last_field_second_last_state = new_last_field_path[-2] | |
sttm[new_last_field_second_last_state] = dict(sttm[new_last_field_second_last_state]) | |
sttm[new_last_field_second_last_state][encode(new_last_field)[-1]] = orig_last_field_last_state | |
fsm = copy(fsm) | |
fsm.states_to_token_maps = sttm | |
return fsm | |