File size: 2,134 Bytes
02d46b3
 
 
 
 
 
 
 
 
 
 
 
 
bb70c02
02d46b3
bca5f5e
02d46b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI
from pydantic import BaseModel
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
from PIL import Image
import torch
import base64
from io import BytesIO

# Initialize FastAPI app
app = FastAPI()

# Load Hugging Face pipeline components
model_id = "fyp1/sketchToImage"
controlnet = ControlNetModel.from_pretrained(f"{model_id}",subfolder="controlnet", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(f"{model_id}",subfolder="scheduler", torch_dtype=torch.float16)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    safety_checker=None,
    torch_dtype=torch.float16,
).to("cuda" if torch.cuda.is_available() else "cpu")

class GenerateRequest(BaseModel):
    prompt: str
    negative_prompt: str
    sketch: str  # Base64 encoded image

@app.post("/generate")
async def generate_image(data: GenerateRequest):
    try:
        # Decode and preprocess the sketch image
        sketch_bytes = base64.b64decode(data.sketch)
        sketch_image = Image.open(BytesIO(sketch_bytes)).convert("L")  # Convert to grayscale
        sketch_image = sketch_image.resize((1024, 1024))

        # Generate the image using the pipeline
        with torch.no_grad():
            images = pipe(
                prompt=data.prompt,
                negative_prompt=data.negative_prompt,
                image=sketch_image,
                controlnet_conditioning_scale=1.0,
                width=1024,
                height=1024,
                num_inference_steps=30,
            ).images

        # Convert output image to Base64
        buffered = BytesIO()
        images[0].save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return {"image": image_base64}
    
    except Exception as e:
        return {"error": str(e)}