Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModel | |
from Database import Database | |
class GraphCodeBert: | |
def __init__(self) -> None: | |
model_name = "microsoft/graphcodebert-base" | |
self.tokenizer= AutoTokenizer.from_pretrained(model_name) | |
self.model=AutoModel.from_pretrained(model_name) | |
def generate_embeddings(self): | |
database = Database("refactoring_details_neg") | |
# database.connect_db() | |
# collection = database.fetch_collection("refactoring_information") | |
# collection_len = collection.estimated_document_count() | |
collection_len = database.estimated_doc_count() | |
doc_count = 1 | |
for doc in database.find_docs({}, {"_id": 1, "method_refactored": 1, "meth_rf_neg":1}): | |
doc_id = doc["_id"] | |
code_snippet = doc["method_refactored"] | |
code_snippet_neg = doc["meth_rf_neg"] | |
print(f'Generating embedding for doc_id:{doc_id} | Count-{doc_count}...') | |
# Compute embeddings | |
tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True) | |
output = self.model(**tokenized_input_pos) | |
embedding_pos = output.last_hidden_state.mean(dim=1).squeeze().tolist() | |
#Neg Embedding | |
tokenized_input_neg = self.tokenizer(code_snippet_neg, return_tensors="pt", padding=True, truncation=True) | |
output = self.model(**tokenized_input_neg) | |
embedding_neg = output.last_hidden_state.mean(dim=1).squeeze().tolist() | |
# Update document in MongoDB with embedding | |
database.update_by_id(doc_id, "embedding_pos", embedding_pos) | |
database.update_by_id(doc_id,"embedding_neg", embedding_neg) | |
collection_len -= 1 | |
doc_count += 1 | |
print(f'Embedding added for doc_id:{doc_id} | Remaining: {collection_len}.') | |
def generate_individual_embedding(self,code_snippet): | |
tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True) | |
output = self.model(**tokenized_input_pos) | |
embedding = output.last_hidden_state.mean(dim=1).squeeze().tolist() | |
return embedding | |
if __name__=="__main__": | |
GraphCodeBert().generate_embeddings() |