Sanshruth's picture
Upload 3 files
c3acf88 verified
import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image, to_tensor
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler
from PIL import Image
import numpy as np
import cv2
class SegmentAnythingModel:
def __init__(self, sam_checkpoint, model_type, device):
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=device)
self.mask_generator = SamAutomaticMaskGenerator(
model=self.sam,
points_per_side=32,
pred_iou_thresh=0.99,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100
)
self.target_size = (512, 512)
def preprocess_image(self, image):
"""Resize image to 512x512"""
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Get current dimensions
width, height = image.size
# Resize to 512x512 directly
image = image.resize(self.target_size, Image.Resampling.LANCZOS)
return np.array(image)
def generate_masks(self, image):
processed_image = self.preprocess_image(image)
return self.mask_generator.generate(processed_image)
class StableDiffusionInpaintingPipeline:
def __init__(self, model_dir):
# Initialize the scheduler first
self.scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder="scheduler")
# Initialize the pipeline with the scheduler
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_dir,
scheduler=self.scheduler,
revision="fp16",
torch_dtype=torch.float16
)
self.pipe = self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.target_size = (512, 512)
def preprocess_image(self, image):
"""Ensure image is in the right format and size"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
return image.resize(self.target_size, Image.Resampling.LANCZOS)
def inpaint(self, prompt, image, mask_image, guidance_scale=10, num_inference_steps=60, generator=None):
"""
Args:
prompt (str): The prompt for inpainting
image (PIL.Image or np.ndarray): The original image
mask_image (PIL.Image or np.ndarray): The mask for inpainting
guidance_scale (float): Higher guidance scale encourages images that are closer to the prompt
num_inference_steps (int): Number of denoising steps
generator (torch.Generator): Generator for reproducibility
"""
# Preprocess images
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if isinstance(mask_image, np.ndarray):
mask_image = Image.fromarray(mask_image)
# Resize images
image = image.resize(self.target_size, Image.Resampling.LANCZOS)
mask_image = mask_image.resize(self.target_size, Image.Resampling.NEAREST)
# Run inpainting
output = self.pipe(
prompt=prompt,
image=image,
mask_image=mask_image,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=512,
width=512
)
return output.images[0]