crimeacs commited on
Commit
4ea56e8
1 Parent(s): d59f437

section plot updated

Browse files
Files changed (2) hide show
  1. Gradio_app.ipynb +21 -0
  2. app.py +47 -26
Gradio_app.ipynb CHANGED
@@ -61,6 +61,27 @@
61
  "execution_count": 4,
62
  "metadata": {},
63
  "output_type": "execute_result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  }
65
  ],
66
  "source": [
 
61
  "execution_count": 4,
62
  "metadata": {},
63
  "output_type": "execute_result"
64
+ },
65
+ {
66
+ "name": "stdout",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "Error in callback <function _draw_all_if_interactive at 0x1774d0ea0> (for post_execute):\n"
70
+ ]
71
+ },
72
+ {
73
+ "ename": "KeyboardInterrupt",
74
+ "evalue": "",
75
+ "output_type": "error",
76
+ "traceback": [
77
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
78
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
79
+ "File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/pyplot.py:120\u001b[0m, in \u001b[0;36m_draw_all_if_interactive\u001b[0;34m()\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_draw_all_if_interactive\u001b[39m():\n\u001b[1;32m 119\u001b[0m \u001b[39mif\u001b[39;00m matplotlib\u001b[39m.\u001b[39mis_interactive():\n\u001b[0;32m--> 120\u001b[0m draw_all()\n",
80
+ "File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/_pylab_helpers.py:132\u001b[0m, in \u001b[0;36mGcf.draw_all\u001b[0;34m(cls, force)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[39mfor\u001b[39;00m manager \u001b[39min\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mget_all_fig_managers():\n\u001b[1;32m 131\u001b[0m \u001b[39mif\u001b[39;00m force \u001b[39mor\u001b[39;00m manager\u001b[39m.\u001b[39mcanvas\u001b[39m.\u001b[39mfigure\u001b[39m.\u001b[39mstale:\n\u001b[0;32m--> 132\u001b[0m manager\u001b[39m.\u001b[39;49mcanvas\u001b[39m.\u001b[39;49mdraw_idle()\n",
81
+ "File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/backend_bases.py:2082\u001b[0m, in \u001b[0;36mFigureCanvasBase.draw_idle\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2080\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_is_idle_drawing:\n\u001b[1;32m 2081\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_idle_draw_cntx():\n\u001b[0;32m-> 2082\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdraw(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
82
+ "File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/backends/backend_agg.py:397\u001b[0m, in \u001b[0;36mFigureCanvasAgg.draw\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrenderer\u001b[39m.\u001b[39mclear()\n\u001b[1;32m 396\u001b[0m \u001b[39m# Acquire a lock on the shared font cache.\u001b[39;00m\n\u001b[0;32m--> 397\u001b[0m \u001b[39mwith\u001b[39;49;00m RendererAgg\u001b[39m.\u001b[39;49mlock, \\\n\u001b[1;32m 398\u001b[0m (\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtoolbar\u001b[39m.\u001b[39;49m_wait_cursor_for_draw_cm() \u001b[39mif\u001b[39;49;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtoolbar\n\u001b[1;32m 399\u001b[0m \u001b[39melse\u001b[39;49;00m nullcontext()):\n\u001b[1;32m 400\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfigure\u001b[39m.\u001b[39;49mdraw(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrenderer)\n\u001b[1;32m 401\u001b[0m \u001b[39m# A GUI class may be need to update a window using this draw, so\u001b[39;49;00m\n\u001b[1;32m 402\u001b[0m \u001b[39m# don't forget to call the superclass.\u001b[39;49;00m\n",
83
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
84
+ ]
85
  }
86
  ],
87
  "source": [
app.py CHANGED
@@ -21,6 +21,8 @@ from obspy.clients.fdsn.header import URL_MAPPINGS
21
  import matplotlib.pyplot as plt
22
  import matplotlib.dates as mdates
23
 
 
 
24
  def make_prediction(waveform):
25
  waveform = np.load(waveform)
26
  processed_input = prepare_waveform(waveform)
@@ -80,6 +82,15 @@ def mark_phases(waveform):
80
  plt.close(fig)
81
  return image
82
 
 
 
 
 
 
 
 
 
 
83
  def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model):
