Spaces:
Runtime error
Runtime error
ckpt download as well as torch bug fix
Browse files- app.py +6 -3
- requirements.txt +2 -2
app.py
CHANGED
@@ -64,11 +64,12 @@ class Demo:
|
|
64 |
label="Learning Rate",
|
65 |
info='Learning rate used to train'
|
66 |
)
|
67 |
-
self.progress_bar = gr.Text(interactive=False, label="Training Progress")
|
68 |
|
69 |
self.train_button = gr.Button(
|
70 |
value="Train",
|
71 |
)
|
|
|
|
|
72 |
|
73 |
with gr.Column(scale=2) as inference_column:
|
74 |
|
@@ -125,7 +126,7 @@ class Demo:
|
|
125 |
self.iterations_input,
|
126 |
self.lr_input
|
127 |
],
|
128 |
-
outputs=[self.train_button, self.infr_button, self.
|
129 |
)
|
130 |
|
131 |
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
@@ -201,6 +202,8 @@ class Demo:
|
|
201 |
loss.backward()
|
202 |
optimizer.step()
|
203 |
|
|
|
|
|
204 |
self.finetuner = finetuner.eval().half()
|
205 |
|
206 |
self.diffuser = self.diffuser.eval().half()
|
@@ -209,7 +212,7 @@ class Demo:
|
|
209 |
|
210 |
self.training = False
|
211 |
|
212 |
-
return [gr.update(interactive=True), gr.update(interactive=True),
|
213 |
|
214 |
|
215 |
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
|
|
|
64 |
label="Learning Rate",
|
65 |
info='Learning rate used to train'
|
66 |
)
|
|
|
67 |
|
68 |
self.train_button = gr.Button(
|
69 |
value="Train",
|
70 |
)
|
71 |
+
|
72 |
+
self.download = gr.Files()
|
73 |
|
74 |
with gr.Column(scale=2) as inference_column:
|
75 |
|
|
|
126 |
self.iterations_input,
|
127 |
self.lr_input
|
128 |
],
|
129 |
+
outputs=[self.train_button, self.infr_button, self.download]
|
130 |
)
|
131 |
|
132 |
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
|
|
202 |
loss.backward()
|
203 |
optimizer.step()
|
204 |
|
205 |
+
torch.save(finetuner.state_dict(), 'ft.ckpt')
|
206 |
+
|
207 |
self.finetuner = finetuner.eval().half()
|
208 |
|
209 |
self.diffuser = self.diffuser.eval().half()
|
|
|
212 |
|
213 |
self.training = False
|
214 |
|
215 |
+
return [gr.update(interactive=True), gr.update(interactive=True), 'ft.ckpt']
|
216 |
|
217 |
|
218 |
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
gradio
|
2 |
-
torch
|
3 |
-
torchvision
|
4 |
diffusers
|
5 |
transformers
|
6 |
accelerate
|
|
|
1 |
gradio
|
2 |
+
torch==1.13.1 --index-url https://download.pytorch.org/whl/cu118
|
3 |
+
torchvision==0.14.1 --index-url https://download.pytorch.org/whl/cu118
|
4 |
diffusers
|
5 |
transformers
|
6 |
accelerate
|