{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PSET 1: Bottom-Up Synthesis\n", "\n", "I follow Algorithm 1 in the BUSTLE paper:\n", "\n", "> Odena, A. *et al.* BUSTLE: Bottom-Up Program Synthesis Through Learning-Guided Exploration. in *9th International Conference on Learning Representations*; 2021 May 3-7; Austria.\n", "\n", "First, I import the required libraries." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import itertools\n", "\n", "# argument parser for command line arguments\n", "import argparse\n", "\n", "# import arithmetic module\n", "from arithmetic import *\n", "from abstract_syntax_tree import OperatorNode\n", "from examples import example_set, check_examples\n", "from synthesizer import extract_constants, observationally_equivalent, check_program\n", "import config" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, I define variables as proxies for command-line arguments provided to the synthesizer." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "domain = \"arithmetic\"\n", "examples_key = \"addition\"\n", "examples = example_set[examples_key]\n", "max_weight = 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I define a function to extract constants from examples." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# initialize program bank\n", "program_bank = extract_constants(examples)\n", "program_bank_str = [p.str() for p in program_bank]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, I define the bottom-up synthesis algorithm." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(x0 + x1)\n" ] } ], "source": [ "# define operators\n", "operators = arithmetic_operators\n", "\n", "# iterate over each level\n", "for weight in range(2, max_weight):\n", "\n", " for op in operators:\n", "\n", " # get all possible combinations of primitives in program bank\n", " combinations = itertools.combinations(program_bank, op.arity)\n", "\n", " # iterate over each combination\n", " for combination in combinations:\n", "\n", " # get type signature\n", " type_signature = [p.type for p in combination]\n", "\n", " # check if type signature matches operator\n", " if type_signature != op.arg_types:\n", " continue\n", "\n", " # check that sum of weights of arguments <= w\n", " if sum([p.weight for p in combination]) > weight:\n", " continue\n", "\n", " # create new program\n", " program = OperatorNode(op, combination)\n", "\n", " # check if program is in program bank using string representation\n", " if program.str() in program_bank_str:\n", " continue\n", " \n", " # check if program is observationally equivalent to any program in program bank\n", " if any([observationally_equivalent(program, p, examples) for p in program_bank]):\n", " continue\n", "\n", " # add program to program bank\n", " program_bank.append(program)\n", " program_bank_str.append(program.str())\n", "\n", " # check if program passes all examples\n", " if check_program(program, examples):\n", " # return(program)\n", " print(program.str())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }