import bisect from typing import TYPE_CHECKING, List, Sequence, Tuple import numpy as np import difflib import os from typing import List from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") def diff_texts(text1: str, text2: str): # Encode the input texts to token IDs text1_tokens_ids = tokenizer.encode(text1, add_special_tokens=False) text2_tokens_ids = tokenizer.encode(text2, add_special_tokens=False) # Create a SequenceMatcher object matcher = difflib.SequenceMatcher(None, text1_tokens_ids, text2_tokens_ids) # Get the opcodes (operations) for the differences opcodes = matcher.get_opcodes() # Process the opcodes to create the merged_tokens list merged_tokens = [] for tag, i1, i2, j1, j2 in opcodes: if tag == 'replace': merged_tokens.append((tokenizer.decode(text2_tokens_ids[j1:j2]), '+')) merged_tokens.append((tokenizer.decode(text1_tokens_ids[i1:i2]), '-')) elif tag == 'delete': merged_tokens.append((tokenizer.decode(text1_tokens_ids[i1:i2]), '-')) elif tag == 'insert': merged_tokens.append((tokenizer.decode(text2_tokens_ids[j1:j2]), '+')) elif tag == 'equal': merged_tokens.append((tokenizer.decode(text1_tokens_ids[i1:i2]), None)) else: raise ValueError(f"Unknown tag: {tag}") if len(merged_tokens) >= 3 and \ merged_tokens[-1][1] == None and \ merged_tokens[-2][1] == '-' and \ merged_tokens[-3][1] in ['+', None]: if merged_tokens[-2][0].endswith(merged_tokens[-3][0]): token_3 = merged_tokens.pop() token_2 = merged_tokens.pop() token_1 = merged_tokens.pop() if token_1[1] == None: merged_tokens.append((token_1[0] + token_2[0][:-len(token_1[0])], token_2[1])) elif token_1[1] == '+': merged_tokens.append((token_2[0][:-len(token_1[0])], token_2[1])) merged_tokens.append((token_1[0] + token_3[0], token_3[1])) elif len(merged_tokens) >= 2: if set([merged_tokens[-1][1], merged_tokens[-2][1]]) == set(['+', '-']): token_2 = merged_tokens.pop() token_1 = merged_tokens.pop() common_prefix = os.path.commonprefix([token_1[0], token_2[0]]) common_suffix = os.path.commonprefix([token_1[0][len(common_prefix):][::-1], token_2[0][len(common_prefix):][::-1]])[::-1] if common_prefix: token_1 = (token_1[0][len(common_prefix):], token_1[1]) token_2 = (token_2[0][len(common_prefix):], token_2[1]) if common_suffix: token_1 = (token_1[0][:-len(common_suffix)], token_1[1]) token_2 = (token_2[0][:-len(common_suffix)], token_2[1]) if common_prefix: merged_tokens.append((common_prefix, None)) if token_1[0]: merged_tokens.append(token_1) if token_2[0]: merged_tokens.append(token_2) if common_suffix: merged_tokens.append((common_suffix, None)) # if merged_tokens[-1][1] == '-' and merged_tokens[-2][1] == '+': # for i in range(len(merged_tokens[-2][0]) + 1): # left_part = merged_tokens[-2][0][:i] # right_part = merged_tokens[-2][0][i:] # if merged_tokens[-1][0].startswith(left_part) and merged_tokens[-1][0].endswith(right_part): # token_2 = merged_tokens.pop() # token_1 = merged_tokens.pop() # if left_part: # merged_tokens.append((left_part, None)) # merged_tokens.append((token_2[0][len(left_part):len(token_2[0])-len(right_part)], token_2[1])) # if right_part: # merged_tokens.append((right_part, None)) # print(merged_tokens) # break # elif merged_tokens[-1][1] == '+' and merged_tokens[-2][1] == '-': # for i in range(len(merged_tokens[-1][0]) + 1): # left_part = merged_tokens[-1][0][:i] # right_part = merged_tokens[-1][0][i:] # if merged_tokens[-2][0].startswith(left_part) and merged_tokens[-2][0].endswith(right_part): # token_1 = merged_tokens.pop() # token_2 = merged_tokens.pop() # if left_part: # merged_tokens.append((left_part, None)) # merged_tokens.append((token_2[0][len(left_part):len(token_2[0])-len(right_part)], token_2[1])) # if right_part: # merged_tokens.append((right_part, None)) # break # Merge adjacent tokens with the same category final_merged_tokens = [merged_tokens[0]] last_token_category = merged_tokens[0][1] for token, category in merged_tokens[1:]: if category == last_token_category: final_merged_tokens[-1] = (final_merged_tokens[-1][0] + token, category) else: final_merged_tokens.append((token, category)) last_token_category = category return final_merged_tokens def compare_lists(list1, list2): # Find common prefix prefix = [] for x, y in zip(list1, list2): if x == y: prefix.append(x) else: break # Find common suffix suffix = [] for x, y in zip(reversed(list1), reversed(list2)): if x == y: suffix.insert(0, x) # Insert at the beginning to maintain order else: break return prefix, suffix def diff_token_ids(text1_tokens_ids: List[int], text2_tokens_ids: List[int]): """ Args: text1_tokens_ids: List of token IDs for the first text. text2_tokens_ids: List of token IDs for the second text. Returns: List of tuples, where each tuple contains a list of token ids and a category. The category is one of '+', '-', or None """ # Create a SequenceMatcher object matcher = difflib.SequenceMatcher(None, text1_tokens_ids, text2_tokens_ids) # Get the opcodes (operations) for the differences opcodes = matcher.get_opcodes() # Process the opcodes to create the merged_tokens list merged_tokens = [] for tag, i1, i2, j1, j2 in opcodes: if tag == 'replace': merged_tokens.append([text2_tokens_ids[j1:j2], '+']) merged_tokens.append([text1_tokens_ids[i1:i2], '-']) elif tag == 'delete': merged_tokens.append([text1_tokens_ids[i1:i2], '-']) elif tag == 'insert': merged_tokens.append([text2_tokens_ids[j1:j2], '+']) elif tag == 'equal': merged_tokens.append([text1_tokens_ids[i1:i2], None]) else: raise ValueError(f"Unknown tag: {tag}") if len(merged_tokens) >= 3 and \ merged_tokens[-1][1] == None and \ merged_tokens[-2][1] == '-' and \ merged_tokens[-3][1] in ['+', None]: # if merged_tokens[-2][0].endswith(merged_tokens[-3][0]): if merged_tokens[-3][0] == merged_tokens[-2][0][-len(merged_tokens[-3][0]):]: token_3 = merged_tokens.pop() token_2 = merged_tokens.pop() token_1 = merged_tokens.pop() if token_1[1] == None: # merged_tokens.append((token_1[0] + token_2[0][:-len(token_1[0])], token_2[1])) merged_tokens.append([token_1[0] + token_2[0][:-len(token_1[0])], token_2[1]]) elif token_1[1] == '+': # merged_tokens.append((token_2[0][:-len(token_1[0])], token_2[1])) merged_tokens.append([token_2[0][:-len(token_1[0])], token_2[1]]) # merged_tokens.append((token_1[0] + token_3[0], token_3[1])) merged_tokens.append([token_1[0] + token_3[0], token_3[1]]) elif len(merged_tokens) >= 2: if set([merged_tokens[-1][1], merged_tokens[-2][1]]) == set(['+', '-']): token_2 = merged_tokens.pop() token_1 = merged_tokens.pop() common_prefix, common_suffix = compare_lists(token_1[0], token_2[0]) if common_prefix: # token_1 = (token_1[0][len(common_prefix):], token_1[1]) token_1 = [token_1[0][len(common_prefix):], token_1[1]] # token_2 = (token_2[0][len(common_prefix):], token_2[1]) token_2 = [token_2[0][len(common_prefix):], token_2[1]] if common_suffix: # token_1 = (token_1[0][:-len(common_suffix)], token_1[1]) token_1 = [token_1[0][:-len(common_suffix)], token_1[1]] # token_2 = (token_2[0][:-len(common_suffix)], token_2[1]) token_2 = [token_2[0][:-len(common_suffix)], token_2[1]] if common_prefix: # merged_tokens.append((common_prefix, None)) merged_tokens.append([common_prefix, None]) if token_1[0]: merged_tokens.append(token_1) if token_2[0]: merged_tokens.append(token_2) if common_suffix: # merged_tokens.append((common_suffix, None)) merged_tokens.append([common_suffix, None]) # Merge adjacent tokens with the same category final_merged_tokens = [merged_tokens[0]] last_token_category = merged_tokens[0][1] for token, category in merged_tokens[1:]: if category == last_token_category: final_merged_tokens[-1] = [final_merged_tokens[-1][0] + token, category] else: final_merged_tokens.append([token, category]) last_token_category = category text1_diff_tokens = [x for x in final_merged_tokens if x[1] in ['-', None]] text2_diff_tokens = [x for x in final_merged_tokens if x[1] in ['+', None]] text1_tokens_after_diff = [] for x in text1_diff_tokens: text1_tokens_after_diff.extend(x[0]) text2_tokens_after_diff = [] for x in text2_diff_tokens: text2_tokens_after_diff.extend(x[0]) assert text1_tokens_after_diff == text1_tokens_ids, f"After diff, the tokens in text1 are \n{text1_tokens_after_diff}\n but expected \n{text1_tokens_ids}\n" assert text2_tokens_after_diff == text2_tokens_ids, f"After diff, the tokens in text2 are \n{text2_tokens_after_diff}\n but expected \n{text2_tokens_ids}\n" return text1_diff_tokens, text2_diff_tokens def get_diff_label_marks(text1_tokens_ids: List[int], text2_tokens_ids: List[int]): text1_diff_tokens, text2_diff_tokens = diff_token_ids(text1_tokens_ids, text2_tokens_ids) # only set labels for those with '-' in text1_diff_tokens text1_diff_labels = [] for token, category in text1_diff_tokens: if category == '-': text1_diff_labels.extend([1] * len(token)) else: text1_diff_labels.extend([0] * len(token)) # only set labels for those with '+' in text2_diff_tokens text2_diff_labels = [] for token, category in text2_diff_tokens: if category == '+': text2_diff_labels.extend([1] * len(token)) else: text2_diff_labels.extend([0] * len(token)) # change it to boolean text1_diff_labels = np.array(text1_diff_labels) == 1 text2_diff_labels = np.array(text2_diff_labels) == 1 return text1_diff_labels, text2_diff_labels def get_diff_label_indices(text1_tokens_ids: List[int], text2_tokens_ids: List[int]): text1_diff_tokens, text2_diff_tokens = diff_token_ids(text1_tokens_ids, text2_tokens_ids) # only set labels for those with '-' in text1_diff_tokens text1_diff_labels = [] for token, category in text1_diff_tokens: if category == '-': text1_diff_labels.extend([1] * len(token)) else: text1_diff_labels.extend([0] * len(token)) # only set labels for those with '+' in text2_diff_tokens text2_diff_labels = [] for token, category in text2_diff_tokens: if category == '+': text2_diff_labels.extend([1] * len(token)) else: text2_diff_labels.extend([0] * len(token)) # print(f"text1_diff_tokens: {text1_diff_tokens}") # print(f"text1_diff_labels: {text1_diff_labels}") # print(f"text2_diff_tokens: {text2_diff_tokens}") # print(f"text2_diff_labels: {text2_diff_labels}") # convert to index so the tensor or numpy array can use these indices to slice text1_diff_label_indices= [i for i, x in enumerate(text1_diff_labels) if x == 1] text2_diff_label_indices = [i for i, x in enumerate(text2_diff_labels) if x == 1] return text1_diff_label_indices, text2_diff_label_indices def get_diff_labels_for_demo(text1:str, text2:str, mode='or'): text1_tokens_ids = tokenizer.encode(text1, add_special_tokens=False) text2_tokens_ids = tokenizer.encode(text2, add_special_tokens=False) if mode == 'naive': text1_diff_tokens, text2_diff_tokens = diff_token_ids(text1_tokens_ids, text2_tokens_ids) elif mode == 'or': text1_diff_tokens_marks, text2_diff_tokens_marks = get_diff_label_marks(text1_tokens_ids, text2_tokens_ids) # to int text1_diff_tokens_marks = text1_diff_tokens_marks.astype(int) text2_diff_tokens_marks = text2_diff_tokens_marks.astype(int) # or the marks for i in range(len(text1_diff_tokens_marks)): if i < len(text2_diff_tokens_marks): if text2_diff_tokens_marks[i] == 1 and text1_diff_tokens_marks[i] == 0: text1_diff_tokens_marks[i] = 2 for i in range(len(text2_diff_tokens_marks)): if i < len(text1_diff_tokens_marks): if text1_diff_tokens_marks[i] == 1 and text2_diff_tokens_marks[i] == 0: text2_diff_tokens_marks[i] = 2 text1_diff_tokens_for_demo = [[[text1_tokens_ids[0]], text1_diff_tokens_marks[0]]] text2_diff_tokens_for_demo = [[[text2_tokens_ids[0]], text2_diff_tokens_marks[0]]] text1_last_token_category = text1_diff_tokens_marks[0] text2_last_token_category = text2_diff_tokens_marks[0] for i in range(1, len(text1_diff_tokens_marks)): if text1_diff_tokens_marks[i] == text1_last_token_category: text1_diff_tokens_for_demo[-1][0].append(text1_tokens_ids[i]) else: text1_diff_tokens_for_demo.append([[text1_tokens_ids[i]], text1_diff_tokens_marks[i]]) text1_last_token_category = text1_diff_tokens_marks[i] for i in range(1, len(text2_diff_tokens_marks)): if text2_diff_tokens_marks[i] == text2_last_token_category: text2_diff_tokens_for_demo[-1][0].append(text2_tokens_ids[i]) else: text2_diff_tokens_for_demo.append([[text2_tokens_ids[i]], text2_diff_tokens_marks[i]]) text2_last_token_category = text2_diff_tokens_marks[i] # decode for i in range(len(text1_diff_tokens_for_demo)): text1_diff_tokens_for_demo[i] = (tokenizer.decode(text1_diff_tokens_for_demo[i][0]), '-' if text1_diff_tokens_for_demo[i][1] == 1 else ('#' if text1_diff_tokens_for_demo[i][1] == 2 else None)) for i in range(len(text2_diff_tokens_for_demo)): text2_diff_tokens_for_demo[i] = (tokenizer.decode(text2_diff_tokens_for_demo[i][0]), '+' if text2_diff_tokens_for_demo[i][1] == 1 else ('#' if text2_diff_tokens_for_demo[i][1] == 2 else None)) return text1_diff_tokens_for_demo, text2_diff_tokens_for_demo elif mode == 'advanced': text1_diff_tokens_marks, text2_diff_tokens_marks = get_diff_label_marks(text1_tokens_ids, text2_tokens_ids) text1_diff_tokens_for_demo = [[[text1_tokens_ids[0]], text1_diff_tokens_marks[0]]] text2_diff_tokens_for_demo = [[[text2_tokens_ids[0]], text2_diff_tokens_marks[0]]] text1_last_token_category = text1_diff_tokens_marks[0] text2_last_token_category = text2_diff_tokens_marks[0] # for i in range(1, max(len(text1_diff_tokens_marks), len(text2_diff_tokens_marks))): # if i < len(text1_diff_tokens_marks): # if text1_diff_tokens_marks[i] == 1: # if text1_last_token_category == text1_diff_tokens_marks[i]: # text1_diff_tokens_for_demo[-1][0].append(text1_tokens_ids[i]) # else: # text1_diff_tokens_for_demo.append([[text1_tokens_ids[i]], text1_diff_tokens_marks[i]]) # text1_last_token_category = text1_diff_tokens_marks[i] # else: # text1_diff_tokens_for_demo.append((tokenizer.decode([text1_tokens_ids[i]]), text1_diff_tokens_marks[i])) # if text1_diff_tokens_marks # if text1_diff_tokens_marks[i] == 1 and text2_diff_tokens_marks[i] == 1: # text1_diff_tokens_marks[i] = 0 # text2_diff_tokens_marks[i] = 0 raise NotImplementedError("This mode is not implemented yet") # decode for i in range(len(text1_diff_tokens)): text1_diff_tokens[i] = (tokenizer.decode(text1_diff_tokens[i][0]), text1_diff_tokens[i][1]) for i in range(len(text2_diff_tokens)): text2_diff_tokens[i] = (tokenizer.decode(text2_diff_tokens[i][0]), text2_diff_tokens[i][1]) return text1_diff_tokens, text2_diff_tokens