import argparse import os.path as osp import random import time from typing import Any from mmengine.config import Config, ConfigDict from mmengine.utils import mkdir_or_exist from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS, TASKS) from opencompass.tasks.base import BaseTask from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg, get_infer_output_path, get_logger, task_abbr_from_cfg) @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run class OpenICLAttackTask(BaseTask): """OpenICL Inference Task. This task is used to run the inference process. """ name_prefix = 'OpenICLAttack' log_subdir = 'logs/attack' output_subdir = 'attack' def __init__(self, cfg: ConfigDict): super().__init__(cfg) run_cfg = self.model_cfgs[0].get('run_cfg', {}) self.num_gpus = run_cfg.get('num_gpus', 0) self.num_procs = run_cfg.get('num_procs', 1) self.logger = get_logger() def get_command(self, cfg_path, template): """Get the command template for the task. Args: cfg_path (str): The path to the config file of the task. template (str): The template which have '{task_cmd}' to format the command. """ script_path = __file__ if self.num_gpus > 0: port = random.randint(12000, 32000) command = (f'torchrun --master_port={port} ' f'--nproc_per_node {self.num_procs} ' f'{script_path} {cfg_path}') else: command = f'python {script_path} {cfg_path}' return template.format(task_cmd=command) def prompt_selection(self, inferencer, prompts): prompt_dict = {} for prompt in prompts: acc = inferencer.predict(prompt) prompt_dict[prompt] = acc'{:.2f}, {}\n'.format(acc * 100, prompt)) sorted_prompts = sorted(prompt_dict.items(), key=lambda x: x[1], reverse=True) return sorted_prompts def run(self):'Task {task_abbr_from_cfg(self.cfg)}') for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): self.max_out_len = model_cfg.get('max_out_len', None) self.batch_size = model_cfg.get('batch_size', None) self.model = build_model_from_cfg(model_cfg) for dataset_cfg in dataset_cfgs: self.model_cfg = model_cfg self.dataset_cfg = dataset_cfg self.infer_cfg = self.dataset_cfg['infer_cfg'] self.dataset = build_dataset_from_cfg(self.dataset_cfg) self.sub_cfg = { 'models': [self.model_cfg], 'datasets': [[self.dataset_cfg]], } out_path = get_infer_output_path( self.model_cfg, self.dataset_cfg, osp.join(self.work_dir, 'attack')) if osp.exists(out_path): continue self._inference() def _inference(self): f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}') assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \ 'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501 ice_template = None if hasattr(self.infer_cfg, 'ice_template'): ice_template = self.infer_cfg['ice_template']) prompt_template = None if hasattr(self.infer_cfg, 'prompt_template'): prompt_template = self.infer_cfg['prompt_template']) retriever_cfg = self.infer_cfg['retriever'].copy() retriever_cfg['dataset'] = self.dataset retriever = # set inferencer's default value according to model's config' inferencer_cfg = self.infer_cfg['inferencer'] inferencer_cfg['model'] = self.model self._set_default_value(inferencer_cfg, 'max_out_len', self.max_out_len) self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size) inferencer_cfg['max_seq_len'] = self.model_cfg['max_seq_len'] inferencer_cfg['dataset_cfg'] = self.dataset_cfg inferencer = out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg, osp.join(self.work_dir, 'attack')) out_dir, out_file = osp.split(out_path) mkdir_or_exist(out_dir) from config import LABEL_SET from prompt_attack.attack import create_attack from prompt_attack.goal_function import PromptGoalFunction inferencer.retriever = retriever inferencer.prompt_template = prompt_template inferencer.ice_template = ice_template inferencer.output_json_filepath = out_dir inferencer.output_json_filename = out_file goal_function = PromptGoalFunction( inference=inferencer, query_budget=self.cfg['attack'].query_budget, logger=self.logger, model_wrapper=None, verbose='True') if self.cfg['attack']['dataset'] not in LABEL_SET: # set default self.cfg['attack']['dataset'] = 'mmlu' attack = create_attack(self.cfg['attack'], goal_function) prompts = self.infer_cfg['inferencer']['original_prompt_list'] sorted_prompts = self.prompt_selection(inferencer, prompts) if True: # if args.prompt_selection: for prompt, acc in sorted_prompts:'Prompt: {}, acc: {:.2f}%\n'.format( prompt, acc * 100)) with open(out_dir + 'attacklog.txt', 'a+') as f: f.write('Prompt: {}, acc: {:.2f}%\n'.format( prompt, acc * 100)) for init_prompt, init_acc in sorted_prompts[:self.cfg['attack']. prompt_topk]: if init_acc > 0: init_acc, attacked_prompt, attacked_acc, dropped_acc = attack.attack( # noqa init_prompt)'Original prompt: {}'.format(init_prompt))'Attacked prompt: {}'.format( attacked_prompt.encode('utf-8'))) 'Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%' # noqa .format(init_acc * 100, attacked_acc * 100, dropped_acc * 100)) with open(out_dir + 'attacklog.txt', 'a+') as f: f.write('Original prompt: {}\n'.format(init_prompt)) f.write('Attacked prompt: {}\n'.format( attacked_prompt.encode('utf-8'))) f.write( 'Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%\n\n' # noqa .format(init_acc * 100, attacked_acc * 100, dropped_acc * 100)) else: with open(out_dir + 'attacklog.txt', 'a+') as f: f.write('Init acc is 0, skip this prompt\n') f.write('Original prompt: {}\n'.format(init_prompt)) f.write('Original acc: {:.2f}% \n\n'.format(init_acc * 100)) def _set_default_value(self, cfg: ConfigDict, key: str, value: Any): if key not in cfg: assert value, (f'{key} must be specified!') cfg[key] = value def parse_args(): parser = argparse.ArgumentParser(description='Model Inferencer') parser.add_argument('config', help='Config file path') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() cfg = Config.fromfile(args.config) start_time = time.time() inferencer = OpenICLAttackTask(cfg) end_time = time.time() get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')