Update pipeline.py
Browse files- pipeline.py +1 -1
pipeline.py
CHANGED
@@ -20,7 +20,7 @@ class PreTrainedPipeline():
|
|
20 |
max_length = 16
|
21 |
num_beams = 4
|
22 |
# self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
23 |
-
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, return_dict_in_generate
|
24 |
|
25 |
self.model.to("cpu")
|
26 |
self.model.eval()
|
|
|
20 |
max_length = 16
|
21 |
num_beams = 4
|
22 |
# self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
23 |
+
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "return_dict_in_generate": True}
|
24 |
|
25 |
self.model.to("cpu")
|
26 |
self.model.eval()
|