Gabriel commited on
Commit
3810c45
1 Parent(s): a3822fa

fixed bug with utils for loading db

Browse files
Files changed (3) hide show
  1. app.py +5 -6
  2. helper/utils.py +41 -34
  3. tabs/htr_tool.py +4 -2
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
 
3
  import gradio as gr
4
 
@@ -10,8 +10,7 @@ from tabs.help_tab import help_tab
10
  from tabs.htr_tool import htr_tool_tab
11
  from tabs.stepwise_htr_tool import stepwise_htr_tool_tab
12
 
13
- handler = TrafficDataHandler()
14
-
15
  VERSION = "Demo version 0.0.2"
16
 
17
  with gr.Blocks(title="Riksarkivet", theme=theme, css=css) as demo:
@@ -36,9 +35,9 @@ with gr.Blocks(title="Riksarkivet", theme=theme, css=css) as demo:
36
  with gr.Tab("About"):
37
  about_tab.render()
38
 
39
- SECRET_KEY = os.environ.get("AM_I_IN_A_DOCKER_CONTAINER", False)
40
- if SECRET_KEY:
41
- demo.load(handler.onload_store_metric_data)
42
 
43
 
44
  demo.queue(concurrency_count=2, max_size=2)
 
1
+ import uuid
2
 
3
  import gradio as gr
4
 
 
10
  from tabs.htr_tool import htr_tool_tab
11
  from tabs.stepwise_htr_tool import stepwise_htr_tool_tab
12
 
13
+ session_uuid = str(uuid.uuid1())
 
14
  VERSION = "Demo version 0.0.2"
15
 
16
  with gr.Blocks(title="Riksarkivet", theme=theme, css=css) as demo:
 
35
  with gr.Tab("About"):
36
  about_tab.render()
37
 
38
+ # SECRET_KEY = os.environ.get("AM_I_IN_A_DOCKER_CONTAINER", False)
39
+ # if SECRET_KEY:
40
+ demo.load(fn=TrafficDataHandler.onload_store_metric_data, inputs=None, outputs=None)
41
 
42
 
43
  demo.queue(concurrency_count=2, max_size=2)
helper/utils.py CHANGED
@@ -18,44 +18,48 @@ class TrafficDataHandler:
18
  _TOKEN = os.environ.get("HUB_TOKEN")
19
  _TZ = "Europe/Stockholm"
20
  _INTERVAL_MIN_UPDATE = 30
 
 
 
 
21
 
22
- def __init__(self, dataset_repo="Riksarkivet/traffic_demo_data"):
23
- self._repo = huggingface_hub.Repository(
24
- local_dir="data", repo_type="dataset", clone_from=dataset_repo, use_auth_token=self._TOKEN
25
- )
26
- self._pull_repo_data()
27
- self._setup_database()
28
-
29
- def _pull_repo_data(self):
30
- self._repo.git_pull()
31
- shutil.copyfile(self._DB_TEMP_PATH, self._DB_FILE_PATH)
32
 
33
- def _hash_ip(self, ip_address):
 
34
  return hashlib.sha256(ip_address.encode()).hexdigest()
35
 
36
- def _current_time_in_sweden(self):
37
- swedish_tz = pytz.timezone(self._TZ)
 
38
  return datetime.now(swedish_tz).strftime("%Y-%m-%d %H:%M:%S")
39
 
40
- def onload_store_metric_data(self, request: gr.Request):
41
- self._session_uuid = str(uuid.uuid1())
42
- hashed_host = self._hash_ip(request.client.host)
43
- self._backup_and_update_database(hashed_host, "load")
 
 
44
 
45
- def store_metric_data(self, action, request: gr.Request):
46
- self._session_uuid = str(uuid.uuid1())
47
- hashed_host = self._hash_ip(request.client.host)
48
- self._backup_and_update_database(hashed_host, action)
49
 
50
- def _commit_host_to_database(self, hashed_host, action):
51
- with sqlite3.connect(self._DB_FILE_PATH) as db:
 
52
  db.execute(
53
  "INSERT INTO ip_data(current_time, hashed_ip, session_uuid, action) VALUES(?,?,?,?)",
54
- [self._current_time_in_sweden(), hashed_host, self._session_uuid, action],
55
  )
56
 
57
- def _setup_database(self):
58
- with sqlite3.connect(self._DB_FILE_PATH) as db:
 
59
  try:
60
  db.execute("SELECT * FROM ip_data").fetchall()
61
  except sqlite3.OperationalError:
@@ -68,23 +72,26 @@ class TrafficDataHandler:
68
  action TEXT)
