Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +4 -1
ts_generation_mixin.py
CHANGED
@@ -28,6 +28,8 @@ class TSGenerationMixin(GenerationMixin):
|
|
28 |
streamer: Optional["BaseStreamer"] = None,
|
29 |
**model_kwargs,
|
30 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
|
|
|
|
31 |
if len(input_ids.shape) == 2:
|
32 |
batch_size, cur_len = input_ids.shape
|
33 |
else:
|
@@ -169,6 +171,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
169 |
if streamer is not None:
|
170 |
streamer.end()
|
171 |
|
|
|
172 |
if return_dict_in_generate:
|
173 |
if self.config.is_encoder_decoder:
|
174 |
return GenerateEncoderDecoderOutput(
|
@@ -192,7 +195,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
192 |
past_key_values=model_kwargs.get("past_key_values"),
|
193 |
)
|
194 |
else:
|
195 |
-
return input_ids
|
196 |
|
197 |
def _update_model_kwargs_for_generation(
|
198 |
self,
|
|
|
28 |
streamer: Optional["BaseStreamer"] = None,
|
29 |
**model_kwargs,
|
30 |
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
31 |
+
input_ids_origin_device = input_ids.device
|
32 |
+
input_ids = input_ids.to(self.device)
|
33 |
if len(input_ids.shape) == 2:
|
34 |
batch_size, cur_len = input_ids.shape
|
35 |
else:
|
|
|
171 |
if streamer is not None:
|
172 |
streamer.end()
|
173 |
|
174 |
+
input_ids.squeeze_(dim=-1).to(input_ids_origin_device)
|
175 |
if return_dict_in_generate:
|
176 |
if self.config.is_encoder_decoder:
|
177 |
return GenerateEncoderDecoderOutput(
|
|
|
195 |
past_key_values=model_kwargs.get("past_key_values"),
|
196 |
)
|
197 |
else:
|
198 |
+
return input_ids
|
199 |
|
200 |
def _update_model_kwargs_for_generation(
|
201 |
self,
|