yuhuili commited on
Commit
c2dc54b
1 Parent(s): f2ce589

Update model/cnets.py

Browse files
Files changed (1) hide show
  1. model/cnets.py +1 -1
model/cnets.py CHANGED
@@ -674,7 +674,7 @@ class Model(nn.Module):
674
 
675
  @torch.no_grad()
676
  def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
677
-
678
  input_ids = input_ids.to(hidden_states.device)
679
  total_tokens = self.total_tokens
680
  depth = self.depth
 
674
 
675
  @torch.no_grad()
676
  def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
677
+ self.init_tree()
678
  input_ids = input_ids.to(hidden_states.device)
679
  total_tokens = self.total_tokens
680
  depth = self.depth