ydshieh commited on
Commit
bdb103e
1 Parent(s): f338d56

remove generate.py

Browse files
Files changed (1) hide show
  1. generate.py +0 -74
generate.py DELETED
@@ -1,74 +0,0 @@
1
- import sys, os
2
-
3
- current_path = os.path.dirname(os.path.abspath(__file__))
4
- sys.path.append(current_path)
5
-
6
- # Main model - ViTGPT2LM
7
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
-
9
- # Vit - as encoder
10
- from transformers import ViTFeatureExtractor
11
- from PIL import Image
12
- import requests
13
- import numpy as np
14
-
15
- # GPT2 / GPT2LM - as decoder
16
- from transformers import ViTFeatureExtractor, GPT2Tokenizer
17
-
18
- model_name_or_path = './outputs/ckpt_2/'
19
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
20
-
21
- vit_model_name = 'google/vit-base-patch16-224-in21k'
22
- feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
23
-
24
- gpt2_model_name = 'asi/gpt-fr-cased-small'
25
- tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
26
-
27
- max_length = 32
28
- num_beams = 16
29
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
30
-
31
-
32
- # encoder data
33
- url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
34
- image = Image.open(requests.get(url, stream=True).raw)
35
- # batch dim is added automatically
36
- encoder_inputs = feature_extractor(images=image, return_tensors="jax")
37
- pixel_values = encoder_inputs.pixel_values
38
- print(f'pixel_values.shape = {pixel_values.shape}')
39
-
40
- # decoder data
41
- sentence = 'mon chien est mignon'
42
- # IMPORTANT: For training/evaluation/attention_mask/loss
43
- sentence += ' ' + tokenizer.eos_token
44
- # batch dim is added automatically
45
- decoder_inputs = tokenizer(sentence, return_tensors="jax")
46
- print(decoder_inputs)
47
- print(f'input_ids.shape = {decoder_inputs.input_ids.shape}')
48
-
49
- # model data
50
- inputs = dict(decoder_inputs)
51
- inputs['pixel_values'] = pixel_values
52
-
53
-
54
- logits = flax_vit_gpt2_lm(**inputs)[0]
55
- preds = np.argmax(logits, axis=-1)
56
- print('=' * 60)
57
- print('Flax: Vit-GPT2-LM')
58
- print('predicted token ids:')
59
- print(preds)
60
- print('=' * 60)
61
-
62
-
63
- # Generation!
64
- batch = {'pixel_values': pixel_values}
65
- generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
66
- print('generation:')
67
- print(generation)
68
- print('=' * 60)
69
-
70
- token_ids = np.array(generation.sequences)[0]
71
- caption = tokenizer.decode(token_ids)
72
- print(f'token_ids: {token_ids}')
73
- print(f'caption: {caption}')
74
- print('=' * 60)