|
import time |
|
|
|
from grazie.api.client.chat.prompt import ChatPrompt |
|
from grazie.api.client.endpoints import GrazieApiGatewayUrls |
|
from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType |
|
from grazie.api.client.profiles import LLMProfile |
|
from tqdm import tqdm |
|
|
|
import config |
|
import hf_data_loader |
|
|
|
client = GrazieApiGatewayClient( |
|
grazie_agent=GrazieAgent(name="commit-rewriting-summary-generation", version="dev"), |
|
url=GrazieApiGatewayUrls.STAGING, |
|
auth_type=AuthType.SERVICE, |
|
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN |
|
) |
|
|
|
|
|
def get_example_prompt(start_msg, end_msg): |
|
return f"""START OF THE EXAMPLE |
|
|
|
For following the edited message: |
|
START OF THE EDITED COMMIT MESSAGE |
|
{end_msg} |
|
END OF THE EDITED COMMIT MESSAGE |
|
|
|
You would output the following initial commit message: |
|
START OF THE INITIAL COMMIT MESSAGE |
|
{start_msg} |
|
END OF THE INITIAL COMMIT MESSAGE |
|
|
|
END OF THE EXAMPLE""" |
|
|
|
|
|
def generate_examples(): |
|
manual_df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()[['commit_msg_start', 'commit_msg_end']] |
|
examples = [ |
|
get_example_prompt(row['commit_msg_start'], row['commit_msg_end']) |
|
for _, row in manual_df.iterrows() |
|
] |
|
|
|
return "\n".join(examples) |
|
|
|
|
|
EXAMPLES = generate_examples() |
|
|
|
|
|
def build_prompt(reference, diff): |
|
return f"""A software developer uses a LLM to generate commit messages. |
|
|
|
They generated a commit message for the following source code changes: |
|
START OF THE SOURCE CODE CHANGES |
|
{diff} |
|
END OF THE SOURCE CODE CHANGES |
|
|
|
After generating the commit message the developer understands that it is not perfect. After making dome changes, |
|
they come up with an edited version of the message. Here is this edited message: |
|
START OF THE COMMIT MESSAGE |
|
{reference} |
|
END OF THE COMMIT MESSAGE |
|
|
|
Your task is to print the initial, LLM-generated commit message. Here are some examples of what you should output: |
|
START OF THE EXAMPLES LIST |
|
{EXAMPLES} |
|
END OF THE EXAMPLES LIST |
|
|
|
Print only the initial commit message's text after the |
|
token "OUTPUT". |
|
|
|
OUTPUT""" |
|
|
|
|
|
def generate_prompt_for_row(row): |
|
reference = row['reference'] |
|
diff = row['mods'] |
|
return build_prompt(reference, diff) |
|
|
|
|
|
def generate_initial_msg(prompt): |
|
commit_msg = client.chat( |
|
chat=ChatPrompt() |
|
.add_system("You are a helpful assistant.") |
|
.add_user(prompt), |
|
profile=LLMProfile("gpt-4-1106-preview") |
|
).content |
|
|
|
return commit_msg |
|
|
|
|
|
def generate_synthetic_dataset(): |
|
df = hf_data_loader.load_full_commit_dataset_as_pandas() |
|
df['initial_msg_prompt'] = df.apply(generate_prompt_for_row, axis=1) |
|
initial_messages_pred = [] |
|
|
|
for i, prompt in enumerate(tqdm(df['initial_msg_ prompt'])): |
|
output = None |
|
|
|
while output is None: |
|
try: |
|
output = generate_initial_msg(prompt) |
|
except: |
|
time.sleep(0.5) |
|
|
|
assert output is not None |
|
initial_messages_pred.append(output) |
|
|
|
df['initial_msg_pred'] = initial_messages_pred |
|
|
|
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT) |
|
|
|
|
|
if __name__ == '__main__': |
|
generate_synthetic_dataset() |
|
|