Reshinth Adithyan commited on
Commit
efc64b8
1 Parent(s): 5ef742a

diff_utils commit

Browse files
Files changed (4) hide show
  1. __pycache__/utils.cpython-39.pyc +0 -0
  2. app.py +10 -3
  3. requirements.txt +1 -0
  4. utils.py +324 -0
__pycache__/utils.cpython-39.pyc ADDED
Binary file (8.28 kB). View file
 
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import difflib
4
  import re
 
5
 
6
  commit_message_per_brush = {
7
  "Annotate Type": "annotate type to the variables.",
@@ -11,7 +12,7 @@ commit_message_per_brush = {
11
  }
12
 
13
 
14
- def load_model_and_tokenizer(model_name:str="CarperAI/diff-codegen-2B-v2"):
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
  model = AutoModelForCausalLM.from_pretrained(model_name)
17
  return tokenizer, model
@@ -29,7 +30,7 @@ def generate_diff(code:str):
29
 
30
 
31
  def postprocess_output(generated_output:str):
32
- pass
33
 
34
  st.title("Code Brush")
35
  st.write("A tool to brush up your code")
@@ -43,7 +44,13 @@ with st.form("my_form"):
43
  submit_button = st.form_submit_button("Submit")
44
  if submit_button:
45
  st.write("## Diff:")
46
- st.text_area(generate_diff(make_prompt(text,brush_type)))
 
 
 
 
 
 
47
 
48
 
49
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import difflib
4
  import re
5
+ from utils import verify_diff, apply_diff_from_output
6
 
7
  commit_message_per_brush = {
8
  "Annotate Type": "annotate type to the variables.",
 
12
  }
13
 
14
 
15
+ def load_model_and_tokenizer(model_name:str="CarperAI/diff-codegen-350M-v2"):
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForCausalLM.from_pretrained(model_name)
18
  return tokenizer, model
 
30
 
31
 
32
  def postprocess_output(generated_output:str):
33
+ return verify_diff(generated_output)
34
 
35
  st.title("Code Brush")
36
  st.write("A tool to brush up your code")
 
44
  submit_button = st.form_submit_button("Submit")
45
  if submit_button:
46
  st.write("## Diff:")
47
+ generate_diff = generate_diff(make_prompt(text,brush_type))
48
+ after_file = apply_diff_from_output(generate_diff)
49
+ generate_diff_processed = postprocess_output(generate_diff)
50
+ st.write(after_file)
51
+ st.write(generate_diff_processed)
52
+ #st.text_area(generate_diff_processed)
53
+ #st.text_area(generate_diff, height=150, value=generate_diff)
54
 
55
 
56
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  streamlit
2
  transformers
 
 
1
  streamlit
2
  transformers
3
+ torch
utils.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Taken from https://github.com/CarperAI/OpenELM/blob/main/src/openelm/utils/diff_eval.py
2
+ import re
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ line_number_pattern = re.compile(r"(?m)^@@ -(?P<l1>\d*),*?(?P<s1>\d*?) \+(?P<l2>\d*),*?(?P<s2>\d*?) @@")
7
+ diff_pattern = re.compile(
8
+ r"""<NME> (?P<name>.*?)
9
+ <BEF> (?P<file>(.|\n)*?)
10
+ <MSG> (?P<message>(.|\n)*?)
11
+ <DFF> (?P<diff>(.|\n)*)"""
12
+ )
13
+ hunk_split_pattern = re.compile(r"(?m)^(@@ .*? @@).*\n")
14
+ ignored = re.compile(r"(?m)^\$\n?")
15
+
16
+
17
+ class DiffState(Enum):
18
+ """
19
+ An Enum keeping track of the validity of the diff data. It is the return of the helper function `verify_diff`.
20
+ Binary codes help internally, as some errors are additive (e.g., can have both invalid text and invalid line num).
21
+ But we convert the binary codes into Enum for better readability.
22
+ """
23
+
24
+ VALID = 0b000 # valid diff
25
+
26
+ # The following are errors that can still be either ignored or fixed.
27
+ INVALID_TEXT = 0b001 # pre-diff texts cannot be found in the context.
28
+ INVALID_LINE_NUM = (
29
+ 0b010 # the numbers in @@ -x,y +a,b @@ are invalid (but can be parsed).
30
+ )
31
+ INVALID_TEXT_AND_LINE_NUM = 0b011 # both 0b001 and 0b010.
32
+
33
+ # The following are format errors that cannot be ignored.
34
+ BAD_FORMAT = 0b100 # cannot be parsed according to <NME> ...\n<BEF> ...\n<MSG> ...\n<DFF> @@ ... @@\n...
35
+ BAD_DIFF_HUNK_FORMAT = 0b101 # diff hunk contains lines whose initial character is not one of ' ', '+', '-'
36
+ BAD_LINE_NUM_FORMAT = (
37
+ 0b110 # the @@ ... @@ bracket can be found but numbers cannot be parsed.
38
+ )
39
+ BAD_HUNK_AND_LINE_FORMAT = 0b111 # both 0b110 and 0b101.
40
+
41
+
42
+ def split_diff(content: str) -> dict:
43
+ """
44
+ Args:
45
+ content: the diff content.
46
+
47
+ Returns:
48
+ A dict with potentially 4 items:
49
+ name: the filename
50
+ file: the file content
51
+ message: the diff message
52
+ diff: the diff hunk
53
+ Any key could be missing. That would mean a failure in matching.
54
+ """
55
+ match = diff_pattern.match(content)
56
+ return {} if match is None else match.groupdict()
57
+
58
+
59
+ def parse_line_info(content: str) -> tuple:
60
+ """
61
+ Parse @@ -x,y +a,b @@
62
+
63
+ Args:
64
+ the @@ ... @@ line
65
+ Returns:
66
+ (x, y, a, b) as integers
67
+ """
68
+ match = line_number_pattern.match(content)
69
+ if match is None:
70
+ return ()
71
+ match_dict = match.groupdict()
72
+ # line numbers are mandatory
73
+ if not match_dict['l1'] or not match_dict['l2']:
74
+ return ()
75
+ for s in ['s1', 's2']:
76
+ # line ranges are optional and default to 1
77
+ match_dict[s] = match_dict[s] if match_dict[s] else '1'
78
+ return int(match_dict['l1']), int(match_dict['s1']), int(match_dict['l2']), int(match_dict['s2'])
79
+
80
+
81
+ def parse_diff_content(
82
+ hunk: str, separate_lines=False, reject_invalid=False
83
+ ) -> Optional[tuple]:
84
+ """
85
+ Parse a diff content to turn it into (before_diff, after_diff) based on '+', '-' at the beginning of each line.
86
+
87
+ Args:
88
+ hunk: the diff content (without "@@ ... @@").
89
+ separate_lines: (Optional) True if return list of lines.
90
+ reject_invalid: (Optional) True if return None for invalid diff hunk (non-empty lines without starting
91
+ with ' ', '-', '+')
92
+ Returns:
93
+ (before_diff, after_diff);
94
+ None if reject_invalid==True and the diff hunk contains invalid format.
95
+ """
96
+ hunk = hunk.split("\n")
97
+ before_diff, after_diff = [], []
98
+ for line in hunk:
99
+ # Ignore invalid trailing '\n'. An empty line in the diff hunk should at least be '\n ' with the space.
100
+ if not line:
101
+ continue
102
+ if line[0] == "-" or line[0] == " ":
103
+ before_diff.append(line[1:])
104
+ if line[0] == "+" or line[0] == " ":
105
+ after_diff.append(line[1:])
106
+ if reject_invalid:
107
+ if all([line[0] != c for c in [" ", "-", "+"]]):
108
+ return None
109
+ if separate_lines:
110
+ return before_diff, after_diff
111
+ else:
112
+ return "\n".join(before_diff), "\n".join(after_diff)
113
+
114
+
115
+ def replace_text(text: str,
116
+ before: str,
117
+ after: str,
118
+ start_pointer: int,
119
+ reject_incomplete_line: bool = True) -> tuple[str, int]:
120
+ """
121
+ Try to match `before` within `text` and replace the content into `after`.
122
+ If not found, return the original text.
123
+
124
+ Args:
125
+ text: the original text.
126
+ before: the text to be matched.
127
+ after: the text to be replaced into.
128
+ start_pointer: the index where we start to match (inclusive).
129
+ reject_incomplete_line: (Optional) reject the patch if `before` does not match till the end of a line.
130
+ Returns:
131
+ (diff_result, new_start_pointer)
132
+ the text after the match-and-replace and the new index at the end of the change.
133
+ """
134
+ idx = text[start_pointer:].find(before)
135
+ start_idx = start_pointer + idx
136
+
137
+ if reject_incomplete_line:
138
+ # If the end of the match is neither EOF nor \n, reject the patch.
139
+ if idx >= 0 and start_idx + len(before) < len(text) and text[start_idx + len(before)] != '\n':
140
+ return text, start_pointer
141
+
142
+ if idx < 0:
143
+ return text, start_pointer
144
+ else:
145
+ # Even if start_idx + len(before) is out-of-bound, the list slicing would return ""
146
+ return text[:start_idx] + after + text[start_idx + len(before):], start_idx + len(after)
147
+
148
+
149
+ def apply_diff(file: str, diff: str, use_line_number=False, allow_add_file=True) -> str:
150
+ """
151
+ Apply the diff to the file content. We try to be lenient and keep applying the patch naively until we cannot.
152
+ (Note: use_line_number=True is somehow slightly slower.)
153
+ (Warning: if use_line_number==False, we could have some problematic cases like, if all lines in diff hunk
154
+ starts with "+", the pre-diff paragraphs relevant to the hunk is empty. Because we only use pre-diff
155
+ paragraphs to match, we would simply match the very beginning.)
156
+ Args:
157
+ file: the file content.
158
+ diff: the diff hunk (containing "@@ -x,y +a,b @@").
159
+ use_line_number: (Optional) use the line numbers in "@@ ... @@" faithfully.
160
+ allow_add_file: (Optional) when file is "ADDFILE" (meaning <BEF> ADDFILE\n... showed up in the diff text),
161
+ we automatically patch the diff by a direct replacement.
162
+ Return:
163
+ the maximally patched file content.
164
+ """
165
+ diff = hunk_split_pattern.split(ignored.sub("", diff))
166
+ # If we use the line numbers, we match-and-replace in a line-by-line fashion.
167
+ file_by_line = file.split("\n") if use_line_number else None
168
+ line_offset = 0 # the offset between pre-/post-patching line numbers
169
+
170
+ # If we do not use the line numbers, for multiple diff hunk, we only move forward in a greedy manner.
171
+ patch_pointer = 0
172
+
173
+ i = (
174
+ 0 if diff[0] else 1
175
+ ) # We have delimiter at the beginning, causing empty initial string
176
+ while (
177
+ i < len(diff) - 1
178
+ ): # Need at least a pair of '@@ ... @@' and diff hunk to continue
179
+ # Expect a string with '@@ ... @@' followed by a diff hunk
180
+ line_info = parse_line_info(diff[i])
181
+ diff_content = diff[i + 1]
182
+ i += 2
183
+
184
+ # Generate the pre-/post-diff string based on the first character being '+' or '-'
185
+ # (Note: parse_diff_content will ignore trailing \n at the beginning and at the end)
186
+ parsed_diff = parse_diff_content(diff_content, separate_lines=use_line_number)
187
+
188
+ # If we allow the recognition of "ADDFILE", special treatment is needed.
189
+ if allow_add_file and file == "ADDFILE":
190
+ if use_line_number:
191
+ # Immediately apply the first hunk but also check the partial validity of line numbers.
192
+ return parsed_diff[1] if line_info == (0, 0) else ""
193
+ else:
194
+ # Immediately apply the first hunk and ignore the rest.
195
+ return parsed_diff[1]
196
+
197
+ if use_line_number:
198
+ # If line numbers cannot be parsed, skip.
199
+ if not line_info:
200
+ continue
201
+
202
+ # Offset the starting line
203
+ start_idx = line_info[0] + line_offset
204
+
205
+ # Match the referred lines with the file context
206
+ referred_lines = file_by_line[start_idx - 1 : start_idx - 1 + line_info[1]]
207
+ valid = all([l1 == l2 for l1, l2 in zip(parsed_diff[0], referred_lines)])
208
+
209
+ # If lines fully match and the number of lines is consistent, apply the patch.
210
+ # We ignore the second pair "+a, b" just to be lenient.
211
+ if valid and len(parsed_diff[0]) == line_info[1]:
212
+ # Update the list of lines
213
+ if start_idx == 0: # Add lines to the beginning.
214
+ file_by_line = parsed_diff[1] + file_by_line
215
+ else:
216
+ file_by_line = file_by_line[: start_idx - 1] + parsed_diff[1] + \
217
+ file_by_line[start_idx - 1 + line_info[1]:]
218
+ line_offset += len(parsed_diff[1]) - line_info[1]
219
+ else:
220
+ # CAUTION: this way of handling empty context is being very lenient and could lead to
221
+ # undesirable behaviors. Only do this when you want to be as tolerant as possible.
222
+ if parsed_diff[0] == "":
223
+ if patch_pointer != 0: # Lack of matching context can only happen at the beginning of file.
224
+ continue
225
+ file = parsed_diff[1] + "\n" + file
226
+ patch_pointer = len(parsed_diff[0]) + 1
227
+ else:
228
+ # Directly (and naively) apply patch by match-and-replace.
229
+ file, patch_pointer = replace_text(file, parsed_diff[0], parsed_diff[1], patch_pointer)
230
+
231
+ if use_line_number:
232
+ file = "\n".join(file_by_line)
233
+ return file
234
+
235
+
236
+ def apply_diff_from_output(generated_text:str):
237
+ before = generated_text.split("<DFF>")[0]
238
+ diff_hunk = generated_text.split("<DFF>")[1]
239
+ return apply_diff(before, diff_hunk)
240
+
241
+ def verify_diff(diff_text: str) -> DiffState:
242
+ """
243
+ Verify the validity of a complete diff text.
244
+
245
+ Args:
246
+ diff_text: the complete diff text.
247
+ The overall format conforms "<NME> ...\n<BEF> ...\n<MSG> ...\n<DFF> ..." and the text
248
+ after <DFF> has 1 or more lines of "@@ -x,y +a,b @@" followed by the corresponding hunk.
249
+ Returns:
250
+ A DiffState (see above).
251
+ """
252
+ diff_dict = split_diff(ignored.sub("", diff_text)) # Ignore the GitHub warning on the end of file
253
+ line_offset = 0
254
+
255
+ keys = ["name", "file", "message", "diff"]
256
+ for key in keys:
257
+ if key not in diff_dict:
258
+ return DiffState(0b100) # Invalid overall format
259
+
260
+ diff_parts = hunk_split_pattern.split(diff_dict["diff"])
261
+ if not diff_parts:
262
+ return DiffState(0b100) # Invalid overall format
263
+
264
+ context_mismatch, line_number_mismatch = False, False
265
+ bad_diff_hunk, bad_line_number = False, False
266
+
267
+ i = 0 if diff_parts[0] else 1
268
+ while (
269
+ i < len(diff_parts) - 1
270
+ ): # Need at least a pair of '@@ ... @@' and diff hunk to continue
271
+ line_info = parse_line_info(diff_parts[i])
272
+ diff_content = parse_diff_content(diff_parts[i + 1], reject_invalid=True)
273
+ i += 2
274
+
275
+ # Special treatment if we are adding a new file
276
+ if diff_dict["file"] == "ADDFILE":
277
+ if (
278
+ len(diff_parts) != i
279
+ or not line_info
280
+ or line_info[:3] != (0, 0, 1)
281
+ or line_info[3] != len(diff_content[1].split("\n"))
282
+ or diff_content[0]
283
+ ):
284
+ return DiffState(0b110)
285
+ else:
286
+ return DiffState(0b000)
287
+
288
+ if not line_info or len(line_info) != 4:
289
+ bad_line_number = True
290
+ if diff_content is None:
291
+ bad_diff_hunk = True
292
+
293
+ # Skip the diff matching checks if bad format already occurred
294
+ if bad_diff_hunk or bad_line_number:
295
+ continue
296
+
297
+ # Try to see if there is a match in the file context. Must match complete lines or till EOF.
298
+ match_idx = diff_dict["file"].find(diff_content[0])
299
+ if match_idx == -1 or (
300
+ match_idx + len(diff_content[0]) != len(diff_dict["file"])
301
+ and diff_dict["file"][match_idx + len(diff_content[0])] != "\n"
302
+ ):
303
+ context_mismatch = True
304
+
305
+ if line_info[0] <= 0:
306
+ # -0,0 only happens when we create a new file (in which case the context is <BEF> ADDFILE\n...).
307
+ if line_info[1] != 0 or diff_dict["file"] != "ADDFILE":
308
+ line_number_mismatch = True
309
+ else:
310
+ # Check the line numbers regardless of whether the context matches.
311
+ pre_diff_line_number = len(diff_content[0].split("\n"))
312
+ post_diff_line_number = len(diff_content[1].split("\n"))
313
+ if (pre_diff_line_number, post_diff_line_number) != (
314
+ line_info[1],
315
+ line_info[3],
316
+ ):
317
+ line_number_mismatch = True
318
+ else:
319
+ line_offset += len(diff_content[1]) - line_info[1]
320
+
321
+ if bad_diff_hunk or bad_line_number:
322
+ return DiffState(bad_diff_hunk * 0b001 + bad_line_number * 0b010 + 0b100)
323
+ else:
324
+ return DiffState(context_mismatch * 0b001 + line_number_mismatch * 0b010)