TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
2.34 kB
import unittest
from opencompass.utils.prompt import PromptList
class TestPromptList(unittest.TestCase):
def test_initialization(self):
pl = PromptList()
self.assertEqual(pl, [])
pl = PromptList(['test', '123'])
self.assertEqual(pl, ['test', '123'])
def test_format(self):
pl = PromptList(['hi {a}{b}', {'prompt': 'hey {a}!'}, '123'])
new_pl = pl.format(a=1, b=2, c=3)
self.assertEqual(new_pl, ['hi 12', {'prompt': 'hey 1!'}, '123'])
new_pl = pl.format(b=2)
self.assertEqual(new_pl, ['hi {a}2', {'prompt': 'hey {a}!'}, '123'])
new_pl = pl.format(d=1)
self.assertEqual(new_pl, ['hi {a}{b}', {'prompt': 'hey {a}!'}, '123'])
def test_replace(self):
pl = PromptList(['hello world', {'prompt': 'hello world'}, '123'])
new_pl = pl.replace('world', 'there')
self.assertEqual(new_pl,
['hello there', {
'prompt': 'hello there'
}, '123'])
new_pl = pl.replace('123', PromptList(['p', {'role': 'BOT'}]))
self.assertEqual(
new_pl,
['hello world', {
'prompt': 'hello world'
}, 'p', {
'role': 'BOT'
}])
new_pl = pl.replace('2', PromptList(['p', {'role': 'BOT'}]))
self.assertEqual(new_pl, [
'hello world', {
'prompt': 'hello world'
}, '1', 'p', {
'role': 'BOT'
}, '3'
])
with self.assertRaises(TypeError):
new_pl = pl.replace('world', PromptList(['new', 'world']))
def test_add(self):
pl = PromptList(['hello'])
new_pl = pl + ' world'
self.assertEqual(new_pl, ['hello', ' world'])
pl2 = PromptList([' world'])
new_pl = pl + pl2
self.assertEqual(new_pl, ['hello', ' world'])
new_pl = 'hi, ' + pl
self.assertEqual(new_pl, ['hi, ', 'hello'])
pl += '!'
self.assertEqual(pl, ['hello', '!'])
def test_str(self):
pl = PromptList(['hello', ' world', {'prompt': '!'}])
self.assertEqual(str(pl), 'hello world!')
with self.assertRaises(TypeError):
pl = PromptList(['hello', ' world', 123])
str(pl)