Spaces:
Running
Running
File size: 8,053 Bytes
d89262b 3010c48 b19f010 9d2e4ab 275c6f8 d89262b a6be944 d89262b 3010c48 275c6f8 3010c48 d89262b cc04230 d626bab d71e34d 23e8de2 9dda4c2 23e8de2 cc04230 d626bab cc04230 d71e34d d626bab d89262b d71e34d d89262b d71e34d b19f010 275c6f8 a82ec09 b19f010 91e19b1 42ddf51 5010115 7a72c33 42ddf51 5010115 d71e34d d89262b 3010c48 d89262b 1695057 42ddf51 d89262b 42ddf51 d89262b 3010c48 d89262b 248b003 42ddf51 1695057 d89262b 42ddf51 d89262b 248b003 d89262b 42ddf51 3010c48 42ddf51 d89262b 42ddf51 d89262b 3010c48 9d2e4ab 3010c48 d89262b 3010c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
import os
token = os.getenv('NEW_TOKEN')
import gradio as gr
from PIL import Image
import os
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
big = False if device == torch.device('cpu') else True
# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 16 if big else 4
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"
# Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
def __init__(self, ds, transform):
# Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
self.originals = [x for x in ds["train"] if x['label'] == 0]
self.targets = [x for x in ds["train"] if x['label'] == 1]
# Ensure the number of original and target images match
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
# Debug: Print dataset size
print(f"Number of original images: {len(self.originals)}")
print(f"Number of target images: {len(self.targets)}")
self.transform = transform # Store the transform
def __len__(self):
return len(self.originals)
def __getitem__(self, idx):
original_img = self.originals[idx]['image']
target_img = self.targets[idx]['image']
original = original_img.convert('RGB') # Convert to RGB if needed
target = target_img.convert('RGB') # Convert to RGB if needed
# Apply the necessary transforms
return self.transform(original), self.transform(target)
class UNetWrapper:
def __init__(self, unet_model, repo_id):
self.model = unet_model
self.repo_id = repo_id
self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
self.api = HfApi(token=os.getenv('NEW_TOKEN'))
def push_to_hub(self):
try:
# Save model state and configuration
save_dict = {
'model_state_dict': self.model.state_dict(),
'model_config': {
'big': isinstance(self.model, big_UNet),
'img_size': 1024 if isinstance(self.model, big_UNet) else 256
},
'model_architecture': str(self.model)
}
# Save model locally
pth_name = 'model_weights.pth'
torch.save(save_dict, pth_name)
# Create repo if it doesn't exist
try:
create_repo(
repo_id=self.repo_id,
token=self.token,
exist_ok=True
)
except Exception as e:
print(f"Repository creation note: {e}")
# Upload the model file
self.api.upload_file(
path_or_fileobj=pth_name,
path_in_repo=pth_name,
repo_id=self.repo_id,
token=self.token,
repo_type="model"
)
# Create and upload model card
model_card = f"""---
tags:
- unet
- pix2pix
library_name: pytorch
---
# Pix2Pix UNet Model
## Model Description
Custom UNet model for Pix2Pix image translation.
- Image Size: {1024 if isinstance(self.model, big_UNet) else 256}
- Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"}
## Usage
```python
import torch
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
# Load the model
checkpoint = torch.load('model_weights.pth')
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
Model Architecture
{str(self.model)}
"""
# Save and upload README
with open("README.md", "w") as f:
f.write(model_card)
self.api.upload_file(
path_or_fileobj="README.md",
path_in_repo="README.md",
repo_id=self.repo_id,
token=self.token,
repo_type="model"
)
# Clean up local files
os.remove(pth_name)
os.remove("README.md")
print(f"Model successfully uploaded to {self.repo_id}")
except Exception as e:
print(f"Error uploading model: {e}")
# Training function
def train_model(epochs):
# Load the dataset
ds = load_dataset(dataset_id)
print(f"ds{ds}")
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
dataset = Pix2PixDataset(ds, transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model, loss function, and optimizer
try:
model = UNet2DModel.from_pretrained(model_repo_id).to(device)
except Exception:
model = big_UNet().to(device) if big else small_UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)
output_text = []
# Training loop
for epoch in range(epochs):
model.train()
for i, (original, target) in enumerate(dataloader):
original, target = original.to(device), target.to(device)
optimizer.zero_grad()
# Forward pass
output = model(target) # Generate cutout image
loss = criterion(output, original) # Compare with original image
# Backward pass
loss.backward()
optimizer.step()
if i % 10 == 0:
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
print(status)
output_text.append(status)
return model, "\n".join(output_text)
# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_id):
wrapper = UNetWrapper(model, repo_id)
wrapper.push_to_hub()
# Push the model to the Hugging Face hub
#model.push_to_hub(repo_name)
# Gradio interface function
def gradio_train(epochs):
model, training_log = train_model(int(epochs))
push_model_to_hub(model, model_repo_id)
return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
# Gradio Interface
gr_interface = gr.Interface(
fn=gradio_train,
inputs=gr.Number(label="Number of Epochs"),
outputs=gr.Textbox(label="Training Progress", lines=10),
title="Pix2Pix Model Training",
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)
if __name__ == '__main__':
# Create or clone the repository if necessary
#repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
#repo.git_pull()
# Launch the Gradio app
gr_interface.launch() |