|
import unittest |
|
|
|
from opencompass.models.base_api import APITemplateParser |
|
from opencompass.utils.prompt import PromptList |
|
|
|
|
|
class TestAPITemplateParser(unittest.TestCase): |
|
|
|
def setUp(self): |
|
self.parser = APITemplateParser() |
|
self.prompt = PromptList([ |
|
{ |
|
'section': 'begin', |
|
'pos': 'begin' |
|
}, |
|
'begin', |
|
{ |
|
'role': 'SYSTEM', |
|
'fallback_role': 'HUMAN', |
|
'prompt': 'system msg' |
|
}, |
|
{ |
|
'section': 'ice', |
|
'pos': 'begin' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U0' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B0' |
|
}, |
|
{ |
|
'section': 'ice', |
|
'pos': 'end' |
|
}, |
|
{ |
|
'section': 'begin', |
|
'pos': 'end' |
|
}, |
|
{ |
|
'section': 'round', |
|
'pos': 'begin' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U1' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B1' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U2' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B2' |
|
}, |
|
{ |
|
'section': 'round', |
|
'pos': 'end' |
|
}, |
|
{ |
|
'section': 'end', |
|
'pos': 'begin' |
|
}, |
|
'end', |
|
{ |
|
'section': 'end', |
|
'pos': 'end' |
|
}, |
|
]) |
|
|
|
def test_parse_template_str_input(self): |
|
prompt = self.parser.parse_template('Hello, world!', mode='gen') |
|
self.assertEqual(prompt, 'Hello, world!') |
|
prompt = self.parser.parse_template('Hello, world!', mode='ppl') |
|
self.assertEqual(prompt, 'Hello, world!') |
|
|
|
def test_parse_template_list_input(self): |
|
prompt = self.parser.parse_template(['Hello', 'world'], mode='gen') |
|
self.assertEqual(prompt, ['Hello', 'world']) |
|
prompt = self.parser.parse_template(['Hello', 'world'], mode='ppl') |
|
self.assertEqual(prompt, ['Hello', 'world']) |
|
|
|
def test_parse_template_PromptList_input_no_meta_template(self): |
|
prompt = self.parser.parse_template(self.prompt, mode='gen') |
|
self.assertEqual(prompt, |
|
'begin\nsystem msg\nU0\nB0\nU1\nB1\nU2\nB2\nend') |
|
prompt = self.parser.parse_template(self.prompt, mode='ppl') |
|
self.assertEqual(prompt, |
|
'begin\nsystem msg\nU0\nB0\nU1\nB1\nU2\nB2\nend') |
|
|
|
def test_parse_template_PromptList_input_with_meta_template(self): |
|
parser = APITemplateParser(meta_template=dict(round=[ |
|
dict(role='HUMAN', api_role='HUMAN'), |
|
dict(role='BOT', api_role='BOT', generate=True) |
|
], )) |
|
with self.assertWarns(Warning): |
|
prompt = parser.parse_template(self.prompt, mode='gen') |
|
self.assertEqual( |
|
prompt, |
|
PromptList([ |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'system msg\nU0' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B0' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U1' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B1' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U2' |
|
}, |
|
])) |
|
with self.assertWarns(Warning): |
|
prompt = parser.parse_template(self.prompt, mode='ppl') |
|
self.assertEqual( |
|
prompt, |
|
PromptList([ |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'system msg\nU0' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B0' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U1' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B1' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U2' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B2' |
|
}, |
|
])) |
|
|
|
parser = APITemplateParser(meta_template=dict( |
|
round=[ |
|
dict(role='HUMAN', api_role='HUMAN'), |
|
dict(role='BOT', api_role='BOT', generate=True) |
|
], |
|
reserved_roles=[ |
|
dict(role='SYSTEM', api_role='SYSTEM'), |
|
], |
|
)) |
|
with self.assertWarns(Warning): |
|
prompt = parser.parse_template(self.prompt, mode='gen') |
|
self.assertEqual( |
|
prompt, |
|
PromptList([ |
|
{ |
|
'role': 'SYSTEM', |
|
'prompt': 'system msg' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U0' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B0' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U1' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B1' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U2' |
|
}, |
|
])) |
|
with self.assertWarns(Warning): |
|
prompt = parser.parse_template(self.prompt, mode='ppl') |
|
self.assertEqual( |
|
prompt, |
|
PromptList([ |
|
{ |
|
'role': 'SYSTEM', |
|
'prompt': 'system msg' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U0' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B0' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U1' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B1' |
|
}, |
|
{ |
|
'role': 'HUMAN', |
|
'prompt': 'U2' |
|
}, |
|
{ |
|
'role': 'BOT', |
|
'prompt': 'B2' |
|
}, |
|
])) |
|
|