K00B404 commited on
Commit
42ddf51
1 Parent(s): 91e19b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -18
app.py CHANGED
@@ -148,19 +148,18 @@ class UNetWrapper:
148
 
149
  except Exception as e:
150
  print(f"Error uploading model: {e}")
151
-
152
  # Training function
153
  def train_model(epochs):
154
  # Load the dataset
155
  ds = load_dataset(dataset_id)
156
  print(f"ds{ds}")
157
- # Create the transform function outside of the dataset class
158
  transform = transforms.Compose([
159
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
160
  transforms.ToTensor(),
161
  ])
162
 
163
- # Create dataset and dataloader
164
  dataset = Pix2PixDataset(ds, transform)
165
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
166
 
@@ -173,8 +172,10 @@ def train_model(epochs):
173
  criterion = nn.L1Loss()
174
  optimizer = optim.Adam(model.parameters(), lr=LR)
175
  output_text = []
 
176
  # Training loop
177
  for epoch in range(epochs):
 
178
  for i, (original, target) in enumerate(dataloader):
179
  original, target = original.to(device), target.to(device)
180
  optimizer.zero_grad()
@@ -188,36 +189,26 @@ def train_model(epochs):
188
  optimizer.step()
189
 
190
  if i % 100 == 0:
191
- status=f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
192
  print(status)
193
  output_text.append(status)
194
-
195
- # Here you could also use a delay to simulate training time
196
- yield "\n".join(output_text) # Send output to Gradio
197
-
198
- # Return trained model
199
- return model
200
 
201
- # Push model to Hugging Face Hub
202
- def push_model_to_hub(model, repo_id):
203
- wrapper = UNetWrapper(model, repo_id)
204
- wrapper.push_to_hub()
205
 
206
  # Gradio interface function
207
  def gradio_train(epochs):
208
- model = train_model(int(epochs))
209
  push_model_to_hub(model, model_repo_id)
210
- return f"Model trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
211
 
212
  # Gradio Interface
213
  gr_interface = gr.Interface(
214
  fn=gradio_train,
215
  inputs=gr.Number(label="Number of Epochs"),
216
- outputs="text",
217
  title="Pix2Pix Model Training",
218
  description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
219
  )
220
-
221
  if __name__ == '__main__':
222
  # Create or clone the repository if necessary
223
  repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
 
148
 
149
  except Exception as e:
150
  print(f"Error uploading model: {e}")
151
+
152
  # Training function
153
  def train_model(epochs):
154
  # Load the dataset
155
  ds = load_dataset(dataset_id)
156
  print(f"ds{ds}")
157
+
158
  transform = transforms.Compose([
159
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
160
  transforms.ToTensor(),
161
  ])
162
 
 
163
  dataset = Pix2PixDataset(ds, transform)
164
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
165
 
 
172
  criterion = nn.L1Loss()
173
  optimizer = optim.Adam(model.parameters(), lr=LR)
174
  output_text = []
175
+
176
  # Training loop
177
  for epoch in range(epochs):
178
+ model.train()
179
  for i, (original, target) in enumerate(dataloader):
180
  original, target = original.to(device), target.to(device)
181
  optimizer.zero_grad()
 
189
  optimizer.step()
190
 
191
  if i % 100 == 0:
192
+ status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
193
  print(status)
194
  output_text.append(status)
 
 
 
 
 
 
195
 
196
+ return model, "\n".join(output_text)
 
 
 
197
 
198
  # Gradio interface function
199
  def gradio_train(epochs):
200
+ model, training_log = train_model(int(epochs))
201
  push_model_to_hub(model, model_repo_id)
202
+ return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
203
 
204
  # Gradio Interface
205
  gr_interface = gr.Interface(
206
  fn=gradio_train,
207
  inputs=gr.Number(label="Number of Epochs"),
208
+ outputs=gr.Textbox(label="Training Progress", lines=10),
209
  title="Pix2Pix Model Training",
210
  description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
211
  )
 
212
  if __name__ == '__main__':
213
  # Create or clone the repository if necessary
214
  repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)