Spaces:
Sleeping
Sleeping
Update analysis.py
Browse files- analyse.py +102 -75
analyse.py
CHANGED
@@ -43,7 +43,7 @@ def load_msgs(msg_lens: list[int], file: str | None = None):
|
|
43 |
|
44 |
while True:
|
45 |
readable_msg = next(iterator)["text"]
|
46 |
-
try
|
47 |
readable_msg[:length].encode("ascii")
|
48 |
break
|
49 |
except Exception as e:
|
@@ -280,15 +280,20 @@ def process_results(results, save_dir):
|
|
280 |
base = r["base"]
|
281 |
delta = r["delta"]
|
282 |
msg_rate = r["msg_rate"]
|
|
|
283 |
perplexity = r["perplexity"]
|
284 |
|
285 |
-
if (base, delta) not in data["msg_rates"][msg_type]:
|
286 |
-
data["msg_rates"][msg_type][(base, delta)] = []
|
287 |
-
data["msg_rates"][msg_type][(base, delta)].append(
|
|
|
|
|
288 |
|
289 |
-
if (base, delta) not in data["perplexities"][msg_type]:
|
290 |
-
data["perplexities"][msg_type][(base, delta)] = []
|
291 |
-
data["perplexities"][msg_type][(base, delta)].append(
|
|
|
|
|
292 |
|
293 |
bases = {
|
294 |
"perplexities": {
|
@@ -310,6 +315,16 @@ def process_results(results, save_dir):
|
|
310 |
"readable": [],
|
311 |
},
|
312 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
values = {
|
314 |
"perplexities": {
|
315 |
"random": [],
|
@@ -322,6 +337,7 @@ def process_results(results, save_dir):
|
|
322 |
}
|
323 |
base_set = set()
|
324 |
delta_set = set()
|
|
|
325 |
for metric in data:
|
326 |
for msg_type in data[metric]:
|
327 |
for k in data[metric][msg_type]:
|
@@ -331,9 +347,12 @@ def process_results(results, save_dir):
|
|
331 |
|
332 |
bases[metric][msg_type].append(k[0])
|
333 |
deltas[metric][msg_type].append(k[1])
|
|
|
334 |
values[metric][msg_type].append(s / cnt)
|
335 |
base_set.add(k[0])
|
336 |
delta_set.add(k[1])
|
|
|
|
|
337 |
for metric in data:
|
338 |
for msg_type in data[metric]:
|
339 |
bases[metric][msg_type] = np.array(
|
@@ -350,7 +369,7 @@ def process_results(results, save_dir):
|
|
350 |
for metric in data:
|
351 |
for msg_type in data[metric]:
|
352 |
fig = plt.figure(dpi=300)
|
353 |
-
s = lambda x: 3.0 + x * (
|
354 |
plt.scatter(
|
355 |
bases[metric][msg_type],
|
356 |
deltas[metric][msg_type],
|
@@ -365,84 +384,92 @@ def process_results(results, save_dir):
|
|
365 |
os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
|
366 |
for metric in data:
|
367 |
for msg_type in data[metric]:
|
|
|
368 |
for base_value in base_set:
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
plt.
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
384 |
|
385 |
-
os.makedirs(os.path.join(save_dir, "
|
386 |
for metric in data:
|
387 |
for msg_type in data[metric]:
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
|
406 |
def main(args):
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
412 |
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
with open(args.results_load_file, "r") as f:
|
443 |
results = json.load(f)
|
444 |
-
else:
|
445 |
-
results = get_results(args, prompts, msgs)
|
446 |
|
447 |
if args.results_save_file:
|
448 |
os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True)
|
|
|
43 |
|
44 |
while True:
|
45 |
readable_msg = next(iterator)["text"]
|
46 |
+
try:
|
47 |
readable_msg[:length].encode("ascii")
|
48 |
break
|
49 |
except Exception as e:
|
|
|
280 |
base = r["base"]
|
281 |
delta = r["delta"]
|
282 |
msg_rate = r["msg_rate"]
|
283 |
+
msg_len = r["msg_len"]
|
284 |
perplexity = r["perplexity"]
|
285 |
|
286 |
+
if (base, delta, msg_len) not in data["msg_rates"][msg_type]:
|
287 |
+
data["msg_rates"][msg_type][(base, delta, msg_len)] = []
|
288 |
+
data["msg_rates"][msg_type][(base, delta, msg_len)].append(
|
289 |
+
msg_rate
|
290 |
+
)
|
291 |
|
292 |
+
if (base, delta, msg_len) not in data["perplexities"][msg_type]:
|
293 |
+
data["perplexities"][msg_type][(base, delta, msg_len)] = []
|
294 |
+
data["perplexities"][msg_type][(base, delta, msg_len)].append(
|
295 |
+
perplexity
|
296 |
+
)
|
297 |
|
298 |
bases = {
|
299 |
"perplexities": {
|
|
|
315 |
"readable": [],
|
316 |
},
|
317 |
}
|
318 |
+
msgs_lens = {
|
319 |
+
"perplexities": {
|
320 |
+
"random": [],
|
321 |
+
"readable": [],
|
322 |
+
},
|
323 |
+
"msg_rates": {
|
324 |
+
"random": [],
|
325 |
+
"readable": [],
|
326 |
+
},
|
327 |
+
}
|
328 |
values = {
|
329 |
"perplexities": {
|
330 |
"random": [],
|
|
|
337 |
}
|
338 |
base_set = set()
|
339 |
delta_set = set()
|
340 |
+
msgs_lens_set = set()
|
341 |
for metric in data:
|
342 |
for msg_type in data[metric]:
|
343 |
for k in data[metric][msg_type]:
|
|
|
347 |
|
348 |
bases[metric][msg_type].append(k[0])
|
349 |
deltas[metric][msg_type].append(k[1])
|
350 |
+
msgs_lens[metric][msg_type].append(k[2])
|
351 |
values[metric][msg_type].append(s / cnt)
|
352 |
base_set.add(k[0])
|
353 |
delta_set.add(k[1])
|
354 |
+
msgs_lens_set.add(k[2])
|
355 |
+
|
356 |
for metric in data:
|
357 |
for msg_type in data[metric]:
|
358 |
bases[metric][msg_type] = np.array(
|
|
|
369 |
for metric in data:
|
370 |
for msg_type in data[metric]:
|
371 |
fig = plt.figure(dpi=300)
|
372 |
+
s = lambda x: 3.0 + x * (30 if metric == "msg_rates" else 10)
|
373 |
plt.scatter(
|
374 |
bases[metric][msg_type],
|
375 |
deltas[metric][msg_type],
|
|
|
384 |
os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
|
385 |
for metric in data:
|
386 |
for msg_type in data[metric]:
|
387 |
+
fig = plt.figure(dpi=300)
|
388 |
for base_value in base_set:
|
389 |
+
deltas_avg = np.array(list(delta_set))
|
390 |
+
values_avg = np.zeros_like(deltas_avg)
|
391 |
+
for i in range(len(deltas_avg)):
|
392 |
+
mask = (deltas[metric][msg_type] == deltas_avg[i]) & (
|
393 |
+
bases[metric][msg_type] == base_value
|
394 |
+
)
|
395 |
+
values_avg[i] = np.mean(values[metric][msg_type][mask])
|
396 |
+
plt.plot(deltas_avg, values_avg, label=f"Base {base_value}")
|
397 |
+
|
398 |
+
plt.legend()
|
399 |
+
plt.savefig(
|
400 |
+
os.path.join(
|
401 |
+
save_dir,
|
402 |
+
f"delta_effect/{metric}_{msg_type}.pdf",
|
403 |
+
),
|
404 |
+
bbox_inches="tight",
|
405 |
+
)
|
406 |
+
plt.close(fig)
|
407 |
|
408 |
+
os.makedirs(os.path.join(save_dir, "msg_len_effect"), exist_ok=True)
|
409 |
for metric in data:
|
410 |
for msg_type in data[metric]:
|
411 |
+
fig = plt.figure(dpi=300)
|
412 |
+
for base_value in base_set:
|
413 |
+
msgs_lens_avg = np.array(list(msgs_lens_set))
|
414 |
+
values_avg = np.zeros_like(msgs_lens_avg)
|
415 |
+
for i in range(len(msgs_lens_avg)):
|
416 |
+
mask = (msgs_lens[metric][msg_type] == msgs_lens_avg[i]) & (
|
417 |
+
bases[metric][msg_type] == base_value
|
418 |
+
)
|
419 |
+
values_avg[i] = np.mean(values[metric][msg_type][mask])
|
420 |
+
|
421 |
+
plt.plot(msgs_lens_avg, values_avg, label=f"Base {base_value}")
|
422 |
+
|
423 |
+
plt.legend()
|
424 |
+
plt.savefig(
|
425 |
+
os.path.join(
|
426 |
+
save_dir,
|
427 |
+
f"msg_len_effect/{metric}_{msg_type}.pdf",
|
428 |
+
),
|
429 |
+
bbox_inches="tight",
|
430 |
+
)
|
431 |
+
plt.close(fig)
|
432 |
|
433 |
|
434 |
def main(args):
|
435 |
+
if not args.results_load_file:
|
436 |
+
prompts = load_prompts(
|
437 |
+
args.num_prompts,
|
438 |
+
args.prompt_size,
|
439 |
+
args.prompts_file if not args.overwrite else None,
|
440 |
+
)
|
441 |
|
442 |
+
msgs_lens = []
|
443 |
+
for i in np.linspace(
|
444 |
+
args.msgs_lengths[0],
|
445 |
+
args.msgs_lengths[1],
|
446 |
+
int(args.msgs_lengths[2]),
|
447 |
+
dtype=np.int32,
|
448 |
+
):
|
449 |
+
for _ in range(args.msgs_per_length):
|
450 |
+
msgs_lens.append(i)
|
451 |
+
|
452 |
+
msgs = load_msgs(
|
453 |
+
msgs_lens,
|
454 |
+
args.msgs_file if not args.overwrite else None,
|
455 |
+
)
|
456 |
|
457 |
+
if args.msgs_file:
|
458 |
+
if not os.path.isfile(args.msgs_file) or args.overwrite:
|
459 |
+
os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True)
|
460 |
+
with open(args.msgs_file, "w") as f:
|
461 |
+
json.dump(msgs, f)
|
462 |
+
print(f"Saved messages to {args.msgs_file}")
|
463 |
+
if args.prompts_file:
|
464 |
+
if not os.path.isfile(args.prompts_file) or args.overwrite:
|
465 |
+
os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True)
|
466 |
+
with open(args.prompts_file, "w") as f:
|
467 |
+
json.dump(prompts, f)
|
468 |
+
print(f"Saved prompts to {args.prompts_file}")
|
469 |
+
results = get_results(args, prompts, msgs)
|
470 |
+
else:
|
471 |
with open(args.results_load_file, "r") as f:
|
472 |
results = json.load(f)
|
|
|
|
|
473 |
|
474 |
if args.results_save_file:
|
475 |
os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True)
|