84
  distances, t0s, st_lats, st_lons, waveforms = [], [], [], [], []
85
 
@@ -101,6 +112,8 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
101
  level='station')
102
 
103
  waveforms = []
 
 
104
  for network in inv:
105
  for station in network:
106
  try:
@@ -115,8 +128,12 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
115
  starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15
116
  endtime = starttime + 60
117
 
118
- waveform = client.get_waveforms(network=network.code, station=station.code, location="*", channel="*",
119
- starttime=starttime, endtime=endtime)
 
 
 
 
120
 
121
  waveform = waveform.select(channel="H[BH][ZNE]")
122
  waveform = waveform.merge(fill_value=0)
@@ -148,29 +165,39 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
148
  p_phases = output[:, 0]
149
  s_phases = output[:, 1]
150
 
151
- fig, ax = plt.subplots(nrows=1, figsize=(10, 3), sharex=True)
 
 
 
 
152
  for i in range(len(waveforms)):
153
  current_P = p_phases[i::len(waveforms)]
154
  current_S = s_phases[i::len(waveforms)]
 
155
  x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]
156
  x = mdates.date2num(x)
157
- ax.plot(x, waveforms[i][0, 0]+distances[i]*111.2, color='black', alpha=0.5)
158
- ax.scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r')
159
- ax.scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b')
160
- ax.set_ylabel('Z')
161
 
162
- ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
163
- ax.xaxis.set_major_locator(mdates.SecondLocator(interval=10))
 
164
 
165
- # for a in ax:
166
- # a.axvline(current_P.mean()*waveforms[i][0].shape[-1], color='r', linestyle='--', label='P')
167
- # a.axvline(current_S.mean()*waveforms[i][0].shape[-1], color='b', linestyle='--', label='S')
168
 
169
- # ax[-1].set_xlabel('Time, samples')
170
- # ax[-1].set_ylabel('Uncert.')
171
- # ax[-1].legend()
172
 
173
- plt.subplots_adjust(hspace=0., wspace=0.)
 
 
 
 
 
 
 
 
 
 
174
 
175
  fig.canvas.draw();
176
  image = np.array(fig.canvas.renderer.buffer_rgba())
@@ -184,12 +211,6 @@ model = Onset_picker.load_from_checkpoint("./weights.ckpt",
184
  learning_rate=3e-4)
185
  model.eval()
186
 
187
-
188
-
189
- # # Create the Gradio interface
190
- # gr.Interface(mark_phases, inputs, outputs, title='PhaseHunter').launch()
191
-
192
-
193
  with gr.Blocks() as demo:
194
  gr.Markdown("# PhaseHunter")
195
  gr.Markdown("""This app allows one to detect P and S seismic phases along with uncertainty of the detection.
@@ -250,7 +271,9 @@ with gr.Blocks() as demo:
250
  radius_inputs = gr.Slider(minimum=1,
251
  maximum=150,
252
  value=50, label="Radius (km)",
253
- info="Select the radius around the earthquake to download data from",
 
 
254
  interactive=True)
255
 
256
  velocity_inputs = gr.Dropdown(
@@ -263,7 +286,7 @@ with gr.Blocks() as demo:
263
 
264
 
265
  button = gr.Button("Predict phases")
266
- outputs_section = gr.outputs.Image(label='Waveforms with Phases Marked', type='numpy')
267
 
268
  button.click(predict_on_section,
269
  inputs=[client_inputs, timestamp_inputs,
@@ -277,6 +300,4 @@ with gr.Blocks() as demo:
277
  Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
278
  """)
279
 
280
-
281
-
282
  demo.launch()
 
21
  import matplotlib.pyplot as plt
22
  import matplotlib.dates as mdates
23
 
24
+ from glob import glob
25
+
26
  def make_prediction(waveform):
27
  waveform = np.load(waveform)
28
  processed_input = prepare_waveform(waveform)
 
82
  plt.close(fig)
