plaggy commited on
Commit
8134cf8
1 Parent(s): 7ddbf9d

refactor, a single process

Browse files
Files changed (1) hide show
  1. app.py +97 -102
app.py CHANGED
@@ -27,7 +27,11 @@ class Chunker:
27
  self.split_seq = split_seq
28
  self.chunk_len = chunk_len
29
  if strategy == "recursive":
30
- self.split = RecursiveCharacterTextSplitter().split_text
 
 
 
 
31
  if strategy == "sequence":
32
  self.split = self.seq_splitter
33
  if strategy == "constant":
@@ -51,26 +55,6 @@ def generator(input_ds, input_text_col, chunker):
51
  yield {input_text_col: chunk}
52
 
53
 
54
- def chunk(input_ds, input_splits, input_text_col, output_ds, strategy, split_seq, chunk_len, private):
55
- input_splits = [spl.strip() for spl in input_splits.split(",") if spl]
56
- input_ds = load_dataset(input_ds, split="+".join(input_splits))
57
- chunker = Chunker(strategy, split_seq, chunk_len)
58
-
59
- gen_kwargs = {
60
- "input_ds": input_ds,
61
- "input_text_col": input_text_col,
62
- "chunker": chunker
63
- }
64
- dataset = Dataset.from_generator(generator, gen_kwargs=gen_kwargs)
65
- dataset.push_to_hub(
66
- output_ds,
67
- private=private,
68
- token=HF_TOKEN
69
- )
70
-
71
- logger.info("Done chunking")
72
-
73
-
74
  async def embed_sent(sentence, embed_in_text_col, semaphore, tei_url, tmp_file):
75
  async with semaphore:
76
  payload = {
@@ -108,6 +92,7 @@ async def embed_ds(input_ds, tei_url, embed_in_text_col, temp_file):
108
 
109
 
110
  def wake_up_endpoint(url):
 
111
  n_loop = 0
112
  while requests.get(
113
  url=url,
@@ -115,30 +100,61 @@ def wake_up_endpoint(url):
115
  ).status_code != 200:
116
  time.sleep(2)
117
  n_loop += 1
118
- if n_loop > 30:
119
- raise TimeoutError("TEI endpoint is unavailable")
120
  logger.info("TEI endpoint is up")
121
 
122
 
123
- def run_embed(input_ds, input_splits, embed_in_text_col, output_ds, tei_url, private):
124
- wake_up_endpoint(tei_url)
125
- input_splits = [spl.strip() for spl in input_splits.split(",") if spl]
126
- input_ds = load_dataset(input_ds, split="+".join(input_splits))
127
- with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
128
- asyncio.run(embed_ds(input_ds, tei_url, embed_in_text_col, temp_file))
 
 
 
129
 
130
- dataset = Dataset.from_json(temp_file.name)
131
- dataset.push_to_hub(
132
- output_ds,
133
- private=private,
134
- token=HF_TOKEN
135
- )
 
 
 
 
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  logger.info("Done embedding")
138
 
139
 
140
  def change_dropdown(choice):
141
- if choice == "recursive" or choice == "sequence":
 
 
 
 
 
142
  return [
143
  gr.Textbox(visible=True),
144
  gr.Textbox(visible=False)
@@ -153,73 +169,52 @@ def change_dropdown(choice):
153
  with gr.Blocks() as demo:
154
  gr.Markdown(
155
  """
156
- ## Chunk your dataset before embedding
157
  """
158
  )
159
- with gr.Tab("Chunk"):
160
- chunk_in_ds = gr.Textbox(lines=1, label="Input dataset name")
161
- with gr.Row():
162
- chunk_in_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test")
163
- chunk_in_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text")
164
- with gr.Row():
165
- chunk_out_ds = gr.Textbox(lines=1, label="Output dataset name", scale=6)
166
- chunk_private = gr.Checkbox(label="Make chunked dataset private")
167
- with gr.Row():
168
- dropdown = gr.Dropdown(
169
- ["recursive", "sequence", "constant"], label="Chunking strategy",
170
- info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, "
171
- "'constant' makes chunks of the constant size",
172
- scale=2
173
- )
174
- split_seq = gr.Textbox(
175
- lines=1,
176
- interactive=True,
177
- visible=False,
178
- label="Sequence",
179
- info="A text sequence to split on",
180
- placeholder="\n\n"
181
- )
182
- chunk_len = gr.Textbox(
183
- lines=1,
184
- interactive=True,
185
- visible=False,
186
- label="Length",
187
- info="The length of chunks to split into",
188
- placeholder="512"
189
- )
190
- dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len])
191
- with gr.Row():
192
- gr.ClearButton(
193
- components=[
194
- chunk_in_ds, chunk_in_splits, chunk_in_text_col, chunk_out_ds,
195
- dropdown, split_seq, chunk_len, chunk_private
196
- ]
197
- )
198
- chunk_btn = gr.Button("Chunk")
199
- chunk_btn.click(
200
- fn=chunk,
201
- inputs=[chunk_in_ds, chunk_in_splits, chunk_in_text_col, chunk_out_ds,
202
- dropdown, split_seq, chunk_len, chunk_private]
203
- )
204
-
205
- with gr.Tab("Embed"):
206
- embed_in_ds = gr.Textbox(lines=1, label="Input dataset name")
207
- with gr.Row():
208
- embed_in_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test")
209
- embed_in_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text")
210
- with gr.Row():
211
- embed_out_ds = gr.Textbox(lines=1, label="Output dataset name", scale=6)
212
- embed_private = gr.Checkbox(label="Make embedded dataset private")
213
- tei_url = gr.Textbox(lines=1, label="TEI endpoint url")
214
- with gr.Row():
215
- gr.ClearButton(
216
- components=[embed_in_ds, embed_in_splits, embed_in_text_col, embed_out_ds, tei_url, embed_private]
217
- )
218
- embed_btn = gr.Button("Run embed")
219
- embed_btn.click(
220
- fn=run_embed,
221
- inputs=[embed_in_ds, embed_in_splits, embed_in_text_col, embed_out_ds, tei_url, embed_private]
222
- )
223
-
224
 
 
225
  demo.launch(debug=True)
 
