jadechoghari
commited on
Commit
•
ec033ea
1
Parent(s):
880499e
Update pipeline.py
Browse files- pipeline.py +18 -14
pipeline.py
CHANGED
@@ -2,7 +2,7 @@ from diffusers import DiffusionPipeline
|
|
2 |
import os
|
3 |
import sys
|
4 |
from huggingface_hub import HfApi, hf_hub_download
|
5 |
-
from .tools import build_dataset_json_from_list
|
6 |
import torch
|
7 |
|
8 |
class MOSDiffusionPipeline(DiffusionPipeline):
|
@@ -67,28 +67,32 @@ class MOSDiffusionPipeline(DiffusionPipeline):
|
|
67 |
|
68 |
|
69 |
@torch.no_grad()
|
70 |
-
def __call__(self,
|
71 |
"""
|
72 |
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
|
73 |
"""
|
74 |
from .infer.infer_mos5 import infer
|
|
|
75 |
|
|
|
76 |
infer(
|
77 |
-
dataset_key=
|
78 |
configs=self.configs,
|
79 |
config_yaml_path=self.config_yaml,
|
80 |
-
exp_group_name=
|
81 |
-
exp_name=
|
82 |
)
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
# Example of how to use the pipeline
|
85 |
if __name__ == "__main__":
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
reload_from_ckpt="checkpoints/checkpoint_389999.ckpt",
|
90 |
-
base_folder=None
|
91 |
-
)
|
92 |
-
|
93 |
-
# Run the pipeline
|
94 |
-
pipeline()
|
|
|
2 |
import os
|
3 |
import sys
|
4 |
from huggingface_hub import HfApi, hf_hub_download
|
5 |
+
# from .tools import build_dataset_json_from_list
|
6 |
import torch
|
7 |
|
8 |
class MOSDiffusionPipeline(DiffusionPipeline):
|
|
|
67 |
|
68 |
|
69 |
@torch.no_grad()
|
70 |
+
def __call__(self, prompt: str):
|
71 |
"""
|
72 |
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
|
73 |
"""
|
74 |
from .infer.infer_mos5 import infer
|
75 |
+
dataset_key = self.build_dataset_json_from_prompt(prompt)
|
76 |
|
77 |
+
# we run inference with the prompt - configs - and other settings
|
78 |
infer(
|
79 |
+
dataset_key=dataset_key,
|
80 |
configs=self.configs,
|
81 |
config_yaml_path=self.config_yaml,
|
82 |
+
exp_group_name="qa_mdt",
|
83 |
+
exp_name="mos_as_token"
|
84 |
)
|
85 |
|
86 |
+
def build_dataset_json_from_prompt(self, prompt: str):
|
87 |
+
"""
|
88 |
+
Build dataset_key dynamically from the provided prompt.
|
89 |
+
"""
|
90 |
+
# for simplicity let's just return the prompt as the dataset_key
|
91 |
+
return {"prompt": prompt}
|
92 |
+
|
93 |
+
|
94 |
# Example of how to use the pipeline
|
95 |
if __name__ == "__main__":
|
96 |
+
pipe = MOSDiffusionPipeline()
|
97 |
+
result = pipe("Generate a description of a sunny day.")
|
98 |
+
print(result)
|
|
|
|
|
|
|
|
|
|
|
|