Maple728 commited on
Commit
e1a0d38
1 Parent(s): b4a2d57

Update ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +4 -1
ts_generation_mixin.py CHANGED
@@ -28,6 +28,8 @@ class TSGenerationMixin(GenerationMixin):
28
  streamer: Optional["BaseStreamer"] = None,
29
  **model_kwargs,
30
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
 
 
31
  if len(input_ids.shape) == 2:
32
  batch_size, cur_len = input_ids.shape
33
  else:
@@ -169,6 +171,7 @@ class TSGenerationMixin(GenerationMixin):
169
  if streamer is not None:
170
  streamer.end()
171
 
 
172
  if return_dict_in_generate:
173
  if self.config.is_encoder_decoder:
174
  return GenerateEncoderDecoderOutput(
@@ -192,7 +195,7 @@ class TSGenerationMixin(GenerationMixin):
192
  past_key_values=model_kwargs.get("past_key_values"),
193
  )
194
  else:
195
- return input_ids.squeeze(dim=-1)
196
 
197
  def _update_model_kwargs_for_generation(
198
  self,
 
28
  streamer: Optional["BaseStreamer"] = None,
29
  **model_kwargs,
30
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
31
+ input_ids_origin_device = input_ids.device
32
+ input_ids = input_ids.to(self.device)
33
  if len(input_ids.shape) == 2:
34
  batch_size, cur_len = input_ids.shape
35
  else:
 
171
  if streamer is not None:
172
  streamer.end()
173
 
174
+ input_ids.squeeze_(dim=-1).to(input_ids_origin_device)
175
  if return_dict_in_generate:
176
  if self.config.is_encoder_decoder:
177
  return GenerateEncoderDecoderOutput(
 
195
  past_key_values=model_kwargs.get("past_key_values"),
196
  )
197
  else:
198
+ return input_ids
199
 
200
  def _update_model_kwargs_for_generation(
201
  self,