import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from datasets import load_dataset from huggingface_hub import Repository from huggingface_hub import HfApi, HfFolder, Repository, create_repo import os token = os.getenv('NEW_TOKEN') import gradio as gr from PIL import Image import os from small_256_model import UNet as small_UNet from big_1024_model import UNet as big_UNet # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') big = False if device == torch.device('cpu') else True # Parameters IMG_SIZE = 1024 if big else 256 BATCH_SIZE = 16 if big else 4 EPOCHS = 12 LR = 0.0002 dataset_id = "K00B404/pix2pix_flux_set" model_repo_id = "K00B404/pix2pix_flux" # Create dataset and dataloader class Pix2PixDataset(torch.utils.data.Dataset): def __init__(self, ds, transform): # Filter dataset for 'original' (label = 0) and 'target' (label = 1) images self.originals = [x for x in ds["train"] if x['label'] == 0] self.targets = [x for x in ds["train"] if x['label'] == 1] # Ensure the number of original and target images match assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images." # Debug: Print dataset size print(f"Number of original images: {len(self.originals)}") print(f"Number of target images: {len(self.targets)}") self.transform = transform # Store the transform def __len__(self): return len(self.originals) def __getitem__(self, idx): original_img = self.originals[idx]['image'] target_img = self.targets[idx]['image'] original = original_img.convert('RGB') # Convert to RGB if needed target = target_img.convert('RGB') # Convert to RGB if needed # Apply the necessary transforms return self.transform(original), self.transform(target) class UNetWrapper: def __init__(self, unet_model, repo_id): self.model = unet_model self.repo_id = repo_id self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set self.api = HfApi(token=os.getenv('NEW_TOKEN')) def push_to_hub(self): try: # Save model state and configuration save_dict = { 'model_state_dict': self.model.state_dict(), 'model_config': { 'big': isinstance(self.model, big_UNet), 'img_size': 1024 if isinstance(self.model, big_UNet) else 256 }, 'model_architecture': str(self.model) } # Save model locally pth_name = 'model_weights.pth' torch.save(save_dict, pth_name) # Create repo if it doesn't exist try: create_repo( repo_id=self.repo_id, token=self.token, exist_ok=True ) except Exception as e: print(f"Repository creation note: {e}") # Upload the model file self.api.upload_file( path_or_fileobj=pth_name, path_in_repo=pth_name, repo_id=self.repo_id, token=self.token, repo_type="model" ) # Create and upload model card model_card = f"""--- tags: - unet - pix2pix library_name: pytorch --- # Pix2Pix UNet Model ## Model Description Custom UNet model for Pix2Pix image translation. - Image Size: {1024 if isinstance(self.model, big_UNet) else 256} - Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"} ## Usage ```python import torch from small_256_model import UNet as small_UNet from big_1024_model import UNet as big_UNet # Load the model checkpoint = torch.load('model_weights.pth') model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() model.load_state_dict(checkpoint['model_state_dict']) model.eval() Model Architecture {str(self.model)} """ # Save and upload README with open("README.md", "w") as f: f.write(model_card) self.api.upload_file( path_or_fileobj="README.md", path_in_repo="README.md", repo_id=self.repo_id, token=self.token, repo_type="model" ) # Clean up local files os.remove(pth_name) os.remove("README.md") print(f"Model successfully uploaded to {self.repo_id}") except Exception as e: print(f"Error uploading model: {e}") # Training function def train_model(epochs): # Load the dataset ds = load_dataset(dataset_id) print(f"ds{ds}") transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) dataset = Pix2PixDataset(ds, transform) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # Initialize model, loss function, and optimizer try: model = UNet2DModel.from_pretrained(model_repo_id).to(device) except Exception: model = big_UNet().to(device) if big else small_UNet().to(device) criterion = nn.L1Loss() optimizer = optim.Adam(model.parameters(), lr=LR) output_text = [] # Training loop for epoch in range(epochs): model.train() for i, (original, target) in enumerate(dataloader): original, target = original.to(device), target.to(device) optimizer.zero_grad() # Forward pass output = model(target) # Generate cutout image loss = criterion(output, original) # Compare with original image # Backward pass loss.backward() optimizer.step() if i % 10 == 0: status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}" print(status) output_text.append(status) return model, "\n".join(output_text) # Push model to Hugging Face Hub def push_model_to_hub(model, repo_id): wrapper = UNetWrapper(model, repo_id) wrapper.push_to_hub() # Push the model to the Hugging Face hub #model.push_to_hub(repo_name) # Gradio interface function def gradio_train(epochs): model, training_log = train_model(int(epochs)) push_model_to_hub(model, model_repo_id) return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository." # Gradio Interface gr_interface = gr.Interface( fn=gradio_train, inputs=gr.Number(label="Number of Epochs"), outputs=gr.Textbox(label="Training Progress", lines=10), title="Pix2Pix Model Training", description="Train the Pix2Pix model and push it to the Hugging Face Hub repository." ) if __name__ == '__main__': # Create or clone the repository if necessary #repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id) #repo.git_pull() # Launch the Gradio app gr_interface.launch()