File size: 4,976 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Agent Inferencer."""
import os.path as osp
import types
from typing import List

from opencompass.models.lagent import LagentAgent
from opencompass.registry import ICL_INFERENCERS

from ..utils.logging import get_logger
from .icl_base_inferencer import dump_results_dict
from .icl_chat_inferencer import ChatInferencer

logger = get_logger(__name__)


class AgentInferencerOutputHandler:

    def __init__(self) -> None:
        self.results_dict = {}

    def write_to_json(self, save_dir: str, filename: str):
        """Dump the result to a json file."""
        dump_results_dict(self.results_dict, osp.join(save_dir, filename))

    def save_results(self,
                     origin_prompt: list,
                     prediction: str,
                     steps: list,
                     idx: int,
                     gold: str = None):
        result_dict = {}
        if gold:
            result_dict['gold'] = gold
        result_dict.update({
            'prediction': prediction,
            'origin_prompt': origin_prompt,
            'steps': steps,
        })
        self.results_dict[str(idx)] = result_dict

    def save_multiround_results(self,
                                origin_prompt: list,
                                prediction: str,
                                steps: list,
                                idx: int,
                                gold: str = None):
        result_dict = self.results_dict.get(str(idx), {
            'gold': [],
            'prediction': [],
            'origin_prompt': [],
            'steps': [],
        })
        result_dict['gold'].append(gold)
        result_dict['prediction'].append(prediction)
        result_dict['origin_prompt'].append(origin_prompt)
        result_dict['steps'].append(steps)
        self.results_dict[str(idx)] = result_dict


def model_adapter(model):
    """Modify the generate method to accept and return single item."""
    if getattr(model, '_generate_is_wrapped', False):
        # Avoid wrap twice.
        return model

    origin_generate = model.generate

    def generate(self, inputs, *args, **kwargs):
        return origin_generate([inputs], *args, **kwargs)[0]

    model.generate = types.MethodType(generate, model)
    setattr(model, '_generate_is_wrapped', True)
    return model


@ICL_INFERENCERS.register_module()
class AgentInferencer(ChatInferencer):
    HandlerType = AgentInferencerOutputHandler

    def __init__(self, model, **kwargs) -> None:
        model.agent._llm = model_adapter(model.agent._llm)
        super().__init__(model, **kwargs)
        self.model: LagentAgent

    def infer_last(self, chat: List[dict], index: int, output_handler):
        assistant_indices = [
            i for i, item in enumerate(chat) if item['role'] == 'assistant'
        ]

        user_idx = assistant_indices[-1] - 1
        self.model.set_history(chat[:user_idx])
        answer, steps, _ = self.model.chat(chat[user_idx]['content'])
        output_handler.save_results(
            origin_prompt=chat[user_idx]['content'],
            prediction=answer,
            steps=steps,
            idx=index,
            gold=chat[assistant_indices[-1]]['content'],
        )
        self.model.reset()

    def infer_every(self, chat: List[dict], index: int, output_handler):
        assistant_indices = [
            i for i, item in enumerate(chat) if item['role'] == 'assistant'
        ]

        history = chat[:assistant_indices[0] - 1]
        for i in assistant_indices:
            answer, steps, inner_steps = self.model.chat(
                chat[i - 1]['content'], history)
            history += inner_steps
            output_handler.save_multiround_results(
                origin_prompt=chat[i - 1]['content'],
                prediction=answer,
                steps=steps,
                idx=index,
                gold=chat[i]['content'],
            )
        self.model.reset()

    def infer_every_with_gt(self, chat: List[dict], index: int,
                            output_handler):
        assistant_indices = [
            i for i, item in enumerate(chat) if item['role'] == 'assistant'
        ]

        history = chat[:assistant_indices[0] - 1]
        prev_idx = 0
        for i in assistant_indices:
            for j in range(prev_idx, i - 1):
                if chat[j]['role'] == 'assistant':
                    history += self.model.gt_response(chat[j]['content'])
                elif chat[j]['role'] == 'user':
                    history += [chat[j]]
            self.model.set_history(history)
            answer, steps, _ = self.model.chat(chat[i - 1]['content'])
            output_handler.save_multiround_results(
                origin_prompt=chat[i - 1]['content'],
                prediction=answer,
                steps=steps,
                idx=index,
                gold=chat[i]['content'],
            )
            history += [chat[i - 1]]
            prev_idx = i
        self.model.reset()