Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
•
1643735
1
Parent(s):
3efe4b4
Blacken
Browse files- src/gistillery/ml.py +19 -6
- src/gistillery/worker.py +3 -1
src/gistillery/ml.py
CHANGED
@@ -32,7 +32,9 @@ class Processor(abc.ABC):
|
|
32 |
|
33 |
|
34 |
class Summarizer(abc.ABC):
|
35 |
-
def __init__(
|
|
|
|
|
36 |
raise NotImplementedError
|
37 |
|
38 |
def get_name(self) -> str:
|
@@ -44,7 +46,9 @@ class Summarizer(abc.ABC):
|
|
44 |
|
45 |
|
46 |
class Tagger(abc.ABC):
|
47 |
-
def __init__(
|
|
|
|
|
48 |
raise NotImplementedError
|
49 |
|
50 |
def get_name(self) -> str:
|
@@ -90,7 +94,9 @@ class MlRegistry:
|
|
90 |
|
91 |
|
92 |
class HfTransformersSummarizer(Summarizer):
|
93 |
-
def __init__(
|
|
|
|
|
94 |
self.model_name = model_name
|
95 |
self.model = model
|
96 |
self.tokenizer = tokenizer
|
@@ -101,7 +107,9 @@ class HfTransformersSummarizer(Summarizer):
|
|
101 |
def __call__(self, x: str) -> str:
|
102 |
text = self.template.format(x)
|
103 |
inputs = self.tokenizer(text, return_tensors="pt")
|
104 |
-
outputs = self.model.generate(
|
|
|
|
|
105 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
106 |
assert isinstance(output, str)
|
107 |
return output
|
@@ -111,7 +119,9 @@ class HfTransformersSummarizer(Summarizer):
|
|
111 |
|
112 |
|
113 |
class HfTransformersTagger(Tagger):
|
114 |
-
def __init__(
|
|
|
|
|
115 |
self.model_name = model_name
|
116 |
self.model = model
|
117 |
self.tokenizer = tokenizer
|
@@ -132,7 +142,9 @@ class HfTransformersTagger(Tagger):
|
|
132 |
def __call__(self, x: str) -> list[str]:
|
133 |
text = self.template.format(x)
|
134 |
inputs = self.tokenizer(text, return_tensors="pt")
|
135 |
-
outputs = self.model.generate(
|
|
|
|
|
136 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
137 |
tags = self._extract_tags(output)
|
138 |
return tags
|
@@ -171,6 +183,7 @@ class DefaultUrlProcessor(Processor):
|
|
171 |
text = self.template.format(url=self.url, content=text)
|
172 |
return text
|
173 |
|
|
|
174 |
# class ProcessorRegistry:
|
175 |
# def __init__(self) -> None:
|
176 |
# self.registry: list[Processor] = []
|
|
|
32 |
|
33 |
|
34 |
class Summarizer(abc.ABC):
|
35 |
+
def __init__(
|
36 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
37 |
+
) -> None:
|
38 |
raise NotImplementedError
|
39 |
|
40 |
def get_name(self) -> str:
|
|
|
46 |
|
47 |
|
48 |
class Tagger(abc.ABC):
|
49 |
+
def __init__(
|
50 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
51 |
+
) -> None:
|
52 |
raise NotImplementedError
|
53 |
|
54 |
def get_name(self) -> str:
|
|
|
94 |
|
95 |
|
96 |
class HfTransformersSummarizer(Summarizer):
|
97 |
+
def __init__(
|
98 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
99 |
+
) -> None:
|
100 |
self.model_name = model_name
|
101 |
self.model = model
|
102 |
self.tokenizer = tokenizer
|
|
|
107 |
def __call__(self, x: str) -> str:
|
108 |
text = self.template.format(x)
|
109 |
inputs = self.tokenizer(text, return_tensors="pt")
|
110 |
+
outputs = self.model.generate(
|
111 |
+
**inputs, generation_config=self.generation_config
|
112 |
+
)
|
113 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
114 |
assert isinstance(output, str)
|
115 |
return output
|
|
|
119 |
|
120 |
|
121 |
class HfTransformersTagger(Tagger):
|
122 |
+
def __init__(
|
123 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
124 |
+
) -> None:
|
125 |
self.model_name = model_name
|
126 |
self.model = model
|
127 |
self.tokenizer = tokenizer
|
|
|
142 |
def __call__(self, x: str) -> list[str]:
|
143 |
text = self.template.format(x)
|
144 |
inputs = self.tokenizer(text, return_tensors="pt")
|
145 |
+
outputs = self.model.generate(
|
146 |
+
**inputs, generation_config=self.generation_config
|
147 |
+
)
|
148 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
149 |
tags = self._extract_tags(output)
|
150 |
return tags
|
|
|
183 |
text = self.template.format(url=self.url, content=text)
|
184 |
return text
|
185 |
|
186 |
+
|
187 |
# class ProcessorRegistry:
|
188 |
# def __init__(self) -> None:
|
189 |
# self.registry: list[Processor] = []
|
src/gistillery/worker.py
CHANGED
@@ -122,7 +122,9 @@ def load_mlregistry(model_name: str) -> MlRegistry:
|
|
122 |
# increase the temperature to make the model more creative
|
123 |
config_tagger.temperature = 1.5
|
124 |
|
125 |
-
summarizer = HfTransformersSummarizer(
|
|
|
|
|
126 |
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
|
127 |
|
128 |
registry = MlRegistry()
|
|
|
122 |
# increase the temperature to make the model more creative
|
123 |
config_tagger.temperature = 1.5
|
124 |
|
125 |
+
summarizer = HfTransformersSummarizer(
|
126 |
+
model_name, model, tokenizer, config_summarizer
|
127 |
+
)
|
128 |
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
|
129 |
|
130 |
registry = MlRegistry()
|