Spaces:
Runtime error
Runtime error
File size: 2,304 Bytes
a5fb347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from transformers import AutoTokenizer, AutoModel
from Database import Database
class GraphCodeBert:
def __init__(self) -> None:
model_name = "microsoft/graphcodebert-base"
self.tokenizer= AutoTokenizer.from_pretrained(model_name)
self.model=AutoModel.from_pretrained(model_name)
def generate_embeddings(self):
database = Database("refactoring_details_neg")
# database.connect_db()
# collection = database.fetch_collection("refactoring_information")
# collection_len = collection.estimated_document_count()
collection_len = database.estimated_doc_count()
doc_count = 1
for doc in database.find_docs({}, {"_id": 1, "method_refactored": 1, "meth_rf_neg":1}):
doc_id = doc["_id"]
code_snippet = doc["method_refactored"]
code_snippet_neg = doc["meth_rf_neg"]
print(f'Generating embedding for doc_id:{doc_id} | Count-{doc_count}...')
# Compute embeddings
tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True)
output = self.model(**tokenized_input_pos)
embedding_pos = output.last_hidden_state.mean(dim=1).squeeze().tolist()
#Neg Embedding
tokenized_input_neg = self.tokenizer(code_snippet_neg, return_tensors="pt", padding=True, truncation=True)
output = self.model(**tokenized_input_neg)
embedding_neg = output.last_hidden_state.mean(dim=1).squeeze().tolist()
# Update document in MongoDB with embedding
database.update_by_id(doc_id, "embedding_pos", embedding_pos)
database.update_by_id(doc_id,"embedding_neg", embedding_neg)
collection_len -= 1
doc_count += 1
print(f'Embedding added for doc_id:{doc_id} | Remaining: {collection_len}.')
def generate_individual_embedding(self,code_snippet):
tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True)
output = self.model(**tokenized_input_pos)
embedding = output.last_hidden_state.mean(dim=1).squeeze().tolist()
return embedding
if __name__=="__main__":
GraphCodeBert().generate_embeddings() |