Edward J. Schwartz
Try again
5e7f50e
import gradio as gr
import shap
import transformers
import os
import re
import subprocess
import sys
import tempfile
model = gr.load("ejschwartz/oo-method-test-model-bylibrary", src="models")
model_interp = transformers.pipeline("text-classification", "ejschwartz/oo-method-test-model-bylibrary")
def get_all_dis(bname, addrs=None):
anafile = tempfile.NamedTemporaryFile(prefix=os.path.basename(bname) + "_", suffix=".bat_ana")
ananame = anafile.name
addrstr = ""
if addrs is not None:
addrstr = " ".join([f"--function-at {x}" for x in addrs])
subprocess.check_output(f"bat-ana {addrstr} --no-post-analysis -o {ananame} {bname} 2>/dev/null", shell=True)
output = subprocess.check_output(f"bat-dis --no-insn-address --no-bb-cfg-arrows --color=off {ananame} 2>/dev/null", shell=True)
output = re.sub(b' +', b' ', output)
func_dis = {}
last_func = None
current_output = []
for l in output.splitlines():
if l.startswith(b";;; function 0x"):
if last_func is not None:
func_dis[last_func] = b"\n".join(current_output)
last_func = int(l.split()[2], 16)
current_output.clear()
if not b";;" in l:
current_output.append(l)
if last_func is not None:
if last_func in func_dis:
print("Warning: Ignoring multiple functions at the same address")
else:
func_dis[last_func] = b"\n".join(current_output)
return func_dis
def get_funs(f):
funs = get_all_dis(f.name)
return "\n".join(("%#x" % addr) for addr in funs.keys())
with gr.Blocks() as demo:
all_dis_state = gr.State()
gr.Markdown(
"""
# Function/Method Detector
First, upload a binary.
This model was only trained on 32-bit MSVC++ binaries. You can provide
other types of binaries, but the result will probably be gibberish.
"""
)
file_widget = gr.File(label="Binary file")
with gr.Column(visible=False) as col:
#output = gr.Textbox("Output")
gr.Markdown("""
Great, you selected an executable! Now pick the function you would like to analyze.
""")
fun_dropdown = gr.Dropdown(label="Select a function", choices=["Woohoo!"], interactive=True)
gr.Markdown("""
Below you can find the selected function's disassembly, and the model's
prediction of whether the function is an object-oriented method or a
regular function.
""")
with gr.Row(visible=True) as result:
disassembly = gr.Textbox(label="Disassembly", lines=20)
with gr.Column():
clazz = gr.Label()
interpret_button = gr.Button("Interpret (very slow)")
interpretation = gr.components.Interpretation(disassembly)
example_widget = gr.Examples(
examples=[f.path for f in os.scandir(os.path.join(os.path.dirname(__file__), "examples"))],
inputs=file_widget,
outputs=[all_dis_state, disassembly, clazz]
)
def file_change_fn(file, progress=gr.Progress()):
if file is None:
return {col: gr.update(visible=False),
all_dis_state: None}
else:
#fun_data = {42: 2, 43: 3}
progress(0, desc="Disassembling executable")
fun_data = get_all_dis(file.name)
addrs = ["%#x" % addr for addr in fun_data.keys()]
return {col: gr.update(visible=True),
fun_dropdown: gr.Dropdown.update(choices=addrs, value=addrs[0]),
all_dis_state: fun_data
}
def function_change_fn(selected_fun, fun_data):
disassembly_str = fun_data[int(selected_fun, 16)].decode("utf-8")
load_results = model.fn(disassembly_str)
top_k = {e['label']: e['confidence'] for e in load_results['confidences']}
return {disassembly: gr.Textbox.update(value=disassembly_str),
clazz: gr.Label.update(top_k),
# I can't figure out how to hide this
#interpretation: {}
}
# XXX: Ideally we'd use the gr.load model, which uses the huggingface
# inference API. But shap library appears to use information in the
# transformers pipeline, and I don't feel like figuring out how to
# reimplement that, so we'll just use a regular transformers pipeline here
# for interpretation.
def interpretation_function(text, progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Interpreting function")
explainer = shap.Explainer(model_interp)
shap_values = explainer([text])
# Dimensions are (batch size, text size, number of classes)
# Since we care about positive sentiment, use index 1
scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))
# Scores contains (word, score) pairs
# Format expected by gr.components.Interpretation
return {"original": text, "interpretation": scores}
file_widget.change(file_change_fn, file_widget, [col, fun_dropdown, all_dis_state])
fun_dropdown.change(function_change_fn, [fun_dropdown, all_dis_state], [disassembly, clazz, interpretation])
interpret_button.click(interpretation_function, disassembly, interpretation)
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)