mtensor
commited on
Commit
•
3856d86
1
Parent(s):
8c4b5c5
only decode the end of the generation
Browse files
README.md
CHANGED
@@ -64,8 +64,8 @@ for k, v in model_inputs.items():
|
|
64 |
model_inputs[k] = v.to("cuda:0")
|
65 |
|
66 |
generation_output = model.generate(**model_inputs, max_new_tokens=7)
|
67 |
-
generation_text = processor.batch_decode(generation_output, skip_special_tokens=True)
|
68 |
-
assert generation_text ==
|
69 |
```
|
70 |
|
71 |
Fuyu can also perform some question answering on natural images and charts/diagrams (thought fine-tuning may be required for good performance):
|
@@ -79,8 +79,8 @@ for k, v in model_inputs.items():
|
|
79 |
model_inputs[k] = v.to("cuda:0")
|
80 |
|
81 |
generation_output = model.generate(**model_inputs, max_new_tokens=6)
|
82 |
-
generation_text = processor.batch_decode(generation_output, skip_special_tokens=True)
|
83 |
-
assert generation_text == "The bus is blue.\n"
|
84 |
|
85 |
|
86 |
text_prompt = "What is the highest life expectancy at birth of male?\n"
|
@@ -92,8 +92,8 @@ for k, v in model_inputs.items():
|
|
92 |
model_inputs[k] = v.to("cuda:0")
|
93 |
|
94 |
generation_output = model.generate(**model_inputs, max_new_tokens=16)
|
95 |
-
generation_text = processor.batch_decode(generation_output, skip_special_tokens=True)
|
96 |
-
assert generation_text == "The life expectancy at birth of males in 2018 is 80.7.\n"
|
97 |
```
|
98 |
|
99 |
## Uses
|
|
|
64 |
model_inputs[k] = v.to("cuda:0")
|
65 |
|
66 |
generation_output = model.generate(**model_inputs, max_new_tokens=7)
|
67 |
+
generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)
|
68 |
+
assert generation_text == ['A bus parked on the side of a road.']
|
69 |
```
|
70 |
|
71 |
Fuyu can also perform some question answering on natural images and charts/diagrams (thought fine-tuning may be required for good performance):
|
|
|
79 |
model_inputs[k] = v.to("cuda:0")
|
80 |
|
81 |
generation_output = model.generate(**model_inputs, max_new_tokens=6)
|
82 |
+
generation_text = processor.batch_decode(generation_output[:, -6:], skip_special_tokens=True)
|
83 |
+
assert generation_text == ["The bus is blue.\n"]
|
84 |
|
85 |
|
86 |
text_prompt = "What is the highest life expectancy at birth of male?\n"
|
|
|
92 |
model_inputs[k] = v.to("cuda:0")
|
93 |
|
94 |
generation_output = model.generate(**model_inputs, max_new_tokens=16)
|
95 |
+
generation_text = processor.batch_decode(generation_output[:, -16:], skip_special_tokens=True)
|
96 |
+
assert generation_text == ["The life expectancy at birth of males in 2018 is 80.7.\n"]
|
97 |
```
|
98 |
|
99 |
## Uses
|