|
import unittest |
|
|
|
from opencompass.utils.prompt import PromptList |
|
|
|
|
|
class TestPromptList(unittest.TestCase): |
|
|
|
def test_initialization(self): |
|
pl = PromptList() |
|
self.assertEqual(pl, []) |
|
|
|
pl = PromptList(['test', '123']) |
|
self.assertEqual(pl, ['test', '123']) |
|
|
|
def test_format(self): |
|
pl = PromptList(['hi {a}{b}', {'prompt': 'hey {a}!'}, '123']) |
|
new_pl = pl.format(a=1, b=2, c=3) |
|
self.assertEqual(new_pl, ['hi 12', {'prompt': 'hey 1!'}, '123']) |
|
|
|
new_pl = pl.format(b=2) |
|
self.assertEqual(new_pl, ['hi {a}2', {'prompt': 'hey {a}!'}, '123']) |
|
|
|
new_pl = pl.format(d=1) |
|
self.assertEqual(new_pl, ['hi {a}{b}', {'prompt': 'hey {a}!'}, '123']) |
|
|
|
def test_replace(self): |
|
pl = PromptList(['hello world', {'prompt': 'hello world'}, '123']) |
|
new_pl = pl.replace('world', 'there') |
|
self.assertEqual(new_pl, |
|
['hello there', { |
|
'prompt': 'hello there' |
|
}, '123']) |
|
|
|
new_pl = pl.replace('123', PromptList(['p', {'role': 'BOT'}])) |
|
self.assertEqual( |
|
new_pl, |
|
['hello world', { |
|
'prompt': 'hello world' |
|
}, 'p', { |
|
'role': 'BOT' |
|
}]) |
|
|
|
new_pl = pl.replace('2', PromptList(['p', {'role': 'BOT'}])) |
|
self.assertEqual(new_pl, [ |
|
'hello world', { |
|
'prompt': 'hello world' |
|
}, '1', 'p', { |
|
'role': 'BOT' |
|
}, '3' |
|
]) |
|
|
|
with self.assertRaises(TypeError): |
|
new_pl = pl.replace('world', PromptList(['new', 'world'])) |
|
|
|
def test_add(self): |
|
pl = PromptList(['hello']) |
|
new_pl = pl + ' world' |
|
self.assertEqual(new_pl, ['hello', ' world']) |
|
|
|
pl2 = PromptList([' world']) |
|
new_pl = pl + pl2 |
|
self.assertEqual(new_pl, ['hello', ' world']) |
|
|
|
new_pl = 'hi, ' + pl |
|
self.assertEqual(new_pl, ['hi, ', 'hello']) |
|
|
|
pl += '!' |
|
self.assertEqual(pl, ['hello', '!']) |
|
|
|
def test_str(self): |
|
pl = PromptList(['hello', ' world', {'prompt': '!'}]) |
|
self.assertEqual(str(pl), 'hello world!') |
|
|
|
with self.assertRaises(TypeError): |
|
pl = PromptList(['hello', ' world', 123]) |
|
str(pl) |
|
|