JGKaaij commited on
Commit
ad9b42d
1 Parent(s): ada6118

Upload flask_kosmos2.py

Browse files

Added requested comments.

Files changed (1) hide show
  1. flask_kosmos2.py +50 -0
flask_kosmos2.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a Flask app to serve the model as a REST API.
2
+ # After starting the server. You can send a POST request to `http://localhost:8005/process_prompt` with the following form data:
3
+ # - `prompt`: For example, `<grounding> an image of`
4
+ # - 'image': The image file as binary data
5
+ # This will produce a reply with the following JSON format:
6
+ # - `message`: The Kosmos-2 generated text
7
+ # - `entities`: The extracted entities
8
+ # An easy way to test this is through an application like Postman. Make sure the image field is set to `File`.
9
+
10
+ from PIL import Image
11
+ from transformers import AutoProcessor, AutoModelForVision2Seq
12
+ from flask import Flask, request, jsonify
13
+
14
+ app = Flask(__name__)
15
+
16
+ model = AutoModelForVision2Seq.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
17
+ processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
18
+
19
+
20
+ @app.route('/process_prompt', methods=['POST'])
21
+ def process_prompt():
22
+ try:
23
+ # Get the uploaded image data from the POST request
24
+ uploaded_file = request.files['image']
25
+ prompt = request.form.get('prompt')
26
+ image = Image.open(uploaded_file.stream)
27
+
28
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
29
+
30
+ generated_ids = model.generate(
31
+ pixel_values=inputs["pixel_values"],
32
+ input_ids=inputs["input_ids"][:, :-1],
33
+ attention_mask=inputs["attention_mask"][:, :-1],
34
+ img_features=None,
35
+ img_attn_mask=inputs["img_attn_mask"][:, :-1],
36
+ use_cache=True,
37
+ max_new_tokens=64,
38
+ )
39
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
+
41
+ # By default, the generated text is cleanup and the entities are extracted.
42
+ processed_text, entities = processor.post_process_generation(generated_text)
43
+
44
+ return jsonify({"message": processed_text, 'entities': entities})
45
+ except Exception as e:
46
+ return jsonify({"error": str(e)})
47
+
48
+
49
+ if __name__ == '__main__':
50
+ app.run(host='localhost', port=8005)