ydshieh
commited on
Commit
•
f082d66
1
Parent(s):
b8c22f0
update test_model.py
Browse files- tests/test_model.py +6 -1
tests/test_model.py
CHANGED
@@ -25,9 +25,12 @@ max_length = 8
|
|
25 |
vision_model_name = 'google/vit-base-patch16-224-in21k'
|
26 |
text_model_name = 'asi/gpt-fr-cased-small'
|
27 |
|
|
|
|
|
28 |
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
|
29 |
vision_pretrained_model_name_or_path=vision_model_name,
|
30 |
-
text_pretrained_model_name_or_path=text_model_name
|
|
|
31 |
)
|
32 |
model = flax_vit_gpt2_lm
|
33 |
|
@@ -103,6 +106,8 @@ print('=' * 60)
|
|
103 |
print(f'orig. GPT2 generated caption: {orig_caption}')
|
104 |
print(f'GPT2 generated caption: {caption}')
|
105 |
|
|
|
|
|
106 |
# model data
|
107 |
model_inputs = {
|
108 |
'pixel_values': pixel_values,
|
|
|
25 |
vision_model_name = 'google/vit-base-patch16-224-in21k'
|
26 |
text_model_name = 'asi/gpt-fr-cased-small'
|
27 |
|
28 |
+
project_encoder = False
|
29 |
+
|
30 |
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
|
31 |
vision_pretrained_model_name_or_path=vision_model_name,
|
32 |
+
text_pretrained_model_name_or_path=text_model_name,
|
33 |
+
project_encoder=project_encoder
|
34 |
)
|
35 |
model = flax_vit_gpt2_lm
|
36 |
|
|
|
106 |
print(f'orig. GPT2 generated caption: {orig_caption}')
|
107 |
print(f'GPT2 generated caption: {caption}')
|
108 |
|
109 |
+
assert list(orig_token_ids) == list(token_ids)
|
110 |
+
|
111 |
# model data
|
112 |
model_inputs = {
|
113 |
'pixel_values': pixel_values,
|