pix2pix_flux_train / app_ko.py
K00B404's picture
Rename app.py to app_ko.py
bc5d099 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
import os
import pandas as pd
import gradio as gr
from PIL import Image
import numpy as np
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
from CLIP import load as load_clip,load_vae,encode_prompt
from rich import print as rp
from diffusers import AutoencoderKL
#url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be a local file
#model = AutoencoderKL.from_single_file(url)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
big = False if device == torch.device('cpu') else True
# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 1 if big else 1
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"
# Global model variable
global_model = None
# CLIP and VAE
clip_model, clip_tokenizer = load_clip()
vae = load_vae()
def load_model():
"""Load the models at startup"""
global global_model
weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
try:
checkpoint = torch.load(weights_name, map_location=device)
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
global_model = model
rp("Model loaded successfully!")
return model
except Exception as e:
rp(f"Error loading model: {e}")
model = big_UNet().to(device) if big else small_UNet().to(device)
global_model = model
return model
class Pix2PixDataset(torch.utils.data.Dataset):
def __init__(self, combined_data, transform, clip_tokenizer,clip_model):
self.data = combined_data
self.transform = transform
self.clip_tokenizer = clip_tokenizer
self.original_folder = 'images_dataset/original/'
self.target_folder = 'images_dataset/target/'
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
original_img_filename = os.path.basename(self.data.iloc[idx]['image_path'])
original_img_path = os.path.join(self.original_folder, original_img_filename)
target_img_path = os.path.join(self.target_folder, original_img_filename)
original_img = Image.open(original_img_path).convert('RGB')
target_img = Image.open(target_img_path).convert('RGB')
# Transform images
original = self.transform(original_img)
target = self.transform(target_img)
# Get prompts from the DataFrame
original_prompt = self.data.iloc[idx]['original_prompt']
enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
# Encode images
original_image_latents = vae.encode(original_img).latent_dist.sample()
target_image_latents = vae.encode(target_img).latent_dist.sample()
# Encode prompts
prompt_latents = encode_prompt(enhanced_prompt,clip_model,clip_tokenizer)
# Pass these to your Pix2Pix model
#generated_images = pix2pix_model(original_latents, prompt_latents)
return original_image_latents,target_image_latents,prompt_latents
# Tokenize the prompts using CLIP tokenizer
#original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
#enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
#return original, target, original_tokens, enhanced_tokens
class UNetWrapper:
def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None):
self.loss = loss
self.epoch = epoch
self.model = unet_model
self.optimizer = optimizer
self.scheduler = scheduler
self.repo_id = repo_id
self.token = os.getenv('NEW_TOKEN') # Ensure the token is set in the environment
self.api = HfApi(token=self.token)
def save_checkpoint(self, save_path):
"""Save checkpoint with model, optimizer, and scheduler states."""
self.save_dict = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'model_config': {
'big': isinstance(self.model, big_UNet),
'img_size': 1024 if isinstance(self.model, big_UNet) else 256
},
'epoch': self.epoch,
'loss': self.loss
}
torch.save(self.save_dict, save_path)
print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
def load_checkpoint(self, checkpoint_path):
"""Load model, optimizer, and scheduler states from the checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if self.scheduler and checkpoint['scheduler_state_dict']:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.epoch = checkpoint['epoch']
self.loss = checkpoint['loss']
print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}")
def push_to_hub(self, pth_name):
"""Push model checkpoint and metadata to the Hugging Face Hub."""
try:
self.api.upload_file(
path_or_fileobj=pth_name,
path_in_repo=pth_name,
repo_id=self.repo_id,
token=self.token,
repo_type="model"
)
print(f"Model checkpoint successfully uploaded to {self.repo_id}")
except Exception as e:
print(f"Error uploading model: {e}")
# Create and upload model card
model_card = f"""---
tags:
- unet
- pix2pix
- pytorch
library_name: pytorch
license: wtfpl
datasets:
- K00B404/pix2pix_flux_set
language:
- en
pipeline_tag: image-to-image
---
# Pix2Pix UNet Model
## Model Description
Custom UNet model for Pix2Pix image translation.
- **Image Size:** {self.save_dict['model_config']['img_size']}
- **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']})
## Usage
```python
import torch
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
big = True
# Load the model
name='big_model_weights.pth' if big else 'small_model_weights.pth'
checkpoint = torch.load(name)
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
```
## Model Architecture
{str(self.model)} """
rp(model_card)
try:
# Save and upload README
with open("README.md", "w") as f:
f.write(f"# Pix2Pix UNet Model\n\n"
f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n"
f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n"
f"## Model Architecture\n{str(self.model)}")
self.api.upload_file(
path_or_fileobj="README.md",
path_in_repo="README.md",
repo_id=self.repo_id,
token=self.token,
repo_type="model"
)
# Clean up local files
os.remove(pth_name)
os.remove("README.md")
print(f"Model successfully uploaded to {self.repo_id}")
except Exception as e:
print(f"Error uploading model: {e}")
def prepare_input(image, device='cpu'):
"""Prepare image for inference"""
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
input_tensor = transform(image).unsqueeze(0).to(device)
return input_tensor
def run_inference(image):
"""Run inference on a single image"""
global global_model
if global_model is None:
return "Error: Model not loaded"
global_model.eval()
input_tensor = prepare_input(image, device)
with torch.no_grad():
output = global_model(input_tensor)
# Convert output to image
output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
rp(output[0])
return output
def to_hub(model, epoch, loss):
wrapper = UNetWrapper(model, model_repo_id, epoch, loss)
wrapper.push_to_hub()
def train_model(epochs, save_interval=1):
"""Training function with checkpoint saving and model uploading."""
global global_model
# Load combined data CSV
data_path = 'combined_data.csv'
combined_data = pd.read_csv(data_path)
# Define the transformation
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
# Initialize dataset and dataloader
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer, clip_model)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
model = global_model
criterion = nn.L1Loss() # You may change this to suit your loss calculation needs
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
output_text = []
for epoch in range(epochs):
model.train()
running_loss = 0.0
for i, (latent_original, latent_target, latent_prompt) in enumerate(dataloader):
# Move data to device
latent_original, latent_target, latent_prompt = latent_original.to(device), latent_target.to(device), latent_prompt.to(device)
optimizer.zero_grad()
# Forward pass with the latents
output = model(latent_target, latent_prompt) # Assuming your model can take both target and prompt latents
# Calculate loss using the original latents
img_loss = criterion(output, latent_original)
total_loss = img_loss
total_loss.backward()
optimizer.step()
running_loss += total_loss.item()
if i % 10 == 0:
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
print(status)
output_text.append(status)
# Update the epoch and loss for checkpoint
wrapper.epoch = epoch + 1
wrapper.loss = running_loss / len(dataloader)
# Save checkpoint at specified intervals
if (epoch + 1) % save_interval == 0:
checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth'
wrapper.save_checkpoint(checkpoint_path)
wrapper.push_to_hub(checkpoint_path)
scheduler.step() # Update learning rate scheduler
global_model = model # Update global model after training
return model, "\n".join(output_text)
def train_model_old(epochs):
"""Training function"""
global global_model
# Load combined data CSV
data_path = 'combined_data.csv' # Adjust this path
combined_data = pd.read_csv(data_path)
# Define the transformation
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
# Initialize the dataset and dataloader
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
model = global_model
criterion = nn.L1Loss() # L1 loss for image reconstruction
optimizer = optim.Adam(model.parameters(), lr=LR)
output_text = []
for epoch in range(epochs):
model.train()
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
# Move images and prompt embeddings to the appropriate device (CPU or GPU)
original, target = original.to(device), target.to(device)
original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() # Convert to float
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() # Convert to float
optimizer.zero_grad()
# Forward pass through the model
output = model(target)
# Compute image reconstruction loss
img_loss = criterion(output, original)
rp(f"Image {i} Loss:{img_loss}")
# Combine losses
total_loss = img_loss # Add any other losses if necessary
total_loss.backward()
# Optimizer step
optimizer.step()
if i % 10 == 0:
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
rp(status)
output_text.append(status)
# Push model to Hugging Face Hub at the end of each epoch
to_hub(model, epoch, total_loss)
global_model = model # Update the global model after training
return model, "\n".join(output_text)
def gradio_train(epochs):
# Gradio training interface function
model, training_log = train_model(int(epochs))
#to_hub(model)
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
def gradio_inference(input_image):
# Gradio inference interface function
output_image = run_inference(input_image) # Assuming `run_inference` returns a tuple (output_image, other_data)
rp(output_image)
# If `run_inference` returns a tuple, you should only return the image part
return output_image # Ensure you're only returning the processed output image
# Create Gradio interface with tabs
with gr.Blocks() as app:
gr.Markdown("# Pix2Pix Model Training and Inference")
with gr.Tab("Train"):
epochs_input = gr.Number(value=EPOCHS, label="Number of epochs")
train_button = gr.Button("Train")
training_output = gr.Textbox(label="Training Log", interactive=False)
train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output])
with gr.Tab("Inference"):
image_input = gr.Image(type='numpy')
prompt_input = gr.Textbox(label="Prompt")
inference_button = gr.Button("Generate")
inference_output = gr.Image(type='numpy', label="Generated Image")
inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output])
load_model()
app.launch()