File size: 8,053 Bytes
d89262b
 
 
 
 
 
3010c48
b19f010
9d2e4ab
275c6f8
 
 
d89262b
 
 
 
a6be944
 
d89262b
 
 
 
3010c48
 
 
 
275c6f8
3010c48
 
 
 
d89262b
cc04230
d626bab
d71e34d
23e8de2
9dda4c2
 
23e8de2
cc04230
d626bab
 
cc04230
 
 
 
d71e34d
 
d626bab
 
 
 
 
 
d89262b
d71e34d
 
d89262b
d71e34d
 
b19f010
 
 
 
 
275c6f8
a82ec09
b19f010
 
91e19b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42ddf51
5010115
 
 
 
7a72c33
42ddf51
5010115
 
 
 
d71e34d
 
d89262b
 
 
3010c48
 
 
 
 
d89262b
 
1695057
42ddf51
d89262b
 
42ddf51
d89262b
 
 
 
 
3010c48
 
d89262b
 
 
 
 
248b003
42ddf51
1695057
 
d89262b
42ddf51
d89262b
248b003
 
 
 
 
 
 
d89262b
 
42ddf51
3010c48
42ddf51
d89262b
 
 
 
 
42ddf51
d89262b
 
 
 
3010c48
9d2e4ab
 
3010c48
d89262b
3010c48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
import os

token = os.getenv('NEW_TOKEN')

import gradio as gr
from PIL import Image
import os

from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

big = False if device == torch.device('cpu') else True

# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 16 if big else 4
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"

# Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
    def __init__(self, ds, transform):
        # Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
        self.originals = [x for x in ds["train"] if x['label'] == 0]
        self.targets = [x for x in ds["train"] if x['label'] == 1]

        # Ensure the number of original and target images match
        assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."

        # Debug: Print dataset size
        print(f"Number of original images: {len(self.originals)}")
        print(f"Number of target images: {len(self.targets)}")

        self.transform = transform  # Store the transform

    def __len__(self):
        return len(self.originals)

    def __getitem__(self, idx):
        original_img = self.originals[idx]['image']
        target_img = self.targets[idx]['image']

        original = original_img.convert('RGB')  # Convert to RGB if needed
        target = target_img.convert('RGB')      # Convert to RGB if needed

        # Apply the necessary transforms
        return self.transform(original), self.transform(target)

class UNetWrapper:
    def __init__(self, unet_model, repo_id):
        self.model = unet_model
        self.repo_id = repo_id
        self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
        self.api = HfApi(token=os.getenv('NEW_TOKEN'))

    def push_to_hub(self):
        try:
            # Save model state and configuration
            save_dict = {
                'model_state_dict': self.model.state_dict(),
                'model_config': {
                    'big': isinstance(self.model, big_UNet),
                    'img_size': 1024 if isinstance(self.model, big_UNet) else 256
                },
                'model_architecture': str(self.model)
            }
            
            # Save model locally
            pth_name = 'model_weights.pth'
            torch.save(save_dict, pth_name)
            
            # Create repo if it doesn't exist
            try:
                create_repo(
                    repo_id=self.repo_id, 
                    token=self.token,
                    exist_ok=True
                )
            except Exception as e:
                print(f"Repository creation note: {e}")
            
            # Upload the model file
            self.api.upload_file(
                path_or_fileobj=pth_name,
                path_in_repo=pth_name,
                repo_id=self.repo_id,
                token=self.token,
                repo_type="model"
            )
            
            # Create and upload model card
            model_card = f"""---
                        tags:
                        - unet
                        - pix2pix
                        library_name: pytorch
                        ---
                        
                        # Pix2Pix UNet Model
                        
                        ## Model Description
                        Custom UNet model for Pix2Pix image translation.
                        - Image Size: {1024 if isinstance(self.model, big_UNet) else 256}
                        - Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"}
                        
                        ## Usage
                        
                        ```python
                        import torch
                        from small_256_model import UNet as small_UNet
                        from big_1024_model import UNet as big_UNet
                        
                        # Load the model
                        checkpoint = torch.load('model_weights.pth')
                        model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
                        model.load_state_dict(checkpoint['model_state_dict'])
                        model.eval()
                        Model Architecture
                        {str(self.model)}
                    """
            # Save and upload README
            with open("README.md", "w") as f:
                f.write(model_card)
            
            self.api.upload_file(
                path_or_fileobj="README.md",
                path_in_repo="README.md",
                repo_id=self.repo_id,
                token=self.token,
                repo_type="model"
            )
            
            # Clean up local files
            os.remove(pth_name)
            os.remove("README.md")
            
            print(f"Model successfully uploaded to {self.repo_id}")
            
        except Exception as e:
            print(f"Error uploading model: {e}")
            
# Training function
def train_model(epochs):
    # Load the dataset
    ds = load_dataset(dataset_id)
    print(f"ds{ds}")
    
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ])
    
    dataset = Pix2PixDataset(ds, transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Initialize model, loss function, and optimizer
    try:
        model = UNet2DModel.from_pretrained(model_repo_id).to(device)
    except Exception:
        model = big_UNet().to(device) if big else small_UNet().to(device)

    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=LR)
    output_text = []
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        for i, (original, target) in enumerate(dataloader):
            original, target = original.to(device), target.to(device)
            optimizer.zero_grad()

            # Forward pass
            output = model(target)  # Generate cutout image
            loss = criterion(output, original)  # Compare with original image

            # Backward pass
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
                print(status)
                output_text.append(status)

    return model, "\n".join(output_text)

# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_id):
    wrapper = UNetWrapper(model, repo_id)
    wrapper.push_to_hub()
    # Push the model to the Hugging Face hub
    #model.push_to_hub(repo_name)
    
# Gradio interface function
def gradio_train(epochs):
    model, training_log = train_model(int(epochs))
    push_model_to_hub(model, model_repo_id)
    return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."

# Gradio Interface
gr_interface = gr.Interface(
    fn=gradio_train,
    inputs=gr.Number(label="Number of Epochs"),
    outputs=gr.Textbox(label="Training Progress", lines=10),
    title="Pix2Pix Model Training",
    description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)
if __name__ == '__main__':
    # Create or clone the repository if necessary
    #repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
    #repo.git_pull()

    # Launch the Gradio app
    gr_interface.launch()