tnk2908 commited on
Commit
e60c05d
1 Parent(s): 946ba2b

Update analysis.py

Browse files
Files changed (1) hide show
  1. analyse.py +102 -75
analyse.py CHANGED
@@ -43,7 +43,7 @@ def load_msgs(msg_lens: list[int], file: str | None = None):
43
 
44
  while True:
45
  readable_msg = next(iterator)["text"]
46
- try :
47
  readable_msg[:length].encode("ascii")
48
  break
49
  except Exception as e:
@@ -280,15 +280,20 @@ def process_results(results, save_dir):
280
  base = r["base"]
281
  delta = r["delta"]
282
  msg_rate = r["msg_rate"]
 
283
  perplexity = r["perplexity"]
284
 
285
- if (base, delta) not in data["msg_rates"][msg_type]:
286
- data["msg_rates"][msg_type][(base, delta)] = []
287
- data["msg_rates"][msg_type][(base, delta)].append(msg_rate)
 
 
288
 
289
- if (base, delta) not in data["perplexities"][msg_type]:
290
- data["perplexities"][msg_type][(base, delta)] = []
291
- data["perplexities"][msg_type][(base, delta)].append(perplexity)
 
 
292
 
293
  bases = {
294
  "perplexities": {
@@ -310,6 +315,16 @@ def process_results(results, save_dir):
310
  "readable": [],
311
  },
312
  }
 
 
 
 
 
 
 
 
 
 
313
  values = {
314
  "perplexities": {
315
  "random": [],
@@ -322,6 +337,7 @@ def process_results(results, save_dir):
322
  }
323
  base_set = set()
324
  delta_set = set()
 
325
  for metric in data:
326
  for msg_type in data[metric]:
327
  for k in data[metric][msg_type]:
@@ -331,9 +347,12 @@ def process_results(results, save_dir):
331
 
332
  bases[metric][msg_type].append(k[0])
333
  deltas[metric][msg_type].append(k[1])
 
334
  values[metric][msg_type].append(s / cnt)
335
  base_set.add(k[0])
336
  delta_set.add(k[1])
 
 
337
  for metric in data:
338
  for msg_type in data[metric]:
339
  bases[metric][msg_type] = np.array(
@@ -350,7 +369,7 @@ def process_results(results, save_dir):
350
  for metric in data:
351
  for msg_type in data[metric]:
352
  fig = plt.figure(dpi=300)
353
- s = lambda x: 3.0 + x * (3 if metric == "msg_rates" else 0.1)
354
  plt.scatter(
355
  bases[metric][msg_type],
356
  deltas[metric][msg_type],
@@ -365,84 +384,92 @@ def process_results(results, save_dir):
365
  os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
366
  for metric in data:
367
  for msg_type in data[metric]:
 
368
  for base_value in base_set:
369
- mask = bases[metric][msg_type] == base_value
370
- fig = plt.figure(dpi=300)
371
- s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0)
372
- plt.plot(
373
- deltas[metric][msg_type][mask],
374
- values[metric][msg_type][mask],
375
- )
376
- plt.savefig(
377
- os.path.join(
378
- save_dir,
379
- f"delta_effect/{metric}_{msg_type}_base{base_value}.pdf",
380
- ),
381
- bbox_inches="tight",
382
- )
383
- plt.close(fig)
 
 
 
384
 
385
- os.makedirs(os.path.join(save_dir, "base_effect"), exist_ok=True)
386
  for metric in data:
387
  for msg_type in data[metric]:
388
- for delta_value in delta_set:
389
- mask = deltas[metric][msg_type] == delta_value
390
- fig = plt.figure(dpi=300)
391
- s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0)
392
- plt.plot(
393
- bases[metric][msg_type][mask],
394
- values[metric][msg_type][mask],
395
- )
396
- plt.savefig(
397
- os.path.join(
398
- save_dir,
399
- f"base_effect/{metric}_{msg_type}_delta{delta_value}.pdf",
400
- ),
401
- bbox_inches="tight",
402
- )
403
- plt.close(fig)
 
 
 
 
 
404
 
405
 
406
  def main(args):
407
- prompts = load_prompts(
408
- args.num_prompts,
409
- args.prompt_size,
410
- args.prompts_file if not args.overwrite else None,
411
- )
 
412
 
413
- msgs_lens = []
414
- for i in np.linspace(
415
- args.msgs_lengths[0],
416
- args.msgs_lengths[1],
417
- int(args.msgs_lengths[2]),
418
- dtype=np.int32,
419
- ):
420
- for _ in range(args.msgs_per_length):
421
- msgs_lens.append(i)
422
-
423
- msgs = load_msgs(
424
- msgs_lens,
425
- args.msgs_file if not args.overwrite else None,
426
- )
427
 
428
- if args.msgs_file:
429
- if not os.path.isfile(args.msgs_file) or args.overwrite:
430
- os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True)
431
- with open(args.msgs_file, "w") as f:
432
- json.dump(msgs, f)
433
- print(f"Saved messages to {args.msgs_file}")
434
- if args.prompts_file:
435
- if not os.path.isfile(args.prompts_file) or args.overwrite:
436
- os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True)
437
- with open(args.prompts_file, "w") as f:
438
- json.dump(prompts, f)
439
- print(f"Saved prompts to {args.prompts_file}")
440
-
441
- if args.results_load_file:
442
  with open(args.results_load_file, "r") as f:
