pix2pix_flux_train / app_back.py
K00B404's picture
Rename app.py to app_back.py
38b513f verified
raw
history blame
8.05 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
import os
token = os.getenv('NEW_TOKEN')
import gradio as gr
from PIL import Image
import os
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
big = False if device == torch.device('cpu') else True
# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 16 if big else 4
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"
# Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
def __init__(self, ds, transform):
# Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
self.originals = [x for x in ds["train"] if x['label'] == 0]
self.targets = [x for x in ds["train"] if x['label'] == 1]
# Ensure the number of original and target images match
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
# Debug: Print dataset size
print(f"Number of original images: {len(self.originals)}")
print(f"Number of target images: {len(self.targets)}")
self.transform = transform # Store the transform
def __len__(self):
return len(self.originals)
def __getitem__(self, idx):
original_img = self.originals[idx]['image']
target_img = self.targets[idx]['image']
original = original_img.convert('RGB') # Convert to RGB if needed
target = target_img.convert('RGB') # Convert to RGB if needed
# Apply the necessary transforms
return self.transform(original), self.transform(target)
class UNetWrapper:
def __init__(self, unet_model, repo_id):
self.model = unet_model
self.repo_id = repo_id
self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
self.api = HfApi(token=os.getenv('NEW_TOKEN'))
def push_to_hub(self):
try:
# Save model state and configuration
save_dict = {
'model_state_dict': self.model.state_dict(),
'model_config': {
'big': isinstance(self.model, big_UNet),
'img_size': 1024 if isinstance(self.model, big_UNet) else 256
},
'model_architecture': str(self.model)
}
# Save model locally
pth_name = 'model_weights.pth'
torch.save(save_dict, pth_name)
# Create repo if it doesn't exist
try:
create_repo(
repo_id=self.repo_id,
token=self.token,
exist_ok=True
)
except Exception as e:
print(f"Repository creation note: {e}")
# Upload the model file
self.api.upload_file(
path_or_fileobj=pth_name,
path_in_repo=pth_name,
repo_id=self.repo_id,
token=self.token,
repo_type="model"
)
# Create and upload model card
model_card = f"""---
tags:
- unet
- pix2pix
library_name: pytorch
---
# Pix2Pix UNet Model
## Model Description
Custom UNet model for Pix2Pix image translation.
- Image Size: {1024 if isinstance(self.model, big_UNet) else 256}
- Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"}
## Usage
```python
import torch
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
# Load the model
checkpoint = torch.load('model_weights.pth')
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
Model Architecture
{str(self.model)}
"""
# Save and upload README
with open("README.md", "w") as f:
f.write(model_card)
self.api.upload_file(
path_or_fileobj="README.md",
path_in_repo="README.md",
repo_id=self.repo_id,
token=self.token,
repo_type="model"
)
# Clean up local files
os.remove(pth_name)
os.remove("README.md")
print(f"Model successfully uploaded to {self.repo_id}")
except Exception as e:
print(f"Error uploading model: {e}")
# Training function
def train_model(epochs):
# Load the dataset
ds = load_dataset(dataset_id)
print(f"ds{ds}")
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
dataset = Pix2PixDataset(ds, transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model, loss function, and optimizer
try:
model = UNet2DModel.from_pretrained(model_repo_id).to(device)
except Exception:
model = big_UNet().to(device) if big else small_UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)
output_text = []
# Training loop
for epoch in range(epochs):
model.train()
for i, (original, target) in enumerate(dataloader):
original, target = original.to(device), target.to(device)
optimizer.zero_grad()
# Forward pass
output = model(target) # Generate cutout image
loss = criterion(output, original) # Compare with original image
# Backward pass
loss.backward()
optimizer.step()
if i % 10 == 0:
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
print(status)
output_text.append(status)
return model, "\n".join(output_text)
# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_id):
wrapper = UNetWrapper(model, repo_id)
wrapper.push_to_hub()
# Push the model to the Hugging Face hub
#model.push_to_hub(repo_name)
# Gradio interface function
def gradio_train(epochs):
model, training_log = train_model(int(epochs))
push_model_to_hub(model, model_repo_id)
return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
# Gradio Interface
gr_interface = gr.Interface(
fn=gradio_train,
inputs=gr.Number(label="Number of Epochs"),
outputs=gr.Textbox(label="Training Progress", lines=10),
title="Pix2Pix Model Training",
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)
if __name__ == '__main__':
# Create or clone the repository if necessary
#repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
#repo.git_pull()
# Launch the Gradio app
gr_interface.launch()