69
  """
70
  )
 
71
 
72
- def _backup_and_update_database(self, hashed_host, action):
73
- self._commit_host_to_database(hashed_host, action)
74
- shutil.copyfile(self._DB_FILE_PATH, self._DB_TEMP_PATH)
 
75
 
76
- with sqlite3.connect(self._DB_FILE_PATH) as db:
77
  ip_data = db.execute("SELECT * FROM ip_data").fetchall()
78
  pd.DataFrame(ip_data, columns=["id", "current_time", "hashed_ip", "session_uuid", "action"]).to_csv(
79
  "./data/ip_data.csv", index=False
80
  )
81
 
82
- self._repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.now()}")
83
 
84
- def _initialize_and_schedule_backup(self, hashed_host, action):
85
- self._backup_and_update_database(hashed_host, action)
 
86
  scheduler = BackgroundScheduler()
87
  scheduler.add_job(
88
- self._backup_and_update_database, "interval", minutes=self._INTERVAL_MIN_UPDATE, args=(hashed_host, action)
89
  )
90
  scheduler.start()
 
18
  _TOKEN = os.environ.get("HUB_TOKEN")
19
  _TZ = "Europe/Stockholm"
20
  _INTERVAL_MIN_UPDATE = 30
21
+ _repo = huggingface_hub.Repository(
22
+ local_dir="data", repo_type="dataset", clone_from="Riksarkivet/traffic_demo_data", use_auth_token=_TOKEN
23
+ )
24
+ _session_uuid = None
25
 
26
+ @classmethod
27
+ def _pull_repo_data(cls):
28
+ cls._repo.git_pull()
29
+ shutil.copyfile(cls._DB_TEMP_PATH, cls._DB_FILE_PATH)
 
 
 
 
 
 
30
 
31
+ @staticmethod
32
+ def _hash_ip(ip_address):
33
  return hashlib.sha256(ip_address.encode()).hexdigest()
34
 
35
+ @classmethod
36
+ def _current_time_in_sweden(cls):
37
+ swedish_tz = pytz.timezone(cls._TZ)
38
  return datetime.now(swedish_tz).strftime("%Y-%m-%d %H:%M:%S")
39
 
40
+ @classmethod
41
+ def onload_store_metric_data(cls, request: gr.Request):
42
+ cls._session_uuid = str(uuid.uuid1())
43
+ cls._setup_database()
44
+ hashed_host = cls._hash_ip(request.client.host)
45
+ cls._backup_and_update_database(hashed_host, "load")
46
 
47
+ @classmethod
48
+ def store_metric_data(cls, action, request: gr.Request):
49
+ hashed_host = cls._hash_ip(request.client.host)
50
+ cls._backup_and_update_database(hashed_host, action)
51
 
52
+ @classmethod
53
+ def _commit_host_to_database(cls, hashed_host, action):
54
+ with sqlite3.connect(cls._DB_FILE_PATH) as db:
55
  db.execute(
56
  "INSERT INTO ip_data(current_time, hashed_ip, session_uuid, action) VALUES(?,?,?,?)",
57
+ [cls._current_time_in_sweden(), hashed_host, cls._session_uuid, action],
58
  )
59
 
60
+ @classmethod
61
+ def _setup_database(cls):
62
+ with sqlite3.connect(cls._DB_FILE_PATH) as db:
63
  try:
64
  db.execute("SELECT * FROM ip_data").fetchall()
65
  except sqlite3.OperationalError:
 
72
  action TEXT)
73
  """
74
  )
75
+ cls._pull_repo_data()
76
 
77
+ @classmethod
78
+ def _backup_and_update_database(cls, hashed_host, action):
79
+ cls._commit_host_to_database(hashed_host, action)
80
+ shutil.copyfile(cls._DB_FILE_PATH, cls._DB_TEMP_PATH)
81
 
82
+ with sqlite3.connect(cls._DB_FILE_PATH) as db:
83
  ip_data = db.execute("SELECT * FROM ip_data").fetchall()
84
  pd.DataFrame(ip_data, columns=["id", "current_time", "hashed_ip", "session_uuid", "action"]).to_csv(
85
  "./data/ip_data.csv", index=False
86
  )
87
 
88
+ cls._repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.now()}")
89
 
90
+ @classmethod
91
+ def _initialize_and_schedule_backup(cls, hashed_host, action):
92
+ cls._backup_and_update_database(hashed_host, action)
93
  scheduler = BackgroundScheduler()
94
  scheduler.add_job(
95
+ cls._backup_and_update_database, "interval", minutes=cls._INTERVAL_MIN_UPDATE, args=(hashed_host, action)
96
  )
97
  scheduler.start()
tabs/htr_tool.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
 
3
  from helper.examples.examples import DemoImages
 
4
  from src.htr_pipeline.gradio_backend import FastTrack, SingletonModelLoader
5
 
6
  model_loader = SingletonModelLoader()
@@ -32,6 +33,7 @@ with gr.Blocks() as htr_tool_tab:
32
  visible=True,
33
  elem_id="run_pipeline_button",
34
  )
 
35
 
36
  htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
37
 
@@ -260,5 +262,5 @@ with gr.Blocks() as htr_tool_tab:
260
  fast_track_output_image.select(
261
  fast_track.get_text_from_coords, inputs=text_polygon_dict, outputs=selection_text_from_image_viewer
262
  )
263
-
264
- htr_pipeline_button.click(fn=handler.store_metric_data, inputs="htr_pipeline_button")
 
1
  import gradio as gr
2
 
3
  from helper.examples.examples import DemoImages
4
+ from helper.utils import TrafficDataHandler
5
  from src.htr_pipeline.gradio_backend import FastTrack, SingletonModelLoader
6
 
7
  model_loader = SingletonModelLoader()
 
33
  visible=True,
34
  elem_id="run_pipeline_button",
35
  )
36
+ htr_pipeline_button_var = gr.State(value="htr_pipeline_button")
37
 
38
  htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
39
 
 
262
  fast_track_output_image.select(
263
  fast_track.get_text_from_coords, inputs=text_polygon_dict, outputs=selection_text_from_image_viewer
264
  )
265
+ gr.Variable()
266
+ htr_pipeline_button.click(fn=TrafficDataHandler.store_metric_data, inputs=htr_pipeline_button_var)