|
from itertools import chain |
|
|
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
import config |
|
import generate_annotated_diffs |
|
import statistics |
|
from api_wrappers import grazie_wrapper, hf_data_loader |
|
from generation_steps import examples |
|
|
|
GENERATION_MULTIPLIER = 3 |
|
REL_INSERTIONS_THRESHOLD = 0.5 |
|
GENERATION_ATTEMPTS = 3 |
|
|
|
|
|
def build_prompt(reference, diff): |
|
return f"""A software developer uses a LLM to generate commit messages. |
|
|
|
They generated a commit message for the following source code changes: |
|
START OF THE SOURCE CODE CHANGES |
|
{diff} |
|
END OF THE SOURCE CODE CHANGES |
|
|
|
After generating the commit message the developer understands that it is not perfect. After making dome changes, |
|
they come up with an edited version of the message. Here is this edited message: |
|
START OF THE COMMIT MESSAGE |
|
{reference} |
|
END OF THE COMMIT MESSAGE |
|
|
|
Your task is to print the initial, LLM-generated commit message. |
|
The message you print must share some fragments with the edited message. |
|
Here are some examples of what you should output: |
|
START OF THE EXAMPLES LIST |
|
{examples.EXAMPLES_END_TO_START} |
|
END OF THE EXAMPLES LIST |
|
|
|
|
|
Print only the initial commit message's text after the |
|
token "OUTPUT". |
|
|
|
OUTPUT""" |
|
|
|
|
|
def generate_start_msg(end_msg, diff): |
|
prompt = build_prompt(reference=end_msg, diff=diff) |
|
results = [] |
|
|
|
for i in range(GENERATION_ATTEMPTS): |
|
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt) |
|
|
|
stats = statistics.get_statistics(start_msg=start_msg_pred, end_msg=end_msg, |
|
annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg_pred, |
|
end_msg)) |
|
if stats["insertions"] < REL_INSERTIONS_THRESHOLD: |
|
return start_msg_pred |
|
else: |
|
results.append((stats["insertions"], start_msg_pred)) |
|
|
|
results.sort() |
|
return results[0][1] |
|
|
|
|
|
COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"] |
|
|
|
COLS_TO_DEFAULT = {"edit_time": None} |
|
|
|
|
|
def transform(df): |
|
print(f"End -> start synthesis:") |
|
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}") |
|
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}") |
|
print(f"REL_INSERTIONS_THRESHOLD = {REL_INSERTIONS_THRESHOLD}") |
|
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}") |
|
|
|
df['end_to_start'] = False |
|
|
|
generated_data = { |
|
"commit_msg_start": [] |
|
} |
|
|
|
for col in chain(COLS_TO_KEEP, COLS_TO_DEFAULT): |
|
generated_data[col] = [] |
|
|
|
for _, row in tqdm(df.iterrows(), total=len(df)): |
|
for i in range(GENERATION_MULTIPLIER): |
|
commit_msg_start_pred = generate_start_msg(end_msg=row["commit_msg_end"], |
|
diff=row["mods"]) |
|
|
|
generated_data["commit_msg_start"].append(commit_msg_start_pred) |
|
for col in COLS_TO_KEEP: |
|
generated_data[col].append(row[col]) |
|
|
|
for col in COLS_TO_DEFAULT: |
|
generated_data[col].append(COLS_TO_DEFAULT[col]) |
|
|
|
generated_df = pd.DataFrame.from_dict(generated_data) |
|
generated_df['end_to_start'] = True |
|
|
|
result = pd.concat([df, generated_df], ignore_index=True) |
|
result.to_csv(config.END_TO_START_ARTIFACT) |
|
|
|
print("Done") |
|
return result |
|
|
|
|
|
def main(): |
|
df = hf_data_loader.load_processed_rewriting_as_pandas() |
|
transform(df) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|