Spaces:
Running
Running
annotation
Browse files
app.py
CHANGED
@@ -133,9 +133,11 @@ def main():
|
|
133 |
# col1, col2, col3 = st.columns(3, gap="medium")
|
134 |
col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
|
135 |
sentiment = col1.slider(
|
136 |
-
"Sentiment
|
|
|
137 |
detoxification = col2.slider(
|
138 |
-
"Detoxification Strength
|
|
|
139 |
steer_interval)
|
140 |
max_length = col3.number_input("Max length", 50, 300, 50, 50)
|
141 |
col1, col2, col3, _ = st.columns(4)
|
@@ -144,15 +146,16 @@ def main():
|
|
144 |
if "output" not in st.session_state:
|
145 |
st.session_state.output = ""
|
146 |
if col1.button("Steer and generate!", type="primary"):
|
147 |
-
|
148 |
-
|
149 |
-
st.session_state.
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
156 |
analyzed_text = \
|
157 |
st.text_area("Generated text:", st.session_state.output, height=200)
|
158 |
|
@@ -176,46 +179,51 @@ def main():
|
|
176 |
[2, 0],
|
177 |
["#ff7f0e", "#1f77b4"],
|
178 |
):
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
st.divider()
|
221 |
st.divider()
|
@@ -234,7 +242,8 @@ def main():
|
|
234 |
["Sentiment", "Detoxification"],
|
235 |
)
|
236 |
dim = 2 if dimension == "Sentiment" else 0
|
237 |
-
|
|
|
238 |
|
239 |
|
240 |
if __name__ == "__main__":
|
|
|
133 |
# col1, col2, col3 = st.columns(3, gap="medium")
|
134 |
col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
|
135 |
sentiment = col1.slider(
|
136 |
+
"Sentiment (the larger the more positive)",
|
137 |
+
-steer_range, steer_range, 3.0, steer_interval)
|
138 |
detoxification = col2.slider(
|
139 |
+
"Detoxification Strength (the larger the less toxic)",
|
140 |
+
-steer_range, steer_range, 0.0,
|
141 |
steer_interval)
|
142 |
max_length = col3.number_input("Max length", 50, 300, 50, 50)
|
143 |
col1, col2, col3, _ = st.columns(4)
|
|
|
146 |
if "output" not in st.session_state:
|
147 |
st.session_state.output = ""
|
148 |
if col1.button("Steer and generate!", type="primary"):
|
149 |
+
with st.spinner("Generating..."):
|
150 |
+
steer_values = [detoxification, 0, sentiment, 0]
|
151 |
+
st.session_state.output = model.generate(
|
152 |
+
st.session_state.prompt,
|
153 |
+
steer_values,
|
154 |
+
seed=None if randomness else 0,
|
155 |
+
min_length=0,
|
156 |
+
max_length=max_length,
|
157 |
+
do_sample=True,
|
158 |
+
)
|
159 |
analyzed_text = \
|
160 |
st.text_area("Generated text:", st.session_state.output, height=200)
|
161 |
|
|
|
179 |
[2, 0],
|
180 |
["#ff7f0e", "#1f77b4"],
|
181 |
):
|
182 |
+
with st.spinner(f"Analyzing {name}..."):
|
183 |
+
col.subheader(name)
|
184 |
+
# classification
|
185 |
+
col.markdown(
|
186 |
+
"##### Dimension-Wise Classification Distribution")
|
187 |
+
_, dist_list, _ = model.steer_analysis(
|
188 |
+
analyzed_text,
|
189 |
+
dim, -steer_range, steer_range,
|
190 |
+
bins=2*int(steer_range)+1,
|
191 |
+
)
|
192 |
+
dist_list = np.array(dist_list)
|
193 |
+
col.bar_chart(
|
194 |
+
pd.DataFrame(
|
195 |
+
{
|
196 |
+
"Value": dist_list[:, 0],
|
197 |
+
"Probability": dist_list[:, 1],
|
198 |
+
}
|
199 |
+
), x="Value", y="Probability",
|
200 |
+
color=color,
|
201 |
+
)
|
202 |
+
|
203 |
+
# key tokens
|
204 |
+
pos_steer, neg_steer = np.zeros((2, 4))
|
205 |
+
pos_steer[dim] = 1
|
206 |
+
neg_steer[dim] = -1
|
207 |
+
_, token_evidence = model.evidence_words(
|
208 |
+
analyzed_text,
|
209 |
+
[pos_steer, neg_steer],
|
210 |
+
)
|
211 |
+
tokens = tokenizer(analyzed_text).input_ids
|
212 |
+
tokens = [f"{i:3d}: {tokenizer.decode([t])}"
|
213 |
+
for i, t in enumerate(tokens)]
|
214 |
+
col.markdown("##### Token's Evidence Score in the Dimension")
|
215 |
+
col.write("The polarity of the token's evidence score "
|
216 |
+
"which aligns with sliding bar directions."
|
217 |
+
)
|
218 |
+
col.bar_chart(
|
219 |
+
pd.DataFrame(
|
220 |
+
{
|
221 |
+
"Token": tokens[1:],
|
222 |
+
"Evidence": token_evidence,
|
223 |
+
}
|
224 |
+
), x="Token", y="Evidence",
|
225 |
+
horizontal=True, color=color,
|
226 |
+
)
|
227 |
|
228 |
st.divider()
|
229 |
st.divider()
|
|
|
242 |
["Sentiment", "Detoxification"],
|
243 |
)
|
244 |
dim = 2 if dimension == "Sentiment" else 0
|
245 |
+
with st.spinner("Analyzing..."):
|
246 |
+
word_embedding_space_analysis(model, tokenizer, dim)
|
247 |
|
248 |
|
249 |
if __name__ == "__main__":
|