ydshieh commited on
Commit
f082d66
1 Parent(s): b8c22f0

update test_model.py

Browse files
Files changed (1) hide show
  1. tests/test_model.py +6 -1
tests/test_model.py CHANGED
@@ -25,9 +25,12 @@ max_length = 8
25
  vision_model_name = 'google/vit-base-patch16-224-in21k'
26
  text_model_name = 'asi/gpt-fr-cased-small'
27
 
 
 
28
  flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
29
  vision_pretrained_model_name_or_path=vision_model_name,
30
- text_pretrained_model_name_or_path=text_model_name
 
31
  )
32
  model = flax_vit_gpt2_lm
33
 
@@ -103,6 +106,8 @@ print('=' * 60)
103
  print(f'orig. GPT2 generated caption: {orig_caption}')
104
  print(f'GPT2 generated caption: {caption}')
105
 
 
 
106
  # model data
107
  model_inputs = {
108
  'pixel_values': pixel_values,
 
25
  vision_model_name = 'google/vit-base-patch16-224-in21k'
26
  text_model_name = 'asi/gpt-fr-cased-small'
27
 
28
+ project_encoder = False
29
+
30
  flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
31
  vision_pretrained_model_name_or_path=vision_model_name,
32
+ text_pretrained_model_name_or_path=text_model_name,
33
+ project_encoder=project_encoder
34
  )
35
  model = flax_vit_gpt2_lm
36
 
 
106
  print(f'orig. GPT2 generated caption: {orig_caption}')
107
  print(f'GPT2 generated caption: {caption}')
108
 
109
+ assert list(orig_token_ids) == list(token_ids)
110
+
111
  # model data
112
  model_inputs = {
113
  'pixel_values': pixel_values,