Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -148,19 +148,18 @@ class UNetWrapper:
|
|
148 |
|
149 |
except Exception as e:
|
150 |
print(f"Error uploading model: {e}")
|
151 |
-
|
152 |
# Training function
|
153 |
def train_model(epochs):
|
154 |
# Load the dataset
|
155 |
ds = load_dataset(dataset_id)
|
156 |
print(f"ds{ds}")
|
157 |
-
|
158 |
transform = transforms.Compose([
|
159 |
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
160 |
transforms.ToTensor(),
|
161 |
])
|
162 |
|
163 |
-
# Create dataset and dataloader
|
164 |
dataset = Pix2PixDataset(ds, transform)
|
165 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
166 |
|
@@ -173,8 +172,10 @@ def train_model(epochs):
|
|
173 |
criterion = nn.L1Loss()
|
174 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
175 |
output_text = []
|
|
|
176 |
# Training loop
|
177 |
for epoch in range(epochs):
|
|
|
178 |
for i, (original, target) in enumerate(dataloader):
|
179 |
original, target = original.to(device), target.to(device)
|
180 |
optimizer.zero_grad()
|
@@ -188,36 +189,26 @@ def train_model(epochs):
|
|
188 |
optimizer.step()
|
189 |
|
190 |
if i % 100 == 0:
|
191 |
-
status=f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
|
192 |
print(status)
|
193 |
output_text.append(status)
|
194 |
-
|
195 |
-
# Here you could also use a delay to simulate training time
|
196 |
-
yield "\n".join(output_text) # Send output to Gradio
|
197 |
-
|
198 |
-
# Return trained model
|
199 |
-
return model
|
200 |
|
201 |
-
|
202 |
-
def push_model_to_hub(model, repo_id):
|
203 |
-
wrapper = UNetWrapper(model, repo_id)
|
204 |
-
wrapper.push_to_hub()
|
205 |
|
206 |
# Gradio interface function
|
207 |
def gradio_train(epochs):
|
208 |
-
model = train_model(int(epochs))
|
209 |
push_model_to_hub(model, model_repo_id)
|
210 |
-
return f"
|
211 |
|
212 |
# Gradio Interface
|
213 |
gr_interface = gr.Interface(
|
214 |
fn=gradio_train,
|
215 |
inputs=gr.Number(label="Number of Epochs"),
|
216 |
-
outputs="
|
217 |
title="Pix2Pix Model Training",
|
218 |
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
|
219 |
)
|
220 |
-
|
221 |
if __name__ == '__main__':
|
222 |
# Create or clone the repository if necessary
|
223 |
repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
|
|
|
148 |
|
149 |
except Exception as e:
|
150 |
print(f"Error uploading model: {e}")
|
151 |
+
|
152 |
# Training function
|
153 |
def train_model(epochs):
|
154 |
# Load the dataset
|
155 |
ds = load_dataset(dataset_id)
|
156 |
print(f"ds{ds}")
|
157 |
+
|
158 |
transform = transforms.Compose([
|
159 |
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
160 |
transforms.ToTensor(),
|
161 |
])
|
162 |
|
|
|
163 |
dataset = Pix2PixDataset(ds, transform)
|
164 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
165 |
|
|
|
172 |
criterion = nn.L1Loss()
|
173 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
174 |
output_text = []
|
175 |
+
|
176 |
# Training loop
|
177 |
for epoch in range(epochs):
|
178 |
+
model.train()
|
179 |
for i, (original, target) in enumerate(dataloader):
|
180 |
original, target = original.to(device), target.to(device)
|
181 |
optimizer.zero_grad()
|
|
|
189 |
optimizer.step()
|
190 |
|
191 |
if i % 100 == 0:
|
192 |
+
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
|
193 |
print(status)
|
194 |
output_text.append(status)
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
+
return model, "\n".join(output_text)
|
|
|
|
|
|
|
197 |
|
198 |
# Gradio interface function
|
199 |
def gradio_train(epochs):
|
200 |
+
model, training_log = train_model(int(epochs))
|
201 |
push_model_to_hub(model, model_repo_id)
|
202 |
+
return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
|
203 |
|
204 |
# Gradio Interface
|
205 |
gr_interface = gr.Interface(
|
206 |
fn=gradio_train,
|
207 |
inputs=gr.Number(label="Number of Epochs"),
|
208 |
+
outputs=gr.Textbox(label="Training Progress", lines=10),
|
209 |
title="Pix2Pix Model Training",
|
210 |
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
|
211 |
)
|
|
|
212 |
if __name__ == '__main__':
|
213 |
# Create or clone the repository if necessary
|
214 |
repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
|