File size: 1,869 Bytes
fa3a890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152dfa1
 
 
 
 
 
fa3a890
 
 
152dfa1
 
 
 
 
 
fa3a890
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import unittest
import os

from omegaconf import OmegaConf
from vectara_agent.agent import Agent

from agent import initialize_agent

from dotenv import load_dotenv
load_dotenv(override=True)

class TestAgentResponses(unittest.TestCase):

    def test_responses(self):

        cfg = OmegaConf.create({
                'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
                'corpus_id': str(os.environ['VECTARA_CORPUS_ID']),
                'api_key': str(os.environ['VECTARA_API_KEY']),
                'examples': os.environ.get('QUERY_EXAMPLES', None)
            })

        agent = initialize_agent(_cfg=cfg)
        self.assertIsInstance(agent, Agent)

        # Test RAG tool
        self.assertIn('elon musk', agent.chat('Who is the CEO of Tesla? Just give their name.').lower())
        self.assertIn('david zinsner', agent.chat('Who is the CFO of Intel? Just give their name.').lower())

        # Test Yahoo Finance tool
        self.assertIn('274.52B', agent.chat('Was was the revenue for Apple in 2020? Just provide the number.'))
        self.assertIn('amazon', agent.chat('Which company had the highest revenue in 2023? Just provide the name.').lower())
        self.assertIn('amazon', agent.chat('Which company had the lowest profits in 2022?').lower())
        self.assertIn('10.46B', agent.chat('What was AMD\'s gross profit in 2023?. Just provide the number.'))

        # Test other custom tools
        self.assertIn('2020', agent.chat('What is the earliest year you can provide information for? Just give the year.').lower())
        self.assertIn('2024', agent.chat('What is the most recent year you can provide information for? Just give the year.').lower())
        self.assertIn('AAPL', agent.chat('What is the stock ticker symbol for Apple? Just give the abbreiated ticker.'))


if __name__ == "__main__":
    unittest.main()