crimeacs commited on
Commit
6762dd9
1 Parent(s): 3c04bfb

Fixed imports

Browse files
Files changed (2) hide show
  1. phasehunter/app.py +164 -0
  2. requirements.txt +1 -0
phasehunter/app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output.
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ from phasehunter.model import Onset_picker, Updated_onset_picker
6
+ from phasehunter.data_preparation import prepare_waveform
7
+ import torch
8
+
9
+ from scipy.stats import gaussian_kde
10
+
11
+ import obspy
12
+ from obspy.clients.fdsn import Client
13
+ from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException
14
+ from obspy.geodetics.base import locations2degrees
15
+ from obspy.taup import TauPyModel
16
+ from obspy.taup.helper_classes import SlownessModelError
17
+
18
+ from obspy.clients.fdsn.header import URL_MAPPINGS
19
+
20
+ def make_prediction(waveform):
21
+ waveform = np.load(waveform)
22
+ processed_input = prepare_waveform(waveform)
23
+
24
+ # Make prediction
25
+ with torch.no_grad():
26
+ output = model(processed_input)
27
+
28
+ p_phase = output[:, 0]
29
+ s_phase = output[:, 1]
30
+
31
+ return processed_input, p_phase, s_phase
32
+
33
+ def mark_phases(waveform):
34
+ processed_input, p_phase, s_phase = make_prediction(waveform)
35
+
36
+ # Create a plot of the waveform with the phases marked
37
+ if sum(processed_input[0][2] == 0): #if input is 1C
38
+ fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
39
+
40
+ ax[0].plot(processed_input[0][0])
41
+ ax[0].set_ylabel('Norm. Ampl.')
42
+
43
+ else: #if input is 3C
44
+ fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
45
+ ax[0].plot(processed_input[0][0])
46
+ ax[1].plot(processed_input[0][1])
47
+ ax[2].plot(processed_input[0][2])
48
+
49
+ ax[0].set_ylabel('Z')
50
+ ax[1].set_ylabel('N')
51
+ ax[2].set_ylabel('E')
52
+
53
+ p_phase_plot = p_phase*processed_input.shape[-1]
54
+ p_kde = gaussian_kde(p_phase_plot)
55
+ p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 )
56
+ ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r')
57
+
58
+ s_phase_plot = s_phase*processed_input.shape[-1]
59
+ s_kde = gaussian_kde(s_phase_plot)
60
+ s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 )
61
+ ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b')
62
+
63
+ for a in ax:
64
+ a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P')
65
+ a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S')
66
+
67
+ ax[-1].set_xlabel('Time, samples')
68
+ ax[-1].set_ylabel('Uncert.')
69
+ ax[-1].legend()
70
+
71
+ plt.subplots_adjust(hspace=0., wspace=0.)
72
+
73
+ # Convert the plot to an image and return it
74
+ fig.canvas.draw()
75
+ image = np.array(fig.canvas.renderer.buffer_rgba())
76
+ plt.close(fig)
77
+ return image
78
+
79
+ def download_data(timestamp, eq_lat, eq_lon, client_name, radius_km):
80
+ client = Client(client_name)
81
+ window = radius_km / 111.2
82
+
83
+ assert eq_lat - window > -90 and eq_lat + window < 90, "Latitude out of bounds"
84
+ assert eq_lon - window > -180 and eq_lon + window < 180, "Longitude out of bounds"
85
+
86
+
87
+
88
+ return 0
89
+
90
+ model = Onset_picker.load_from_checkpoint("./weights.ckpt",
91
+ picker=Updated_onset_picker(),
92
+ learning_rate=3e-4)
93
+ model.eval()
94
+
95
+
96
+
97
+ # # Create the Gradio interface
98
+ # gr.Interface(mark_phases, inputs, outputs, title='PhaseHunter').launch()
99
+
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("# PhaseHunter")
103
+ with gr.Tab("Default example"):
104
+ # Define the input and output types for Gradio
105
+ inputs = gr.Dropdown(
106
+ ["data/sample/sample_0.npy",
107
+ "data/sample/sample_1.npy",
108
+ "data/sample/sample_2.npy"],
109
+ label="Sample waveform",
110
+ info="Select one of the samples",
111
+ value = "data/sample/sample_0.npy"
112
+ )
113
+
114
+ button = gr.Button("Predict phases")
115
+ outputs = gr.outputs.Image(label='Waveform with Phases Marked', type='numpy')
116
+
117
+ button.click(mark_phases, inputs=inputs, outputs=outputs)
118
+
119
+ with gr.Tab("Select earthquake from catalogue"):
120
+ gr.Markdown('TEST')
121
+
122
+ client_inputs = gr.Dropdown(
123
+ choices = list(URL_MAPPINGS.keys()),
124
+ label="FDSN Client",
125
+ info="Select one of the available FDSN clients",
126
+ value = "IRIS",
127
+ interactive=True
128
+ )
129
+ with gr.Row():
130
+
131
+ timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',
132
+ placeholder='YYYY-MM-DD HH:MM:SS',
133
+ label="Timestamp",
134
+ info="Timestamp of the earthquake",
135
+ max_lines=1,
136
+ interactive=True)
137
+
138
+ eq_lat_inputs = gr.Number(value=35.766,
139
+ label="Latitude",
140
+ info="Latitude of the earthquake",
141
+ interactive=True)
142
+
143
+ eq_lo_inputs = gr.Number(value=117.605,
144
+ label="Longitude",
145
+ info="Longitude of the earthquake",
146
+ interactive=True)
147
+
148
+ radius_inputs = gr.Slider(minimum=1,
149
+ maximum=150,
150
+ value=50, label="Radius (km)",
151
+ info="Select the radius around the earthquake to download data from",
152
+ interactive=True)
153
+
154
+ button = gr.Button("Predict phases")
155
+
156
+ with gr.Tab("Predict on your own waveform"):
157
+ gr.Markdown("""
158
+ Please upload your waveform in .npy (numpy) format.
159
+ Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
160
+ """)
161
+
162
+ button.click(mark_phases, inputs=inputs, outputs=outputs)
163
+
164
+ demo.launch()
requirements.txt CHANGED
@@ -11,4 +11,5 @@ torchmetrics==0.11.4
11
  torchvision==0.15.1
12
  tqdm==4.65.0
13
  webdataset==0.2.48
 
14
  git+http://github.com/nikitadurasov/masksembles
 
11
  torchvision==0.15.1
12
  tqdm==4.65.0
13
  webdataset==0.2.48
14
+ obspy
15
  git+http://github.com/nikitadurasov/masksembles