|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
class AutofixCodeAILLModel(AutoModelForCausalLM): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.decoder = AutoDecoder(self.config.decoder_hidden_size, self.config.decoder_num_layers) |
|
|
|
@property |
|
def decoder(self): |
|
return self._decoder |
|
|
|
@decoder.setter |
|
def decoder(self, value): |
|
self._decoder = value |
|
|
|
class AutoDecoder(torch.nn.Module): |
|
def __init__(self, hidden_size, num_layers): |
|
super().__init__() |
|
self.layers = torch.nn.ModuleList([torch.nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, dropout=0.1) for _ in range(num_layers)]) |
|
|
|
def forward(self, x): |
|
for layer in self.layers: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
model_name_or_path = "autofixcodeai-base" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
ll_model = AutofixCodeAILLModel.from_pretrained(model_name_or_path) |
|
|
|
|
|
class CodeFixDataset(torch.utils.data.Dataset): |
|
def __init__(self, code_snippets, fix_snippets): |
|
self.code_snippets = code_snippets |
|
self.fix_snippets = fix_snippets |
|
|
|
def __len__(self): |
|
return len(self.code_snippets) |
|
|
|
def __getitem__(self, idx): |
|
code = self.code_snippets[idx]["code"] |
|
fix = self.fix_snippets[idx]["fix"] |
|
input_ids = tokenizer.encode(code, max_length=512, return_tensors="pt", truncation=True) |
|
attention_mask = tokenizer.encode(fix, max_length=512, return_tensors="pt", truncation=True, add_special_tokens=False) |
|
labels = torch.tensor(tokenizer.encode(fix, return_tensors="pt", add_special_tokens=False)).flatten() |
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
|
|
|
|
|
dataset = CodeFixDataset(code_snippets, fix_snippets) |
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) |
|
|
|
|
|
class Trainer(torch.nn.Module): |
|
def __init__(self, model, data_loader, device="cuda"): |
|
super().__init__() |
|
self.model = model |
|
self.data_loader = data_loader |
|
self.device = device |
|
|
|
def forward(self, input_ids, attention_mask, labels): |
|
output = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
loss = self.loss_fn(output, labels) |
|
return loss |
|
|
|
@property |
|
def loss_fn(self): |
|
return torch.nn.CrossEntropyLoss() |
|
|
|
|
|
trainer = Trainer(ll_model, data_loader, device="cuda") |
|
for epoch in range(5): |
|
trainer.model.train() |
|
total_loss = 0 |
|
for batch in data_loader: |
|
input_ids = batch["input_ids"].to(device) |
|
attention_mask = batch["attention_mask"].to(device) |
|
labels = batch["labels"].to(device) |
|
loss = trainer(input_ids, attention_mask, labels).mean() |
|
optimizer = torch.optim.Adam(trainer.model.parameters(), lr=1e-4) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
print(f"Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}") |
|
|
|
|
|
trainer.model.eval() |
|
test_loss = 0 |
|
correct = 0 |
|
with torch.no_grad(): |
|
for batch in data_loader: |
|
input_ids = batch["input_ids"].to(device) |
|
attention_mask = batch["attention_mask"].to(device) |
|
labels = batch["labels"].to(device) |
|
output = trainer(input_ids, attention_mask, labels).mean() |
|
loss = self.loss_fn(output, labels) |
|
test_loss += loss.item() |
|
_, predicted = torch.max(output, 1) |
|
correct += (predicted == labels).sum().item() |
|
|
|
accuracy = correct / len(data_loader.dataset) |
|
print(f"Test Loss: {test_loss / len(data_loader)}, Accuracy: {accuracy:.2f}") |