443
  results = json.load(f)
444
- else:
445
- results = get_results(args, prompts, msgs)
446
 
447
  if args.results_save_file:
448
  os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True)
 
43
 
44
  while True:
45
  readable_msg = next(iterator)["text"]
46
+ try:
47
  readable_msg[:length].encode("ascii")
48
  break
49
  except Exception as e:
 
280
  base = r["base"]
281
  delta = r["delta"]
282
  msg_rate = r["msg_rate"]
283
+ msg_len = r["msg_len"]
284
  perplexity = r["perplexity"]
285
 
286
+ if (base, delta, msg_len) not in data["msg_rates"][msg_type]:
287
+ data["msg_rates"][msg_type][(base, delta, msg_len)] = []
288
+ data["msg_rates"][msg_type][(base, delta, msg_len)].append(
289
+ msg_rate
290
+ )
291
 
292
+ if (base, delta, msg_len) not in data["perplexities"][msg_type]:
293
+ data["perplexities"][msg_type][(base, delta, msg_len)] = []
294
+ data["perplexities"][msg_type][(base, delta, msg_len)].append(
295
+ perplexity
296
+ )
297
 
298
  bases = {
299
  "perplexities": {
 
315
  "readable": [],
316
  },
317
  }
318
+ msgs_lens = {
319
+ "perplexities": {
320
+ "random": [],
321
+ "readable": [],
322
+ },
323
+ "msg_rates": {
324
+ "random": [],
325
+ "readable": [],
326
+ },
327
+ }
328
  values = {
329
  "perplexities": {
330
  "random": [],
 
337
  }
338
  base_set = set()
339
  delta_set = set()
340
+ msgs_lens_set = set()
341
  for metric in data:
342
  for msg_type in data[metric]:
343
  for k in data[metric][msg_type]:
 
347
 
348
  bases[metric][msg_type].append(k[0])
349
  deltas[metric][msg_type].append(k[1])
350
+ msgs_lens[metric][msg_type].append(k[2])
351
  values[metric][msg_type].append(s / cnt)
352
  base_set.add(k[0])
353
  delta_set.add(k[1])
354
+ msgs_lens_set.add(k[2])
355
+
356
  for metric in data:
357
  for msg_type in data[metric]:
358
  bases[metric][msg_type] = np.array(
 
369
  for metric in data:
370
  for msg_type in data[metric]:
371
  fig = plt.figure(dpi=300)
372
+ s = lambda x: 3.0 + x * (30 if metric == "msg_rates" else 10)
373
  plt.scatter(
374
  bases[metric][msg_type],
375
  deltas[metric][msg_type],
 
384
  os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
385
  for metric in data:
386
  for msg_type in data[metric]:
387
+ fig = plt.figure(dpi=300)
388
  for base_value in base_set:
389
+ deltas_avg = np.array(list(delta_set))
390
+ values_avg = np.zeros_like(deltas_avg)
391
+ for i in range(len(deltas_avg)):
392
+ mask = (deltas[metric][msg_type] == deltas_avg[i]) & (
393
+ bases[metric][msg_type] == base_value
394
+ )
395
+ values_avg[i] = np.mean(values[metric][msg_type][mask])
396
+ plt.plot(deltas_avg, values_avg, label=f"Base {base_value}")
397
+
398
+ plt.legend()
399
+ plt.savefig(
400
+ os.path.join(
401
+ save_dir,
402
+ f"delta_effect/{metric}_{msg_type}.pdf",
403
+ ),
404
+ bbox_inches="tight",
405
+ )
406
+ plt.close(fig)
407
 
408
+ os.makedirs(os.path.join(save_dir, "msg_len_effect"), exist_ok=True)
409
  for metric in data:
410
  for msg_type in data[metric]:
411
+ fig = plt.figure(dpi=300)
412
+ for base_value in base_set:
413
+ msgs_lens_avg = np.array(list(msgs_lens_set))
414
+ values_avg = np.zeros_like(msgs_lens_avg)
415
+ for i in range(len(msgs_lens_avg)):
416
+ mask = (msgs_lens[metric][msg_type] == msgs_lens_avg[i]) & (
417
+ bases[metric][msg_type] == base_value
418
+ )
419
+ values_avg[i] = np.mean(values[metric][msg_type][mask])
420
+
421
+ plt.plot(msgs_lens_avg, values_avg, label=f"Base {base_value}")
422
+
423
+ plt.legend()
424
+ plt.savefig(
425
+ os.path.join(
426
+ save_dir,
427
+ f"msg_len_effect/{metric}_{msg_type}.pdf",
428
+ ),
429
+ bbox_inches="tight",
430
+ )
431
+ plt.close(fig)
432
 
433
 
434
  def main(args):
435
+ if not args.results_load_file:
436
+ prompts = load_prompts(
437
+ args.num_prompts,
438
+ args.prompt_size,
439
+ args.prompts_file if not args.overwrite else None,
440
+ )
441
 
442
+ msgs_lens = []
443
+ for i in np.linspace(
444
+ args.msgs_lengths[0],
445
+ args.msgs_lengths[1],
446
+ int(args.msgs_lengths[2]),
447
+ dtype=np.int32,
448
+ ):
449
+ for _ in range(args.msgs_per_length):
450
+ msgs_lens.append(i)
451
+
452
+ msgs = load_msgs(
453
+ msgs_lens,
454
+ args.msgs_file if not args.overwrite else None,
455
+ )
456
 
457
+ if args.msgs_file:
458
+ if not os.path.isfile(args.msgs_file) or args.overwrite:
459
+ os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True)
460
+ with open(args.msgs_file, "w") as f:
461
+ json.dump(msgs, f)
462
+ print(f"Saved messages to {args.msgs_file}")
463
+ if args.prompts_file:
464
+ if not os.path.isfile(args.prompts_file) or args.overwrite:
465
+ os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True)
466
+ with open(args.prompts_file, "w") as f:
467
+ json.dump(prompts, f)
468
+ print(f"Saved prompts to {args.prompts_file}")
469
+ results = get_results(args, prompts, msgs)
470
+ else:
471
  with open(args.results_load_file, "r") as f:
472
  results = json.load(f)
 
 
473
 
474
  if args.results_save_file:
475
  os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True)