Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from datasets import load_dataset | |
from huggingface_hub import Repository | |
from huggingface_hub import HfApi, HfFolder, Repository, create_repo | |
import os | |
import pandas as pd | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from small_256_model import UNet as small_UNet | |
from big_1024_model import UNet as big_UNet | |
from CLIP import load as load_clip,load_vae,encode_prompt | |
from rich import print as rp | |
from diffusers import AutoencoderKL | |
#url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be a local file | |
#model = AutoencoderKL.from_single_file(url) | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
big = False if device == torch.device('cpu') else True | |
# Parameters | |
IMG_SIZE = 1024 if big else 256 | |
BATCH_SIZE = 1 if big else 1 | |
EPOCHS = 12 | |
LR = 0.0002 | |
dataset_id = "K00B404/pix2pix_flux_set" | |
model_repo_id = "K00B404/pix2pix_flux" | |
# Global model variable | |
global_model = None | |
# CLIP and VAE | |
clip_model, clip_tokenizer = load_clip() | |
vae = load_vae() | |
def load_model(): | |
"""Load the models at startup""" | |
global global_model | |
weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth' | |
try: | |
checkpoint = torch.load(weights_name, map_location=device) | |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() | |
global_model = model | |
rp("Model loaded successfully!") | |
return model | |
except Exception as e: | |
rp(f"Error loading model: {e}") | |
model = big_UNet().to(device) if big else small_UNet().to(device) | |
global_model = model | |
return model | |
class Pix2PixDataset(torch.utils.data.Dataset): | |
def __init__(self, combined_data, transform, clip_tokenizer,clip_model): | |
self.data = combined_data | |
self.transform = transform | |
self.clip_tokenizer = clip_tokenizer | |
self.original_folder = 'images_dataset/original/' | |
self.target_folder = 'images_dataset/target/' | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
original_img_filename = os.path.basename(self.data.iloc[idx]['image_path']) | |
original_img_path = os.path.join(self.original_folder, original_img_filename) | |
target_img_path = os.path.join(self.target_folder, original_img_filename) | |
original_img = Image.open(original_img_path).convert('RGB') | |
target_img = Image.open(target_img_path).convert('RGB') | |
# Transform images | |
original = self.transform(original_img) | |
target = self.transform(target_img) | |
# Get prompts from the DataFrame | |
original_prompt = self.data.iloc[idx]['original_prompt'] | |
enhanced_prompt = self.data.iloc[idx]['enhanced_prompt'] | |
# Encode images | |
original_image_latents = vae.encode(original_img).latent_dist.sample() | |
target_image_latents = vae.encode(target_img).latent_dist.sample() | |
# Encode prompts | |
prompt_latents = encode_prompt(enhanced_prompt,clip_model,clip_tokenizer) | |
# Pass these to your Pix2Pix model | |
#generated_images = pix2pix_model(original_latents, prompt_latents) | |
return original_image_latents,target_image_latents,prompt_latents | |
# Tokenize the prompts using CLIP tokenizer | |
#original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) | |
#enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) | |
#return original, target, original_tokens, enhanced_tokens | |
class UNetWrapper: | |
def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None): | |
self.loss = loss | |
self.epoch = epoch | |
self.model = unet_model | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.repo_id = repo_id | |
self.token = os.getenv('NEW_TOKEN') # Ensure the token is set in the environment | |
self.api = HfApi(token=self.token) | |
def save_checkpoint(self, save_path): | |
"""Save checkpoint with model, optimizer, and scheduler states.""" | |
self.save_dict = { | |
'model_state_dict': self.model.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, | |
'model_config': { | |
'big': isinstance(self.model, big_UNet), | |
'img_size': 1024 if isinstance(self.model, big_UNet) else 256 | |
}, | |
'epoch': self.epoch, | |
'loss': self.loss | |
} | |
torch.save(self.save_dict, save_path) | |
print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}") | |
def load_checkpoint(self, checkpoint_path): | |
"""Load model, optimizer, and scheduler states from the checkpoint.""" | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
if self.scheduler and checkpoint['scheduler_state_dict']: | |
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
self.epoch = checkpoint['epoch'] | |
self.loss = checkpoint['loss'] | |
print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}") | |
def push_to_hub(self, pth_name): | |
"""Push model checkpoint and metadata to the Hugging Face Hub.""" | |
try: | |
self.api.upload_file( | |
path_or_fileobj=pth_name, | |
path_in_repo=pth_name, | |
repo_id=self.repo_id, | |
token=self.token, | |
repo_type="model" | |
) | |
print(f"Model checkpoint successfully uploaded to {self.repo_id}") | |
except Exception as e: | |
print(f"Error uploading model: {e}") | |
# Create and upload model card | |
model_card = f"""--- | |
tags: | |
- unet | |
- pix2pix | |
- pytorch | |
library_name: pytorch | |
license: wtfpl | |
datasets: | |
- K00B404/pix2pix_flux_set | |
language: | |
- en | |
pipeline_tag: image-to-image | |
--- | |
# Pix2Pix UNet Model | |
## Model Description | |
Custom UNet model for Pix2Pix image translation. | |
- **Image Size:** {self.save_dict['model_config']['img_size']} | |
- **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']}) | |
## Usage | |
```python | |
import torch | |
from small_256_model import UNet as small_UNet | |
from big_1024_model import UNet as big_UNet | |
big = True | |
# Load the model | |
name='big_model_weights.pth' if big else 'small_model_weights.pth' | |
checkpoint = torch.load(name) | |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
``` | |
## Model Architecture | |
{str(self.model)} """ | |
rp(model_card) | |
try: | |
# Save and upload README | |
with open("README.md", "w") as f: | |
f.write(f"# Pix2Pix UNet Model\n\n" | |
f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n" | |
f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n" | |
f"## Model Architecture\n{str(self.model)}") | |
self.api.upload_file( | |
path_or_fileobj="README.md", | |
path_in_repo="README.md", | |
repo_id=self.repo_id, | |
token=self.token, | |
repo_type="model" | |
) | |
# Clean up local files | |
os.remove(pth_name) | |
os.remove("README.md") | |
print(f"Model successfully uploaded to {self.repo_id}") | |
except Exception as e: | |
print(f"Error uploading model: {e}") | |
def prepare_input(image, device='cpu'): | |
"""Prepare image for inference""" | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
]) | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
return input_tensor | |
def run_inference(image): | |
"""Run inference on a single image""" | |
global global_model | |
if global_model is None: | |
return "Error: Model not loaded" | |
global_model.eval() | |
input_tensor = prepare_input(image, device) | |
with torch.no_grad(): | |
output = global_model(input_tensor) | |
# Convert output to image | |
output = output.cpu().squeeze(0).permute(1, 2, 0).numpy() | |
output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8) | |
rp(output[0]) | |
return output | |
def to_hub(model, epoch, loss): | |
wrapper = UNetWrapper(model, model_repo_id, epoch, loss) | |
wrapper.push_to_hub() | |
def train_model(epochs, save_interval=1): | |
"""Training function with checkpoint saving and model uploading.""" | |
global global_model | |
# Load combined data CSV | |
data_path = 'combined_data.csv' | |
combined_data = pd.read_csv(data_path) | |
# Define the transformation | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
]) | |
# Initialize dataset and dataloader | |
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer, clip_model) | |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
model = global_model | |
criterion = nn.L1Loss() # You may change this to suit your loss calculation needs | |
optimizer = optim.Adam(model.parameters(), lr=LR) | |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler | |
wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler) | |
output_text = [] | |
for epoch in range(epochs): | |
model.train() | |
running_loss = 0.0 | |
for i, (latent_original, latent_target, latent_prompt) in enumerate(dataloader): | |
# Move data to device | |
latent_original, latent_target, latent_prompt = latent_original.to(device), latent_target.to(device), latent_prompt.to(device) | |
optimizer.zero_grad() | |
# Forward pass with the latents | |
output = model(latent_target, latent_prompt) # Assuming your model can take both target and prompt latents | |
# Calculate loss using the original latents | |
img_loss = criterion(output, latent_original) | |
total_loss = img_loss | |
total_loss.backward() | |
optimizer.step() | |
running_loss += total_loss.item() | |
if i % 10 == 0: | |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" | |
print(status) | |
output_text.append(status) | |
# Update the epoch and loss for checkpoint | |
wrapper.epoch = epoch + 1 | |
wrapper.loss = running_loss / len(dataloader) | |
# Save checkpoint at specified intervals | |
if (epoch + 1) % save_interval == 0: | |
checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth' | |
wrapper.save_checkpoint(checkpoint_path) | |
wrapper.push_to_hub(checkpoint_path) | |
scheduler.step() # Update learning rate scheduler | |
global_model = model # Update global model after training | |
return model, "\n".join(output_text) | |
def train_model_old(epochs): | |
"""Training function""" | |
global global_model | |
# Load combined data CSV | |
data_path = 'combined_data.csv' # Adjust this path | |
combined_data = pd.read_csv(data_path) | |
# Define the transformation | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
]) | |
# Initialize the dataset and dataloader | |
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer) | |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
model = global_model | |
criterion = nn.L1Loss() # L1 loss for image reconstruction | |
optimizer = optim.Adam(model.parameters(), lr=LR) | |
output_text = [] | |
for epoch in range(epochs): | |
model.train() | |
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): | |
# Move images and prompt embeddings to the appropriate device (CPU or GPU) | |
original, target = original.to(device), target.to(device) | |
original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() # Convert to float | |
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() # Convert to float | |
optimizer.zero_grad() | |
# Forward pass through the model | |
output = model(target) | |
# Compute image reconstruction loss | |
img_loss = criterion(output, original) | |
rp(f"Image {i} Loss:{img_loss}") | |
# Combine losses | |
total_loss = img_loss # Add any other losses if necessary | |
total_loss.backward() | |
# Optimizer step | |
optimizer.step() | |
if i % 10 == 0: | |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" | |
rp(status) | |
output_text.append(status) | |
# Push model to Hugging Face Hub at the end of each epoch | |
to_hub(model, epoch, total_loss) | |
global_model = model # Update the global model after training | |
return model, "\n".join(output_text) | |
def gradio_train(epochs): | |
# Gradio training interface function | |
model, training_log = train_model(int(epochs)) | |
#to_hub(model) | |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}" | |
def gradio_inference(input_image): | |
# Gradio inference interface function | |
output_image = run_inference(input_image) # Assuming `run_inference` returns a tuple (output_image, other_data) | |
rp(output_image) | |
# If `run_inference` returns a tuple, you should only return the image part | |
return output_image # Ensure you're only returning the processed output image | |
# Create Gradio interface with tabs | |
with gr.Blocks() as app: | |
gr.Markdown("# Pix2Pix Model Training and Inference") | |
with gr.Tab("Train"): | |
epochs_input = gr.Number(value=EPOCHS, label="Number of epochs") | |
train_button = gr.Button("Train") | |
training_output = gr.Textbox(label="Training Log", interactive=False) | |
train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output]) | |
with gr.Tab("Inference"): | |
image_input = gr.Image(type='numpy') | |
prompt_input = gr.Textbox(label="Prompt") | |
inference_button = gr.Button("Generate") | |
inference_output = gr.Image(type='numpy', label="Generated Image") | |
inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output]) | |
load_model() | |
app.launch() | |