ydshieh commited on
Commit
e5b3a97
1 Parent(s): 87485e5

remove unused scripts

Browse files
Files changed (1) hide show
  1. tests/test_vit_gpt2.py +0 -83
tests/test_vit_gpt2.py DELETED
@@ -1,83 +0,0 @@
1
- import sys, os
2
-
3
- current_path = os.path.dirname(os.path.abspath(__file__))
4
- sys.path.append(current_path)
5
-
6
- # Vit - as encoder
7
- from transformers import ViTFeatureExtractor
8
- from PIL import Image
9
- import requests
10
- import numpy as np
11
-
12
- url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
- image = Image.open(requests.get(url, stream=True).raw)
14
-
15
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
- encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
- pixel_values = encoder_inputs.pixel_values
18
-
19
- # GPT2 / GPT2LM - as decoder
20
- from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
-
22
- name = 'asi/gpt-fr-cased-small'
23
- tokenizer = GPT2Tokenizer.from_pretrained(name)
24
- decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
-
26
- inputs = dict(decoder_inputs)
27
- inputs['pixel_values'] = pixel_values
28
- print(inputs)
29
-
30
- # With new added LM head
31
- from vit_gpt2.modeling_flax_vit_gpt2 import FlaxViTGPT2ForConditionalGeneration
32
- flax_vit_gpt2 = FlaxViTGPT2ForConditionalGeneration.from_vit_gpt2_pretrained(
33
- 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
34
- )
35
- logits = flax_vit_gpt2(**inputs)[0]
36
- preds = np.argmax(logits, axis=-1)
37
- print('=' * 60)
38
- print('Flax: Vit + modified GPT2 + LM')
39
- print(preds)
40
-
41
- del flax_vit_gpt2
42
-
43
- # With the LM head in GPT2LM
44
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
45
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
46
- 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
47
- )
48
-
49
- logits = flax_vit_gpt2_lm(**inputs)[0]
50
- preds = np.argmax(logits, axis=-1)
51
- print('=' * 60)
52
- print('Flax: Vit + modified GPT2LM')
53
- print(preds)
54
-
55
- del flax_vit_gpt2_lm
56
-
57
- # With PyTorch [Vit + unmodified GPT2LMHeadModel]
58
- import torch
59
- from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
60
-
61
- vit_model_pt = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
62
- encoder_inputs = feature_extractor(images=image, return_tensors="pt")
63
- vit_outputs = vit_model_pt(**encoder_inputs)
64
- vit_last_hidden_states = vit_outputs.last_hidden_state
65
-
66
- del vit_model_pt
67
-
68
- inputs_pt = tokenizer("mon chien est mignon", return_tensors="pt")
69
- inputs_pt = dict(inputs_pt)
70
- inputs_pt['encoder_hidden_states'] = vit_last_hidden_states
71
-
72
- config = GPT2Config.from_pretrained('asi/gpt-fr-cased-small')
73
- config.add_cross_attention = True
74
- gpt2_model_pt = GPT2LMHeadModel.from_pretrained('asi/gpt-fr-cased-small', config=config)
75
-
76
- gp2lm_outputs = gpt2_model_pt(**inputs_pt)
77
- logits_pt = gp2lm_outputs.logits
78
- preds_pt = torch.argmax(logits_pt, dim=-1).cpu().detach().numpy()
79
- print('=' * 60)
80
- print('Pytorch: Vit + unmodified GPT2LM')
81
- print(preds_pt)
82
-
83
- del gpt2_model_pt