Joshua Lochner commited on
Commit
c2ccf6d
1 Parent(s): 063e3aa

Improve segmentation logic

Browse files
Files changed (1) hide show
  1. src/segment.py +13 -8
src/segment.py CHANGED
@@ -99,25 +99,30 @@ def generate_segments(words, tokenizer, segmentation_args):
99
  current_segment_num_tokens = 0
100
  current_segment = []
101
  for word in segment:
102
- if current_segment_num_tokens + word['num_tokens'] < max_q_size:
103
- # Can add tokens to current segment
104
- current_segment.append(word)
105
- current_segment_num_tokens += word['num_tokens']
106
- else:
107
  # Adding this token would make it have too many tokens
108
  # We save this batch and create new
109
  second_pass_segments.append(current_segment.copy())
110
 
111
- current_segment.append(word)
112
- current_segment_num_tokens += word['num_tokens']
 
113
 
 
 
114
  while current_segment_num_tokens > buffer_size and current_segment:
115
  first_word = current_segment.pop(0)
116
  current_segment_num_tokens -= first_word['num_tokens']
117
 
118
- if current_segment:
119
  second_pass_segments.append(current_segment.copy())
120
 
 
 
 
 
 
121
  return second_pass_segments
122
 
123
 
 
99
  current_segment_num_tokens = 0
100
  current_segment = []
101
  for word in segment:
102
+ new_seg = current_segment_num_tokens + word['num_tokens'] >= max_q_size
103
+ if new_seg:
 
 
 
104
  # Adding this token would make it have too many tokens
105
  # We save this batch and create new
106
  second_pass_segments.append(current_segment.copy())
107
 
108
+ # Add tokens to current segment
109
+ current_segment.append(word)
110
+ current_segment_num_tokens += word['num_tokens']
111
 
112
+ if new_seg:
113
+ # Just created a new segment, so we remove until we only have buffer_size tokens
114
  while current_segment_num_tokens > buffer_size and current_segment:
115
  first_word = current_segment.pop(0)
116
  current_segment_num_tokens -= first_word['num_tokens']
117
 
118
+ if current_segment: # Add remaining segment
119
  second_pass_segments.append(current_segment.copy())
120
 
121
+ # Cleaning up, delete 'num_tokens' from each word
122
+ for segment in second_pass_segments:
123
+ for word in segment:
124
+ word.pop('num_tokens', None)
125
+
126
  return second_pass_segments
127
 
128