lengyue233 commited on
Commit
75e9ff1
1 Parent(s): 1caffd8

optimize compile by removing if branch

Browse files
Files changed (2) hide show
  1. app.py +16 -6
  2. tools/llama/generate.py +42 -45
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from huggingface_hub import snapshot_download
3
  import hydra
4
 
@@ -125,17 +126,26 @@ def inference(
125
  )
126
 
127
  payload = dict(
128
- event=threading.Event(),
129
  request=request,
130
  )
131
  llama_queue.put(payload)
132
 
133
- # Wait for the result
134
- payload["event"].wait()
135
- if payload["success"] is False:
136
- raise payload["response"]
 
 
137
 
138
- codes = payload["response"][0]
 
 
 
 
 
 
 
139
 
140
  # VQGAN Inference
141
  feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
 
1
  import os
2
+ import queue
3
  from huggingface_hub import snapshot_download
4
  import hydra
5
 
 
126
  )
127
 
128
  payload = dict(
129
+ response_queue=queue.Queue(),
130
  request=request,
131
  )
132
  llama_queue.put(payload)
133
 
134
+ codes = []
135
+ while True:
136
+ result = payload["response_queue"].get()
137
+ if result == "next":
138
+ # TODO: handle next sentence
139
+ continue
140
 
141
+ if result == "done":
142
+ if payload["success"] is False:
143
+ raise payload["response"]
144
+ break
145
+
146
+ codes.append(result)
147
+
148
+ codes = torch.cat(codes, dim=1)
149
 
150
  # VQGAN Inference
151
  feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
tools/llama/generate.py CHANGED
@@ -47,32 +47,32 @@ def logits_to_probs(
47
  top_p: Optional[int] = None,
48
  repetition_penalty: float = 1.0,
49
  ):
50
- if previous_tokens is not None and repetition_penalty != 1.0:
51
- previous_tokens = previous_tokens.long()
52
- score = torch.gather(logits, dim=0, index=previous_tokens)
53
- score = torch.where(
54
- score < 0, score * repetition_penalty, score / repetition_penalty
55
- )
56
- logits.scatter_(dim=0, index=previous_tokens, src=score)
57
 
58
- if top_p is not None and top_p < 1.0:
59
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
60
- cum_probs = torch.cumsum(
61
- torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
62
- )
63
- sorted_indices_to_remove = cum_probs > top_p
64
- sorted_indices_to_remove[0] = False # keep at least one option
65
- indices_to_remove = sorted_indices_to_remove.scatter(
66
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
67
- )
68
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
69
 
70
  logits = logits / max(temperature, 1e-5)
71
 
72
- if top_k is not None:
73
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
74
- pivot = v.select(-1, -1).unsqueeze(-1)
75
- logits = torch.where(logits < pivot, -float("Inf"), logits)
76
 
77
  probs = torch.nn.functional.softmax(logits, dim=-1)
78
  return probs
