david-oplatka commited on
Commit
152dfa1
1 Parent(s): fa3a890

Add test cases and fix formatting

Browse files
Files changed (2) hide show
  1. agent.py +15 -15
  2. test_agent.py +12 -2
agent.py CHANGED
@@ -1,16 +1,16 @@
1
-
2
  import os
3
  import pandas as pd
4
  import requests
 
5
 
6
  from omegaconf import OmegaConf
7
 
 
 
 
8
  from dotenv import load_dotenv
9
  load_dotenv(override=True)
10
 
11
- from pydantic import Field, BaseModel
12
- from vectara_agent.agent import Agent, AgentStatusType
13
- from vectara_agent.tools import ToolsFactory, VectaraToolFactory
14
 
15
  tickers = {
16
  "AAPL": "Apple Computer",
@@ -28,7 +28,7 @@ tickers = {
28
  years = [2020, 2021, 2022, 2023, 2024]
29
  initial_prompt = "How can I help you today?"
30
 
31
- def create_assistant_tools(cfg):
32
 
33
  def get_company_info() -> list[str]:
34
  """
@@ -44,7 +44,7 @@ def create_assistant_tools(cfg):
44
  Always check this before using any other tool.
45
  """
46
  return years
47
-
48
  # Tool to get the income statement for a given company and year using the FMP API
49
  def get_income_statement(
50
  ticker=Field(description="the ticker symbol of the company."),
@@ -68,16 +68,16 @@ def create_assistant_tools(cfg):
68
  ]
69
  values_dict = income_statement_specific_year.to_dict(orient="records")[0]
70
  return f"Financial results: {', '.join([f'{key}: {value}' for key, value in values_dict.items() if key not in ['date', 'cik', 'link', 'finalLink']])}"
71
- else:
72
- return "FMP API returned error. This tool does not work."
73
 
74
  class QueryTranscriptsArgs(BaseModel):
75
  query: str = Field(..., description="The user query.")
76
  year: int = Field(..., description=f"The year. An integer between {min(years)} and {max(years)}.")
77
  ticker: str = Field(..., description=f"The company ticker. Must be a valid ticket symbol from the list {tickers.keys()}.")
78
 
79
- vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
80
- vectara_customer_id=cfg.customer_id,
81
  vectara_corpus_id=cfg.corpus_id)
82
  tools_factory = ToolsFactory()
83
 
@@ -85,10 +85,10 @@ def create_assistant_tools(cfg):
85
  tool_name = "ask_transcripts",
86
  tool_description = """
87
  Given a company name and year, responds to a user question about the company, based on analyst call transcripts about the company's financial reports for that year.
88
- You can ask this tool any question about the compaany including risks, opportunities, financial performance, competitors and more.
89
  """,
90
  tool_args_schema = QueryTranscriptsArgs,
91
- reranker = "multilingual_reranker_v1", rerank_k = 100,
92
  n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
93
  summary_num_results = 10,
94
  vectara_summarizer = 'vectara-summary-ext-24-05-med-omni',
@@ -98,13 +98,13 @@ def create_assistant_tools(cfg):
98
  return (
99
  [tools_factory.create_tool(tool) for tool in
100
  [
101
- get_company_info,
102
  get_valid_years,
103
  get_income_statement,
104
  ]
105
  ] +
106
- tools_factory.standard_tools() +
107
- tools_factory.financial_tools() +
108
  tools_factory.guardrail_tools() +
109
  [ask_transcripts]
110
  )
 
 
1
  import os
2
  import pandas as pd
3
  import requests
4
+ from pydantic import Field, BaseModel
5
 
6
  from omegaconf import OmegaConf
7
 
8
+ from vectara_agent.agent import Agent
9
+ from vectara_agent.tools import ToolsFactory, VectaraToolFactory
10
+
11
  from dotenv import load_dotenv
12
  load_dotenv(override=True)
13
 
 
 
 
14
 
15
  tickers = {
16
  "AAPL": "Apple Computer",
 
28
  years = [2020, 2021, 2022, 2023, 2024]
29
  initial_prompt = "How can I help you today?"
30
 
31
+ def create_assistant_tools(cfg):
32
 
33
  def get_company_info() -> list[str]:
34
  """
 
44
  Always check this before using any other tool.
45
  """
46
  return years
47
+
48
  # Tool to get the income statement for a given company and year using the FMP API
49
  def get_income_statement(
50
  ticker=Field(description="the ticker symbol of the company."),
 
68
  ]
69
  values_dict = income_statement_specific_year.to_dict(orient="records")[0]
70
  return f"Financial results: {', '.join([f'{key}: {value}' for key, value in values_dict.items() if key not in ['date', 'cik', 'link', 'finalLink']])}"
71
+
72
+ return "FMP API returned error. This tool does not work."
73
 
74
  class QueryTranscriptsArgs(BaseModel):
75
  query: str = Field(..., description="The user query.")
76
  year: int = Field(..., description=f"The year. An integer between {min(years)} and {max(years)}.")
77
  ticker: str = Field(..., description=f"The company ticker. Must be a valid ticket symbol from the list {tickers.keys()}.")
78
 
79
+ vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
80
+ vectara_customer_id=cfg.customer_id,
81
  vectara_corpus_id=cfg.corpus_id)
82
  tools_factory = ToolsFactory()
83
 
 
85
  tool_name = "ask_transcripts",
86
  tool_description = """
87
  Given a company name and year, responds to a user question about the company, based on analyst call transcripts about the company's financial reports for that year.
88
+ You can ask this tool any question about the company including risks, opportunities, financial performance, competitors and more.
89
  """,
90
  tool_args_schema = QueryTranscriptsArgs,
91
+ reranker = "multilingual_reranker_v1", rerank_k = 100,
92
  n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
93
  summary_num_results = 10,
94
  vectara_summarizer = 'vectara-summary-ext-24-05-med-omni',
 
98
  return (
99
  [tools_factory.create_tool(tool) for tool in
100
  [
101
+ get_company_info,
102
  get_valid_years,
103
  get_income_statement,
104
  ]
105
  ] +
106
+ tools_factory.standard_tools() +
107
+ tools_factory.financial_tools() +
108
  tools_factory.guardrail_tools() +
109
  [ask_transcripts]
110
  )
test_agent.py CHANGED
@@ -22,11 +22,21 @@ class TestAgentResponses(unittest.TestCase):
22
 
23
  agent = initialize_agent(_cfg=cfg)
24
  self.assertIsInstance(agent, Agent)
25
-
26
- # Basic number questions
 
 
 
 
27
  self.assertIn('274.52B', agent.chat('Was was the revenue for Apple in 2020? Just provide the number.'))
28
  self.assertIn('amazon', agent.chat('Which company had the highest revenue in 2023? Just provide the name.').lower())
29
  self.assertIn('amazon', agent.chat('Which company had the lowest profits in 2022?').lower())
 
 
 
 
 
 
30
 
31
 
32
  if __name__ == "__main__":
 
22
 
23
  agent = initialize_agent(_cfg=cfg)
24
  self.assertIsInstance(agent, Agent)
25
+
26
+ # Test RAG tool
27
+ self.assertIn('elon musk', agent.chat('Who is the CEO of Tesla? Just give their name.').lower())
28
+ self.assertIn('david zinsner', agent.chat('Who is the CFO of Intel? Just give their name.').lower())
29
+
30
+ # Test Yahoo Finance tool
31
  self.assertIn('274.52B', agent.chat('Was was the revenue for Apple in 2020? Just provide the number.'))
32
  self.assertIn('amazon', agent.chat('Which company had the highest revenue in 2023? Just provide the name.').lower())
33
  self.assertIn('amazon', agent.chat('Which company had the lowest profits in 2022?').lower())
34
+ self.assertIn('10.46B', agent.chat('What was AMD\'s gross profit in 2023?. Just provide the number.'))
35
+
36
+ # Test other custom tools
37
+ self.assertIn('2020', agent.chat('What is the earliest year you can provide information for? Just give the year.').lower())
38
+ self.assertIn('2024', agent.chat('What is the most recent year you can provide information for? Just give the year.').lower())
39
+ self.assertIn('AAPL', agent.chat('What is the stock ticker symbol for Apple? Just give the abbreiated ticker.'))
40
 
41
 
42
  if __name__ == "__main__":