K00B404 commited on
Commit
d89262b
1 Parent(s): 01a7d56

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define the Pix2Pix model (UNet)
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torch.utils.data import DataLoader
7
+ from torchvision import transforms
8
+ from datasets import load_dataset
9
+ from huggingface_hub import Repository, create_repo
10
+ import gradio as gr
11
+ from PIL import Image
12
+ import os
13
+
14
+ # Parameters
15
+ IMG_SIZE = 256
16
+ BATCH_SIZE = 1
17
+ EPOCHS = 12
18
+ LR = 0.0002
19
+
20
+ # Device configuration
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+
23
+ # Define the Pix2Pix model (Simplified UNet)
24
+ class UNet(nn.Module):
25
+ def __init__(self):
26
+ super(UNet, self).__init__()
27
+
28
+ # Encoder
29
+ self.encoder = nn.Sequential(
30
+ nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), # 256 -> 128
31
+ nn.ReLU(inplace=True),
32
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 -> 64
33
+ nn.ReLU(inplace=True),
34
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 64 -> 32
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 32 -> 16
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1), # 16 -> 8
39
+ nn.ReLU(inplace=True)
40
+ )
41
+
42
+ # Decoder
43
+ self.decoder = nn.Sequential(
44
+ nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), # 8 -> 16
45
+ nn.ReLU(inplace=True),
46
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 16 -> 32
47
+ nn.ReLU(inplace=True),
48
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 32 -> 64
49
+ nn.ReLU(inplace=True),
50
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64 -> 128
51
+ nn.ReLU(inplace=True),
52
+ nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), # 128 -> 256
53
+ nn.Tanh() # Output range [-1, 1]
54
+ )
55
+
56
+ def forward(self, x):
57
+ enc = self.encoder(x)
58
+ dec = self.decoder(enc)
59
+ return dec
60
+
61
+ # Training function
62
+ def train_model(epochs):
63
+ # Load the dataset
64
+ ds = load_dataset("K00B404/pix2pix_flux_set")
65
+
66
+ # Transform function to resize and convert to tensor
67
+ transform = transforms.Compose([
68
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
69
+ transforms.ToTensor(),
70
+ ])
71
+
72
+ # Create dataset and dataloader
73
+ class Pix2PixDataset(torch.utils.data.Dataset):
74
+ def __init__(self, ds):
75
+ self.ds = ds
76
+
77
+ def __len__(self):
78
+ return len(self.ds["train"])
79
+
80
+ def __getitem__(self, idx):
81
+ original = Image.open(self.ds["train"][idx]['original_image']).convert('RGB')
82
+ target = Image.open(self.ds["train"][idx]['target_image']).convert('RGB')
83
+ return transform(original), transform(target)
84
+
85
+ dataset = Pix2PixDataset(ds)
86
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
87
+
88
+ # Initialize model, loss function, and optimizer
89
+ model = UNet().to(device)
90
+ criterion = nn.L1Loss()
91
+ optimizer = optim.Adam(model.parameters(), lr=LR)
92
+
93
+ # Training loop
94
+ for epoch in range(epochs):
95
+ for i, (original, target) in enumerate(dataloader):
96
+ original, target = original.to(device), target.to(device)
97
+ optimizer.zero_grad()
98
+
99
+ # Forward pass
100
+ output = model(target)
101
+ loss = criterion(output, original)
102
+
103
+ # Backward pass
104
+ loss.backward()
105
+ optimizer.step()
106
+
107
+ if i % 100 == 0:
108
+ print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")
109
+
110
+ # Return trained model
111
+ return model
112
+
113
+ # Push model to Hugging Face Hub
114
+ def push_model_to_hub(model, repo_name):
115
+ repo = Repository(repo_name)
116
+ repo.push_to_hub()
117
+
118
+ # Save the model state dict
119
+ model_save_path = os.path.join(repo_name, "pix2pix_model.pth")
120
+ torch.save(model.state_dict(), model_save_path)
121
+
122
+ # Push the model to the repo
123
+ repo.push_to_hub(commit_message="Initial commit with trained Pix2Pix model.")
124
+
125
+ # Gradio interface function
126
+ def gradio_train(epochs):
127
+ model = train_model(int(epochs))
128
+ push_model_to_hub(model, "K00B404/pix2pix_flux")
129
+ return f"Model trained for {epochs} epochs and pushed to Hugging Face Hub repository 'K00B404/pix2pix_flux'."
130
+
131
+ # Gradio Interface
132
+ gr_interface = gr.Interface(
133
+ fn=gradio_train,
134
+ inputs=gr.Number(label="Number of Epochs"),
135
+ outputs="text",
136
+ title="Pix2Pix Model Training",
137
+ description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
138
+ )
139
+
140
+ if __name__ == '__main__':
141
+ # Create or clone the repository
142
+ create_repo("K00B404/pix2pix_flux", exist_ok=True)
143
+
144
+ # Launch the Gradio app
145
+ gr_interface.launch()