{ "cells": [ { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import tensorflow as tf\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "\n", "model_path = \"kia_apartment_keras_model.keras\"\n", "model = tf.keras.models.load_model(model_path)\n", "\n", "df_bfs_data = pd.read_csv('bfs_municipality_and_tax_data.csv', sep=',', encoding='utf-8')\n", "df_bfs_data['tax_income'] = df_bfs_data['tax_income'].str.replace(\"'\", \"\").astype(float)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "locations = {\n", " \"Zürich\": 261,\n", " \"Kloten\": 62,\n", " \"Uster\": 198,\n", " \"Illnau-Effretikon\": 296,\n", " \"Feuerthalen\": 27,\n", " \"Pfäffikon\": 177,\n", " \"Ottenbach\": 11,\n", " \"Dübendorf\": 191,\n", " \"Richterswil\": 138,\n", " \"Maur\": 195,\n", " \"Embrach\": 56,\n", " \"Bülach\": 53,\n", " \"Winterthur\": 230,\n", " \"Oetwil am See\": 157,\n", " \"Russikon\": 178,\n", " \"Obfelden\": 10,\n", " \"Wald (ZH)\": 120,\n", " \"Niederweningen\": 91,\n", " \"Dällikon\": 84,\n", " \"Buchs (ZH)\": 83,\n", " \"Rüti (ZH)\": 118,\n", " \"Hittnau\": 173,\n", " \"Bassersdorf\": 52,\n", " \"Glattfelden\": 58,\n", " \"Opfikon\": 66,\n", " \"Hinwil\": 117,\n", " \"Regensberg\": 95,\n", " \"Langnau am Albis\": 136,\n", " \"Dietikon\": 243,\n", " \"Erlenbach (ZH)\": 151,\n", " \"Kappel am Albis\": 6,\n", " \"Stäfa\": 158,\n", " \"Zell (ZH)\": 231,\n", " \"Turbenthal\": 228,\n", " \"Oberglatt\": 92,\n", " \"Winkel\": 72,\n", " \"Volketswil\": 199,\n", " \"Kilchberg (ZH)\": 135,\n", " \"Wetzikon (ZH)\": 121,\n", " \"Zumikon\": 160,\n", " \"Weisslingen\": 180,\n", " \"Elsau\": 219,\n", " \"Hettlingen\": 221,\n", " \"Rüschlikon\": 139,\n", " \"Stallikon\": 13,\n", " \"Dielsdorf\": 86,\n", " \"Wallisellen\": 69,\n", " \"Dietlikon\": 54,\n", " \"Meilen\": 156,\n", " \"Wangen-Brüttisellen\": 200,\n", " \"Flaach\": 28,\n", " \"Regensdorf\": 96,\n", " \"Niederhasli\": 90,\n", " \"Bauma\": 297,\n", " \"Aesch (ZH)\": 241,\n", " \"Schlieren\": 247,\n", " \"Dürnten\": 113,\n", " \"Unterengstringen\": 249,\n", " \"Gossau (ZH)\": 115,\n", " \"Oberengstringen\": 245,\n", " \"Schleinikon\": 98,\n", " \"Aeugst am Albis\": 1,\n", " \"Rheinau\": 38,\n", " \"Höri\": 60,\n", " \"Rickenbach (ZH)\": 225,\n", " \"Rafz\": 67,\n", " \"Adliswil\": 131,\n", " \"Zollikon\": 161,\n", " \"Urdorf\": 250,\n", " \"Hombrechtikon\": 153,\n", " \"Birmensdorf (ZH)\": 242,\n", " \"Fehraltorf\": 172,\n", " \"Weiach\": 102,\n", " \"Männedorf\": 155,\n", " \"Küsnacht (ZH)\": 154,\n", " \"Hausen am Albis\": 4,\n", " \"Hochfelden\": 59,\n", " \"Fällanden\": 193,\n", " \"Greifensee\": 194,\n", " \"Mönchaltorf\": 196,\n", " \"Dägerlen\": 214,\n", " \"Thalheim an der Thur\": 39,\n", " \"Uetikon am See\": 159,\n", " \"Seuzach\": 227,\n", " \"Uitikon\": 248,\n", " \"Affoltern am Albis\": 2,\n", " \"Geroldswil\": 244,\n", " \"Niederglatt\": 89,\n", " \"Thalwil\": 141,\n", " \"Rorbas\": 68,\n", " \"Pfungen\": 224,\n", " \"Weiningen (ZH)\": 251,\n", " \"Bubikon\": 112,\n", " \"Neftenbach\": 223,\n", " \"Mettmenstetten\": 9,\n", " \"Otelfingen\": 94,\n", " \"Flurlingen\": 29,\n", " \"Stadel\": 100,\n", " \"Grüningen\": 116,\n", " \"Henggart\": 31,\n", " \"Dachsen\": 25,\n", " \"Bonstetten\": 3,\n", " \"Bachenbülach\": 51,\n", " \"Horgen\": 295\n", "}" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Define the core prediction function\n", "def predict_apartment(rooms, area, town):\n", " bfs_number = locations[town]\n", " df = df_bfs_data[df_bfs_data['bfs_number']==bfs_number]\n", " \n", " if len(df) != 1: # if there are more than two records with the same bfs_number reutrn -1\n", " return -1\n", " \n", " input_data = np.array([rooms, area, df['pop'].iloc[0], df['pop_dens'].iloc[0], df['frg_pct'].iloc[0], df['emp'].iloc[0], df['tax_income'].iloc[0]])\n", " input_data = input_data.reshape(1, 7)\n", " prediction = model.predict(input_data)\n", " return float(np.round(prediction[0][0], 0))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7863\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "