Spaces:
Runtime error
Runtime error
from data_provider.context_gen import * | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="A simple argument parser") | |
# Script arguments | |
parser.add_argument('--name', default='none', type=str) | |
parser.add_argument('--seed', default=0, type=int) | |
parser.add_argument('--epochs', default=100, type=int) | |
parser.add_argument('--chunk_size', default=100, type=int) | |
parser.add_argument('--rxn_num', default=50000, type=int) | |
parser.add_argument('--k', default=4, type=int) | |
parser.add_argument('--root', default='data/pretrain_data', type=str) | |
args = parser.parse_args() | |
return args | |
def pad_shorter_array(arr1, arr2): | |
len1 = arr1.shape[0] | |
len2 = arr2.shape[0] | |
if len1 > len2: | |
arr2 = np.pad(arr2, (0, len1 - len2), 'constant') | |
elif len2 > len1: | |
arr1 = np.pad(arr1, (0, len2 - len1), 'constant') | |
return arr1, arr2 | |
def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'): | |
num_full_chunks = len(values) // chunk_size | |
values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1) | |
values = np.sort(values)[::-1] | |
plt.figure(figsize=(10, 4), dpi=100) | |
x = np.arange(len(values)) | |
plt.bar(x, values, color=color) | |
current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int) | |
plt.xticks((current_values/chunk_size).astype(int), current_values) | |
plt.ylabel('Molecule Frequency', fontsize=20) | |
if x_lim: | |
plt.xlim(*x_lim) | |
if y_lim: | |
plt.ylim(*y_lim) | |
plt.tick_params(axis='both', which='major', labelsize=12) | |
plt.tight_layout(pad=0.5) | |
plt.savefig(target_path) | |
print(f'Figure saved to {target_path}') | |
plt.clf() | |
def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100): | |
num_full_chunks = len(list1) // chunk_size | |
list1, list2 = pad_shorter_array(list1, list2) | |
values1, values2 = [ | |
np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1] | |
for values in (list1, list2)] | |
plt.figure(figsize=(10, 6), dpi=100) | |
x = np.arange(len(values1)) | |
plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6) | |
plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5) | |
current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int) | |
plt.xticks((current_values/chunk_size).astype(int), current_values) | |
plt.ylabel('Molecule Frequency', fontsize=20) | |
if x_lim: | |
plt.xlim(*x_lim) | |
if y_lim: | |
plt.ylim(*y_lim) | |
plt.tick_params(axis='both', which='major', labelsize=18) | |
plt.tight_layout(pad=0.5) | |
plt.legend(fontsize=24, loc='upper right') | |
plt.savefig(target_path) | |
print(f'Figure saved to {target_path}') | |
plt.clf() | |
def statistics(args): | |
if args.seed: | |
set_random_seed(args.seed) | |
# 1141864 rxns from ord | |
# 1120773 rxns from uspto | |
cluster = Reaction_Cluster(args.root) | |
rxn_num = len(cluster.reaction_data) | |
abstract_num = 0 | |
property_num = 0 | |
calculated_property_num = 0 | |
experimental_property_num = 0 | |
avg_calculated_property_len = 0 | |
avg_experimental_property_len = 0 | |
mol_set = set() | |
for rxn_dict in cluster.reaction_data: | |
for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']: | |
for mol in rxn_dict[key]: | |
mol_set.add(mol) | |
mol_num = len(mol_set) | |
for mol_dict in cluster.property_data: | |
if 'abstract' in mol_dict: | |
abstract_num += 1 | |
if 'property' in mol_dict: | |
property_num += 1 | |
if 'Experimental Properties' in mol_dict['property']: | |
experimental_property_num += 1 | |
avg_experimental_property_len += len(mol_dict['property']['Experimental Properties']) | |
if 'Computed Properties' in mol_dict['property']: | |
calculated_property_num += 1 | |
avg_calculated_property_len += len(mol_dict['property']['Computed Properties']) | |
print(f'Reaction Number: {rxn_num}') | |
print(f'Molecule Number: {mol_num}') | |
print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)') | |
print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)') | |
print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule') | |
print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule') | |
def visualize(args): | |
if args.seed: | |
set_random_seed(args.seed) | |
cluster = Reaction_Cluster(args.root) | |
prob_values, rxn_weights = cluster.visualize_mol_distribution() | |
rand_prob_values, rand_rxn_weights = cluster._randomly( | |
cluster.visualize_mol_distribution | |
) | |
fig_root = f'results/{args.name}/' | |
plot_distribution(prob_values, fig_root+'mol_distribution.pdf') | |
plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf') | |
plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf') | |
plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf') | |
plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5)) | |
plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf') | |
def visualize_frequency(args): | |
if args.seed: | |
set_random_seed(args.seed) | |
fig_root = f'results/{args.name}/' | |
name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}' | |
cache_path = f'{fig_root}/freq_{name_suffix}.npy' | |
if os.path.exists(cache_path): | |
mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True) | |
else: | |
cluster = Reaction_Cluster(args.root) | |
mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs) | |
rand_mol_freq, rand_rxn_freq = cluster._randomly( | |
cluster.visualize_mol_frequency, | |
rxn_num=args.rxn_num, k=args.k, epochs=args.epochs | |
) | |
np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True) | |
color1 = '#FA7F6F' | |
color2 = '#80AFBF' | |
color3 = '#FFBE7A' | |
plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2) | |
# plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1) | |
plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2) | |
# plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1) | |
plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size) | |
# plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size) | |
if __name__=='__main__': | |
args = parse_args() | |
print(args, flush=True) | |
# statistics(args) | |
# visualize(args) | |
visualize_frequency(args) |