david-oplatka commited on
Commit
fa3a890
1 Parent(s): e2b12f2

Add test file

Browse files
Files changed (2) hide show
  1. agent.py +1 -1
  2. test_agent.py +33 -0
agent.py CHANGED
@@ -109,7 +109,7 @@ def create_assistant_tools(cfg):
109
  [ask_transcripts]
110
  )
111
 
112
- def initialize_agent(_cfg, update_func):
113
  financial_bot_instructions = """
114
  - You are a helpful financial assistant, with expertise in financial reporting, in conversation with a user.
115
  - Respond in a compact format by using appropriate units of measure (e.g., K for thousands, M for millions, B for billions).
 
109
  [ask_transcripts]
110
  )
111
 
112
+ def initialize_agent(_cfg, update_func=None):
113
  financial_bot_instructions = """
114
  - You are a helpful financial assistant, with expertise in financial reporting, in conversation with a user.
115
  - Respond in a compact format by using appropriate units of measure (e.g., K for thousands, M for millions, B for billions).
test_agent.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import os
3
+
4
+ from omegaconf import OmegaConf
5
+ from vectara_agent.agent import Agent
6
+
7
+ from agent import initialize_agent
8
+
9
+ from dotenv import load_dotenv
10
+ load_dotenv(override=True)
11
+
12
+ class TestAgentResponses(unittest.TestCase):
13
+
14
+ def test_responses(self):
15
+
16
+ cfg = OmegaConf.create({
17
+ 'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
18
+ 'corpus_id': str(os.environ['VECTARA_CORPUS_ID']),
19
+ 'api_key': str(os.environ['VECTARA_API_KEY']),
20
+ 'examples': os.environ.get('QUERY_EXAMPLES', None)
21
+ })
22
+
23
+ agent = initialize_agent(_cfg=cfg)
24
+ self.assertIsInstance(agent, Agent)
25
+
26
+ # Basic number questions
27
+ self.assertIn('274.52B', agent.chat('Was was the revenue for Apple in 2020? Just provide the number.'))
28
+ self.assertIn('amazon', agent.chat('Which company had the highest revenue in 2023? Just provide the name.').lower())
29
+ self.assertIn('amazon', agent.chat('Which company had the lowest profits in 2022?').lower())
30
+
31
+
32
+ if __name__ == "__main__":
33
+ unittest.main()