ifmain commited on
Commit
3474710
1 Parent(s): d24e86c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -29
README.md CHANGED
@@ -21,48 +21,52 @@ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoToken
21
  import torch
22
  from PIL import Image
23
  import re
 
24
 
25
  def prepare(text):
26
- text = text.replace('. ', '.').replace(' .', '.')
27
- text = text.replace('( ', '(').replace(' (', '(')
28
- text = text.replace(') ', ')').replace(' )', ')')
29
- text = text.replace(': ', ':').replace(' :', ':')
30
- text = text.replace('_ ', '_').replace(' _', '_')
31
- text = text.replace(',(())', '').replace('(()),', '')
32
- for i in range(10):
33
- text = text.replace(')))', '))').replace('(((', '((')
34
  text = re.sub(r'<[^>]*>', '', text)
 
 
 
 
 
35
  return text
36
 
37
- model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
38
- feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
39
- tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
 
40
 
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  model.to(device)
43
 
44
- max_length = 16
45
  num_beams = 4
46
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
47
- def predict_step(image_paths):
48
- images = []
49
- for image_path in image_paths:
50
- i_image = Image.open(image_path)
51
- if i_image.mode != "RGB":
52
- i_image = i_image.convert(mode="RGB")
53
-
54
- images.append(i_image)
55
 
56
- pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
57
- pixel_values = pixel_values.to(device)
58
 
59
- output_ids = model.generate(pixel_values, **gen_kwargs)
60
-
61
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
62
- preds = [prepare(pred).strip() for pred in preds]
63
- return preds
64
-
65
- predict_step(['doctor.e16ba4e4.jpg']) # ['a woman in a hospital bed with a woman in a hospital bed']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ```
67
 
68
  ## Additional Information
 
21
  import torch
22
  from PIL import Image
23
  import re
24
+ import requests
25
 
26
  def prepare(text):
 
 
 
 
 
 
 
 
27
  text = re.sub(r'<[^>]*>', '', text)
28
+ text = ','.join(list(set(text.split(',')))[:-1])
29
+ for i in range(5):
30
+ if text[0]==',' or text[0]==' ':
31
+ text=text[1:]
32
+
33
  return text
34
 
35
+ path_to_model = "ifmain/vit-gpt2-image2promt-stable-diffusion"
36
+
37
+ model = VisionEncoderDecoderModel.from_pretrained(path_to_model)
38
+ feature_extractor = ViTImageProcessor.from_pretrained(path_to_model)
39
+ tokenizer = AutoTokenizer.from_pretrained(path_to_model)
40
 
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  model.to(device)
43
 
44
+ max_length = 256
45
  num_beams = 4
46
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
 
 
 
 
 
 
 
 
47
 
 
 
48
 
49
+ def predict_step(image_paths):
50
+ images = []
51
+ for image_path in image_paths:
52
+ if 'http' in image_path:
53
+ i_image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
54
+ else:
55
+ i_image = Image.open(image_path).convert('RGB')
56
+ images.append(i_image)
57
+
58
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
59
+ pixel_values = pixel_values.to(device)
60
+
61
+ output_ids = model.generate(pixel_values, **gen_kwargs)
62
+
63
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
64
+ preds = [prepare(pred).strip() for pred in preds]
65
+ return preds
66
+
67
+ img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
68
+ result = predict_step([img_url]) # ['red shirt, chromatic aberration, light emitting object, barefoot, best quality, ocean background, 1girl, 8k wallpaper, intricate details, chromatic light, light, ocean, backpack, ultra-detailed, ocean light,masterpiece']
69
+ print(result)
70
  ```
71
 
72
  ## Additional Information