phase-hunter / phasehunter /data_preparation.py
crimeacs's picture
Fixed imports
2bbf18f
raw
history blame
No virus
6.4 kB
import torch
import numpy as np
from scipy import signal
from scipy.signal import butter, lfilter, detrend
# Make bandpass filter
def butter_bandpass(lowcut, highcut, fs, order=5):
nyq = 0.5 * fs # Nyquist frequency
low = lowcut / nyq # Normalized frequency
high = highcut / nyq
b, a = butter(order, [low, high], btype="band") # Bandpass filter
return b, a
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
b, a = butter_bandpass(lowcut, highcut, fs, order=order)
y = lfilter(b, a, data)
return y
def rotate_waveform(waveform, angle):
fft_waveform = np.fft.fft(waveform) # Compute the Fourier transform of the waveform
rotate_factor = np.exp(
1j * angle
) # Create a complex exponential with the specified rotation angle
rotated_fft_waveform = (
fft_waveform * rotate_factor
) # Multiply the Fourier transform by the rotation factor
rotated_waveform = np.fft.ifft(
rotated_fft_waveform
) # Compute the inverse Fourier transform to get the rotated waveform in the time domain
return rotated_waveform
def augment(sample):
# SET PARAMETERS:
crop_length = 6000
padding = 120
test = False
waveform = sample["waveform.npy"]
meta = sample["meta.json"]
if meta["split"] != "train":
test = True
target_sample_P = meta["trace_p_arrival_sample"]
target_sample_S = meta["trace_s_arrival_sample"]
if target_sample_P is None:
target_sample_P = 0
if target_sample_S is None:
target_sample_S = 0
# Randomly select a phase to start the crop
current_phases = [x for x in (target_sample_P, target_sample_S) if x > 0]
phase_selector = np.random.randint(0, len(current_phases))
first_phase = current_phases[phase_selector]
# Shuffle
if first_phase - (crop_length - padding) > padding:
start_indx = int(
first_phase
- torch.randint(low=padding, high=(crop_length - padding), size=(1,))
)
if test == True:
start_indx = int(first_phase - 2 * padding)
elif int(first_phase - padding) > 0:
start_indx = int(
first_phase
- torch.randint(low=0, high=(int(first_phase - padding)), size=(1,))
)
if test == True:
start_indx = int(first_phase - padding)
else:
start_indx = padding
end_indx = start_indx + crop_length
if (waveform.shape[-1] - end_indx) < 0:
start_indx += waveform.shape[-1] - end_indx
end_indx = start_indx + crop_length
# Update target
new_target_P = target_sample_P - start_indx
new_target_S = target_sample_S - start_indx
# Cut
waveform_cropped = waveform[:, start_indx:end_indx]
# Preprocess
waveform_cropped = detrend(waveform_cropped)
waveform_cropped = butter_bandpass_filter(
waveform_cropped, lowcut=0.2, highcut=40, fs=100, order=5
)
window = signal.windows.tukey(waveform_cropped[-1].shape[0], alpha=0.1)
waveform_cropped = waveform_cropped * window
waveform_cropped = detrend(waveform_cropped)
if np.isnan(waveform_cropped).any() == True:
waveform_cropped = np.zeros(shape=waveform_cropped.shape)
new_target_P = 0
new_target_S = 0
if np.sum(waveform_cropped) == 0:
new_target_P = 0
new_target_S = 0
# Normalize data
max_val = np.max(np.abs(waveform_cropped))
waveform_cropped_norm = waveform_cropped / max_val
# Added Z component only
if len(waveform_cropped_norm) < 3:
zeros = np.zeros((3, waveform_cropped_norm.shape[-1]))
zeros[0] = waveform_cropped_norm
waveform_cropped_norm = zeros
if test == False:
##### Rotate waveform #####
probability = torch.randint(0, 2, size=(1,)).item()
angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
if probability == 1:
waveform_cropped_norm = rotate_waveform(waveform_cropped_norm, angle).real
#### Channel DropOUT #####
probability = torch.randint(0, 2, size=(1,)).item()
channel = torch.randint(1, 3, size=(1,)).item()
if probability == 1:
waveform_cropped_norm[channel, :] = 1e-6
# Normalize target
new_target_P = new_target_P / crop_length
new_target_S = new_target_S / crop_length
if (new_target_P <= 0) or (new_target_P >= 1) or (np.isnan(new_target_P)):
new_target_P = 0
if (new_target_S <= 0) or (new_target_S >= 1) or (np.isnan(new_target_S)):
new_target_S = 0
return waveform_cropped_norm, new_target_P, new_target_S
def collation_fn(sample):
waveforms = np.stack([x[0] for x in sample])
targets_P = np.stack([x[1] for x in sample])
targets_S = np.stack([x[2] for x in sample])
return (
torch.tensor(waveforms, dtype=torch.float),
torch.tensor(targets_P, dtype=torch.float),
torch.tensor(targets_S, dtype=torch.float),
)
def my_split_by_node(urls):
node_id, node_count = (
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
)
return list(urls)[node_id::node_count]
def prepare_waveform(waveform):
# SET PARAMETERS:
crop_length = 6000
padding = 120
assert waveform.shape[0] <= 3, "Waveform has more than 3 channels"
if waveform.shape[-1] < crop_length:
waveform = np.pad(
waveform,
((0, 0), (0, crop_length - waveform.shape[-1])),
mode="constant",
constant_values=0,
)
if waveform.shape[-1] > crop_length:
waveform = waveform[:, :crop_length]
# Preprocess
waveform = detrend(waveform)
waveform = butter_bandpass_filter(
waveform, lowcut=0.2, highcut=40, fs=100, order=5
)
window = signal.windows.tukey(waveform[-1].shape[0], alpha=0.1)
waveform = waveform * window
waveform = detrend(waveform)
assert np.isnan(waveform).any() != True, "Nan in waveform"
assert np.sum(waveform) != 0, "Sum of waveform sample is zero"
# Normalize data
max_val = np.max(np.abs(waveform))
waveform = waveform / max_val
# Added Z component only
if len(waveform) < 3:
zeros = np.zeros((3, waveform.shape[-1]))
zeros[0] = waveform
waveform = zeros
return torch.tensor([waveform]*128, dtype=torch.float)