@@ -470,16 +470,14 @@ def generate_long(
470
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
471
 
472
  if use_prompt:
473
- encoded.append(
474
- encode_tokens(
475
- tokenizer,
476
- prompt_text,
477
- prompt_tokens=prompt_tokens,
478
- bos=True,
479
- device=device,
480
- speaker=speaker,
481
- num_codebooks=model.config.num_codebooks,
482
- )
483
  )
484
 
485
  for idx, text in enumerate(texts):
@@ -501,10 +499,6 @@ def generate_long(
501
  all_codes = []
502
  seg_idx = 0
503
 
504
- if use_prompt:
505
- seg_idx = 1
506
- global_encoded.append(encoded[0])
507
-
508
  while seg_idx < len(encoded):
509
  logger.info(
510
  f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
@@ -531,6 +525,9 @@ def generate_long(
531
  else:
532
  partial_encoded = global_encoded
533
 
 
 
 
534
  cat_encoded = torch.cat(partial_encoded, dim=1)
535
  prompt_length = cat_encoded.size(1)
536
 
@@ -593,14 +590,13 @@ def generate_long(
593
 
594
  if is_streaming:
595
  # This indicates the end of the current sample
596
- yield None
597
  else:
598
  all_codes = torch.cat(all_codes, dim=1)
599
  assert (all_codes >= 0).all(), f"Negative code found: {codes}"
600
  yield all_codes
601
 
602
 
603
-
604
  def launch_thread_safe_queue(
605
  config_name,
606
  checkpoint_path,
@@ -624,20 +620,21 @@ def launch_thread_safe_queue(
624
  break
625
 
626
  kwargs = item["request"]
627
- event = item["event"]
628
 
629
  try:
630
  item["success"] = True
631
- item["response"] = list(
632
- generate_long(
633
- model=model, decode_one_token=decode_one_token, **kwargs
634
- )
635
- )
 
636
  except Exception as e:
637
  item["success"] = False
638
  item["response"] = e
639
 
640
- event.set()
641
 
642
  threading.Thread(target=worker, daemon=True).start()
643
  init_event.wait()
 
47
  top_p: Optional[int] = None,
48
  repetition_penalty: float = 1.0,
49
  ):
50
+ # if previous_tokens is not None and repetition_penalty != 1.0:
51
+ previous_tokens = previous_tokens.long()
52
+ score = torch.gather(logits, dim=0, index=previous_tokens)
53
+ score = torch.where(
54
+ score < 0, score * repetition_penalty, score / repetition_penalty
55
+ )
56
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
57
 
58
+ # if top_p is not None and top_p < 1.0:
59
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
60
+ cum_probs = torch.cumsum(
61
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
62
+ )
63
+ sorted_indices_to_remove = cum_probs > top_p
64
+ sorted_indices_to_remove[0] = False # keep at least one option
65
+ indices_to_remove = sorted_indices_to_remove.scatter(
66
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
67
+ )
68
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
69
 
70
  logits = logits / max(temperature, 1e-5)
71
 
72
+ # if top_k is not None:
73
+ # v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
74
+ # pivot = v.select(-1, -1).unsqueeze(-1)
75
+ # logits = torch.where(logits < pivot, -float("Inf"), logits)
76
 
77
  probs = torch.nn.functional.softmax(logits, dim=-1)
78
  return probs
 
470
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
471
 
472
  if use_prompt:
473
+ encoded_prompts = encode_tokens(
474
+ tokenizer,
475
+ prompt_text,
476
+ prompt_tokens=prompt_tokens,
477
+ bos=True,
478
+ device=device,
479
+ speaker=speaker,
480
+ num_codebooks=model.config.num_codebooks,
 
 
481
  )
482
 
483
  for idx, text in enumerate(texts):
 
499
  all_codes = []
500
  seg_idx = 0
501
 
 
 
 
 
502
  while seg_idx < len(encoded):
503
  logger.info(
504
  f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
 
525
  else:
526
  partial_encoded = global_encoded
527
 
528
+ if use_prompt:
529
+ partial_encoded = [encoded_prompts] + partial_encoded
530
+
531
  cat_encoded = torch.cat(partial_encoded, dim=1)
532
  prompt_length = cat_encoded.size(1)
533
 
 
590
 
591
  if is_streaming:
592
  # This indicates the end of the current sample
593
+ yield "next"
594
  else:
595
  all_codes = torch.cat(all_codes, dim=1)
596
  assert (all_codes >= 0).all(), f"Negative code found: {codes}"
597
  yield all_codes
598
 
599
 
 
600
  def launch_thread_safe_queue(
601
  config_name,
602
  checkpoint_path,
 
620
  break
621
 
622
  kwargs = item["request"]
623
+ response_queue = item["response_queue"]
624
 
625
  try:
626
  item["success"] = True
627
+ for chunk in generate_long(
628
+ model=model, decode_one_token=decode_one_token, **kwargs
629
+ ):
630
+ response_queue.put(chunk)
631
+
632
+ response_queue.put("done")
633
  except Exception as e:
634
  item["success"] = False
635
  item["response"] = e
636
 
637
+ response_queue.put("done")
638
 
639
  threading.Thread(target=worker, daemon=True).start()
640
  init_event.wait()