pszemraj commited on
Commit
e741efb
1 Parent(s): de56fbd

move to zerpGPU

Browse files
Files changed (1) hide show
  1. app.py +32 -39
app.py CHANGED
@@ -3,6 +3,9 @@ import logging
3
  import os
4
  import re
5
 
 
 
 
6
  import torch
7
  from cleantext import clean
8
  import gradio as gr
@@ -13,45 +16,34 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
13
  logging.basicConfig(level=logging.INFO)
14
  logging.info(f"torch version:\t{torch.__version__}")
15
 
 
16
  checker_model_name = "textattack/roberta-base-CoLA"
17
  corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
18
 
19
- # pipelines
20
- checker = pipeline(
21
- "text-classification",
22
- checker_model_name,
23
- )
24
- # checker.model = torch.compile(checker.model)
25
 
26
- gc.collect()
27
-
28
- if os.environ.get("HF_DEMO_NO_USE_ONNX") is None:
29
- # load onnx runtime unless HF_DEMO_NO_USE_ONNX is set
30
- from optimum.pipelines import pipeline
31
-
32
- corrector = pipeline(
33
- "text2text-generation", model=corrector_model_name, accelerator="ort"
34
- )
35
- else:
36
- corrector = pipeline("text2text-generation", corrector_model_name)
37
 
 
 
 
 
 
38
 
39
  def split_text(text: str) -> list:
40
  # Split the text into sentences using regex
41
  sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
42
 
43
- # Initialize a list to store the sentence batches
44
  sentence_batches = []
45
-
46
- # Initialize a temporary list to store the current batch of sentences
47
  temp_batch = []
48
 
49
- # Iterate through the sentences
50
  for sentence in sentences:
51
- # Add the sentence to the temporary batch
52
  temp_batch.append(sentence)
53
-
54
- # If the length of the temporary batch is between 2 and 3 sentences, or if it is the last batch, add it to the list of sentence batches
55
  if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
56
  sentence_batches.append(temp_batch)
57
  temp_batch = []
@@ -59,44 +51,44 @@ def split_text(text: str) -> list:
59
  return sentence_batches
60
 
61
 
62
- def correct_text(text: str, checker, corrector, separator: str = " ") -> str:
 
 
63
  # Split the text into sentence batches
64
  sentence_batches = split_text(text)
65
 
66
  # Initialize a list to store the corrected text
67
  corrected_text = []
68
 
69
- # Iterate through the sentence batches
70
  for batch in tqdm(
71
  sentence_batches, total=len(sentence_batches), desc="correcting text.."
72
  ):
73
- # Join the sentences in the batch into a single string
74
  raw_text = " ".join(batch)
75
 
76
- # Check the grammar quality of the text using the text-classification pipeline
77
  results = checker(raw_text)
78
 
79
- # Only correct the text if the results of the text-classification are not LABEL_1 or are LABEL_1 with a score below 0.9
80
  if results[0]["label"] != "LABEL_1" or (
81
  results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
82
  ):
83
- # Correct the text using the text-generation pipeline
84
  corrected_batch = corrector(raw_text)
85
  corrected_text.append(corrected_batch[0]["generated_text"])
86
  else:
87
  corrected_text.append(raw_text)
88
 
89
- # Join the corrected text into a single string
90
- corrected_text = separator.join(corrected_text)
91
-
92
- return corrected_text
93
 
94
 
95
  def update(text: str):
 
96
  text = clean(text[:4000], lower=False)
97
- return correct_text(text, checker, corrector)
98
 
99
 
 
100
  with gr.Blocks() as demo:
101
  gr.Markdown("# <center>Robust Grammar Correction with FLAN-T5</center>")
102
  gr.Markdown(
@@ -111,7 +103,7 @@ with gr.Blocks() as demo:
111
  with gr.Row():
112
  inp = gr.Textbox(
113
  label="input",
114
- placeholder="PUT TEXT TO CHECK & CORRECT BROSKI",
115
  value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.",
116
  )
117
  out = gr.Textbox(label="output", interactive=False)
@@ -119,7 +111,8 @@ with gr.Blocks() as demo:
119
  btn.click(fn=update, inputs=inp, outputs=out)
120
  gr.Markdown("---")
121
  gr.Markdown(
122
- "- see the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
123
  )
124
- gr.Markdown("- if experiencing long wait times, feel free to duplicate the space!")
125
- demo.launch(debug=True)
 
 
3
  import os
4
  import re
5
 
6
+
7
+ import spaces
8
+
9
  import torch
10
  from cleantext import clean
11
  import gradio as gr
 
16
  logging.basicConfig(level=logging.INFO)
17
  logging.info(f"torch version:\t{torch.__version__}")
18
 
19
+ # Model names
20
  checker_model_name = "textattack/roberta-base-CoLA"
21
  corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
22
 
 
 
 
 
 
 
23
 
24
+ checker = pipeline(
25
+ "text-classification",
26
+ checker_model_name,
27
+ device_map="cuda",
28
+ )
 
 
 
 
 
 
29
 
30
+ corrector = pipeline(
31
+ "text2text-generation",
32
+ corrector_model_name,
33
+ device_map="cuda",
34
+ )
35
 
36
  def split_text(text: str) -> list:
37
  # Split the text into sentences using regex
38
  sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
39
 
40
+ # Initialize lists for batching
41
  sentence_batches = []
 
 
42
  temp_batch = []
43
 
44
+ # Create batches of 2-3 sentences
45
  for sentence in sentences:
 
46
  temp_batch.append(sentence)
 
 
47
  if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
48
  sentence_batches.append(temp_batch)
49
  temp_batch = []
 
51
  return sentence_batches
52
 
53
 
54
+ @spaces.GPU(duration=60)
55
+ def correct_text(text: str, separator: str = " ") -> str:
56
+
57
  # Split the text into sentence batches
58
  sentence_batches = split_text(text)
59
 
60
  # Initialize a list to store the corrected text
61
  corrected_text = []
62
 
63
+ # Process each batch
64
  for batch in tqdm(
65
  sentence_batches, total=len(sentence_batches), desc="correcting text.."
66
  ):
 
67
  raw_text = " ".join(batch)
68
 
69
+ # Check grammar quality
70
  results = checker(raw_text)
71
 
72
+ # Correct text if needed
73
  if results[0]["label"] != "LABEL_1" or (
74
  results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
75
  ):
 
76
  corrected_batch = corrector(raw_text)
77
  corrected_text.append(corrected_batch[0]["generated_text"])
78
  else:
79
  corrected_text.append(raw_text)
80
 
81
+ # Join the corrected text
82
+ return separator.join(corrected_text)
 
 
83
 
84
 
85
  def update(text: str):
86
+ # Clean and truncate input text
87
  text = clean(text[:4000], lower=False)
88
+ return correct_text(text)
89
 
90
 
91
+ # Create the Gradio interface
92
  with gr.Blocks() as demo:
93
  gr.Markdown("# <center>Robust Grammar Correction with FLAN-T5</center>")
94
  gr.Markdown(
 
103
  with gr.Row():
104
  inp = gr.Textbox(
105
  label="input",
106
+ placeholder="Enter text to check & correct",
107
  value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.",
108
  )
109
  out = gr.Textbox(label="output", interactive=False)
 
111
  btn.click(fn=update, inputs=inp, outputs=out)
112
  gr.Markdown("---")
113
  gr.Markdown(
114
+ "- See the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
115
  )
116
+
117
+ # Launch the demo
118
+ demo.launch(debug=True)