Spaces:
Running
Running
david-oplatka
commited on
Commit
•
152dfa1
1
Parent(s):
fa3a890
Add test cases and fix formatting
Browse files- agent.py +15 -15
- test_agent.py +12 -2
agent.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import pandas as pd
|
4 |
import requests
|
|
|
5 |
|
6 |
from omegaconf import OmegaConf
|
7 |
|
|
|
|
|
|
|
8 |
from dotenv import load_dotenv
|
9 |
load_dotenv(override=True)
|
10 |
|
11 |
-
from pydantic import Field, BaseModel
|
12 |
-
from vectara_agent.agent import Agent, AgentStatusType
|
13 |
-
from vectara_agent.tools import ToolsFactory, VectaraToolFactory
|
14 |
|
15 |
tickers = {
|
16 |
"AAPL": "Apple Computer",
|
@@ -28,7 +28,7 @@ tickers = {
|
|
28 |
years = [2020, 2021, 2022, 2023, 2024]
|
29 |
initial_prompt = "How can I help you today?"
|
30 |
|
31 |
-
def create_assistant_tools(cfg):
|
32 |
|
33 |
def get_company_info() -> list[str]:
|
34 |
"""
|
@@ -44,7 +44,7 @@ def create_assistant_tools(cfg):
|
|
44 |
Always check this before using any other tool.
|
45 |
"""
|
46 |
return years
|
47 |
-
|
48 |
# Tool to get the income statement for a given company and year using the FMP API
|
49 |
def get_income_statement(
|
50 |
ticker=Field(description="the ticker symbol of the company."),
|
@@ -68,16 +68,16 @@ def create_assistant_tools(cfg):
|
|
68 |
]
|
69 |
values_dict = income_statement_specific_year.to_dict(orient="records")[0]
|
70 |
return f"Financial results: {', '.join([f'{key}: {value}' for key, value in values_dict.items() if key not in ['date', 'cik', 'link', 'finalLink']])}"
|
71 |
-
|
72 |
-
|
73 |
|
74 |
class QueryTranscriptsArgs(BaseModel):
|
75 |
query: str = Field(..., description="The user query.")
|
76 |
year: int = Field(..., description=f"The year. An integer between {min(years)} and {max(years)}.")
|
77 |
ticker: str = Field(..., description=f"The company ticker. Must be a valid ticket symbol from the list {tickers.keys()}.")
|
78 |
|
79 |
-
vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
|
80 |
-
vectara_customer_id=cfg.customer_id,
|
81 |
vectara_corpus_id=cfg.corpus_id)
|
82 |
tools_factory = ToolsFactory()
|
83 |
|
@@ -85,10 +85,10 @@ def create_assistant_tools(cfg):
|
|
85 |
tool_name = "ask_transcripts",
|
86 |
tool_description = """
|
87 |
Given a company name and year, responds to a user question about the company, based on analyst call transcripts about the company's financial reports for that year.
|
88 |
-
You can ask this tool any question about the
|
89 |
""",
|
90 |
tool_args_schema = QueryTranscriptsArgs,
|
91 |
-
reranker = "multilingual_reranker_v1", rerank_k = 100,
|
92 |
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
|
93 |
summary_num_results = 10,
|
94 |
vectara_summarizer = 'vectara-summary-ext-24-05-med-omni',
|
@@ -98,13 +98,13 @@ def create_assistant_tools(cfg):
|
|
98 |
return (
|
99 |
[tools_factory.create_tool(tool) for tool in
|
100 |
[
|
101 |
-
get_company_info,
|
102 |
get_valid_years,
|
103 |
get_income_statement,
|
104 |
]
|
105 |
] +
|
106 |
-
tools_factory.standard_tools() +
|
107 |
-
tools_factory.financial_tools() +
|
108 |
tools_factory.guardrail_tools() +
|
109 |
[ask_transcripts]
|
110 |
)
|
|
|
|
|
1 |
import os
|
2 |
import pandas as pd
|
3 |
import requests
|
4 |
+
from pydantic import Field, BaseModel
|
5 |
|
6 |
from omegaconf import OmegaConf
|
7 |
|
8 |
+
from vectara_agent.agent import Agent
|
9 |
+
from vectara_agent.tools import ToolsFactory, VectaraToolFactory
|
10 |
+
|
11 |
from dotenv import load_dotenv
|
12 |
load_dotenv(override=True)
|
13 |
|
|
|
|
|
|
|
14 |
|
15 |
tickers = {
|
16 |
"AAPL": "Apple Computer",
|
|
|
28 |
years = [2020, 2021, 2022, 2023, 2024]
|
29 |
initial_prompt = "How can I help you today?"
|
30 |
|
31 |
+
def create_assistant_tools(cfg):
|
32 |
|
33 |
def get_company_info() -> list[str]:
|
34 |
"""
|
|
|
44 |
Always check this before using any other tool.
|
45 |
"""
|
46 |
return years
|
47 |
+
|
48 |
# Tool to get the income statement for a given company and year using the FMP API
|
49 |
def get_income_statement(
|
50 |
ticker=Field(description="the ticker symbol of the company."),
|
|
|
68 |
]
|
69 |
values_dict = income_statement_specific_year.to_dict(orient="records")[0]
|
70 |
return f"Financial results: {', '.join([f'{key}: {value}' for key, value in values_dict.items() if key not in ['date', 'cik', 'link', 'finalLink']])}"
|
71 |
+
|
72 |
+
return "FMP API returned error. This tool does not work."
|
73 |
|
74 |
class QueryTranscriptsArgs(BaseModel):
|
75 |
query: str = Field(..., description="The user query.")
|
76 |
year: int = Field(..., description=f"The year. An integer between {min(years)} and {max(years)}.")
|
77 |
ticker: str = Field(..., description=f"The company ticker. Must be a valid ticket symbol from the list {tickers.keys()}.")
|
78 |
|
79 |
+
vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
|
80 |
+
vectara_customer_id=cfg.customer_id,
|
81 |
vectara_corpus_id=cfg.corpus_id)
|
82 |
tools_factory = ToolsFactory()
|
83 |
|
|
|
85 |
tool_name = "ask_transcripts",
|
86 |
tool_description = """
|
87 |
Given a company name and year, responds to a user question about the company, based on analyst call transcripts about the company's financial reports for that year.
|
88 |
+
You can ask this tool any question about the company including risks, opportunities, financial performance, competitors and more.
|
89 |
""",
|
90 |
tool_args_schema = QueryTranscriptsArgs,
|
91 |
+
reranker = "multilingual_reranker_v1", rerank_k = 100,
|
92 |
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
|
93 |
summary_num_results = 10,
|
94 |
vectara_summarizer = 'vectara-summary-ext-24-05-med-omni',
|
|
|
98 |
return (
|
99 |
[tools_factory.create_tool(tool) for tool in
|
100 |
[
|
101 |
+
get_company_info,
|
102 |
get_valid_years,
|
103 |
get_income_statement,
|
104 |
]
|
105 |
] +
|
106 |
+
tools_factory.standard_tools() +
|
107 |
+
tools_factory.financial_tools() +
|
108 |
tools_factory.guardrail_tools() +
|
109 |
[ask_transcripts]
|
110 |
)
|
test_agent.py
CHANGED
@@ -22,11 +22,21 @@ class TestAgentResponses(unittest.TestCase):
|
|
22 |
|
23 |
agent = initialize_agent(_cfg=cfg)
|
24 |
self.assertIsInstance(agent, Agent)
|
25 |
-
|
26 |
-
#
|
|
|
|
|
|
|
|
|
27 |
self.assertIn('274.52B', agent.chat('Was was the revenue for Apple in 2020? Just provide the number.'))
|
28 |
self.assertIn('amazon', agent.chat('Which company had the highest revenue in 2023? Just provide the name.').lower())
|
29 |
self.assertIn('amazon', agent.chat('Which company had the lowest profits in 2022?').lower())
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
if __name__ == "__main__":
|
|
|
22 |
|
23 |
agent = initialize_agent(_cfg=cfg)
|
24 |
self.assertIsInstance(agent, Agent)
|
25 |
+
|
26 |
+
# Test RAG tool
|
27 |
+
self.assertIn('elon musk', agent.chat('Who is the CEO of Tesla? Just give their name.').lower())
|
28 |
+
self.assertIn('david zinsner', agent.chat('Who is the CFO of Intel? Just give their name.').lower())
|
29 |
+
|
30 |
+
# Test Yahoo Finance tool
|
31 |
self.assertIn('274.52B', agent.chat('Was was the revenue for Apple in 2020? Just provide the number.'))
|
32 |
self.assertIn('amazon', agent.chat('Which company had the highest revenue in 2023? Just provide the name.').lower())
|
33 |
self.assertIn('amazon', agent.chat('Which company had the lowest profits in 2022?').lower())
|
34 |
+
self.assertIn('10.46B', agent.chat('What was AMD\'s gross profit in 2023?. Just provide the number.'))
|
35 |
+
|
36 |
+
# Test other custom tools
|
37 |
+
self.assertIn('2020', agent.chat('What is the earliest year you can provide information for? Just give the year.').lower())
|
38 |
+
self.assertIn('2024', agent.chat('What is the most recent year you can provide information for? Just give the year.').lower())
|
39 |
+
self.assertIn('AAPL', agent.chat('What is the stock ticker symbol for Apple? Just give the abbreiated ticker.'))
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|