ydshieh
commited on
Commit
•
ec3ceb6
1
Parent(s):
7d3b1a0
Update test_model.py
Browse files- tests/test_model.py +29 -20
tests/test_model.py
CHANGED
@@ -95,7 +95,7 @@ print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
|
|
95 |
orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
|
96 |
gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
|
97 |
|
98 |
-
#
|
99 |
num_beams = 1
|
100 |
gen_kwargs = {"max_length": 6, "num_beams": num_beams}
|
101 |
|
@@ -138,20 +138,27 @@ logits = model_outputs[0]
|
|
138 |
preds = np.argmax(logits, axis=-1)
|
139 |
|
140 |
print('=' * 60)
|
141 |
-
print('Flax
|
142 |
-
print('predicted token ids:')
|
143 |
print(preds)
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# ================================================================================
|
152 |
-
# Check generation
|
153 |
|
154 |
-
#
|
155 |
num_beams = 1
|
156 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
157 |
|
@@ -215,17 +222,19 @@ logits = text_model_pt_outputs[0]
|
|
215 |
preds = np.argmax(logits.detach().numpy(), axis=-1)
|
216 |
|
217 |
print('=' * 60)
|
218 |
-
print('PyTroch:
|
219 |
print('predicted token ids:')
|
220 |
print(preds)
|
221 |
|
222 |
-
|
223 |
-
|
|
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
95 |
orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
|
96 |
gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
|
97 |
|
98 |
+
# generation!
|
99 |
num_beams = 1
|
100 |
gen_kwargs = {"max_length": 6, "num_beams": num_beams}
|
101 |
|
|
|
138 |
preds = np.argmax(logits, axis=-1)
|
139 |
|
140 |
print('=' * 60)
|
141 |
+
print('Flax ViT-GPT2-LM - predicted token ids:')
|
|
|
142 |
print(preds)
|
143 |
|
144 |
+
encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
|
145 |
+
print('=' * 60)
|
146 |
+
print("encoder_last_hidden_state given by model.__call__():")
|
147 |
+
print(encoder_last_hidden_state)
|
148 |
+
|
149 |
+
encoder_outputs = model.encode(pixel_values, return_dict=True)
|
150 |
+
print('=' * 60)
|
151 |
+
print("encoder's last_hidden_state given by model.encode():")
|
152 |
+
print(encoder_outputs['last_hidden_state'])
|
153 |
+
|
154 |
+
total_diff = np.sum(np.abs(encoder_outputs['last_hidden_state'] - encoder_last_hidden_state))
|
155 |
+
print('=' * 60)
|
156 |
+
print(f"total difference: {total_diff}")
|
157 |
|
158 |
# ================================================================================
|
159 |
+
# Check model generation
|
160 |
|
161 |
+
# generation
|
162 |
num_beams = 1
|
163 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
164 |
|
|
|
222 |
preds = np.argmax(logits.detach().numpy(), axis=-1)
|
223 |
|
224 |
print('=' * 60)
|
225 |
+
print('PyTroch: ViT --> GPT2-LM')
|
226 |
print('predicted token ids:')
|
227 |
print(preds)
|
228 |
|
229 |
+
model_logits = np.array(model_outputs.logits)
|
230 |
+
text_model_pt_logits = text_model_pt_outputs.logits.detach().cpu().numpy()
|
231 |
+
total_diff = np.sum(np.abs(model_logits - text_model_pt_logits))
|
232 |
|
233 |
+
print('=' * 60)
|
234 |
+
print("model_logits:")
|
235 |
+
print(model_logits)
|
236 |
+
print('=' * 60)
|
237 |
+
print("text_model_pt_logits:")
|
238 |
+
print(text_model_pt_logits)
|
239 |
+
print('=' * 60)
|
240 |
+
print(f"total difference between logits: {total_diff}")
|