27
  self.split_seq = split_seq
28
  self.chunk_len = chunk_len
29
  if strategy == "recursive":
30
+ # https://huggingface.co/spaces/m-ric/chunk_visualizer
31
+ self.split = RecursiveCharacterTextSplitter(
32
+ chunk_size=chunk_len,
33
+ separators=[split_seq]
34
+ ).split_text
35
  if strategy == "sequence":
36
  self.split = self.seq_splitter
37
  if strategy == "constant":
 
55
  yield {input_text_col: chunk}
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  async def embed_sent(sentence, embed_in_text_col, semaphore, tei_url, tmp_file):
59
  async with semaphore:
60
  payload = {
 
92
 
93
 
94
  def wake_up_endpoint(url):
95
+ logger.info("Starting up TEI endpoint")
96
  n_loop = 0
97
  while requests.get(
98
  url=url,
 
100
  ).status_code != 200:
101
  time.sleep(2)
102
  n_loop += 1
103
+ if n_loop > 40:
104
+ raise gr.Error("TEI endpoint is unavailable")
105
  logger.info("TEI endpoint is up")
106
 
107
 
108
+ def chunk_embed(input_ds, input_splits, input_text_col, chunk_out_ds,
109
+ strategy, split_seq, chunk_len, embed_out_ds, tei_url, private):
110
+ gr.Info("Started chunking")
111
+ try:
112
+ input_splits = [spl.strip() for spl in input_splits.split(",") if spl]
113
+ input_ds = load_dataset(input_ds, split="+".join(input_splits), token=HF_TOKEN)
114
+ chunker = Chunker(strategy, split_seq, chunk_len)
115
+ except Exception as e:
116
+ raise gr.Error(str(e))
117
 
118
+ gen_kwargs = {
119
+ "input_ds": input_ds,
120
+ "input_text_col": input_text_col,
121
+ "chunker": chunker
122
+ }
123
+ chunked_ds = Dataset.from_generator(generator, gen_kwargs=gen_kwargs)
124
+ chunked_ds.push_to_hub(
125
+ chunk_out_ds,
126
+ private=private,
127
+ token=HF_TOKEN
128
+ )
129
 
130
+ gr.Info("Done chunking")
131
+ logger.info("Done chunking")
132
+
133
+ try:
134
+ wake_up_endpoint(tei_url)
135
+ with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
136
+ asyncio.run(embed_ds(chunked_ds, tei_url, input_text_col, temp_file))
137
+
138
+ embedded_ds = Dataset.from_json(temp_file.name)
139
+ embedded_ds.push_to_hub(
140
+ embed_out_ds,
141
+ private=private,
142
+ token=HF_TOKEN
143
+ )
144
+ except Exception as e:
145
+ raise gr.Error(str(e))
146
+
147
+ gr.Info("Done embedding")
148
  logger.info("Done embedding")
149
 
150
 
151
  def change_dropdown(choice):
152
+ if choice == "recursive":
153
+ return [
154
+ gr.Textbox(visible=True),
155
+ gr.Textbox(visible=True)
156
+ ]
157
+ elif choice == "sequence":
158
  return [
159
  gr.Textbox(visible=True),
160
  gr.Textbox(visible=False)
 
169
  with gr.Blocks() as demo:
170
  gr.Markdown(
171
  """
172
+ ## Chunk and embed
173
  """
174
  )
175
+ input_ds = gr.Textbox(lines=1, label="Input dataset name")
176
+ with gr.Row():
177
+ input_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test")
178
+ input_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text")
179
+ chunk_out_ds = gr.Textbox(lines=1, label="Chunked dataset name")
180
+ with gr.Row():
181
+ dropdown = gr.Dropdown(
182
+ ["recursive", "sequence", "constant"], label="Chunking strategy",
183
+ info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, "
184
+ "'constant' makes chunks of the constant size",
185
+ scale=2
186
+ )
187
+ split_seq = gr.Textbox(
188
+ lines=1,
189
+ interactive=True,
190
+ visible=False,
191
+ label="Sequence",
192
+ info="A text sequence to split on",
193
+ placeholder="\n\n"
194
+ )
195
+ chunk_len = gr.Textbox(
196
+ lines=1,
197
+ interactive=True,
198
+ visible=False,
199
+ label="Length",
200
+ info="The length of chunks to split into in characters",
201
+ placeholder="512"
202
+ )
203
+ dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len])
204
+ embed_out_ds = gr.Textbox(lines=1, label="Embedded dataset name")
205
+ private = gr.Checkbox(label="Make output datasets private")
206
+ tei_url = gr.Textbox(lines=1, label="TEI endpoint url")
207
+ with gr.Row():
208
+ clear = gr.ClearButton(
209
+ components=[input_ds, input_splits, input_text_col, chunk_out_ds,
210
+ dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
211
+ )
212
+ embed_btn = gr.Button("Submit")
213
+ embed_btn.click(
214
+ fn=chunk_embed,
215
+ inputs=[input_ds, input_splits, input_text_col, chunk_out_ds,
216
+ dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
217
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ demo.queue()
220
  demo.launch(debug=True)