import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from datasets import load_dataset from huggingface_hub import Repository from huggingface_hub import HfApi, HfFolder, Repository, create_repo import os import pandas as pd import gradio as gr from PIL import Image import numpy as np from small_256_model import UNet as small_UNet from big_1024_model import UNet as big_UNet from CLIP import load as load_clip from rich import print as rp # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') big = False if device == torch.device('cpu') else True # Parameters IMG_SIZE = 1024 if big else 256 BATCH_SIZE = 1 if big else 1 EPOCHS = 12 LR = 0.0002 dataset_id = "K00B404/pix2pix_flux_set" model_repo_id = "K00B404/pix2pix_flux" # Global model variable global_model = None # CLIP clip_model, clip_tokenizer = load_clip() def load_model(): """Load the models at startup""" global global_model weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth' try: checkpoint = torch.load(weights_name, map_location=device) model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() global_model = model rp("Model loaded successfully!") return model except Exception as e: rp(f"Error loading model: {e}") model = big_UNet().to(device) if big else small_UNet().to(device) global_model = model return model class Pix2PixDataset(torch.utils.data.Dataset): def __init__(self, combined_data, transform, clip_tokenizer): self.data = combined_data self.transform = transform self.clip_tokenizer = clip_tokenizer self.original_folder = 'images_dataset/original/' self.target_folder = 'images_dataset/target/' def __len__(self): return len(self.data) def __getitem__(self, idx): original_img_filename = os.path.basename(self.data.iloc[idx]['image_path']) original_img_path = os.path.join(self.original_folder, original_img_filename) target_img_path = os.path.join(self.target_folder, original_img_filename) original_img = Image.open(original_img_path).convert('RGB') target_img = Image.open(target_img_path).convert('RGB') # Transform images original = self.transform(original_img) target = self.transform(target_img) # Get prompts from the DataFrame original_prompt = self.data.iloc[idx]['original_prompt'] enhanced_prompt = self.data.iloc[idx]['enhanced_prompt'] # Tokenize the prompts using CLIP tokenizer original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) return original, target, original_tokens, enhanced_tokens class UNetWrapper: def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None): self.loss = loss self.epoch = epoch self.model = unet_model self.optimizer = optimizer self.scheduler = scheduler self.repo_id = repo_id self.token = os.getenv('NEW_TOKEN') # Ensure the token is set in the environment self.api = HfApi(token=self.token) def save_checkpoint(self, save_path): """Save checkpoint with model, optimizer, and scheduler states.""" self.save_dict = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, 'model_config': { 'big': isinstance(self.model, big_UNet), 'img_size': 1024 if isinstance(self.model, big_UNet) else 256 }, 'epoch': self.epoch, 'loss': self.loss } torch.save(self.save_dict, save_path) print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}") def load_checkpoint(self, checkpoint_path): """Load model, optimizer, and scheduler states from the checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location=device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if self.scheduler and checkpoint['scheduler_state_dict']: self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.epoch = checkpoint['epoch'] self.loss = checkpoint['loss'] print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}") def push_to_hub(self, pth_name): """Push model checkpoint and metadata to the Hugging Face Hub.""" try: self.api.upload_file( path_or_fileobj=pth_name, path_in_repo=pth_name, repo_id=self.repo_id, token=self.token, repo_type="model" ) print(f"Model checkpoint successfully uploaded to {self.repo_id}") except Exception as e: print(f"Error uploading model: {e}") # Create and upload model card model_card = f"""--- tags: - unet - pix2pix - pytorch library_name: pytorch license: wtfpl datasets: - K00B404/pix2pix_flux_set language: - en pipeline_tag: image-to-image --- # Pix2Pix UNet Model ## Model Description Custom UNet model for Pix2Pix image translation. - **Image Size:** {self.save_dict['model_config']['img_size']} - **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']}) ## Usage ```python import torch from small_256_model import UNet as small_UNet from big_1024_model import UNet as big_UNet big = True # Load the model name='big_model_weights.pth' if big else 'small_model_weights.pth' checkpoint = torch.load(name) model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() model.load_state_dict(checkpoint['model_state_dict']) model.eval() ``` ## Model Architecture {str(self.model)} """ rp(model_card) try: # Save and upload README with open("README.md", "w") as f: f.write(f"# Pix2Pix UNet Model\n\n" f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n" f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n" f"## Model Architecture\n{str(self.model)}") self.api.upload_file( path_or_fileobj="README.md", path_in_repo="README.md", repo_id=self.repo_id, token=self.token, repo_type="model" ) # Clean up local files os.remove(pth_name) os.remove("README.md") print(f"Model successfully uploaded to {self.repo_id}") except Exception as e: print(f"Error uploading model: {e}") def prepare_input(image, device='cpu'): """Prepare image for inference""" transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) if isinstance(image, np.ndarray): image = Image.fromarray(image) input_tensor = transform(image).unsqueeze(0).to(device) return input_tensor def run_inference(image): """Run inference on a single image""" global global_model if global_model is None: return "Error: Model not loaded" global_model.eval() input_tensor = prepare_input(image, device) with torch.no_grad(): output = global_model(input_tensor) # Convert output to image output = output.cpu().squeeze(0).permute(1, 2, 0).numpy() output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8) rp(output[0]) return output def to_hub(model, epoch, loss): wrapper = UNetWrapper(model, model_repo_id, epoch, loss) wrapper.push_to_hub() def train_model(epochs, save_interval=1): """Training function with checkpoint saving and model uploading.""" global global_model # Load combined data CSV data_path = 'combined_data.csv' combined_data = pd.read_csv(data_path) # Define the transformation transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) # Initialize dataset and dataloader dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) model = global_model criterion = nn.L1Loss() optimizer = optim.Adam(model.parameters(), lr=LR) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler) output_text = [] for epoch in range(epochs): model.train() running_loss = 0.0 for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): # Move data to device original, target = original.to(device), target.to(device) original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() optimizer.zero_grad() # Forward pass output = model(target) img_loss = criterion(output, original) total_loss = img_loss total_loss.backward() optimizer.step() running_loss += total_loss.item() if i % 10 == 0: status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" print(status) output_text.append(status) # Update the epoch and loss for checkpoint wrapper.epoch = epoch + 1 wrapper.loss = running_loss / len(dataloader) # Save checkpoint at specified intervals if (epoch + 1) % save_interval == 0: checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth' wrapper.save_checkpoint(checkpoint_path) wrapper.push_to_hub(checkpoint_path) scheduler.step() # Update learning rate scheduler global_model = model # Update global model after training return model, "\n".join(output_text) def train_model_old(epochs): """Training function""" global global_model # Load combined data CSV data_path = 'combined_data.csv' # Adjust this path combined_data = pd.read_csv(data_path) # Define the transformation transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) # Initialize the dataset and dataloader dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) model = global_model criterion = nn.L1Loss() # L1 loss for image reconstruction optimizer = optim.Adam(model.parameters(), lr=LR) output_text = [] for epoch in range(epochs): model.train() for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): # Move images and prompt embeddings to the appropriate device (CPU or GPU) original, target = original.to(device), target.to(device) original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() # Convert to float enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() # Convert to float optimizer.zero_grad() # Forward pass through the model output = model(target) # Compute image reconstruction loss img_loss = criterion(output, original) rp(f"Image {i} Loss:{img_loss}") # Combine losses total_loss = img_loss # Add any other losses if necessary total_loss.backward() # Optimizer step optimizer.step() if i % 10 == 0: status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" rp(status) output_text.append(status) # Push model to Hugging Face Hub at the end of each epoch to_hub(model, epoch, total_loss) global_model = model # Update the global model after training return model, "\n".join(output_text) def gradio_train(epochs): # Gradio training interface function model, training_log = train_model(int(epochs)) #to_hub(model) return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}" def gradio_inference(input_image): # Gradio inference interface function output_image = run_inference(input_image) # Assuming `run_inference` returns a tuple (output_image, other_data) rp(output_image) # If `run_inference` returns a tuple, you should only return the image part return output_image # Ensure you're only returning the processed output image # Create Gradio interface with tabs with gr.Blocks() as app: gr.Markdown("# Pix2Pix Model Training and Inference") with gr.Tab("Train"): epochs_input = gr.Number(value=EPOCHS, label="Number of epochs") train_button = gr.Button("Train") training_output = gr.Textbox(label="Training Log", interactive=False) train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output]) with gr.Tab("Inference"): image_input = gr.Image(type='numpy') prompt_input = gr.Textbox(label="Prompt") inference_button = gr.Button("Generate") inference_output = gr.Image(type='numpy', label="Generated Image") inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output]) load_model() app.launch()