File size: 4,395 Bytes
0c19cc1
4e70ef0
0c19cc1
4e70ef0
30085b1
0c19cc1
4e70ef0
7a8d779
0c19cc1
4e70ef0
0c19cc1
4e70ef0
0c19cc1
4e70ef0
 
 
 
b5c347a
4e70ef0
 
 
b5c347a
4e70ef0
 
 
 
 
 
 
 
 
cb6599c
4e70ef0
 
 
cb6599c
4e70ef0
 
b5c347a
 
 
 
4e70ef0
0c19cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c569a26
 
4e70ef0
 
 
 
 
30085b1
4e70ef0
 
 
 
 
 
 
 
0c19cc1
4e70ef0
0c19cc1
4e70ef0
0c19cc1
538b43a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from flask import Flask, request, jsonify, render_template_string
from transformers import AutoProcessor, AutoModelForCausalLM
import subprocess
import re
from PIL import Image
import io

# Install the necessary packages
subprocess.run('pip install flash-attn einops flask', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

app = Flask(__name__)

model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).eval()
processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)

def modify_caption(caption: str) -> str:
    """
    Removes specific prefixes from captions if present, otherwise returns the original caption.
    Args:
        caption (str): A string containing a caption.
    Returns:
        str: The caption with the prefix removed if it was present, or the original caption.
    """
    # Define the prefixes to remove
    prefix_substrings = [
        ('captured from ', ''),
        ('captured at ', '')
    ]
    
    # Create a regex pattern to match any of the prefixes
    pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
    replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
    
    # Function to replace matched prefix with its corresponding replacement
    def replace_fn(match):
        return replacers[match.group(0).lower()]
    
    # Apply the regex to the caption
    modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
    
    # If the caption was modified, return the modified version; otherwise, return the original
    return modified_caption if modified_caption != caption else caption
    
@app.route('/')
def index():
    html = '''
    <!DOCTYPE html>
    <html>
    <head>
        <title>Florence-2 SD3 Long Captioner</title>
        <style>
            #output {
                height: 500px; 
                overflow: auto; 
                border: 1px solid #ccc; 
            }
        </style>
    </head>
    <body>
        <h1>Florence-2 SD3 Long Captioner</h1>
        <p>Florence-2 Base fine-tuned on Long SD3 Prompt and Image pairs. Check the Hugging Face link for datasets that are used for fine-tuning.</p>
        <form id="uploadForm">
            <label for="imageInput">Input Picture</label>
            <input type="file" id="imageInput" name="image">
            <button type="submit">Submit</button>
        </form>
        <div id="output">
            <h3>Output Text</h3>
            <p id="outputText"></p>
        </div>
        <script>
            document.getElementById('uploadForm').onsubmit = async function(event) {
                event.preventDefault();
                const formData = new FormData();
                const imageFile = document.getElementById('imageInput').files[0];
                formData.append('image', imageFile);

                const response = await fetch('/generate', {
                    method: 'POST',
                    body: formData
                });

                const data = await response.json();
                document.getElementById('outputText').innerText = data.caption;
            };
        </script>
    </body>
    </html>
    '''
    return render_template_string(html)

@app.route('/generate', methods=['POST'])
def generate():
    if 'image' not in request.files:
        return jsonify({"error": "No image provided"}), 400

    image_file = request.files['image']
    image = Image.open(image_file.stream)

    task_prompt = "<DESCRIPTION>"
    prompt = task_prompt + "Describe this image in great detail."

    # Ensure the image is in RGB mode
    if image.mode != "RGB":
        image = image.convert("RGB")

    inputs = processor(text=prompt, images=image, return_tensors="pt")
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
    caption = modify_caption(parsed_answer["<DESCRIPTION>"])

    return jsonify({"caption": caption})

if __name__ == '__main__':
    app.run(debug=True, port=7860, host="0.0.0.0")