|
from flask import Flask, request, jsonify, render_template_string |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
import subprocess |
|
import re |
|
from PIL import Image |
|
import io |
|
|
|
|
|
subprocess.run('pip install flash-attn einops flask', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
app = Flask(__name__) |
|
|
|
model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).eval() |
|
processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True) |
|
|
|
def modify_caption(caption: str) -> str: |
|
""" |
|
Removes specific prefixes from captions if present, otherwise returns the original caption. |
|
Args: |
|
caption (str): A string containing a caption. |
|
Returns: |
|
str: The caption with the prefix removed if it was present, or the original caption. |
|
""" |
|
|
|
prefix_substrings = [ |
|
('captured from ', ''), |
|
('captured at ', '') |
|
] |
|
|
|
|
|
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) |
|
replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings} |
|
|
|
|
|
def replace_fn(match): |
|
return replacers[match.group(0).lower()] |
|
|
|
|
|
modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) |
|
|
|
|
|
return modified_caption if modified_caption != caption else caption |
|
|
|
@app.route('/') |
|
def index(): |
|
html = ''' |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Florence-2 SD3 Long Captioner</title> |
|
<style> |
|
#output { |
|
height: 500px; |
|
overflow: auto; |
|
border: 1px solid #ccc; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<h1>Florence-2 SD3 Long Captioner</h1> |
|
<p>Florence-2 Base fine-tuned on Long SD3 Prompt and Image pairs. Check the Hugging Face link for datasets that are used for fine-tuning.</p> |
|
<form id="uploadForm"> |
|
<label for="imageInput">Input Picture</label> |
|
<input type="file" id="imageInput" name="image"> |
|
<button type="submit">Submit</button> |
|
</form> |
|
<div id="output"> |
|
<h3>Output Text</h3> |
|
<p id="outputText"></p> |
|
</div> |
|
<script> |
|
document.getElementById('uploadForm').onsubmit = async function(event) { |
|
event.preventDefault(); |
|
const formData = new FormData(); |
|
const imageFile = document.getElementById('imageInput').files[0]; |
|
formData.append('image', imageFile); |
|
|
|
const response = await fetch('/generate', { |
|
method: 'POST', |
|
body: formData |
|
}); |
|
|
|
const data = await response.json(); |
|
document.getElementById('outputText').innerText = data.caption; |
|
}; |
|
</script> |
|
</body> |
|
</html> |
|
''' |
|
return render_template_string(html) |
|
|
|
@app.route('/generate', methods=['POST']) |
|
def generate(): |
|
if 'image' not in request.files: |
|
return jsonify({"error": "No image provided"}), 400 |
|
|
|
image_file = request.files['image'] |
|
image = Image.open(image_file.stream) |
|
|
|
task_prompt = "<DESCRIPTION>" |
|
prompt = task_prompt + "Describe this image in great detail." |
|
|
|
|
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
inputs = processor(text=prompt, images=image, return_tensors="pt") |
|
generated_ids = model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
num_beams=3 |
|
) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) |
|
caption = modify_caption(parsed_answer["<DESCRIPTION>"]) |
|
|
|
return jsonify({"caption": caption}) |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True, port=7860, host="0.0.0.0") |