83
  return image
84
 
85
+ def variance_coefficient(residuals):
86
+ # calculate the variance of the residuals
87
+ var = residuals.var()
88
+
89
+ # scale the variance to a coefficient between 0 and 1
90
+ coeff = 1 - (var / (residuals.max() - residuals.min()))
91
+
92
+ return coeff
93
+
94
  def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model):
95
  distances, t0s, st_lats, st_lons, waveforms = [], [], [], [], []
96
 
 
112
  level='station')
113
 
114
  waveforms = []
115
+ cached_waveforms = glob("data/cached/*.mseed")
116
+
117
  for network in inv:
118
  for station in network:
119
  try:
 
128
  starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15
129
  endtime = starttime + 60
130
 
131
+ if f"data/cached/{network.code}_{station.code}_{starttime}.mseed" not in cached_waveforms:
132
+ waveform = client.get_waveforms(network=network.code, station=station.code, location="*", channel="*",
133
+ starttime=starttime, endtime=endtime)
134
+ waveform.write(f"data/cached/{network.code}_{station.code}_{starttime}.mseed", format="MSEED")
135
+ else:
136
+ waveform = obspy.read(f"data/cached/{network.code}_{station.code}_{starttime}.mseed")
137
 
138
  waveform = waveform.select(channel="H[BH][ZNE]")
139
  waveform = waveform.merge(fill_value=0)
 
165
  p_phases = output[:, 0]
166
  s_phases = output[:, 1]
167
 
168
+ # Max confidence - min variance
169
+ p_max_confidence = np.min([p_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
170
+ s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
171
+
172
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3), sharex=True)
173
  for i in range(len(waveforms)):
174
  current_P = p_phases[i::len(waveforms)]
175
  current_S = s_phases[i::len(waveforms)]
176
+
177
  x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]
178
  x = mdates.date2num(x)
 
 
 
 
179
 
180
+ # Normalize confidence for the plot
181
+ p_conf = 1/(current_P.std()/p_max_confidence).item()
182
+ s_conf = 1/(current_S.std()/s_max_confidence).item()
183
 
184
+ ax[0].plot(x, waveforms[i][0, 0]*10+distances[i]*111.2, color='black', alpha=0.5, lw=1)
 
 
185
 
186
+ ax[0].scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r', alpha=p_conf, marker='|')
187
+ ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')
188
+ ax[0].set_ylabel('Z')
189
 
190
+ ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
191
+ ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=5))
192
+
193
+ ax[0].scatter(None, None, color='r', marker='|', label='P')
194
+ ax[0].scatter(None, None, color='b', marker='|', label='S')
195
+ ax[0].legend()
196
+
197
+ ax[1].scatter(st_lats, st_lons, color='b', marker='d', label='Stations')
198
+ ax[1].scatter(eq_lat, eq_lon, color='r', marker='*', label='Earthquake')
199
+ ax[1].legend()
200
+ plt.subplots_adjust(hspace=0., wspace=0.)
201
 
202
  fig.canvas.draw();
203
  image = np.array(fig.canvas.renderer.buffer_rgba())
 
211
  learning_rate=3e-4)
212
  model.eval()
213
 
 
 
 
 
 
 
214
  with gr.Blocks() as demo:
215
  gr.Markdown("# PhaseHunter")
216
  gr.Markdown("""This app allows one to detect P and S seismic phases along with uncertainty of the detection.
 
271
  radius_inputs = gr.Slider(minimum=1,
272
  maximum=150,
273
  value=50, label="Radius (km)",
274
+ step=10,
275
+ info="""Select the radius around the earthquake to download data from.\n
276
+ Note that the larger the radius, the longer the app will take to run.""",
277
  interactive=True)
278
 
279
  velocity_inputs = gr.Dropdown(
 
286
 
287
 
288
  button = gr.Button("Predict phases")
289
+ outputs_section = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)
290
 
291
  button.click(predict_on_section,
292
  inputs=[client_inputs, timestamp_inputs,
 
300
  Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
301
  """)
302
 
 
 
303
  demo.launch()