Virtual Compiler Is All You Need For Assembly Code Search


This repo contains the models and the corresponding evaluation datasets of ACL 2024 paper "Virtual Compiler Is All You Need For Assembly Code Search".

A virtual compiler is a LLM that is capable of compiling any programming language into underlying assembly code. The virtual compiler model is available at elsagranger/VirtualCompiler, based on 34B CodeLlama.

We evaluate the similiarity of the virtual assembly code generated by the virtual compiler and the real assembly code using force execution by script force-exec.py, the corresponding evaluation dataset is avaiable at virtual_assembly_and_ground_truth.

We evaluate the effective of the virtual compiler throught downstream task -- assembly code search, the evaluation dataset is avaiable at elsagranger/AssemblyCodeSearchEval.


We use FastChat and vllm worker to host the model. Run these following commands in seperate terminals, such as tmux.

LOGDIR="" python3 -m fastchat.serve.openai_api_server \
    --host --port 8080 \
    --controller-address http://localhost:21000

LOGDIR="" python3 -m fastchat.serve.controller \
    --host --port 21000

    python3 -m fastchat.serve.vllm_worker \
    --model-path ./VirtualCompiler \
    --num-gpus 8 \
    --controller http://localhost:21000 \
    --max-num-batched-tokens 40960 \
    --disable-log-requests \
    --host --port 22000 \
    --worker-address http://localhost:22000 \
    --model-names "VirtualCompiler"

Then with the model hosted, use do_request.py to make request to the model.

~/C/VirtualCompiler (main)> python3 do_request.py
test rdx, rdx
setz al
movzx eax, al
neg eax

Assembly Code Search Encoder

As huggingface does not support load a remote model inside a folder, we host the model trained on the assembly code search dataset augmented by the Virtual Compiler in vic-encoder. You can use the model.py to test the custom model loading.

Here is a example on text encoder and asm encoder. Please refer to this script on how to extract the assembly code from the binary: process_asm.py.

def calc_map_at_k(logits, pos_cnt, ks=[10,]):
    _, indices = torch.sort(logits, dim=1, descending=True)

    # [batch_size, pos_cnt]
    ranks = torch.nonzero(
        indices < pos_cnt,
    )[:, 1].reshape(logits.shape[0], -1)

    # [batch_size, pos_cnt]
    mrr = torch.mean(1 / (ranks + 1), dim=1)

    res = {}

    for k in ks:
        res[k] = (
            torch.sum((ranks < k).float(), dim=1) / min(k, pos_cnt)

    return ranks.cpu().numpy(), res, mrr.cpu().numpy()

pos_asm_cnt = 1

query = ["List all files in a directory"]
anchor_asm = [...]
neg_anchor_asm = [...]

query_embs = text_encoder(**text_tokenizer(query))
asm_embs = asm_encoder(**asm_tokenizer(anchor_asm))
asm_neg_emb = asm_encoder(**asm_tokenizer(neg_anchor_asm))

# query_embs: [query_cnt, emb_dim]
# asm_embs: [pos_asm_cnt, emb_dim]

# logits_pos: [query_cnt, pos_asm_cnt]
logits_pos = torch.einsum(
    "ic,jc->ij", [query_embs, asm_embs])
# logits_neg: [query_cnt, neg_asm_cnt]
logits_neg = torch.einsum(
    "ic,jc->ij", [query_embs, asm_neg_emb[pos_asm_cnt:]]
logits = torch.cat([logits_pos, logits_neg], dim=1)

ranks, map_at_k, mrr = calc_map_at_k(
    logits, pos_asm_cnt, [1, 5, 10, 20, 50, 100])
