Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from datasets import load_dataset | |
from huggingface_hub import Repository | |
from huggingface_hub import HfApi, HfFolder, Repository, create_repo | |
import os | |
token = os.getenv('NEW_TOKEN') | |
import gradio as gr | |
from PIL import Image | |
import os | |
from small_256_model import UNet as small_UNet | |
from big_1024_model import UNet as big_UNet | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
big = False if device == torch.device('cpu') else True | |
# Parameters | |
IMG_SIZE = 1024 if big else 256 | |
BATCH_SIZE = 16 if big else 4 | |
EPOCHS = 12 | |
LR = 0.0002 | |
dataset_id = "K00B404/pix2pix_flux_set" | |
model_repo_id = "K00B404/pix2pix_flux" | |
# Create dataset and dataloader | |
class Pix2PixDataset(torch.utils.data.Dataset): | |
def __init__(self, ds, transform): | |
# Filter dataset for 'original' (label = 0) and 'target' (label = 1) images | |
self.originals = [x for x in ds["train"] if x['label'] == 0] | |
self.targets = [x for x in ds["train"] if x['label'] == 1] | |
# Ensure the number of original and target images match | |
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images." | |
# Debug: Print dataset size | |
print(f"Number of original images: {len(self.originals)}") | |
print(f"Number of target images: {len(self.targets)}") | |
self.transform = transform # Store the transform | |
def __len__(self): | |
return len(self.originals) | |
def __getitem__(self, idx): | |
original_img = self.originals[idx]['image'] | |
target_img = self.targets[idx]['image'] | |
original = original_img.convert('RGB') # Convert to RGB if needed | |
target = target_img.convert('RGB') # Convert to RGB if needed | |
# Apply the necessary transforms | |
return self.transform(original), self.transform(target) | |
class UNetWrapper: | |
def __init__(self, unet_model, repo_id): | |
self.model = unet_model | |
self.repo_id = repo_id | |
self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set | |
self.api = HfApi(token=os.getenv('NEW_TOKEN')) | |
def push_to_hub(self): | |
try: | |
# Save model state and configuration | |
save_dict = { | |
'model_state_dict': self.model.state_dict(), | |
'model_config': { | |
'big': isinstance(self.model, big_UNet), | |
'img_size': 1024 if isinstance(self.model, big_UNet) else 256 | |
}, | |
'model_architecture': str(self.model) | |
} | |
# Save model locally | |
pth_name = 'model_weights.pth' | |
torch.save(save_dict, pth_name) | |
# Create repo if it doesn't exist | |
try: | |
create_repo( | |
repo_id=self.repo_id, | |
token=self.token, | |
exist_ok=True | |
) | |
except Exception as e: | |
print(f"Repository creation note: {e}") | |
# Upload the model file | |
self.api.upload_file( | |
path_or_fileobj=pth_name, | |
path_in_repo=pth_name, | |
repo_id=self.repo_id, | |
token=self.token, | |
repo_type="model" | |
) | |
# Create and upload model card | |
model_card = f"""--- | |
tags: | |
- unet | |
- pix2pix | |
library_name: pytorch | |
--- | |
# Pix2Pix UNet Model | |
## Model Description | |
Custom UNet model for Pix2Pix image translation. | |
- Image Size: {1024 if isinstance(self.model, big_UNet) else 256} | |
- Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"} | |
## Usage | |
```python | |
import torch | |
from small_256_model import UNet as small_UNet | |
from big_1024_model import UNet as big_UNet | |
# Load the model | |
checkpoint = torch.load('model_weights.pth') | |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
Model Architecture | |
{str(self.model)} | |
""" | |
# Save and upload README | |
with open("README.md", "w") as f: | |
f.write(model_card) | |
self.api.upload_file( | |
path_or_fileobj="README.md", | |
path_in_repo="README.md", | |
repo_id=self.repo_id, | |
token=self.token, | |
repo_type="model" | |
) | |
# Clean up local files | |
os.remove(pth_name) | |
os.remove("README.md") | |
print(f"Model successfully uploaded to {self.repo_id}") | |
except Exception as e: | |
print(f"Error uploading model: {e}") | |
# Training function | |
def train_model(epochs): | |
# Load the dataset | |
ds = load_dataset(dataset_id) | |
print(f"ds{ds}") | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
]) | |
dataset = Pix2PixDataset(ds, transform) | |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
# Initialize model, loss function, and optimizer | |
try: | |
model = UNet2DModel.from_pretrained(model_repo_id).to(device) | |
except Exception: | |
model = big_UNet().to(device) if big else small_UNet().to(device) | |
criterion = nn.L1Loss() | |
optimizer = optim.Adam(model.parameters(), lr=LR) | |
output_text = [] | |
# Training loop | |
for epoch in range(epochs): | |
model.train() | |
for i, (original, target) in enumerate(dataloader): | |
original, target = original.to(device), target.to(device) | |
optimizer.zero_grad() | |
# Forward pass | |
output = model(target) # Generate cutout image | |
loss = criterion(output, original) # Compare with original image | |
# Backward pass | |
loss.backward() | |
optimizer.step() | |
if i % 10 == 0: | |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}" | |
print(status) | |
output_text.append(status) | |
return model, "\n".join(output_text) | |
# Push model to Hugging Face Hub | |
def push_model_to_hub(model, repo_id): | |
wrapper = UNetWrapper(model, repo_id) | |
wrapper.push_to_hub() | |
# Push the model to the Hugging Face hub | |
#model.push_to_hub(repo_name) | |
# Gradio interface function | |
def gradio_train(epochs): | |
model, training_log = train_model(int(epochs)) | |
push_model_to_hub(model, model_repo_id) | |
return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository." | |
# Gradio Interface | |
gr_interface = gr.Interface( | |
fn=gradio_train, | |
inputs=gr.Number(label="Number of Epochs"), | |
outputs=gr.Textbox(label="Training Progress", lines=10), | |
title="Pix2Pix Model Training", | |
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository." | |
) | |
if __name__ == '__main__': | |
# Create or clone the repository if necessary | |
#repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id) | |
#repo.git_pull() | |
# Launch the Gradio app | |
gr_interface.launch() |