from data_provider.context_gen import * def parse_args(): parser = argparse.ArgumentParser(description="A simple argument parser") # Script arguments parser.add_argument('--name', default='none', type=str) parser.add_argument('--seed', default=0, type=int) parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--chunk_size', default=100, type=int) parser.add_argument('--rxn_num', default=50000, type=int) parser.add_argument('--k', default=4, type=int) parser.add_argument('--root', default='data/pretrain_data', type=str) args = parser.parse_args() return args def pad_shorter_array(arr1, arr2): len1 = arr1.shape[0] len2 = arr2.shape[0] if len1 > len2: arr2 = np.pad(arr2, (0, len1 - len2), 'constant') elif len2 > len1: arr1 = np.pad(arr1, (0, len2 - len1), 'constant') return arr1, arr2 def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'): num_full_chunks = len(values) // chunk_size values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1) values = np.sort(values)[::-1] plt.figure(figsize=(10, 4), dpi=100) x = np.arange(len(values)) plt.bar(x, values, color=color) current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int) plt.xticks((current_values/chunk_size).astype(int), current_values) plt.ylabel('Molecule Frequency', fontsize=20) if x_lim: plt.xlim(*x_lim) if y_lim: plt.ylim(*y_lim) plt.tick_params(axis='both', which='major', labelsize=12) plt.tight_layout(pad=0.5) plt.savefig(target_path) print(f'Figure saved to {target_path}') plt.clf() def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100): num_full_chunks = len(list1) // chunk_size list1, list2 = pad_shorter_array(list1, list2) values1, values2 = [ np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1] for values in (list1, list2)] plt.figure(figsize=(10, 6), dpi=100) x = np.arange(len(values1)) plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6) plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5) current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int) plt.xticks((current_values/chunk_size).astype(int), current_values) plt.ylabel('Molecule Frequency', fontsize=20) if x_lim: plt.xlim(*x_lim) if y_lim: plt.ylim(*y_lim) plt.tick_params(axis='both', which='major', labelsize=18) plt.tight_layout(pad=0.5) plt.legend(fontsize=24, loc='upper right') plt.savefig(target_path) print(f'Figure saved to {target_path}') plt.clf() def statistics(args): if args.seed: set_random_seed(args.seed) # 1141864 rxns from ord # 1120773 rxns from uspto cluster = Reaction_Cluster(args.root) rxn_num = len(cluster.reaction_data) abstract_num = 0 property_num = 0 calculated_property_num = 0 experimental_property_num = 0 avg_calculated_property_len = 0 avg_experimental_property_len = 0 mol_set = set() for rxn_dict in cluster.reaction_data: for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']: for mol in rxn_dict[key]: mol_set.add(mol) mol_num = len(mol_set) for mol_dict in cluster.property_data: if 'abstract' in mol_dict: abstract_num += 1 if 'property' in mol_dict: property_num += 1 if 'Experimental Properties' in mol_dict['property']: experimental_property_num += 1 avg_experimental_property_len += len(mol_dict['property']['Experimental Properties']) if 'Computed Properties' in mol_dict['property']: calculated_property_num += 1 avg_calculated_property_len += len(mol_dict['property']['Computed Properties']) print(f'Reaction Number: {rxn_num}') print(f'Molecule Number: {mol_num}') print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)') print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)') print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule') print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule') def visualize(args): if args.seed: set_random_seed(args.seed) cluster = Reaction_Cluster(args.root) prob_values, rxn_weights = cluster.visualize_mol_distribution() rand_prob_values, rand_rxn_weights = cluster._randomly( cluster.visualize_mol_distribution ) fig_root = f'results/{args.name}/' plot_distribution(prob_values, fig_root+'mol_distribution.pdf') plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf') plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf') plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf') plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5)) plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf') def visualize_frequency(args): if args.seed: set_random_seed(args.seed) fig_root = f'results/{args.name}/' name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}' cache_path = f'{fig_root}/freq_{name_suffix}.npy' if os.path.exists(cache_path): mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True) else: cluster = Reaction_Cluster(args.root) mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs) rand_mol_freq, rand_rxn_freq = cluster._randomly( cluster.visualize_mol_frequency, rxn_num=args.rxn_num, k=args.k, epochs=args.epochs ) np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True) color1 = '#FA7F6F' color2 = '#80AFBF' color3 = '#FFBE7A' plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2) # plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1) plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2) # plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1) plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size) # plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size) if __name__=='__main__': args = parse_args() print(args, flush=True) # statistics(args) # visualize(args) visualize_frequency(args)