Crystalcareai commited on
Commit
835534a
1 Parent(s): f2459a7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +28 -2
modeling_quiet.py CHANGED
@@ -22,6 +22,7 @@ import inspect
22
  import math
23
  import warnings
24
  from typing import List, Optional, Tuple, Union
 
25
 
26
  import torch
27
  import torch.nn.functional as F
@@ -56,6 +57,31 @@ logger = logging.get_logger(__name__)
56
 
57
  _CONFIG_FOR_DOC = "QuietConfig"
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
61
  def _get_unpad_data(attention_mask):
@@ -1097,7 +1123,7 @@ class QuietModel(QuietPreTrainedModel):
1097
 
1098
  if not return_dict:
1099
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1100
- return BaseModelOutputWithPast(
1101
  last_hidden_state=hidden_states,
1102
  past_key_values=next_cache,
1103
  hidden_states=all_hidden_states,
@@ -1216,7 +1242,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1216
  )
1217
 
1218
  hidden_states = outputs.last_hidden_state
1219
- base_logits = outputs.logits # Use the logits from the model output
1220
 
1221
  thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
1222
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
 
22
  import math
23
  import warnings
24
  from typing import List, Optional, Tuple, Union
25
+ from dataclasses import dataclass
26
 
27
  import torch
28
  import torch.nn.functional as F
 
57
 
58
  _CONFIG_FOR_DOC = "QuietConfig"
59
 
60
+ @dataclass
61
+ class ModelOutput:
62
+ """
63
+ Base class for model's outputs, with potential hidden states and attentions.
64
+ """
65
+
66
+ def to_tuple(self):
67
+ """
68
+ Convert the output to a tuple.
69
+ """
70
+ return tuple(self[k] for k in self.keys())
71
+
72
+ @dataclass
73
+ class BaseModelOutput(ModelOutput):
74
+ last_hidden_state: torch.FloatTensor = None
75
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
76
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
77
+
78
+ @dataclass
79
+ class QuietModelOutputWithPast(BaseModelOutput):
80
+ last_hidden_state: torch.FloatTensor = None
81
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
82
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
83
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
84
+ logits: torch.FloatTensor = None
85
 
86
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
87
  def _get_unpad_data(attention_mask):
 
1123
 
1124
  if not return_dict:
1125
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1126
+ return QuietModelOutputWithPast(
1127
  last_hidden_state=hidden_states,
1128
  past_key_values=next_cache,
1129
  hidden_states=all_hidden_states,
 
1242
  )
1243
 
1244
  hidden_states = outputs.last_hidden_state
1245
+ base_logits = outputs.logits
1246
 
1247
  thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
1248
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state