Spaces:
Sleeping
Sleeping
from zlib import crc32 | |
import struct | |
import gradio as gr | |
import os | |
import pandas as pd | |
import numpy as np | |
import joblib | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Define top features | |
top_features = set([ | |
'pm.vbatMV', 'stateEstimate.z', 'motor.m3', 'stateEstimate.yaw', 'yaw_cos', | |
'motor.m2', 'stateEstimate.y', 'stateEstimate.x', 'motor.m1', 'theta', | |
'motor.m4', 'position_magnitude', 'combined_orientation', 'pwm.m3_pwm', | |
'stateEstimate.roll', 'phi', 'pwm.m2_pwm', 'roll_cos', 'vx_cosine', | |
'stateEstimate.vx', 'velocity_magnitude', 'stateEstimate.vy', 'pwm.m4_pwm', | |
'stateEstimate.vz', 'pwm.m1_pwm' | |
]) | |
# Load the median values from the CSV once | |
feature_medians = pd.read_csv("model/feature_medians.csv") | |
medians_dict = feature_medians.set_index('Feature')['Median'].to_dict() | |
# Load the label encoder, scaler, and saved feature names | |
label_encoder = joblib.load('model/label_encoder.pkl') | |
scaler = joblib.load('model/scaler.pkl') | |
saved_feature_names = joblib.load('model/feature_names.pkl') | |
# Define the EnhancedFaultDetectionNN model | |
class EnhancedFaultDetectionNN(nn.Module): | |
def __init__(self, input_size, output_size, dropout_prob=0.08): | |
super(EnhancedFaultDetectionNN, self).__init__() | |
self.fc1 = nn.Linear(input_size, 1024) | |
self.bn1 = nn.BatchNorm1d(1024) | |
self.fc2 = nn.Linear(1024, 512) | |
self.bn2 = nn.BatchNorm1d(512) | |
self.fc3 = nn.Linear(512, 256) | |
self.bn3 = nn.BatchNorm1d(256) | |
self.fc4 = nn.Linear(256, output_size) | |
self.dropout = nn.Dropout(dropout_prob) | |
def forward(self, x): | |
x = F.relu(self.bn1(self.fc1(x))) | |
x = self.dropout(x) | |
x = F.relu(self.bn2(self.fc2(x))) | |
x = self.dropout(x) | |
x = F.relu(self.bn3(self.fc3(x))) | |
x = self.dropout(x) | |
x = self.fc4(x) | |
return x | |
# Load the PyTorch model | |
model_path = 'model/best_model_without_oversampling128.pth' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
input_size = len(saved_feature_names) | |
output_size = len(label_encoder.classes_) | |
model = EnhancedFaultDetectionNN(input_size, output_size).to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
# Mapping of fault types to corresponding images and comments | |
defect_image_map = { | |
"Extra Weight": { | |
"image": "images/weight.png", | |
"comment": "A weight added near the M3 motor causes lift imbalance." | |
}, | |
"Propeller Cut": { | |
"image": "images/propeller_cut.png", | |
"comment": "A cut on the M2 propeller reduces thrust and causes instability." | |
}, | |
"Tape on Propeller": { | |
"image": "images/tape.png", | |
"comment": "Tape on the M3 propeller leads to imbalance, drag, and vibrations, reducing stability." | |
}, | |
"Normal Flight": { | |
"image": "images/normal_flight.png", | |
"comment": "The quadcopter operates normally with balanced thrust and stability." | |
}, | |
} | |
# List of log files corresponding to the fault types | |
log_files = [ | |
"Logs_Samples/add_weight_W1_near_M3_E9_log04", | |
"Logs_Samples/cut_M2_0.5мм_46.5мм_E9_log02", | |
"Logs_Samples/tape_on_propeller_M3_E9_log01", | |
"Logs_Samples/normal_flight_E8_log03" | |
] | |
# Mapping simplified labels to their corresponding folder names | |
LabelsMap = { | |
"Extra Weight": "add_weight_W1_near_M3", | |
"Propeller Cut": "cut_M2_0.5мм_46.5мм", | |
"Tape on Propeller": "tape_on_propeller_M3", | |
"Normal Flight": "normal_flight" | |
} | |
# Function to retrieve the log file path using LabelsMap and log_files | |
def get_log_file_path(label_key): | |
label_value = LabelsMap[label_key] | |
for log_file in log_files: | |
if label_value in log_file: | |
return log_file | |
return None # Return None if no matching file is found | |
def get_name(data, idx): | |
end_idx = idx | |
while data[end_idx] != 0: | |
end_idx += 1 | |
return data[idx:end_idx].decode("utf-8"), end_idx + 1 | |
def cfusdlog_decode(file): | |
data = file.read() | |
if data[0] != 0xBC: | |
raise gr.Error("Invalid file format: Magic header not found.") | |
crc = crc32(data[0:-4]) | |
expected_crc, = struct.unpack('I', data[-4:]) | |
if crc != expected_crc: | |
raise gr.Error("File integrity check failed: CRC mismatch.") | |
version, num_event_types = struct.unpack('HH', data[1:5]) | |
if version not in [1, 2]: | |
raise gr.Error(f"Unsupported log file version: {version}") | |
result = {} | |
event_by_id = {} | |
idx = 5 | |
for _ in range(num_event_types): | |
event_id, = struct.unpack('H', data[idx:idx+2]) | |
idx += 2 | |
event_name, idx = get_name(data, idx) | |
result[event_name] = {'timestamp': []} | |
num_variables, = struct.unpack('H', data[idx:idx+2]) | |
idx += 2 | |
fmt_str = "<" | |
variables = [] | |
for _ in range(num_variables): | |
var_name_and_type, idx = get_name(data, idx) | |
var_name = var_name_and_type[:-3] | |
var_type = var_name_and_type[-2] | |
result[event_name][var_name] = [] | |
fmt_str += var_type | |
variables.append(var_name) | |
event_by_id[event_id] = { | |
'name': event_name, | |
'fmt_str': fmt_str, | |
'num_bytes': struct.calcsize(fmt_str), | |
'variables': variables, | |
} | |
while idx < len(data) - 4: | |
if version == 1: | |
event_id, timestamp = struct.unpack('<HI', data[idx:idx+6]) | |
idx += 6 | |
elif version == 2: | |
event_id, timestamp = struct.unpack('<HQ', data[idx:idx+10]) | |
timestamp /= 1000.0 | |
idx += 10 | |
event = event_by_id[event_id] | |
event_data = struct.unpack(event['fmt_str'], data[idx:idx+event['num_bytes']]) | |
idx += event['num_bytes'] | |
for var, value in zip(event['variables'], event_data): | |
result[event['name']][var].append(value) | |
result[event['name']]['timestamp'].append(timestamp) | |
for event_name, event_data in result.items(): | |
for var_name, var_data in event_data.items(): | |
result[event_name][var_name] = np.array(var_data) | |
return {k: v for k, v in result.items() if len(v['timestamp']) > 0} # Ensure that only non-empty timestamps are kept | |
def fix_time(log_data): | |
try: | |
timestamps = log_data["timestamp"] | |
if len(timestamps) == 0: | |
raise gr.Error("Timestamp data is empty.") | |
first_value = timestamps[0] | |
log_data["timestamp"] = [t - first_value for t in timestamps] | |
except KeyError: | |
raise gr.Error("Timestamp key not found in the log data.") | |
except Exception as e: | |
raise gr.Error(f"Failed to adjust timestamps: {e}") | |
def process_log_file(file): | |
try: | |
log_data = cfusdlog_decode(file) | |
log_data = log_data.get('fixedFrequency', {}) | |
if not log_data: | |
raise gr.Warning(f"No 'fixedFrequency' data found in the log file") | |
fix_time(log_data) | |
parent_dir_name = os.path.basename(os.path.dirname(file.name)) | |
log_data["true_label"] = [parent_dir_name] * len(log_data.get("timestamp", [])) | |
df = pd.DataFrame(log_data) | |
return df | |
except Exception as e: | |
raise gr.Error(f"Failed to process log file: {e}") | |
def preprocess_single_data_point(single_data_point): | |
try: | |
if 'timestamp' in single_data_point.columns: | |
single_data_point.drop(columns=["timestamp"], inplace=True) | |
single_data_point.fillna(medians_dict, inplace=True) | |
state_x, state_y, state_z = single_data_point[['stateEstimate.x', 'stateEstimate.y', 'stateEstimate.z']].values.T | |
single_data_point['r'] = np.sqrt(state_x**2 + state_y**2 + state_z**2) | |
single_data_point['theta'] = np.arccos(np.clip(single_data_point['stateEstimate.z'] / single_data_point['r'], -1.0, 1.0)) # Clip to avoid invalid values | |
single_data_point['phi'] = np.arctan2(single_data_point['stateEstimate.y'], single_data_point['stateEstimate.x']) | |
single_data_point['position_magnitude'] = single_data_point['r'] | |
velocity_x, velocity_y, velocity_z = single_data_point[['stateEstimate.vx', 'stateEstimate.vy', 'stateEstimate.vz']].values.T | |
single_data_point['velocity_magnitude'] = np.sqrt(velocity_x**2 + velocity_y**2 + velocity_z**2) | |
single_data_point['vx_cosine'] = np.divide(velocity_x, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_x), where=single_data_point['velocity_magnitude']!=0) | |
single_data_point['vy_cosine'] = np.divide(velocity_y, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_y), where=single_data_point['velocity_magnitude']!=0) | |
single_data_point['vz_cosine'] = np.divide(velocity_z, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_z), where=single_data_point['velocity_magnitude']!=0) | |
roll, yaw = single_data_point[['stateEstimate.roll', 'stateEstimate.yaw']].values.T | |
single_data_point['combined_orientation'] = roll + yaw | |
single_data_point['roll_sin'] = np.sin(np.radians(roll)) | |
single_data_point['roll_cos'] = np.cos(np.radians(roll)) | |
single_data_point['yaw_sin'] = np.sin(np.radians(yaw)) | |
single_data_point['yaw_cos'] = np.cos(np.radians(yaw)) | |
features_to_keep = list(top_features.intersection(single_data_point.columns)) | |
return single_data_point[features_to_keep + ['true_label']] | |
except Exception as e: | |
raise gr.Error(f"Failed to preprocess single data point: {e}") | |
def predict(file_path): | |
try: | |
with open(file_path, 'rb') as file: | |
log_df = process_log_file(file) | |
if log_df is not None: | |
single_data_point = log_df.sample(1) | |
preprocessed_data_point = preprocess_single_data_point(single_data_point) | |
if preprocessed_data_point is not None: | |
X = preprocessed_data_point.drop(columns=['true_label']) | |
y = preprocessed_data_point['true_label'] | |
X_ordered = X[saved_feature_names] | |
X_scaled = scaler.transform(X_ordered) | |
X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(device) | |
with torch.no_grad(): | |
logits = model(X_tensor) | |
probabilities = F.softmax(logits, dim=1) | |
confidence_scores, predicted_classes = torch.max(probabilities, dim=1) | |
predicted_labels = label_encoder.inverse_transform(predicted_classes.cpu().numpy()) | |
confidence_scores = confidence_scores.cpu().numpy() | |
predicted_label_value = predicted_labels[0] | |
predicted_label_key = [k for k, v in LabelsMap.items() if v == predicted_label_value][0] | |
label_confidence_pairs = f"{predicted_label_key}: {predicted_label_value} (Confidence: {confidence_scores[0]:.4f})" | |
# Retrieve the corresponding image and comment using the key name | |
defect_info = defect_image_map.get(predicted_label_key, {"image": "images/Placeholder.png", "comment": "No information available."}) | |
image_path = defect_info["image"] | |
comment = defect_info["comment"] | |
return image_path, f"{label_confidence_pairs}\n\nComment: {comment}" | |
else: | |
raise gr.Warning("Log file processing returned no data.") | |
except Exception as e: | |
raise gr.Error(f"Failed to process file: {e}") | |
return None, "Failed to process file" | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Fault Detection in Nano-Quadcopter") | |
gr.Markdown("This interface classifies faults in a nano-quadcopter using a deep neural network model.") | |
with gr.Row(): | |
with gr.Column(): | |
example_dropdown = gr.Dropdown( | |
choices=["Extra Weight", "Propeller Cut", "Tape on Propeller", "Normal Flight"], | |
label="Select Fault Type" | |
) | |
submit_btn = gr.Button("Classify") | |
with gr.Column(): | |
image_output = gr.Image(type="filepath", label="Corresponding Image") | |
label_output = gr.Textbox(label="Predicted Label and Confidence Score") | |
def classify_example(example): | |
try: | |
file_path = get_log_file_path(example) | |
if file_path: | |
file_path = file_path | |
image_path, label_and_comment = predict(file_path) | |
return image_path, label_and_comment | |
else: | |
raise gr.Error("No matching log file found.") | |
except KeyError as e: | |
raise gr.Error(f"Error: {e}") | |
submit_btn.click( | |
fn=classify_example, | |
inputs=[example_dropdown], | |
outputs=[image_output, label_output], | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=True, debug=True) | |