Quentin Lhoest commited on
Commit
4f83ec0
1 Parent(s): 7238d75

initial commit

Browse files
Files changed (9) hide show
  1. README.md +1 -1
  2. __init__.py +0 -0
  3. api.py +57 -0
  4. fsm.py +92 -0
  5. generate.py +129 -0
  6. gradio_app.py +77 -0
  7. requirements.txt +6 -0
  8. samplers.py +72 -0
  9. utils.py +60 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
  ---
11
 
__init__.py ADDED
File without changes
api.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Annotated
3
+
4
+ from fastapi import FastAPI, Request
5
+ from fastapi.responses import StreamingResponse
6
+ from pydantic import BaseModel, StringConstraints
7
+ from outlines import generate
8
+
9
+ from generate import model, sampler, stream_file
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class Status(BaseModel):
15
+ status: Annotated[str, StringConstraints(pattern="ok")]
16
+
17
+ status_generator = generate.json(model, Status, sampler=sampler)
18
+
19
+ status_stream = status_generator.stream("status:")
20
+ status = "".join(char.strip() for char in status_stream if char.strip())
21
+ logger.warning("Model status: " + status)
22
+
23
+
24
+ async def stream_response(filename: str, prompt: str, columns: list[str], seed: int, size: int):
25
+ for chunk in stream_file(
26
+ filename=filename,
27
+ prompt=prompt,
28
+ columns=columns,
29
+ seed=seed,
30
+ size=size,
31
+ ):
32
+ yield chunk
33
+
34
+
35
+ async def dummy_stream():
36
+ yield ""
37
+
38
+
39
+ app = FastAPI()
40
+
41
+ @app.head("/{filename}.jsonl")
42
+ @app.get("/{filename}.jsonl")
43
+ async def read_item(request: Request, filename: str, prompt: str = "", columns: str = "", seed: int = 42, size: int = 3):
44
+ if request.method == 'GET':
45
+ columns = [field.strip() for field in columns.split(",") if field.strip()]
46
+ content = stream_response(
47
+ filename,
48
+ prompt=prompt,
49
+ columns=columns,
50
+ seed=seed,
51
+ size=size
52
+ )
53
+ else:
54
+ content = dummy_stream()
55
+ response = StreamingResponse(content, media_type="text/jsonlines")
56
+ response.headers["Content-Disposition"] = f"attachment; filename={filename}.jsonl"
57
+ return response
fsm.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from functools import partial
3
+
4
+ from outlines.fsm.guide import RegexGuide
5
+ from pydantic import BaseModel
6
+ from transformers import PreTrainedTokenizerBase
7
+
8
+
9
+ def merge_successive_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
10
+ states_to_token_maps = dict(states_to_token_maps)
11
+ transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
12
+ transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
13
+ transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
14
+ for s1, s2 in transitions_i.items():
15
+ while s2 in transitions_j:
16
+ s2 = transitions_j[s2]
17
+ if s2 != transitions_i[s1]:
18
+ states_to_token_maps[s1] = dict(states_to_token_maps[s1])
19
+ states_to_token_maps[s1][i] = s2
20
+ return states_to_token_maps
21
+
22
+
23
+ def replace_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
24
+ states_to_token_maps = dict(states_to_token_maps)
25
+ transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
26
+ transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
27
+ transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
28
+ for s1, s2 in transitions_i.items():
29
+ if s2 != transitions_j.get(s1):
30
+ states_to_token_maps[s1] = dict(states_to_token_maps[s1])
31
+ if s1 in transitions_j:
32
+ states_to_token_maps[s1][i] = transitions_j[s1]
33
+ else:
34
+ states_to_token_maps[s1].pop(i)
35
+ states_to_token_maps[s1][j] = s2
36
+ return states_to_token_maps
37
+
38
+
39
+ def find_paths_with_transitions(states_to_token_maps: dict[int, dict[int, int]], transitions: list[int]) -> list[list[int]]:
40
+ possible_s0 = {s0 for s0 in states_to_token_maps if transitions[0] in states_to_token_maps[s0]}
41
+ possible_s1 = {s1 for s1 in states_to_token_maps if transitions[1] in states_to_token_maps[s1]} - possible_s0
42
+ starts = sorted(
43
+ s0 for s0 in possible_s0
44
+ if states_to_token_maps[s0][transitions[0]] in possible_s1
45
+ )
46
+ paths = [[start] for start in starts]
47
+ for path in paths:
48
+ for i in transitions:
49
+ if i in states_to_token_maps[path[-1]]:
50
+ path.append(states_to_token_maps[path[-1]][i])
51
+ else:
52
+ break
53
+ return [path for path in paths if len(path) == len(transitions) + 1]
54
+
55
+
56
+ def replace_fields(fsm: RegexGuide, model: BaseModel, new_fields: list[str], tokenizer: PreTrainedTokenizerBase, make_infinite_loop: bool = False) -> RegexGuide:
57
+ assert len(new_fields) <= len(model.model_fields)
58
+ sttm = dict(fsm.states_to_token_maps)
59
+ encode = partial(tokenizer.encode, add_special_tokens=False)
60
+ quote = encode('"')[0]
61
+ # Let's replace the placeholder fields from the model in the finite state model by the new fields
62
+ for orig_field, new_field in zip(model.model_fields, new_fields):
63
+ orig_field_tokens = [encode(orig_field_char)[0] for orig_field_char in orig_field]
64
+ new_field_tokens = encode(new_field)
65
+ assert len(new_field_tokens) <= len(orig_field_tokens)
66
+ # Merge transitions until we have number of transitions = number of tokens in the field name
67
+ for k in reversed(range(len(new_field_tokens), len(orig_field_tokens))):
68
+ sttm = merge_successive_transitions(sttm, orig_field_tokens[k - 1], orig_field_tokens[k])
69
+ # Replace the token ids in the transitions with the ones of the new field name
70
+ for k in range(len(new_field_tokens)):
71
+ sttm = replace_transitions(sttm, orig_field_tokens[k], new_field_tokens[k])
72
+ if len(new_fields) < len(model.model_fields) or make_infinite_loop:
73
+ # Set the last field last state to generate less than the number of fields in the model
74
+ # We need to do this for every possible path
75
+ # e.g. multiple paths are used to count items when setting a min/max length
76
+ orig_last_field = list(model.model_fields)[-1]
77
+ new_last_field = new_fields[-1]
78
+ orig_last_field_paths = find_paths_with_transitions(sttm, [quote] + [encode(c)[0] for c in orig_last_field])
79
+ new_last_field_paths = find_paths_with_transitions(sttm, [quote] + encode(new_last_field))
80
+ if make_infinite_loop: # this is a hack to loop on the same states over and over again
81
+ orig_last_field_paths = [orig_last_field_paths[0]] * len(orig_last_field_paths)
82
+ for orig_last_field_path, new_last_field_path in zip(
83
+ orig_last_field_paths,
84
+ new_last_field_paths
85
+ ):
86
+ orig_last_field_last_state = orig_last_field_path[-1]
87
+ new_last_field_second_last_state = new_last_field_path[-2]
88
+ sttm[new_last_field_second_last_state] = dict(sttm[new_last_field_second_last_state])
89
+ sttm[new_last_field_second_last_state][encode(new_last_field)[-1]] = orig_last_field_last_state
90
+ fsm = copy(fsm)
91
+ fsm.states_to_token_maps = sttm
92
+ return fsm
generate.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import logging
4
+ import time
5
+ from typing import Annotated, Iterator
6
+
7
+ import ijson
8
+ import outlines
9
+ import torch
10
+ from pydantic import BaseModel, StringConstraints, conlist, conset
11
+ from outlines import generate, models
12
+ from outlines.generate.api import SequenceGenerator
13
+ from transformers import AutoTokenizer
14
+
15
+ from fsm import replace_fields
16
+ from samplers import PenalizedMultinomialSampler
17
+ from utils import StringIteratorIO
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ if torch.cuda.is_available():
23
+ device = "cuda"
24
+ elif torch.backends.mps.is_available():
25
+ device = "mps"
26
+ else:
27
+ raise RuntimeError("couldn't find cuda or mps")
28
+
29
+ logger.warning("Loading model...")
30
+ model_id = "google/gemma-2b-it"
31
+ # model_id = "Qwen/Qwen1.5-0.5B-Chat"
32
+ model = models.transformers(model_id, device=device)
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
34
+ sampler = PenalizedMultinomialSampler()
35
+ empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
36
+ sampler.set_max_repeats(empty_tokens, 1)
37
+
38
+ # This Sample & Dataset models ztr just templated with placeholder fields
39
+
40
+ class Sample(BaseModel):
41
+ # We use get_samples_generator() to replace the placeholder with the requested fields
42
+ ABCDabcd: str
43
+ EFGHefgh: str
44
+ IJKLijkl: str
45
+ MNOPmnop: str
46
+ QRSTqrst: str
47
+ # PS: don't use StringConstraints with max_length here since it creates a fsm that is too big
48
+
49
+
50
+ class Dataset(BaseModel):
51
+ # We use get_samples_generator() to set the length to infinity
52
+ data: conlist(Sample, min_length=2, max_length=3) # type: ignore
53
+
54
+
55
+ samples_generator_template = generate.json(model, Dataset, sampler=sampler)
56
+
57
+ class Columns(BaseModel):
58
+ columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields)) # type: ignore
59
+
60
+ columns_generator = generate.json(model, Columns, sampler=sampler)
61
+
62
+ def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
63
+ fsm=samples_generator_template.fsm
64
+ fsm = replace_fields( # replace the placeholder fields by the real fields
65
+ fsm=samples_generator_template.fsm,
66
+ model=Sample,
67
+ new_fields=new_fields,
68
+ tokenizer=tokenizer,
69
+ make_infinite_loop=True # to generate as many samples as we want
70
+ )
71
+ return SequenceGenerator(
72
+ fsm=fsm,
73
+ model=samples_generator_template.model,
74
+ sampler=samples_generator_template.sampler,
75
+ device=samples_generator_template.device
76
+ )
77
+
78
+
79
+ @outlines.prompt
80
+ def columns_prompt(filename: str):
81
+ """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
82
+ Give an example of column names / columns for this dataset to populate a SQL schema.
83
+ Please reply in JSON format and place the columns in a field named "columns".
84
+ """
85
+
86
+ @outlines.prompt
87
+ def samples_prommpt(filename: str, prompt: str, columns: str):
88
+ """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
89
+ Give an example of content using a JSON field named "data" with samples with columns {{ columns }}.
90
+ {{ prompt }}
91
+ """
92
+
93
+ def stream_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
94
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
95
+ _start = time.time()
96
+ rng = torch.Generator(device=model.device)
97
+ rng.manual_seed(seed)
98
+ if not columns:
99
+
100
+ messages = [
101
+ {"role": "user", "content": columns_prompt(filename=filename)}
102
+ ]
103
+ text = tokenizer.apply_chat_template(
104
+ messages,
105
+ tokenize=False,
106
+ add_generation_prompt=True
107
+ )
108
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns...")
109
+ columns_generator_tokens = columns_generator.stream(text, rng=rng)
110
+ for column in ijson.items(StringIteratorIO(columns_generator_tokens), "columns.item", buf_size=16):
111
+ columns.append(column)
112
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)")
113
+
114
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...")
115
+ samples_generator = get_samples_generator(new_fields=columns)
116
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)")
117
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples...")
118
+ messages = [
119
+ {"role": "user", "content": samples_prommpt(filename=filename, prompt=prompt, columns="'" + "', '".join(columns) + "'")}
120
+ ]
121
+ text = tokenizer.apply_chat_template(
122
+ messages,
123
+ tokenize=False,
124
+ add_generation_prompt=True
125
+ )
126
+ samples_generator_tokens = samples_generator.stream(text, rng=rng)
127
+ for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
128
+ yield json.dumps(sample, ensure_ascii=False) + "\n"
129
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
gradio_app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import io
5
+ import pandas as pd
6
+ import spaces
7
+
8
+ @spaces.GPU(duration=120)
9
+ def stream_output(filename: str):
10
+ if filename.endswith(".jsonl"):
11
+ filename = filename[:-len(".jsonl")]
12
+ from generate import stream_file
13
+ content = ""
14
+ size=3
15
+ start_time = time.time()
16
+ for i, chunk in enumerate(stream_file(
17
+ filename=filename,
18
+ prompt="",
19
+ columns=[],
20
+ seed=42,
21
+ size=size,
22
+ )):
23
+ content += chunk
24
+ df = pd.read_json(io.StringIO(content), lines=True)
25
+ state_msg = (
26
+ f"✅ Done generating {size} samples in {time.time() - start_time:.2f}s"
27
+ if i + 1 == size else
28
+ f"⚙️ Generating... [{i}/{size}]"
29
+ )
30
+ yield df, "```json\n" + content + "\n```", state_msg
31
+
32
+ def test(filename: str):
33
+ if not filename.endswith(".jsonl"):
34
+ yield "❌ 404: File name must end with .jsonl", None, ""
35
+ return
36
+
37
+ content = ""
38
+ size = 10
39
+ start_time = time.time()
40
+ for i in range(size):
41
+ content += f'{{"i": {i}, "filename": "{filename}"}}\n'
42
+ df = pd.read_json(io.StringIO(content), lines=True)
43
+ state_msg = (
44
+ f"✅ Done generating {size} samples in {time.time() - start_time:.2f}s"
45
+ if i + 1 == size else
46
+ f"⚙️ Generating... [{i}/{size}]"
47
+ )
48
+ yield df, "```json\n" + content + "\n```", state_msg
49
+ time.sleep(0.1)
50
+
51
+ title = "LLM DataGen"
52
+ description = "Generate and stream synthetic dataset files in JSON Lines format"
53
+ examples = [
54
+ "movies_data.jsonl",
55
+ "common_first_names.jsonl",
56
+ "bad_amazon_reviews_on_defunct_products_that_people_hate.jsonl",
57
+ "dungeon_and_dragon_characters.jsonl"
58
+ ]
59
+
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown(f"# {title}")
62
+ gr.Markdown(description)
63
+ filename_comp = gr.Textbox(examples[0], placeholder=examples[0])
64
+ gr.Examples(examples, filename_comp)
65
+ generate_button = gr.Button("Generate dataset")
66
+ state_msg_comp = gr.Markdown("🔥 Ready to generate")
67
+ with gr.Tab("Dataset"):
68
+ dataframe_comp = gr.DataFrame()
69
+ with gr.Tab("File content"):
70
+ with gr.Blocks(fill_height=True):
71
+ with gr.Row():
72
+ file_content_comp = gr.Markdown()
73
+
74
+ generate_button.click(test, filename_comp, [dataframe_comp, file_content_comp, state_msg_comp])
75
+
76
+
77
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spaces
2
+ outlines==0.0.37
3
+ transformers
4
+ torch
5
+ ijson
6
+ pydantic
samplers.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from outlines.samplers import MultinomialSampler
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class PenalizedMultinomialSampler(MultinomialSampler):
10
+
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.penalized_tokens_group: list[torch.IntTensor] = []
14
+ self.max_repeats_per_token_group: list[int] = []
15
+ self.repeats_per_token_group: list[int] = []
16
+ self.token_id_to_tokens_groups: list[list[int]] = []
17
+
18
+ def set_max_repeats(self, token_ids: list[int], max_repeats: int) -> None:
19
+ max_token_ids = max(token_ids)
20
+ if max_token_ids >= len(self.token_id_to_tokens_groups):
21
+ self.token_id_to_tokens_groups += [[] for _ in range(len(self.token_id_to_tokens_groups), max_token_ids + 1)]
22
+ for token_id in token_ids:
23
+ self.token_id_to_tokens_groups[token_id].append(len(self.penalized_tokens_group))
24
+ self.penalized_tokens_group.append(torch.tensor(token_ids, dtype=torch.int32))
25
+ self.max_repeats_per_token_group.append(max_repeats)
26
+ self.repeats_per_token_group.append(0)
27
+
28
+ def __call__(
29
+ self,
30
+ next_token_logits: torch.DoubleTensor,
31
+ sequence_weights: torch.DoubleTensor,
32
+ rng: torch.Generator,
33
+ ) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
34
+ """Call the multinomial sampler.
35
+
36
+ Parameters
37
+ ----------
38
+ next_token_logits
39
+ A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
40
+ probability distribution of the next token over the vocabulary.
41
+ sequence_weights
42
+ A tensor of shape ``(n_seqs,)`` that represents the cumulative
43
+ weight of each sequence.
44
+ rng
45
+ A random number generator.
46
+
47
+ Returns
48
+ -------
49
+ A tuple with an array that contains the ids of the sampled tokens of
50
+ shape ``(n_seqs, 1)``, an array that contains the ancestors of each
51
+ sampled id of shape ``(n_seqs,)`` and an array that contains the updated
52
+ cumulative weights of each sequence of shape ``(n_seqs,)``.
53
+
54
+ """
55
+ if sequence_weights.min() == sequence_weights.max() == 0:
56
+ self.repeats_per_token_group = [0] * len(self.repeats_per_token_group)
57
+ else:
58
+ for penalized_tokens_group, max_repeats_per_token_group, repeats_per_token_group in zip(self.penalized_tokens_group, self.max_repeats_per_token_group, self.repeats_per_token_group):
59
+ if repeats_per_token_group >= max_repeats_per_token_group:
60
+ penalty = torch.zeros_like(next_token_logits)
61
+ penalty[:, penalized_tokens_group] = - torch.inf
62
+ next_token_logits = next_token_logits + penalty
63
+ next_token_ids, ancestors, weights = super().__call__(
64
+ next_token_logits=next_token_logits,
65
+ sequence_weights=sequence_weights,
66
+ rng=rng
67
+ )
68
+ for next_token_id in next_token_ids.cpu():
69
+ if next_token_id < len(self.token_id_to_tokens_groups):
70
+ for token_group in self.token_id_to_tokens_groups[next_token_id]:
71
+ self.repeats_per_token_group[token_group] += 1
72
+ return next_token_ids, ancestors, weights
utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class StringIteratorIO(io.TextIOBase):
9
+ """From: https://stackoverflow.com/a/12604375"""
10
+
11
+ def __init__(self, iter):
12
+ self._iter = iter
13
+ self._left = ''
14
+
15
+ def readable(self):
16
+ return True
17
+
18
+ def _read1(self, n=None):
19
+ while not self._left:
20
+ try:
21
+ self._left = next(self._iter)
22
+ except StopIteration:
23
+ break
24
+ ret = self._left[:n]
25
+ self._left = self._left[len(ret):]
26
+ return ret
27
+
28
+ def read(self, n=None):
29
+ buf = []
30
+ if n is None or n < 0:
31
+ while True:
32
+ m = self._read1()
33
+ if not m:
34
+ break
35
+ buf.append(m)
36
+ else:
37
+ while n > 0:
38
+ m = self._read1(n)
39
+ if not m:
40
+ break
41
+ n -= len(m)
42
+ buf.append(m)
43
+ return ''.join(buf)
44
+
45
+ def readline(self):
46
+ buf = []
47
+ while True:
48
+ i = self._left.find('\n')
49
+ if i == -1:
50
+ buf.append(self._left)
51
+ try:
52
+ self._left = next(self._iter)
53
+ except StopIteration:
54
+ self._left = ''
55
+ break
56
+ else:
57
+ buf.append(self._left[:i+1])
58
+ self._left = self._left[i+1:]
59
+ break
60
+ return ''.join(buf)