Upload IsoPro Package
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +201 -0
- README.md +276 -3
- isopro/.DS_Store +0 -0
- isopro/__init__.py +84 -0
- isopro/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__init__.py +18 -0
- isopro/adversarial_simulation/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_agent.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_environment.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_envrionment.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_simulator.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/attack_utils.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/adversarial_agent.py +51 -0
- isopro/adversarial_simulation/adversarial_environment.py +81 -0
- isopro/adversarial_simulation/adversarial_simulator.py +47 -0
- isopro/adversarial_simulation/attack_utils.py +65 -0
- isopro/adversarial_simulation/main.py +124 -0
- isopro/agents/__init__.py +7 -0
- isopro/agents/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/agents/__pycache__/ai_agent.cpython-38.pyc +0 -0
- isopro/agents/ai_agent.py +44 -0
- isopro/base/__init__.py +8 -0
- isopro/base/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/base/__pycache__/base_component.cpython-38.pyc +0 -0
- isopro/base/__pycache__/base_wrapper.cpython-38.pyc +0 -0
- isopro/base/base_component.py +34 -0
- isopro/base/base_wrapper.py +82 -0
- isopro/car_simulator/__init__.py +12 -0
- isopro/car_simulator/car_llm_agent.py +143 -0
- isopro/car_simulator/car_rl_environment.py +155 -0
- isopro/car_simulator/car_rl_model.zip +3 -0
- isopro/car_simulator/car_rl_training.py +38 -0
- isopro/car_simulator/carviz.py +227 -0
- isopro/car_simulator/llm_main.py +74 -0
- isopro/car_simulator/main.py +48 -0
- isopro/conversation_simulation/README.md +252 -0
- isopro/conversation_simulation/__init__.py +19 -0
- isopro/conversation_simulation/conversation_agent.py +41 -0
- isopro/conversation_simulation/conversation_environment.py +78 -0
- isopro/conversation_simulation/conversation_simulator.py +67 -0
- isopro/conversation_simulation/custom_persona.py +58 -0
- isopro/conversation_simulation/main.py +117 -0
- isopro/conversation_simulation/user_personas.py +112 -0
- isopro/environments/__init__.py +9 -0
- isopro/environments/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/environments/__pycache__/custom_environment.cpython-38.pyc +0 -0
- isopro/environments/__pycache__/llm_orchestrator.cpython-38.pyc +0 -0
- isopro/environments/__pycache__/simulation_environment.cpython-38.pyc +0 -0
- isopro/environments/custom_environment.py +108 -0
- isopro/environments/llm_orchestrator.py +194 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,276 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ISOPro: Pro Tools for Intelligent Simulation Orchestration for Large Language Models
|
2 |
+
|
3 |
+
ISOPRO is a powerful and flexible Python package designed for creating, managing, and analyzing simulations involving Large Language Models (LLMs). It provides a comprehensive suite of tools for reinforcement learning, conversation simulations, adversarial testing, custom environment creation, and advanced orchestration of multi-agent systems.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- **Custom Environment Creation**: Easily create and manage custom simulation environments for LLMs
|
8 |
+
- **Conversation Simulation**: Simulate and analyze conversations with AI agents using various user personas
|
9 |
+
- **Adversarial Testing**: Conduct adversarial simulations to test the robustness of LLM-based systems
|
10 |
+
- **Reinforcement Learning**: Implement and experiment with RL algorithms in LLM contexts
|
11 |
+
- **Workflow Automation**: Learn and replicate UI workflows from video demonstrations
|
12 |
+
- **Car Environment Simulation**: Train and evaluate RL agents in driving scenarios
|
13 |
+
- **Utility Functions**: Analyze simulation results, calculate LLM metrics, and more
|
14 |
+
- **Flexible Integration**: Works with popular LLM platforms like OpenAI's GPT models, Claude (Anthropic), and Hugging Face models
|
15 |
+
- **Orchestration Simulation**: Manage and execute complex multi-agent simulations with different execution modes
|
16 |
+
|
17 |
+
## Installation
|
18 |
+
|
19 |
+
You can install isopro using pip:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
pip install isopro
|
23 |
+
```
|
24 |
+
|
25 |
+
For workflow simulation features, ensure you have the required dependencies:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
pip install opencv-python numpy torch stable-baselines3 gymnasium tqdm
|
29 |
+
```
|
30 |
+
|
31 |
+
If you plan to use Claude capabilities:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
export ANTHROPIC_API_KEY=your_api_key_here
|
35 |
+
```
|
36 |
+
|
37 |
+
## Usage
|
38 |
+
|
39 |
+
### Adversarial Simulation
|
40 |
+
|
41 |
+
Test the robustness of AI models against adversarial attacks.
|
42 |
+
|
43 |
+
```python
|
44 |
+
from isopro.adversarial_simulation import AdversarialSimulator, AdversarialEnvironment
|
45 |
+
from isopro.agents.ai_agent import AI_Agent
|
46 |
+
import anthropic
|
47 |
+
|
48 |
+
class ClaudeAgent(AI_Agent):
|
49 |
+
def __init__(self, name):
|
50 |
+
super().__init__(name)
|
51 |
+
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
52 |
+
|
53 |
+
def run(self, input_data):
|
54 |
+
response = self.client.messages.create(
|
55 |
+
model="claude-3-opus-20240229",
|
56 |
+
max_tokens=100,
|
57 |
+
messages=[{"role": "user", "content": input_data['text']}]
|
58 |
+
)
|
59 |
+
return response.content[0].text
|
60 |
+
|
61 |
+
# Create the AdversarialEnvironment
|
62 |
+
adv_env = AdversarialEnvironment(
|
63 |
+
agent_wrapper=ClaudeAgent("Claude Agent"),
|
64 |
+
num_adversarial_agents=2,
|
65 |
+
attack_types=["textbugger", "deepwordbug"],
|
66 |
+
attack_targets=["input", "output"]
|
67 |
+
)
|
68 |
+
|
69 |
+
# Set up the adversarial simulator
|
70 |
+
simulator = AdversarialSimulator(adv_env)
|
71 |
+
|
72 |
+
# Run the simulation
|
73 |
+
input_data = ["What is the capital of France?", "How does photosynthesis work?"]
|
74 |
+
simulation_results = simulator.run_simulation(input_data, num_steps=1)
|
75 |
+
```
|
76 |
+
|
77 |
+
### Conversation Simulation
|
78 |
+
|
79 |
+
Simulate conversations between an AI assistant and various user personas.
|
80 |
+
|
81 |
+
```python
|
82 |
+
from isopro.conversation_simulation.conversation_simulator import ConversationSimulator
|
83 |
+
|
84 |
+
# Initialize the ConversationSimulator
|
85 |
+
simulator = ConversationSimulator(
|
86 |
+
ai_prompt="You are an AI assistant created to be helpful, harmless, and honest. You are a customer service agent for a tech company. Respond politely and professionally."
|
87 |
+
)
|
88 |
+
|
89 |
+
# Run a simulation with a predefined persona
|
90 |
+
conversation_history = simulator.run_simulation("upset", num_turns=3)
|
91 |
+
|
92 |
+
# Run a simulation with a custom persona
|
93 |
+
custom_persona = {
|
94 |
+
"name": "Techie Customer",
|
95 |
+
"characteristics": ["tech-savvy", "impatient", "detail-oriented"],
|
96 |
+
"message_templates": [
|
97 |
+
"I've tried rebooting my device, but the error persists. Can you help?",
|
98 |
+
"What's the latest update on the cloud service outage?",
|
99 |
+
"I need specifics on the API rate limits for the enterprise plan."
|
100 |
+
]
|
101 |
+
}
|
102 |
+
|
103 |
+
custom_conversation = simulator.run_custom_simulation(**custom_persona, num_turns=3)
|
104 |
+
```
|
105 |
+
|
106 |
+
### Workflow Simulation
|
107 |
+
|
108 |
+
Automate UI workflows by learning from video demonstrations.
|
109 |
+
|
110 |
+
```python
|
111 |
+
from isopro.workflow_simulation import WorkflowAutomation, AgentConfig
|
112 |
+
|
113 |
+
# Basic workflow automation
|
114 |
+
automation = WorkflowAutomation(
|
115 |
+
video="path/to/workflow.mp4",
|
116 |
+
config="config.json",
|
117 |
+
output="output_dir",
|
118 |
+
logs="logs_dir"
|
119 |
+
)
|
120 |
+
automation.run()
|
121 |
+
|
122 |
+
# Advanced configuration
|
123 |
+
agent_config = AgentConfig(
|
124 |
+
learning_rate=3e-4,
|
125 |
+
pretrain_epochs=10,
|
126 |
+
use_demonstration=True,
|
127 |
+
use_reasoning=True
|
128 |
+
)
|
129 |
+
|
130 |
+
simulator = WorkflowSimulator(
|
131 |
+
video_path="path/to/video.mp4",
|
132 |
+
agent_config=agent_config,
|
133 |
+
viz_config=visualization_config,
|
134 |
+
validation_config=validation_config,
|
135 |
+
output_dir="output"
|
136 |
+
)
|
137 |
+
|
138 |
+
training_results = simulator.train_agents()
|
139 |
+
evaluation_results = simulator.evaluate_agents()
|
140 |
+
```
|
141 |
+
|
142 |
+
### Car Reinforcement Learning
|
143 |
+
|
144 |
+
Train and evaluate RL agents in driving scenarios.
|
145 |
+
|
146 |
+
```python
|
147 |
+
from isopro.car_simulation import CarRLEnvironment, LLMCarRLWrapper, CarVisualization
|
148 |
+
|
149 |
+
# Create the car environment with LLM integration
|
150 |
+
env = CarRLEnvironment()
|
151 |
+
llm_env = LLMCarRLWrapper(env)
|
152 |
+
|
153 |
+
# Initialize visualization
|
154 |
+
viz = CarVisualization(env)
|
155 |
+
|
156 |
+
# Train and visualize
|
157 |
+
observation = llm_env.reset()
|
158 |
+
for step in range(1000):
|
159 |
+
action = llm_env.get_action(observation)
|
160 |
+
observation, reward, done, info = llm_env.step(action)
|
161 |
+
viz.render(observation)
|
162 |
+
|
163 |
+
if done:
|
164 |
+
observation = llm_env.reset()
|
165 |
+
```
|
166 |
+
|
167 |
+
### Reinforcement Learning with LLM
|
168 |
+
|
169 |
+
Integrate Large Language Models with reinforcement learning environments.
|
170 |
+
|
171 |
+
```python
|
172 |
+
import gymnasium as gym
|
173 |
+
from isopro.rl.rl_agent import RLAgent
|
174 |
+
from isopro.rl.rl_environment import LLMRLEnvironment
|
175 |
+
from stable_baselines3 import PPO
|
176 |
+
from isopro.rl.llm_cartpole_wrapper import LLMCartPoleWrapper
|
177 |
+
|
178 |
+
agent_prompt = """You are an AI trained to play the CartPole game.
|
179 |
+
Your goal is to balance a pole on a moving cart for as long as possible.
|
180 |
+
You will receive observations about the cart's position, velocity, pole angle, and angular velocity.
|
181 |
+
Based on these, you should decide whether to move the cart left or right."""
|
182 |
+
|
183 |
+
env = LLMCartPoleWrapper(agent_prompt, llm_call_limit=100, api_key=os.getenv("ANTHROPIC_API_KEY"))
|
184 |
+
rl_agent = RLAgent("LLM_CartPole_Agent", env, algorithm='PPO')
|
185 |
+
|
186 |
+
# Train the model
|
187 |
+
model.learn(total_timesteps=2)
|
188 |
+
|
189 |
+
# Test the model
|
190 |
+
obs, _ = env.reset()
|
191 |
+
for _ in range(1000):
|
192 |
+
action, _ = model.predict(obs, deterministic=True)
|
193 |
+
obs, reward, done, _, _ = env.step(action)
|
194 |
+
if done:
|
195 |
+
obs, _ = env.reset()
|
196 |
+
```
|
197 |
+
|
198 |
+
### AI Orchestration
|
199 |
+
|
200 |
+
Orchestrate multiple AI agents to work together on complex tasks.
|
201 |
+
|
202 |
+
```python
|
203 |
+
from isopro.orchestration_simulation import OrchestrationEnv
|
204 |
+
from isopro.orchestration_simulation.components import LLaMAAgent, AnalysisAgent, WritingAgent
|
205 |
+
from isopro.orchestration_simulation.evaluator import Evaluator
|
206 |
+
|
207 |
+
# Create the orchestration environment
|
208 |
+
env = OrchestrationEnv()
|
209 |
+
|
210 |
+
# Add agents to the environment
|
211 |
+
env.add_component(LLaMAAgent("Research", "conduct thorough research on the impact of artificial intelligence on job markets"))
|
212 |
+
env.add_component(AnalysisAgent("Analysis"))
|
213 |
+
env.add_component(WritingAgent("Writing"))
|
214 |
+
|
215 |
+
# Define the task
|
216 |
+
task = "Prepare a comprehensive report on the impact of artificial intelligence on job markets in the next decade."
|
217 |
+
|
218 |
+
# Run simulations in different modes
|
219 |
+
modes = ['parallel', 'sequence', 'node']
|
220 |
+
results = {}
|
221 |
+
|
222 |
+
for mode in modes:
|
223 |
+
result = env.run_simulation(mode=mode, input_data={'task': task, 'run_order': 'first'})
|
224 |
+
results[mode] = result
|
225 |
+
|
226 |
+
# Evaluate the results
|
227 |
+
evaluator = Evaluator()
|
228 |
+
best_mode = evaluator.evaluate(results)
|
229 |
+
print(f"The best execution mode for this task was: {best_mode}")
|
230 |
+
```
|
231 |
+
|
232 |
+
## Documentation
|
233 |
+
|
234 |
+
For more detailed information on each module and its usage, please refer to the [full documentation](https://isopro.readthedocs.io).
|
235 |
+
|
236 |
+
## Examples
|
237 |
+
|
238 |
+
The [isopro examples](https://github.com/iso-ai/isopro_examples) repository contains Jupyter notebooks with detailed examples:
|
239 |
+
|
240 |
+
- `adversarial_example.ipynb`: Demonstrates adversarial testing of language models
|
241 |
+
- `conversation_simulation_example.ipynb`: Shows how to simulate conversations with various user personas
|
242 |
+
- `workflow_automation_example.ipynb`: Illustrates automated UI workflow learning
|
243 |
+
- `car_rl_example.ipynb`: Demonstrates car environment training scenarios
|
244 |
+
- `run_cartpole_example.ipynb`: Illustrates the integration of LLMs with reinforcement learning
|
245 |
+
- `orchestrator_example.ipynb`: Provides a tutorial on using the AI orchestration capabilities
|
246 |
+
|
247 |
+
## Contributing
|
248 |
+
|
249 |
+
We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for more details.
|
250 |
+
|
251 |
+
## License
|
252 |
+
|
253 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
254 |
+
|
255 |
+
## Support
|
256 |
+
|
257 |
+
If you encounter any problems or have any questions, please [open an issue](https://github.com/iso-ai/isopro/issues) on our GitHub repository.
|
258 |
+
|
259 |
+
## Citation
|
260 |
+
|
261 |
+
If you use ISOPRO in your research, please cite it as follows:
|
262 |
+
|
263 |
+
```
|
264 |
+
@software{isopro2024,
|
265 |
+
author = {Jazmia Henry},
|
266 |
+
title = {ISOPRO: Intelligent Simulation Orchestration for Large Language Models},
|
267 |
+
year = {2024},
|
268 |
+
publisher = {GitHub},
|
269 |
+
journal = {GitHub repository},
|
270 |
+
howpublished = {\url{https://github.com/iso-ai/isopro}}
|
271 |
+
}
|
272 |
+
```
|
273 |
+
|
274 |
+
## Contact
|
275 |
+
|
276 |
+
For questions or support, please open an issue on our [GitHub issue tracker](https://github.com/iso-ai/isopro/issues).
|
isopro/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
isopro/__init__.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# isopro/__init__.py
|
2 |
+
|
3 |
+
"""
|
4 |
+
isopro: Intelligent Simulation Orchestration for LLMs
|
5 |
+
|
6 |
+
This package provides tools for creating, managing, and analyzing simulations
|
7 |
+
involving Large Language Models (LLMs), including reinforcement learning,
|
8 |
+
conversation simulations, and adversarial testing.
|
9 |
+
"""
|
10 |
+
|
11 |
+
__version__ = "0.1.5"
|
12 |
+
|
13 |
+
# Core components
|
14 |
+
from .environments.simulation_environment import SimulationEnvironment
|
15 |
+
from .environments.custom_environment import CustomEnvironment
|
16 |
+
from .environments.llm_orchestrator import LLMOrchestrator
|
17 |
+
from .agents.ai_agent import AI_Agent
|
18 |
+
from .base.base_component import BaseComponent
|
19 |
+
from .wrappers.simulation_wrapper import SimulationWrapper
|
20 |
+
from .rl.rl_environment import BaseRLEnvironment
|
21 |
+
from .rl.rl_agent import RLAgent
|
22 |
+
from .conversation_simulation import ConversationSimulator, ConversationEnvironment, ConversationAgent
|
23 |
+
from .adversarial_simulation import AdversarialSimulator, AdversarialEnvironment, AdversarialAgent
|
24 |
+
from .orchestration_simulation import LLaMAAgent, SubAgent, OrchestrationEnv, AI_AgentException, ComponentException, AI_Agent
|
25 |
+
|
26 |
+
# Workflow simulation components
|
27 |
+
from .workflow_simulation import (
|
28 |
+
WorkflowSimulator,
|
29 |
+
WorkflowEnvironment,
|
30 |
+
WorkflowState,
|
31 |
+
UIElement,
|
32 |
+
UIElementDetector,
|
33 |
+
MotionDetector,
|
34 |
+
EpisodeMetrics,
|
35 |
+
AgentConfig,
|
36 |
+
VisualizationConfig,
|
37 |
+
ValidationConfig,
|
38 |
+
WorkflowAutomation
|
39 |
+
)
|
40 |
+
|
41 |
+
# Car RL components
|
42 |
+
from .car_simulator import CarRLEnvironment, LLMCarRLWrapper, CarVisualization
|
43 |
+
|
44 |
+
__all__ = [
|
45 |
+
# Core components
|
46 |
+
"LLaMAAgent",
|
47 |
+
"SubAgent",
|
48 |
+
"OrchestrationEnv",
|
49 |
+
"AI_AgentException",
|
50 |
+
"ComponentException",
|
51 |
+
"AI_Agent",
|
52 |
+
"SimulationEnvironment",
|
53 |
+
"CustomEnvironment",
|
54 |
+
"LLMOrchestrator",
|
55 |
+
"AI_Agent",
|
56 |
+
"BaseComponent",
|
57 |
+
"SimulationWrapper",
|
58 |
+
"BaseRLEnvironment",
|
59 |
+
"RLAgent",
|
60 |
+
"ConversationSimulator",
|
61 |
+
"ConversationEnvironment",
|
62 |
+
"ConversationAgent",
|
63 |
+
"AdversarialSimulator",
|
64 |
+
"AdversarialEnvironment",
|
65 |
+
"AdversarialAgent",
|
66 |
+
|
67 |
+
# Workflow components
|
68 |
+
"WorkflowSimulator",
|
69 |
+
"WorkflowEnvironment",
|
70 |
+
"WorkflowState",
|
71 |
+
"UIElement",
|
72 |
+
"UIElementDetector",
|
73 |
+
"MotionDetector",
|
74 |
+
"EpisodeMetrics",
|
75 |
+
"AgentConfig",
|
76 |
+
"VisualizationConfig",
|
77 |
+
"ValidationConfig",
|
78 |
+
"WorkflowAutomation",
|
79 |
+
|
80 |
+
# Car RL components
|
81 |
+
"CarRLEnvironment",
|
82 |
+
"LLMCarRLWrapper",
|
83 |
+
"CarVisualization"
|
84 |
+
]
|
isopro/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.56 kB). View file
|
|
isopro/adversarial_simulation/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Simulation Module
|
3 |
+
|
4 |
+
This module provides tools for simulating adversarial attacks on AI models.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .adversarial_environment import AdversarialEnvironment
|
8 |
+
from .adversarial_agent import AdversarialAgent
|
9 |
+
from .adversarial_simulator import AdversarialSimulator
|
10 |
+
from .attack_utils import get_available_attacks, create_attack
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"AdversarialEnvironment",
|
14 |
+
"AdversarialAgent",
|
15 |
+
"AdversarialSimulator",
|
16 |
+
"get_available_attacks",
|
17 |
+
"create_attack",
|
18 |
+
]
|
isopro/adversarial_simulation/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (621 Bytes). View file
|
|
isopro/adversarial_simulation/__pycache__/adversarial_agent.cpython-38.pyc
ADDED
Binary file (1.87 kB). View file
|
|
isopro/adversarial_simulation/__pycache__/adversarial_environment.cpython-38.pyc
ADDED
Binary file (4.88 kB). View file
|
|
isopro/adversarial_simulation/__pycache__/adversarial_envrionment.cpython-38.pyc
ADDED
Binary file (4.88 kB). View file
|
|
isopro/adversarial_simulation/__pycache__/adversarial_simulator.cpython-38.pyc
ADDED
Binary file (2.48 kB). View file
|
|
isopro/adversarial_simulation/__pycache__/attack_utils.cpython-38.pyc
ADDED
Binary file (2.85 kB). View file
|
|
isopro/adversarial_simulation/adversarial_agent.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Agent
|
3 |
+
|
4 |
+
This module defines the AdversarialAgent class, which can apply various attacks to input or output text.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from typing import Dict, Any
|
8 |
+
from isopro.agents.ai_agent import AI_Agent
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class AdversarialAgent(AI_Agent):
|
14 |
+
def __init__(self, name: str, attack, target: str = "input"):
|
15 |
+
"""
|
16 |
+
Initialize the AdversarialAgent.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
name (str): The name of the agent.
|
20 |
+
attack (callable): The attack function to apply.
|
21 |
+
target (str): The target of the attack, either "input" or "output".
|
22 |
+
"""
|
23 |
+
super().__init__(name)
|
24 |
+
self.attack = attack
|
25 |
+
self.target = target
|
26 |
+
logger.info(f"Initialized AdversarialAgent '{name}' targeting {target}")
|
27 |
+
|
28 |
+
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
29 |
+
"""
|
30 |
+
Apply the adversarial attack to the input or output data.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
input_data (Dict[str, Any]): The input data containing 'text' and 'output' keys.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Dict[str, Any]: The perturbed data.
|
37 |
+
"""
|
38 |
+
logger.info(f"Running adversarial agent: {self.name}")
|
39 |
+
if self.target == "input":
|
40 |
+
if input_data.get('text'):
|
41 |
+
input_data['text'] = self.attack(input_data['text'])
|
42 |
+
else:
|
43 |
+
logger.warning("Input text is empty or missing. Skipping attack.")
|
44 |
+
elif self.target == "output":
|
45 |
+
if input_data.get('output'):
|
46 |
+
input_data['output'] = self.attack(input_data['output'])
|
47 |
+
else:
|
48 |
+
logger.warning("Output text is empty or missing. Skipping attack.")
|
49 |
+
else:
|
50 |
+
raise ValueError(f"Invalid target: {self.target}")
|
51 |
+
return input_data
|
isopro/adversarial_simulation/adversarial_environment.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Environment
|
3 |
+
|
4 |
+
This module defines the AdversarialEnvironment class, which manages adversarial agents and applies attacks to the simulation state.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import random
|
8 |
+
from typing import List, Dict, Any
|
9 |
+
from isopro.environments.simulation_environment import SimulationEnvironment
|
10 |
+
from .adversarial_agent import AdversarialAgent
|
11 |
+
from .attack_utils import get_model_and_tokenizer, create_attack, get_available_attacks
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class AdversarialEnvironment(SimulationEnvironment):
|
17 |
+
def __init__(self, agent_wrapper, num_adversarial_agents: int = 1, attack_types: List[str] = None, attack_targets: List[str] = None):
|
18 |
+
"""
|
19 |
+
Initialize the AdversarialEnvironment.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
agent_wrapper: The wrapped agent to pass the adversarially modified state to.
|
23 |
+
num_adversarial_agents (int): The number of adversarial agents to create.
|
24 |
+
attack_types (List[str], optional): The types of attacks to use. If None, all available attacks will be used.
|
25 |
+
attack_targets (List[str], optional): The targets for the attacks ("input", "output", or both). If None, both will be used.
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
self.agent_wrapper = agent_wrapper
|
29 |
+
self.num_adversarial_agents = num_adversarial_agents
|
30 |
+
self.attack_types = attack_types or get_available_attacks()
|
31 |
+
self.attack_targets = attack_targets or ["input", "output"]
|
32 |
+
self.model, self.tokenizer = get_model_and_tokenizer()
|
33 |
+
self._create_adversarial_agents()
|
34 |
+
logger.info(f"Initialized AdversarialEnvironment with {num_adversarial_agents} agents")
|
35 |
+
|
36 |
+
def _create_adversarial_agents(self):
|
37 |
+
"""Create adversarial agents with random attack types and targets."""
|
38 |
+
for i in range(self.num_adversarial_agents):
|
39 |
+
attack_type = random.choice(self.attack_types)
|
40 |
+
attack_target = random.choice(self.attack_targets)
|
41 |
+
attack = create_attack(attack_type, self.model, self.tokenizer)
|
42 |
+
agent = AdversarialAgent(name=f"Adversarial Agent {i+1} ({attack_type}, {attack_target})", attack=attack, target=attack_target)
|
43 |
+
self.add_agent(agent)
|
44 |
+
logger.info(f"Created {self.num_adversarial_agents} adversarial agents")
|
45 |
+
|
46 |
+
def step(self, sim_state: Dict[str, Any]) -> Dict[str, Any]:
|
47 |
+
"""
|
48 |
+
Apply adversarial attacks and step the environment.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
sim_state (Dict[str, Any]): The current simulation state.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Dict[str, Any]: The updated simulation state after applying attacks and stepping the wrapped agent.
|
55 |
+
"""
|
56 |
+
# Apply adversarial attacks
|
57 |
+
for agent in self.agents:
|
58 |
+
sim_state = agent.run(sim_state)
|
59 |
+
|
60 |
+
# Pass the adversarially modified state to the wrapped agent
|
61 |
+
return self.agent_wrapper.step(sim_state)
|
62 |
+
|
63 |
+
def reset(self):
|
64 |
+
"""Reset the environment and recreate adversarial agents."""
|
65 |
+
super().reset()
|
66 |
+
self._create_adversarial_agents()
|
67 |
+
logger.info("Reset AdversarialEnvironment and recreated agents")
|
68 |
+
|
69 |
+
def get_attack_distribution(self) -> Dict[str, int]:
|
70 |
+
"""
|
71 |
+
Get the distribution of attack types and targets among the adversarial agents.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
Dict[str, int]: A dictionary containing the count of each attack type and target.
|
75 |
+
"""
|
76 |
+
attack_counts = {f"{attack_type}_{target}": 0 for attack_type in self.attack_types for target in self.attack_targets}
|
77 |
+
for agent in self.agents:
|
78 |
+
attack_type, target = agent.name.split('(')[-1].split(')')[0].split(', ')
|
79 |
+
attack_counts[f"{attack_type}_{target}"] += 1
|
80 |
+
logger.info(f"Current attack distribution: {attack_counts}")
|
81 |
+
return attack_counts
|
isopro/adversarial_simulation/adversarial_simulator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adversarial Simulator
|
3 |
+
|
4 |
+
This module provides a high-level interface for running adversarial simulations.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from typing import List, Dict, Any
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class AdversarialSimulator:
|
13 |
+
def __init__(self, environment):
|
14 |
+
"""
|
15 |
+
Initialize the AdversarialSimulator.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
environment: The AdversarialEnvironment to use in the simulation.
|
19 |
+
"""
|
20 |
+
self.environment = environment
|
21 |
+
logger.info("Initialized AdversarialSimulator")
|
22 |
+
|
23 |
+
def run_simulation(self, input_data: List[str], num_steps: int = 1) -> List[Dict[str, Any]]:
|
24 |
+
"""
|
25 |
+
Run an adversarial simulation.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
input_data (List[str]): The list of input texts to use in the simulation.
|
29 |
+
num_steps (int): The number of steps to run the simulation for each input.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
List[Dict[str, Any]]: A list of simulation results, including original and perturbed inputs and outputs.
|
33 |
+
"""
|
34 |
+
results = []
|
35 |
+
for text in input_data:
|
36 |
+
sim_state = {"text": text, "output": ""}
|
37 |
+
original_output = self.environment.agent_wrapper.run({"text": text})
|
38 |
+
for _ in range(num_steps):
|
39 |
+
sim_state = self.environment.step(sim_state)
|
40 |
+
results.append({
|
41 |
+
"original_input": text,
|
42 |
+
"perturbed_input": sim_state["text"],
|
43 |
+
"original_output": original_output,
|
44 |
+
"perturbed_output": sim_state["output"]
|
45 |
+
})
|
46 |
+
logger.info(f"Completed simulation with {len(input_data)} inputs and {num_steps} steps each")
|
47 |
+
return results
|
isopro/adversarial_simulation/attack_utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Attack Utilities
|
3 |
+
|
4 |
+
This module provides utility functions for creating and managing adversarial attacks.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from typing import Tuple, Callable
|
9 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
10 |
+
from isoadverse.attacks.text_fgsm import text_fgsm_attack
|
11 |
+
from isoadverse.attacks.text_pgd import text_pgd_attack
|
12 |
+
from isoadverse.attacks.textbugger import textbugger_attack
|
13 |
+
from isoadverse.attacks.deepwordbug import deepwordbug_attack
|
14 |
+
import logging
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
def get_model_and_tokenizer(model_name: str = 'bert-base-uncased') -> Tuple[torch.nn.Module, torch.nn.Module]:
|
19 |
+
"""
|
20 |
+
Load a pre-trained model and tokenizer.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
model_name (str): The name of the model to load.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Tuple[torch.nn.Module, torch.nn.Module]: The loaded model and tokenizer.
|
27 |
+
"""
|
28 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
model.to(device)
|
32 |
+
logger.info(f"Loaded model {model_name} on {device}")
|
33 |
+
return model, tokenizer
|
34 |
+
|
35 |
+
def create_attack(attack_type: str, model: torch.nn.Module, tokenizer: torch.nn.Module) -> Callable:
|
36 |
+
"""
|
37 |
+
Create an attack function based on the specified attack type.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
attack_type (str): The type of attack to create.
|
41 |
+
model (torch.nn.Module): The model to use for the attack.
|
42 |
+
tokenizer (torch.nn.Module): The tokenizer to use for the attack.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Callable: The attack function.
|
46 |
+
"""
|
47 |
+
if attack_type == "fgsm":
|
48 |
+
return lambda x: text_fgsm_attack(model, tokenizer, x, torch.tensor([1]), epsilon=0.3)
|
49 |
+
elif attack_type == "pgd":
|
50 |
+
return lambda x: text_pgd_attack(model, tokenizer, x, torch.tensor([1]), epsilon=0.3, alpha=0.1, num_steps=10)
|
51 |
+
elif attack_type == "textbugger":
|
52 |
+
return lambda x: textbugger_attack(x, num_bugs=5)
|
53 |
+
elif attack_type == "deepwordbug":
|
54 |
+
return lambda x: deepwordbug_attack(x, num_bugs=5)
|
55 |
+
else:
|
56 |
+
raise ValueError(f"Unknown attack type: {attack_type}")
|
57 |
+
|
58 |
+
def get_available_attacks() -> list:
|
59 |
+
"""
|
60 |
+
Get a list of available attack types.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
list: A list of available attack types.
|
64 |
+
"""
|
65 |
+
return ["fgsm", "pgd", "textbugger", "deepwordbug"]
|
isopro/adversarial_simulation/main.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List
|
3 |
+
from .adversarial_simulator import AdversarialSimulator
|
4 |
+
from .adversarial_environment import AdversarialEnvironment
|
5 |
+
from isopro.utils.analyze_adversarial_sim import analyze_adversarial_results, summarize_adversarial_impact
|
6 |
+
from isopro.agents.ai_agent import AI_Agent
|
7 |
+
import anthropic
|
8 |
+
import os
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
import json
|
11 |
+
from datetime import datetime
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
# Set up logging
|
18 |
+
logging.basicConfig(level=logging.INFO)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
class ClaudeAgent(AI_Agent):
|
22 |
+
def __init__(self, name):
|
23 |
+
super().__init__(name)
|
24 |
+
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
25 |
+
|
26 |
+
def run(self, input_data):
|
27 |
+
response = self.client.messages.create(
|
28 |
+
model="claude-3-opus-20240229",
|
29 |
+
max_tokens=100,
|
30 |
+
messages=[{"role": "user", "content": input_data['text']}]
|
31 |
+
)
|
32 |
+
return response.content[0].text
|
33 |
+
|
34 |
+
def step(self, sim_state):
|
35 |
+
sim_state['output'] = self.run(sim_state)
|
36 |
+
return sim_state
|
37 |
+
|
38 |
+
class NumpyEncoder(json.JSONEncoder):
|
39 |
+
def default(self, obj):
|
40 |
+
if isinstance(obj, np.floating):
|
41 |
+
return float(obj)
|
42 |
+
if isinstance(obj, torch.Tensor):
|
43 |
+
return obj.tolist()
|
44 |
+
return super(NumpyEncoder, self).default(obj)
|
45 |
+
|
46 |
+
def setup_logging(log_dir: str, run_id: str) -> None:
|
47 |
+
os.makedirs(log_dir, exist_ok=True)
|
48 |
+
log_file = os.path.join(log_dir, f"adv-{run_id}.log")
|
49 |
+
|
50 |
+
logging.basicConfig(
|
51 |
+
level=logging.INFO,
|
52 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
53 |
+
handlers=[
|
54 |
+
logging.FileHandler(log_file),
|
55 |
+
logging.StreamHandler()
|
56 |
+
]
|
57 |
+
)
|
58 |
+
|
59 |
+
def save_scores(output_dir: str, run_id: str, analysis_results: dict) -> None:
|
60 |
+
os.makedirs(output_dir, exist_ok=True)
|
61 |
+
output_file = os.path.join(output_dir, f"adv-{run_id}.json")
|
62 |
+
|
63 |
+
try:
|
64 |
+
with open(output_file, 'w') as f:
|
65 |
+
json.dump(analysis_results, f, indent=2, cls=NumpyEncoder)
|
66 |
+
|
67 |
+
logging.info(f"Saved analysis results to {output_file}")
|
68 |
+
except Exception as e:
|
69 |
+
logging.error(f"Error saving analysis results: {str(e)}")
|
70 |
+
|
71 |
+
def get_sample_inputs() -> List[str]:
|
72 |
+
return [
|
73 |
+
"What is the capital of France?",
|
74 |
+
"How does photosynthesis work?",
|
75 |
+
"Explain the theory of relativity."
|
76 |
+
]
|
77 |
+
|
78 |
+
def main():
|
79 |
+
try:
|
80 |
+
run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
|
81 |
+
|
82 |
+
log_dir = "logs"
|
83 |
+
setup_logging(log_dir, run_id)
|
84 |
+
|
85 |
+
logger = logging.getLogger(__name__)
|
86 |
+
logger.info(f"Starting adversarial simulation run {run_id}")
|
87 |
+
|
88 |
+
claude_agent = ClaudeAgent("Claude Agent")
|
89 |
+
|
90 |
+
# Create the AdversarialEnvironment
|
91 |
+
adv_env = AdversarialEnvironment(
|
92 |
+
agent_wrapper=claude_agent,
|
93 |
+
num_adversarial_agents=2,
|
94 |
+
attack_types=["textbugger", "deepwordbug"],
|
95 |
+
attack_targets=["input", "output"]
|
96 |
+
)
|
97 |
+
|
98 |
+
# Set up the adversarial simulator with the environment
|
99 |
+
simulator = AdversarialSimulator(adv_env)
|
100 |
+
|
101 |
+
input_data = get_sample_inputs()
|
102 |
+
|
103 |
+
logger.info("Starting adversarial simulation...")
|
104 |
+
simulation_results = simulator.run_simulation(input_data, num_steps=1)
|
105 |
+
|
106 |
+
logger.info("Analyzing simulation results...")
|
107 |
+
analysis_results = analyze_adversarial_results(simulation_results)
|
108 |
+
|
109 |
+
summary = summarize_adversarial_impact(analysis_results)
|
110 |
+
|
111 |
+
print("\nAdversarial Simulation Summary:")
|
112 |
+
print(summary)
|
113 |
+
|
114 |
+
output_dir = "output"
|
115 |
+
save_scores(output_dir, run_id, analysis_results)
|
116 |
+
|
117 |
+
logger.info("Simulation complete.")
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"An error occurred during the simulation: {str(e)}", exc_info=True)
|
121 |
+
raise
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
main()
|
isopro/agents/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Agent classes for the isopro package.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .ai_agent import AI_Agent
|
6 |
+
|
7 |
+
__all__ = ["AI_Agent"]
|
isopro/agents/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (263 Bytes). View file
|
|
isopro/agents/__pycache__/ai_agent.cpython-38.pyc
ADDED
Binary file (1.62 kB). View file
|
|
isopro/agents/ai_agent.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""AI Agent for Simulation Environment."""
|
2 |
+
from ..base.base_component import BaseComponent, agent_component
|
3 |
+
|
4 |
+
@agent_component
|
5 |
+
class AI_Agent(BaseComponent):
|
6 |
+
"""AI Agent for Simulation Environment."""
|
7 |
+
|
8 |
+
def __init__(self, name):
|
9 |
+
"""
|
10 |
+
Initialize the AI_Agent.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
name (str): The name of the agent.
|
14 |
+
"""
|
15 |
+
super().__init__(name)
|
16 |
+
self.components = []
|
17 |
+
|
18 |
+
def add_component(self, component):
|
19 |
+
"""
|
20 |
+
Add a component to the agent.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
component (BaseComponent): The component to add.
|
24 |
+
"""
|
25 |
+
if getattr(component, '_is_agent_component', False):
|
26 |
+
self.components.append(component)
|
27 |
+
else:
|
28 |
+
raise ValueError(f"Component {component} is not decorated with @agent_component")
|
29 |
+
|
30 |
+
def run(self, input_data):
|
31 |
+
"""
|
32 |
+
Run the agent's components and process input data.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
input_data (dict): The input data for the agent.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
dict: The processed output data.
|
39 |
+
"""
|
40 |
+
self.logger.info(f"Running agent: {self.name}")
|
41 |
+
output = input_data
|
42 |
+
for component in self.components:
|
43 |
+
output = component.run(output)
|
44 |
+
return output
|
isopro/base/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Base classes for the isopro package.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .base_wrapper import BaseWrapper
|
6 |
+
from .base_component import BaseComponent
|
7 |
+
|
8 |
+
__all__ = ["BaseWrapper", "BaseComponent"]
|
isopro/base/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (326 Bytes). View file
|
|
isopro/base/__pycache__/base_component.cpython-38.pyc
ADDED
Binary file (1.44 kB). View file
|
|
isopro/base/__pycache__/base_wrapper.cpython-38.pyc
ADDED
Binary file (2.86 kB). View file
|
|
isopro/base/base_component.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base Component for Simulation Environment."""
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from ..utils.logging_utils import setup_logger
|
4 |
+
|
5 |
+
class BaseComponent(ABC):
|
6 |
+
"""Base Component for Simulation Environment."""
|
7 |
+
|
8 |
+
def __init__(self, name):
|
9 |
+
"""
|
10 |
+
Initialize the BaseComponent.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
name (str): The name of the component.
|
14 |
+
"""
|
15 |
+
self.name = name
|
16 |
+
self.logger = setup_logger(f"{self.__class__.__name__}_{self.name}")
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def run(self):
|
20 |
+
"""Execute the component's main functionality."""
|
21 |
+
pass
|
22 |
+
|
23 |
+
def __str__(self):
|
24 |
+
return f"{self.__class__.__name__}({self.name})"
|
25 |
+
|
26 |
+
def agent_component(cls):
|
27 |
+
"""
|
28 |
+
Decorator to mark a class as an agent component.
|
29 |
+
|
30 |
+
This decorator can be used to add metadata or perform
|
31 |
+
additional setup for agent components.
|
32 |
+
"""
|
33 |
+
cls._is_agent_component = True
|
34 |
+
return cls
|
isopro/base/base_wrapper.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base Wrapper for Simulation Environment."""
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
import logging
|
4 |
+
from ..utils.logging_utils import setup_logger
|
5 |
+
|
6 |
+
class BaseWrapper(ABC):
|
7 |
+
"""Base Wrapper for Simulation Environment."""
|
8 |
+
|
9 |
+
def __init__(self, agent):
|
10 |
+
"""
|
11 |
+
Initialize the BaseWrapper.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
agent: The agent to be wrapped.
|
15 |
+
"""
|
16 |
+
self.agent = agent
|
17 |
+
self.logger = setup_logger(self.__class__.__name__)
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def step(self):
|
21 |
+
"""Execute one time step within the environment."""
|
22 |
+
pass
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def reset(self):
|
26 |
+
"""Reset the state of the environment to an initial state."""
|
27 |
+
pass
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def render(self):
|
31 |
+
"""Render the environment."""
|
32 |
+
pass
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def close(self):
|
36 |
+
"""Close the environment, clean up any resources."""
|
37 |
+
pass
|
38 |
+
|
39 |
+
@abstractmethod
|
40 |
+
def convert_to_agent_input(self, sim_state):
|
41 |
+
"""
|
42 |
+
Convert simulation state to agent input format.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
sim_state (dict): The current state of the simulation.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
dict: The converted input for the agent.
|
49 |
+
"""
|
50 |
+
pass
|
51 |
+
|
52 |
+
@abstractmethod
|
53 |
+
def convert_from_agent_output(self, agent_output):
|
54 |
+
"""
|
55 |
+
Convert agent output to simulation input format.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
agent_output (dict): The output from the agent.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
dict: The converted input for the simulation.
|
62 |
+
"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
def __getattr__(self, name):
|
66 |
+
"""
|
67 |
+
Attempt to get an attribute from the agent if it's not found in the wrapper.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
name (str): The name of the attribute.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
The requested attribute.
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
AttributeError: If the attribute is not found in the agent or wrapper.
|
77 |
+
"""
|
78 |
+
try:
|
79 |
+
return getattr(self.agent, name)
|
80 |
+
except AttributeError:
|
81 |
+
self.logger.warning(f"Attribute '{name}' not found in agent or wrapper")
|
82 |
+
raise
|
isopro/car_simulator/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Car Reinforcement Learning Package
|
3 |
+
|
4 |
+
This package contains modules for simulating and visualizing
|
5 |
+
reinforcement learning agents in a car driving environment.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from .car_rl_environment import CarRLEnvironment
|
9 |
+
from .car_llm_agent import LLMCarRLWrapper
|
10 |
+
from .carviz import CarVisualization
|
11 |
+
|
12 |
+
__all__ = ['CarRLEnvironment', 'LLMCarRLWrapper', 'CarVisualization']
|
isopro/car_simulator/car_llm_agent.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gymnasium as gym
|
2 |
+
from stable_baselines3 import PPO
|
3 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
4 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
5 |
+
import numpy as np
|
6 |
+
import anthropic
|
7 |
+
import logging
|
8 |
+
from typing import List, Dict, Any
|
9 |
+
from .car_rl_environment import CarRLEnvironment
|
10 |
+
import os
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
# Load environment variables from .env file
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
# Set up logging
|
17 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
class LLMCarRLWrapper(CarRLEnvironment):
|
21 |
+
def __init__(self, num_cars=1, time_of_day="12:00", is_rainy=False, is_weekday=True,
|
22 |
+
agent_prompt="You are an expert driving instructor. Provide concise guidance to improve the RL agent's driving performance.",
|
23 |
+
llm_call_limit=100, llm_call_frequency=100):
|
24 |
+
super().__init__(num_cars, time_of_day, is_rainy, is_weekday)
|
25 |
+
self.agent_prompt = agent_prompt
|
26 |
+
api_key = os.getenv('ANTHROPIC_API_KEY')
|
27 |
+
if not api_key:
|
28 |
+
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
|
29 |
+
self.client = anthropic.Anthropic(api_key=api_key)
|
30 |
+
self.llm_call_count = 0
|
31 |
+
self.llm_call_limit = llm_call_limit
|
32 |
+
self.llm_call_frequency = llm_call_frequency
|
33 |
+
self.conversation_history: List[Dict[str, str]] = []
|
34 |
+
self.step_count = 0
|
35 |
+
self.current_guidance = {"action": "unknown"}
|
36 |
+
|
37 |
+
def reset(self, seed=None, options=None):
|
38 |
+
self.step_count = 0
|
39 |
+
self.current_guidance = {"action": "unknown"}
|
40 |
+
return super().reset(seed=seed)
|
41 |
+
|
42 |
+
def step(self, action):
|
43 |
+
self.step_count += 1
|
44 |
+
|
45 |
+
if self.step_count % self.llm_call_frequency == 0 and self.llm_call_count < self.llm_call_limit:
|
46 |
+
observation, reward, terminated, truncated, info = super().step(action)
|
47 |
+
self.current_guidance = self._get_llm_guidance(observation, reward, terminated)
|
48 |
+
self.llm_call_count += 1
|
49 |
+
else:
|
50 |
+
observation, reward, terminated, truncated, info = super().step(action)
|
51 |
+
|
52 |
+
adjusted_action = self._adjust_action_based_on_guidance(action, self.current_guidance)
|
53 |
+
|
54 |
+
return observation, reward, terminated, truncated, info
|
55 |
+
|
56 |
+
def _get_llm_guidance(self, observation, reward, terminated):
|
57 |
+
user_message = f"Current state: {observation}, Reward: {reward}, Terminated: {terminated}. Provide brief driving advice."
|
58 |
+
|
59 |
+
messages = self.conversation_history + [
|
60 |
+
{"role": "user", "content": user_message},
|
61 |
+
]
|
62 |
+
|
63 |
+
try:
|
64 |
+
response = self.client.messages.create(
|
65 |
+
model="claude-3-opus-20240229",
|
66 |
+
max_tokens=50,
|
67 |
+
system=self.agent_prompt,
|
68 |
+
messages=messages
|
69 |
+
)
|
70 |
+
|
71 |
+
ai_response = response.content[0].text
|
72 |
+
self.conversation_history.append({"role": "user", "content": user_message})
|
73 |
+
self.conversation_history.append({"role": "assistant", "content": ai_response})
|
74 |
+
logger.debug(f"LLM guidance: {ai_response}")
|
75 |
+
return self._parse_llm_guidance(ai_response)
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error getting LLM guidance: {e}")
|
78 |
+
return {"action": "unknown"}
|
79 |
+
|
80 |
+
def _parse_llm_guidance(self, guidance):
|
81 |
+
guidance_lower = guidance.lower()
|
82 |
+
actions = {
|
83 |
+
"increase speed": {"action": "increase_speed"},
|
84 |
+
"decrease speed": {"action": "decrease_speed"},
|
85 |
+
"slow down": {"action": "decrease_speed"},
|
86 |
+
"turn left": {"action": "turn_left"},
|
87 |
+
"turn right": {"action": "turn_right"},
|
88 |
+
"stop": {"action": "stop"},
|
89 |
+
"start raining": {"environment": "rain", "status": True},
|
90 |
+
"increase traffic": {"environment": "traffic", "density": "high"}
|
91 |
+
}
|
92 |
+
|
93 |
+
for key, value in actions.items():
|
94 |
+
if key in guidance_lower:
|
95 |
+
return value
|
96 |
+
|
97 |
+
return {"action": "unknown"}
|
98 |
+
|
99 |
+
def _adjust_action_based_on_guidance(self, action, guidance):
|
100 |
+
adjustments = {
|
101 |
+
"increase_speed": (0, 0.1),
|
102 |
+
"decrease_speed": (0, -0.1),
|
103 |
+
"turn_left": (1, -0.1),
|
104 |
+
"turn_right": (1, 0.1),
|
105 |
+
}
|
106 |
+
|
107 |
+
if guidance["action"] in adjustments:
|
108 |
+
index, adjustment = adjustments[guidance["action"]]
|
109 |
+
action[index] = np.clip(action[index] + adjustment, -1.0, 1.0)
|
110 |
+
|
111 |
+
return action
|
112 |
+
|
113 |
+
def make_env(llm_call_limit):
|
114 |
+
def _init():
|
115 |
+
return LLMCarRLWrapper(num_cars=3, time_of_day="08:00", is_rainy=False, is_weekday=True,
|
116 |
+
llm_call_limit=llm_call_limit)
|
117 |
+
return _init
|
118 |
+
|
119 |
+
def train_and_evaluate(env, total_timesteps=100000, eval_episodes=10):
|
120 |
+
model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.0003, n_steps=2048,
|
121 |
+
batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2)
|
122 |
+
|
123 |
+
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
124 |
+
|
125 |
+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=eval_episodes)
|
126 |
+
logger.info(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
127 |
+
|
128 |
+
return model, mean_reward
|
129 |
+
|
130 |
+
def main():
|
131 |
+
llm_call_limit = int(os.getenv('LLM_CALL_LIMIT', '10')) # Default to 10 if not set
|
132 |
+
|
133 |
+
env = DummyVecEnv([make_env(llm_call_limit)])
|
134 |
+
|
135 |
+
model, mean_reward = train_and_evaluate(env)
|
136 |
+
|
137 |
+
model.save("car_rl_llm_ppo_model")
|
138 |
+
|
139 |
+
logger.info("Training and evaluation completed.")
|
140 |
+
logger.info(f"Final mean reward: {mean_reward:.2f}")
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
main()
|
isopro/car_simulator/car_rl_environment.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gymnasium as gym
|
2 |
+
from gymnasium import spaces
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from typing import List, Dict, Tuple, Union
|
7 |
+
|
8 |
+
class CarRLEnvironment(gym.Env):
|
9 |
+
def __init__(self, num_cars=1, time_of_day="12:00", is_rainy=False, is_weekday=True):
|
10 |
+
super().__init__()
|
11 |
+
self.num_cars = num_cars
|
12 |
+
self.time_of_day = self.convert_time(time_of_day)
|
13 |
+
self.is_rainy = is_rainy
|
14 |
+
self.is_weekday = is_weekday
|
15 |
+
self.friction = 0.4 if is_rainy else 0.8
|
16 |
+
|
17 |
+
# Define action and observation spaces
|
18 |
+
self.action_space = spaces.Box(low=-1, high=1, shape=(num_cars * 2,), dtype=np.float32)
|
19 |
+
|
20 |
+
# Observation space: [x, y, vx, vy, angle] for each car + [time_of_day, is_rainy, is_weekday]
|
21 |
+
self.observation_space = spaces.Box(
|
22 |
+
low=-np.inf,
|
23 |
+
high=np.inf,
|
24 |
+
shape=(num_cars * 5 + 3,),
|
25 |
+
dtype=np.float32
|
26 |
+
)
|
27 |
+
|
28 |
+
self.cars = self.initialize_cars()
|
29 |
+
|
30 |
+
def convert_time(self, time_of_day: Union[str, float]) -> float:
|
31 |
+
"""Convert time to a float between 0 and 24."""
|
32 |
+
if isinstance(time_of_day, str):
|
33 |
+
try:
|
34 |
+
hours, minutes = map(int, time_of_day.split(':'))
|
35 |
+
return float(hours + minutes / 60.0)
|
36 |
+
except ValueError:
|
37 |
+
print(f"Invalid time format: {time_of_day}. Using default value of 12:00.")
|
38 |
+
return 12.0
|
39 |
+
elif isinstance(time_of_day, (int, float)):
|
40 |
+
return float(time_of_day) % 24.0
|
41 |
+
else:
|
42 |
+
print(f"Invalid time format: {time_of_day}. Using default value of 12:00.")
|
43 |
+
return 12.0
|
44 |
+
|
45 |
+
def initialize_cars(self) -> List[Dict[str, torch.Tensor]]:
|
46 |
+
"""Initialize car parameters."""
|
47 |
+
return [
|
48 |
+
{
|
49 |
+
"position": torch.tensor([random.uniform(-1, 1), random.uniform(-1, 1)], dtype=torch.float32),
|
50 |
+
"velocity": torch.tensor([random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)], dtype=torch.float32),
|
51 |
+
"angle": torch.tensor([random.uniform(-np.pi, np.pi)], dtype=torch.float32)
|
52 |
+
} for _ in range(self.num_cars)
|
53 |
+
]
|
54 |
+
|
55 |
+
def reset(self, seed=None) -> Tuple[np.ndarray, Dict]:
|
56 |
+
super().reset(seed=seed)
|
57 |
+
self.cars = self.initialize_cars()
|
58 |
+
return self.get_observation(), {}
|
59 |
+
|
60 |
+
def get_observation(self) -> np.ndarray:
|
61 |
+
"""Get the current observation of the environment."""
|
62 |
+
car_obs = np.concatenate([
|
63 |
+
np.concatenate([
|
64 |
+
car["position"].numpy(),
|
65 |
+
car["velocity"].numpy(),
|
66 |
+
car["angle"].numpy()
|
67 |
+
]) for car in self.cars
|
68 |
+
])
|
69 |
+
env_obs = np.array([
|
70 |
+
self.time_of_day,
|
71 |
+
float(self.is_rainy),
|
72 |
+
float(self.is_weekday)
|
73 |
+
], dtype=np.float32)
|
74 |
+
return np.concatenate([car_obs, env_obs]).astype(np.float32)
|
75 |
+
|
76 |
+
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
77 |
+
"""
|
78 |
+
Take a step in the environment.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
action (np.ndarray): Array of actions for all cars [acceleration1, steering1, acceleration2, steering2, ...]
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
observation, reward, terminated, truncated, info
|
85 |
+
"""
|
86 |
+
# Ensure action is the correct shape
|
87 |
+
action = np.array(action).flatten()
|
88 |
+
if action.shape[0] != self.num_cars * 2:
|
89 |
+
raise ValueError(f"Action shape {action.shape} does not match expected shape ({self.num_cars * 2},)")
|
90 |
+
|
91 |
+
for i in range(self.num_cars):
|
92 |
+
car_action = action[i*2:(i+1)*2]
|
93 |
+
self.apply_action(self.cars[i], car_action)
|
94 |
+
self.update_physics(self.cars[i])
|
95 |
+
|
96 |
+
observation = self.get_observation()
|
97 |
+
reward = self.calculate_reward()
|
98 |
+
terminated = self.is_terminated()
|
99 |
+
truncated = False
|
100 |
+
info = {}
|
101 |
+
|
102 |
+
return observation, reward, terminated, truncated, info
|
103 |
+
|
104 |
+
def apply_action(self, car: Dict[str, torch.Tensor], action: np.ndarray):
|
105 |
+
"""Apply the RL agent's action to the car."""
|
106 |
+
if len(action) != 2:
|
107 |
+
raise ValueError(f"Expected action to have 2 values, got {len(action)}")
|
108 |
+
|
109 |
+
acceleration, steering = action
|
110 |
+
car["velocity"] += torch.tensor([acceleration, 0.0], dtype=torch.float32) * 0.1 # Scale down the acceleration
|
111 |
+
car["angle"] += torch.tensor([steering], dtype=torch.float32) * 0.1 # Scale down the steering
|
112 |
+
|
113 |
+
def update_physics(self, car: Dict[str, torch.Tensor], dt: float = 0.1):
|
114 |
+
"""Update car position and velocity using physics simulation."""
|
115 |
+
# Update velocity (apply friction)
|
116 |
+
car["velocity"] *= (1 - self.friction * dt)
|
117 |
+
|
118 |
+
# Update position
|
119 |
+
car["position"] += car["velocity"] * dt
|
120 |
+
|
121 |
+
# Apply steering
|
122 |
+
angle = car["angle"].item()
|
123 |
+
rotation_matrix = torch.tensor([
|
124 |
+
[np.cos(angle), -np.sin(angle)],
|
125 |
+
[np.sin(angle), np.cos(angle)]
|
126 |
+
], dtype=torch.float32)
|
127 |
+
car["velocity"] = torch.matmul(rotation_matrix, car["velocity"])
|
128 |
+
|
129 |
+
# Bound the position to keep cars on the screen
|
130 |
+
car["position"] = torch.clamp(car["position"], -1, 1)
|
131 |
+
|
132 |
+
def calculate_reward(self) -> float:
|
133 |
+
"""Calculate the reward based on the current state."""
|
134 |
+
reward = 0.0
|
135 |
+
for car in self.cars:
|
136 |
+
# Reward for moving
|
137 |
+
speed = torch.norm(car["velocity"]).item()
|
138 |
+
reward += speed * 0.1
|
139 |
+
|
140 |
+
# Penalty for being close to the edge
|
141 |
+
distance_from_center = torch.norm(car["position"]).item()
|
142 |
+
reward -= distance_from_center * 0.1
|
143 |
+
|
144 |
+
return reward
|
145 |
+
|
146 |
+
def is_terminated(self) -> bool:
|
147 |
+
"""Check if the episode should be terminated."""
|
148 |
+
for car in self.cars:
|
149 |
+
if torch.any(torch.abs(car["position"]) > 1):
|
150 |
+
return True
|
151 |
+
return False
|
152 |
+
|
153 |
+
def render(self):
|
154 |
+
"""Render the environment (placeholder for potential future implementation)."""
|
155 |
+
pass
|
isopro/car_simulator/car_rl_model.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85303b6b7e544f04d04cb949709ee37ac956a78f098c0390e2b210448bc446bb
|
3 |
+
size 164031
|
isopro/car_simulator/car_rl_training.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gymnasium as gym
|
2 |
+
from stable_baselines3 import PPO
|
3 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
4 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
5 |
+
import numpy as np
|
6 |
+
from .car_rl_environment import CarRLEnvironment
|
7 |
+
|
8 |
+
def make_env():
|
9 |
+
"""Create and return an instance of the CarRLEnvironment."""
|
10 |
+
return CarRLEnvironment(num_cars=3, time_of_day="08:00", is_rainy=False, is_weekday=True)
|
11 |
+
|
12 |
+
# Create a vectorized environment
|
13 |
+
env = DummyVecEnv([make_env])
|
14 |
+
|
15 |
+
# Initialize the PPO agent
|
16 |
+
model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.0003, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, ent_coef=0.0)
|
17 |
+
|
18 |
+
# Train the agent
|
19 |
+
total_timesteps = 1_000_000
|
20 |
+
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
21 |
+
|
22 |
+
# Evaluate the trained agent
|
23 |
+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
|
24 |
+
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
25 |
+
|
26 |
+
# Save the trained model
|
27 |
+
model.save("car_rl_ppo_model")
|
28 |
+
|
29 |
+
# Test the trained agent
|
30 |
+
obs = env.reset()
|
31 |
+
for _ in range(1000):
|
32 |
+
action, _states = model.predict(obs, deterministic=True)
|
33 |
+
obs, rewards, dones, info = env.step(action)
|
34 |
+
env.render()
|
35 |
+
if dones.any():
|
36 |
+
obs = env.reset()
|
37 |
+
|
38 |
+
env.close()
|
isopro/car_simulator/carviz.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pygame
|
2 |
+
import numpy as np
|
3 |
+
from .car_rl_environment import CarRLEnvironment
|
4 |
+
from stable_baselines3 import PPO
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
from datetime import datetime, timedelta
|
8 |
+
|
9 |
+
# Initialize Pygame
|
10 |
+
pygame.init()
|
11 |
+
|
12 |
+
# Constants
|
13 |
+
SCREEN_WIDTH = 1000
|
14 |
+
SCREEN_HEIGHT = 800
|
15 |
+
ROAD_WIDTH = 800
|
16 |
+
ROAD_HEIGHT = 600
|
17 |
+
CAR_WIDTH = 40
|
18 |
+
CAR_HEIGHT = 20
|
19 |
+
INFO_BOX_WIDTH = 200
|
20 |
+
INFO_BOX_HEIGHT = 120
|
21 |
+
UI_PANEL_WIDTH = 200
|
22 |
+
|
23 |
+
# Colors
|
24 |
+
WHITE = (255, 255, 255)
|
25 |
+
BLACK = (0, 0, 0)
|
26 |
+
GRAY = (200, 200, 200)
|
27 |
+
RED = (255, 0, 0)
|
28 |
+
GREEN = (0, 255, 0)
|
29 |
+
BLUE = (0, 0, 255)
|
30 |
+
YELLOW = (255, 255, 0)
|
31 |
+
|
32 |
+
class CarVisualization:
|
33 |
+
def __init__(self, env, model):
|
34 |
+
self.env = env
|
35 |
+
self.unwrapped_env = env.envs[0]
|
36 |
+
self.model = model
|
37 |
+
self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
|
38 |
+
pygame.display.set_caption("Enhanced Car RL Visualization")
|
39 |
+
self.clock = pygame.time.Clock()
|
40 |
+
self.font = pygame.font.Font(None, 24)
|
41 |
+
self.rain = [self.RainDrop() for _ in range(100)]
|
42 |
+
self.obstacles = [self.Obstacle() for _ in range(5)]
|
43 |
+
self.time_of_day = self.float_to_datetime(self.unwrapped_env.time_of_day)
|
44 |
+
|
45 |
+
def float_to_datetime(self, time_float):
|
46 |
+
"""Convert a float time (0-24) to a datetime object."""
|
47 |
+
hours = int(time_float)
|
48 |
+
minutes = int((time_float - hours) * 60)
|
49 |
+
return datetime.min + timedelta(hours=hours, minutes=minutes)
|
50 |
+
|
51 |
+
def datetime_to_string(self, dt):
|
52 |
+
"""Convert a datetime object to a string in HH:MM format."""
|
53 |
+
return dt.strftime("%H:%M")
|
54 |
+
|
55 |
+
def draw_road(self):
|
56 |
+
road_rect = pygame.Rect((SCREEN_WIDTH - ROAD_WIDTH) // 2, (SCREEN_HEIGHT - ROAD_HEIGHT) // 2, ROAD_WIDTH, ROAD_HEIGHT)
|
57 |
+
road_color = self.get_road_color()
|
58 |
+
pygame.draw.rect(self.screen, road_color, road_rect)
|
59 |
+
|
60 |
+
# Draw lane markings
|
61 |
+
for i in range(1, 3):
|
62 |
+
y = (SCREEN_HEIGHT - ROAD_HEIGHT) // 2 + i * (ROAD_HEIGHT // 3)
|
63 |
+
pygame.draw.line(self.screen, WHITE, (road_rect.left, y), (road_rect.right, y), 2)
|
64 |
+
|
65 |
+
def get_road_color(self):
|
66 |
+
hour = self.time_of_day.hour
|
67 |
+
if 6 <= hour < 18: # Daytime
|
68 |
+
return GRAY
|
69 |
+
elif 18 <= hour < 20 or 4 <= hour < 6: # Dawn/Dusk
|
70 |
+
return (150, 150, 170)
|
71 |
+
else: # Night
|
72 |
+
return (100, 100, 120)
|
73 |
+
|
74 |
+
def draw_car(self, position, angle, color):
|
75 |
+
x, y = position
|
76 |
+
x = (x + 1) * ROAD_WIDTH / 2 + (SCREEN_WIDTH - ROAD_WIDTH) // 2
|
77 |
+
y = (y + 1) * ROAD_HEIGHT / 2 + (SCREEN_HEIGHT - ROAD_HEIGHT) // 2
|
78 |
+
|
79 |
+
car_surface = pygame.Surface((CAR_WIDTH, CAR_HEIGHT), pygame.SRCALPHA)
|
80 |
+
pygame.draw.rect(car_surface, color, (0, 0, CAR_WIDTH, CAR_HEIGHT))
|
81 |
+
pygame.draw.polygon(car_surface, BLACK, [(0, 0), (CAR_WIDTH // 2, 0), (0, CAR_HEIGHT)])
|
82 |
+
rotated_car = pygame.transform.rotate(car_surface, -math.degrees(angle))
|
83 |
+
self.screen.blit(rotated_car, rotated_car.get_rect(center=(x, y)))
|
84 |
+
|
85 |
+
def draw_info_box(self, car_index, position, action, reward):
|
86 |
+
x, y = position
|
87 |
+
x = (x + 1) * ROAD_WIDTH / 2 + (SCREEN_WIDTH - ROAD_WIDTH) // 2
|
88 |
+
y = (y + 1) * ROAD_HEIGHT / 2 + (SCREEN_HEIGHT - ROAD_HEIGHT) // 2
|
89 |
+
|
90 |
+
info_box = pygame.Surface((INFO_BOX_WIDTH, INFO_BOX_HEIGHT))
|
91 |
+
info_box.fill(WHITE)
|
92 |
+
pygame.draw.rect(info_box, BLACK, info_box.get_rect(), 2)
|
93 |
+
|
94 |
+
texts = [
|
95 |
+
f"Car {car_index + 1}",
|
96 |
+
f"Acceleration: {action[0]:.2f}",
|
97 |
+
f"Steering: {action[1]:.2f}",
|
98 |
+
f"Reward: {reward:.2f}",
|
99 |
+
f"Speed: {np.linalg.norm(self.unwrapped_env.cars[car_index]['velocity']):.2f}"
|
100 |
+
]
|
101 |
+
|
102 |
+
for i, text in enumerate(texts):
|
103 |
+
text_surface = self.font.render(text, True, BLACK)
|
104 |
+
info_box.blit(text_surface, (10, 10 + i * 25))
|
105 |
+
|
106 |
+
self.screen.blit(info_box, (x - INFO_BOX_WIDTH // 2, y - INFO_BOX_HEIGHT - 30))
|
107 |
+
|
108 |
+
|
109 |
+
def draw_rain(self):
|
110 |
+
for drop in self.rain:
|
111 |
+
pygame.draw.line(self.screen, (200, 200, 255), (drop.x, drop.y), (drop.x, drop.y + drop.size), drop.size)
|
112 |
+
drop.fall()
|
113 |
+
|
114 |
+
def draw_obstacles(self):
|
115 |
+
for obstacle in self.obstacles:
|
116 |
+
pygame.draw.rect(self.screen, YELLOW, ((SCREEN_WIDTH - ROAD_WIDTH) // 2 + obstacle.x,
|
117 |
+
(SCREEN_HEIGHT - ROAD_HEIGHT) // 2 + obstacle.y,
|
118 |
+
obstacle.width, obstacle.height))
|
119 |
+
|
120 |
+
def draw_ui_panel(self):
|
121 |
+
panel = pygame.Surface((UI_PANEL_WIDTH, SCREEN_HEIGHT))
|
122 |
+
panel.fill(WHITE)
|
123 |
+
pygame.draw.rect(panel, BLACK, panel.get_rect(), 2)
|
124 |
+
|
125 |
+
texts = [
|
126 |
+
f"Time: {self.datetime_to_string(self.time_of_day)}",
|
127 |
+
f"Rainy: {'Yes' if self.unwrapped_env.is_rainy else 'No'}",
|
128 |
+
f"Weekday: {'Yes' if self.unwrapped_env.is_weekday else 'No'}",
|
129 |
+
"Press keys to change:",
|
130 |
+
"T: Time +1 hour",
|
131 |
+
"R: Toggle Rain",
|
132 |
+
"W: Toggle Weekday"
|
133 |
+
]
|
134 |
+
|
135 |
+
for i, text in enumerate(texts):
|
136 |
+
text_surface = self.font.render(text, True, BLACK)
|
137 |
+
panel.blit(text_surface, (10, 10 + i * 30))
|
138 |
+
|
139 |
+
self.screen.blit(panel, (SCREEN_WIDTH - UI_PANEL_WIDTH, 0))
|
140 |
+
|
141 |
+
def handle_events(self):
|
142 |
+
for event in pygame.event.get():
|
143 |
+
if event.type == pygame.QUIT:
|
144 |
+
return False
|
145 |
+
elif event.type == pygame.KEYDOWN:
|
146 |
+
if event.key == pygame.K_t:
|
147 |
+
self.time_of_day += timedelta(hours=1)
|
148 |
+
self.unwrapped_env.time_of_day = (self.time_of_day.hour + self.time_of_day.minute / 60) % 24
|
149 |
+
elif event.key == pygame.K_r:
|
150 |
+
self.unwrapped_env.is_rainy = not self.unwrapped_env.is_rainy
|
151 |
+
elif event.key == pygame.K_w:
|
152 |
+
self.unwrapped_env.is_weekday = not self.unwrapped_env.is_weekday
|
153 |
+
return True
|
154 |
+
|
155 |
+
class RainDrop:
|
156 |
+
def __init__(self):
|
157 |
+
self.x = random.randint(0, SCREEN_WIDTH)
|
158 |
+
self.y = random.randint(0, SCREEN_HEIGHT)
|
159 |
+
self.speed = random.randint(5, 15)
|
160 |
+
self.size = random.randint(1, 3)
|
161 |
+
|
162 |
+
def fall(self):
|
163 |
+
self.y += self.speed
|
164 |
+
if self.y > SCREEN_HEIGHT:
|
165 |
+
self.y = 0
|
166 |
+
self.x = random.randint(0, SCREEN_WIDTH)
|
167 |
+
|
168 |
+
class Obstacle:
|
169 |
+
def __init__(self):
|
170 |
+
self.width = random.randint(30, 60)
|
171 |
+
self.height = random.randint(30, 60)
|
172 |
+
self.x = random.randint(0, ROAD_WIDTH - self.width)
|
173 |
+
self.y = random.randint(0, ROAD_HEIGHT - self.height)
|
174 |
+
|
175 |
+
def run_visualization(self, num_episodes=5):
|
176 |
+
for episode in range(num_episodes):
|
177 |
+
obs = self.env.reset()
|
178 |
+
done = False
|
179 |
+
total_reward = 0
|
180 |
+
step = 0
|
181 |
+
|
182 |
+
while not done:
|
183 |
+
if not self.handle_events():
|
184 |
+
return
|
185 |
+
|
186 |
+
self.screen.fill(WHITE)
|
187 |
+
self.draw_road()
|
188 |
+
self.draw_obstacles()
|
189 |
+
if self.unwrapped_env.is_rainy:
|
190 |
+
self.draw_rain()
|
191 |
+
|
192 |
+
action, _ = self.model.predict(obs, deterministic=True)
|
193 |
+
obs, reward, done, info = self.env.step(action)
|
194 |
+
total_reward += reward[0]
|
195 |
+
|
196 |
+
for i, car in enumerate(self.unwrapped_env.cars):
|
197 |
+
position = car["position"].numpy()
|
198 |
+
angle = car["angle"].item()
|
199 |
+
color = (RED, GREEN, BLUE)[i % 3] # Cycle through colors for different cars
|
200 |
+
self.draw_car(position, angle, color)
|
201 |
+
self.draw_info_box(i, position, action[0][i*2:(i+1)*2], reward[0])
|
202 |
+
|
203 |
+
self.draw_ui_panel()
|
204 |
+
pygame.display.flip()
|
205 |
+
self.clock.tick(30)
|
206 |
+
step += 1
|
207 |
+
|
208 |
+
if done[0]:
|
209 |
+
break
|
210 |
+
|
211 |
+
print(f"Episode {episode + 1} finished. Total reward: {total_reward:.2f}")
|
212 |
+
|
213 |
+
pygame.quit()
|
214 |
+
|
215 |
+
|
216 |
+
def main():
|
217 |
+
# Create and train the model (you might want to load a pre-trained model instead)
|
218 |
+
env = CarRLEnvironment(num_cars=3, time_of_day="08:00", is_rainy=False, is_weekday=True)
|
219 |
+
model = PPO("MlpPolicy", env, verbose=1)
|
220 |
+
model.learn(total_timesteps=10000) # Adjust as needed
|
221 |
+
|
222 |
+
# Create and run the visualization
|
223 |
+
viz = CarVisualization(env, model)
|
224 |
+
viz.run_visualization()
|
225 |
+
|
226 |
+
if __name__ == "__main__":
|
227 |
+
main()
|
isopro/car_simulator/llm_main.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from stable_baselines3 import PPO
|
4 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
5 |
+
from .car_llm_agent import LLMCarRLWrapper
|
6 |
+
from .car_rl_environment import CarRLEnvironment
|
7 |
+
from .carviz import CarVisualization
|
8 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
# Load environment variables from .env file
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
def parse_arguments():
|
15 |
+
parser = argparse.ArgumentParser(description="Car RL Simulation with LLM Integration and Visualization")
|
16 |
+
parser.add_argument("--num_cars", type=int, default=3, help="Number of cars in the simulation")
|
17 |
+
parser.add_argument("--time_of_day", type=str, default="08:00", help="Initial time of day (HH:MM format)")
|
18 |
+
parser.add_argument("--is_rainy", action="store_true", help="Set initial weather to rainy")
|
19 |
+
parser.add_argument("--is_weekday", action="store_true", help="Set initial day to weekday")
|
20 |
+
parser.add_argument("--train_steps", type=int, default=100000, help="Number of training steps")
|
21 |
+
parser.add_argument("--visualize_episodes", type=int, default=5, help="Number of episodes to visualize")
|
22 |
+
parser.add_argument("--load_model", type=str, help="Path to a pre-trained model to load")
|
23 |
+
parser.add_argument("--llm_call_limit", type=int, default=1000, help="Maximum number of LLM API calls")
|
24 |
+
parser.add_argument("--llm_call_frequency", type=int, default=100, help="Frequency of LLM calls (in steps)")
|
25 |
+
return parser.parse_args()
|
26 |
+
|
27 |
+
def make_env(num_cars, time_of_day, is_rainy, is_weekday, llm_call_limit, llm_call_frequency):
|
28 |
+
def _init():
|
29 |
+
return LLMCarRLWrapper(num_cars=num_cars, time_of_day=time_of_day, is_rainy=is_rainy,
|
30 |
+
is_weekday=is_weekday, llm_call_limit=llm_call_limit,
|
31 |
+
llm_call_frequency=llm_call_frequency)
|
32 |
+
return _init
|
33 |
+
|
34 |
+
def train_and_evaluate(env, total_timesteps, eval_episodes=10):
|
35 |
+
model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.0003, n_steps=2048,
|
36 |
+
batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2)
|
37 |
+
|
38 |
+
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
39 |
+
|
40 |
+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=eval_episodes)
|
41 |
+
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
42 |
+
|
43 |
+
return model, mean_reward
|
44 |
+
|
45 |
+
def main():
|
46 |
+
args = parse_arguments()
|
47 |
+
|
48 |
+
# Ensure the ANTHROPIC_API_KEY is set
|
49 |
+
if not os.getenv('ANTHROPIC_API_KEY'):
|
50 |
+
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
|
51 |
+
|
52 |
+
# Create the vectorized environment with LLM integration
|
53 |
+
env = DummyVecEnv([make_env(args.num_cars, args.time_of_day, args.is_rainy, args.is_weekday,
|
54 |
+
args.llm_call_limit, args.llm_call_frequency)])
|
55 |
+
|
56 |
+
# Create or load the RL agent
|
57 |
+
if args.load_model and os.path.exists(args.load_model):
|
58 |
+
print(f"Loading pre-trained model from {args.load_model}")
|
59 |
+
model = PPO.load(args.load_model, env=env)
|
60 |
+
else:
|
61 |
+
print("Creating and training a new model")
|
62 |
+
model, mean_reward = train_and_evaluate(env, total_timesteps=args.train_steps)
|
63 |
+
|
64 |
+
# Save the trained model
|
65 |
+
model.save("car_rl_llm_model")
|
66 |
+
print("Model saved as car_rl_llm_model")
|
67 |
+
print(f"Final mean reward: {mean_reward:.2f}")
|
68 |
+
|
69 |
+
# Run the visualization
|
70 |
+
viz = CarVisualization(env, model)
|
71 |
+
viz.run_visualization(num_episodes=args.visualize_episodes)
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|
isopro/car_simulator/main.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from stable_baselines3 import PPO
|
4 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
5 |
+
from .car_rl_environment import CarRLEnvironment
|
6 |
+
from .carviz import CarVisualization
|
7 |
+
|
8 |
+
def parse_arguments():
|
9 |
+
parser = argparse.ArgumentParser(description="Car RL Simulation and Visualization")
|
10 |
+
parser.add_argument("--num_cars", type=int, default=3, help="Number of cars in the simulation")
|
11 |
+
parser.add_argument("--time_of_day", type=str, default="08:00", help="Initial time of day (HH:MM format)")
|
12 |
+
parser.add_argument("--is_rainy", action="store_true", help="Set initial weather to rainy")
|
13 |
+
parser.add_argument("--is_weekday", action="store_true", help="Set initial day to weekday")
|
14 |
+
parser.add_argument("--train_steps", type=int, default=10000, help="Number of training steps")
|
15 |
+
parser.add_argument("--visualize_episodes", type=int, default=5, help="Number of episodes to visualize")
|
16 |
+
parser.add_argument("--load_model", type=str, help="Path to a pre-trained model to load")
|
17 |
+
return parser.parse_args()
|
18 |
+
|
19 |
+
def make_env(num_cars, time_of_day, is_rainy, is_weekday):
|
20 |
+
def _init():
|
21 |
+
return CarRLEnvironment(num_cars=num_cars, time_of_day=time_of_day, is_rainy=is_rainy, is_weekday=is_weekday)
|
22 |
+
return _init
|
23 |
+
|
24 |
+
def main():
|
25 |
+
args = parse_arguments()
|
26 |
+
|
27 |
+
# Create the vectorized environment
|
28 |
+
env = DummyVecEnv([make_env(args.num_cars, args.time_of_day, args.is_rainy, args.is_weekday)])
|
29 |
+
|
30 |
+
# Create or load the RL agent
|
31 |
+
if args.load_model and os.path.exists(args.load_model):
|
32 |
+
print(f"Loading pre-trained model from {args.load_model}")
|
33 |
+
model = PPO.load(args.load_model, env=env)
|
34 |
+
else:
|
35 |
+
print("Creating and training a new model")
|
36 |
+
model = PPO("MlpPolicy", env, verbose=1)
|
37 |
+
model.learn(total_timesteps=args.train_steps)
|
38 |
+
|
39 |
+
# Save the trained model
|
40 |
+
model.save("car_rl_model")
|
41 |
+
print("Model saved as car_rl_model")
|
42 |
+
|
43 |
+
# Run the visualization
|
44 |
+
viz = CarVisualization(env, model)
|
45 |
+
viz.run_visualization(num_episodes=args.visualize_episodes)
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
isopro/conversation_simulation/README.md
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Conversation Simulator
|
2 |
+
|
3 |
+
This module is part of the `isopro` package and simulates conversations between an AI assistant (either Claude or GPT-4) and various user personas. It's designed to test and demonstrate how the AI handles different types of customer service scenarios.
|
4 |
+
|
5 |
+
## Project Structure
|
6 |
+
|
7 |
+
The Conversation Simulator is located in the `conversation_simulator` folder within the `isopro` package:
|
8 |
+
|
9 |
+
```
|
10 |
+
isopro/
|
11 |
+
└── conversation_simulator/
|
12 |
+
├── main.py
|
13 |
+
├── conversation_simulator.ipynb
|
14 |
+
├── conversation_agent.py
|
15 |
+
├── conversation_environment.py
|
16 |
+
├── custom_persona.py
|
17 |
+
└── user_personas.py
|
18 |
+
```
|
19 |
+
|
20 |
+
## Prerequisites
|
21 |
+
|
22 |
+
Before you begin, ensure you have met the following requirements:
|
23 |
+
|
24 |
+
* You have installed Python 3.7 or later.
|
25 |
+
* You have an Anthropic API key (for Claude) and/or an OpenAI API key (for GPT-4).
|
26 |
+
* You have installed the `isopro` package.
|
27 |
+
* For the Jupyter notebook, you have Jupyter Notebook or JupyterLab installed.
|
28 |
+
|
29 |
+
## Setting up the Conversation Simulator
|
30 |
+
|
31 |
+
1. If you haven't already, install the `isopro` package:
|
32 |
+
```
|
33 |
+
pip install isopro
|
34 |
+
```
|
35 |
+
|
36 |
+
2. Create a `.env` file in your project root and add your API keys:
|
37 |
+
```
|
38 |
+
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
39 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
40 |
+
```
|
41 |
+
|
42 |
+
## Running the Conversation Simulator
|
43 |
+
|
44 |
+
You can run the Conversation Simulator either as a Python script or interactively using a Jupyter notebook.
|
45 |
+
|
46 |
+
### Using the Python Script
|
47 |
+
|
48 |
+
1. Basic usage:
|
49 |
+
```python
|
50 |
+
from isopro.conversation_simulator.main import main
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
main()
|
54 |
+
```
|
55 |
+
|
56 |
+
2. Running from the command line:
|
57 |
+
```
|
58 |
+
python -m isopro.conversation_simulator.main
|
59 |
+
```
|
60 |
+
|
61 |
+
### Using the Jupyter Notebook
|
62 |
+
|
63 |
+
Navigate to the `isopro/conversation_simulator/` directory and open the `conversation_simulator.ipynb` file using Jupyter Notebook or JupyterLab. Here's what you'll find in the notebook:
|
64 |
+
|
65 |
+
```python
|
66 |
+
# Conversation Simulator Jupyter Notebook
|
67 |
+
|
68 |
+
## Setup
|
69 |
+
|
70 |
+
import logging
|
71 |
+
from logging.handlers import RotatingFileHandler
|
72 |
+
import os
|
73 |
+
from datetime import datetime
|
74 |
+
from dotenv import load_dotenv
|
75 |
+
from isopro.conversation_simulation.conversation_simulator import ConversationSimulator
|
76 |
+
from isopro.conversation_simulation.custom_persona import create_custom_persona
|
77 |
+
|
78 |
+
# Load environment variables
|
79 |
+
load_dotenv()
|
80 |
+
|
81 |
+
# Set up logging
|
82 |
+
log_directory = "logs"
|
83 |
+
os.makedirs(log_directory, exist_ok=True)
|
84 |
+
log_file = os.path.join(log_directory, "conversation_simulator.log")
|
85 |
+
|
86 |
+
# Create a rotating file handler
|
87 |
+
file_handler = RotatingFileHandler(log_file, maxBytes=1024*1024, backupCount=5)
|
88 |
+
file_handler.setLevel(logging.DEBUG)
|
89 |
+
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
90 |
+
file_handler.setFormatter(file_formatter)
|
91 |
+
|
92 |
+
# Create a console handler
|
93 |
+
console_handler = logging.StreamHandler()
|
94 |
+
console_handler.setLevel(logging.INFO)
|
95 |
+
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
96 |
+
console_handler.setFormatter(console_formatter)
|
97 |
+
|
98 |
+
# Set up the logger
|
99 |
+
logger = logging.getLogger()
|
100 |
+
logger.setLevel(logging.DEBUG)
|
101 |
+
logger.addHandler(file_handler)
|
102 |
+
logger.addHandler(console_handler)
|
103 |
+
|
104 |
+
print("Setup complete.")
|
105 |
+
|
106 |
+
## Helper Functions
|
107 |
+
|
108 |
+
def save_output(content, filename):
|
109 |
+
"""Save the output content to a file."""
|
110 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
111 |
+
f.write(content)
|
112 |
+
|
113 |
+
def get_user_choice():
|
114 |
+
"""Get user's choice of AI model."""
|
115 |
+
while True:
|
116 |
+
choice = input("Choose AI model (claude/openai): ").lower()
|
117 |
+
if choice in ['claude', 'openai']:
|
118 |
+
return choice
|
119 |
+
print("Invalid choice. Please enter 'claude' or 'openai'.")
|
120 |
+
|
121 |
+
print("Helper functions defined.")
|
122 |
+
|
123 |
+
## Main Simulation Function
|
124 |
+
|
125 |
+
def run_simulation():
|
126 |
+
# Get user's choice of AI model
|
127 |
+
ai_choice = get_user_choice()
|
128 |
+
|
129 |
+
# Set up the appropriate model and API key
|
130 |
+
if ai_choice == 'claude':
|
131 |
+
model = "claude-3-opus-20240229"
|
132 |
+
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
|
133 |
+
ai_name = "Claude"
|
134 |
+
else: # openai
|
135 |
+
model = "gpt-4-1106-preview"
|
136 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
137 |
+
ai_name = "GPT-4 Turbo"
|
138 |
+
|
139 |
+
# Initialize the ConversationSimulator
|
140 |
+
simulator = ConversationSimulator(
|
141 |
+
ai_prompt=f"You are {ai_name}, an AI assistant created to be helpful, harmless, and honest. You are a customer service agent for a tech company. Respond politely and professionally."
|
142 |
+
)
|
143 |
+
|
144 |
+
output_content = f"Conversation Simulator using {ai_name} model: {model}\n\n"
|
145 |
+
|
146 |
+
# Run simulations with different personas
|
147 |
+
personas = ["upset", "human_request", "inappropriate", "incomplete_info"]
|
148 |
+
|
149 |
+
for persona in personas:
|
150 |
+
logger.info(f"Running simulation with {persona} persona using {ai_name}")
|
151 |
+
conversation_history = simulator.run_simulation(persona, num_turns=3)
|
152 |
+
|
153 |
+
output_content += f"\nConversation with {persona} persona:\n"
|
154 |
+
for message in conversation_history:
|
155 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
156 |
+
output_content += output_line
|
157 |
+
logger.debug(output_line.strip())
|
158 |
+
output_content += "\n" + "-"*50 + "\n"
|
159 |
+
|
160 |
+
# Create and run a simulation with a custom persona
|
161 |
+
custom_persona_name = "Techie Customer"
|
162 |
+
custom_characteristics = ["tech-savvy", "impatient", "detail-oriented"]
|
163 |
+
custom_message_templates = [
|
164 |
+
"I've tried rebooting my device, but the error persists. Can you help?",
|
165 |
+
"What's the latest update on the cloud service outage?",
|
166 |
+
"I need specifics on the API rate limits for the enterprise plan.",
|
167 |
+
"The latency on your servers is unacceptable. What's being done about it?",
|
168 |
+
"Can you explain the technical details of your encryption method?"
|
169 |
+
]
|
170 |
+
|
171 |
+
logger.info(f"Running simulation with custom persona: {custom_persona_name} using {ai_name}")
|
172 |
+
custom_conversation = simulator.run_custom_simulation(
|
173 |
+
custom_persona_name,
|
174 |
+
custom_characteristics,
|
175 |
+
custom_message_templates,
|
176 |
+
num_turns=3
|
177 |
+
)
|
178 |
+
|
179 |
+
output_content += f"\nConversation with {custom_persona_name}:\n"
|
180 |
+
for message in custom_conversation:
|
181 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
182 |
+
output_content += output_line
|
183 |
+
logger.debug(output_line.strip())
|
184 |
+
|
185 |
+
# Save the output to a file
|
186 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
187 |
+
output_directory = "output"
|
188 |
+
os.makedirs(output_directory, exist_ok=True)
|
189 |
+
output_file = os.path.join(output_directory, f"{ai_name.lower()}_conversation_output_{timestamp}.txt")
|
190 |
+
save_output(output_content, output_file)
|
191 |
+
logger.info(f"Output saved to {output_file}")
|
192 |
+
|
193 |
+
return output_content
|
194 |
+
|
195 |
+
print("Main simulation function defined.")
|
196 |
+
|
197 |
+
## Run the Simulation
|
198 |
+
|
199 |
+
simulation_output = run_simulation()
|
200 |
+
print(simulation_output)
|
201 |
+
|
202 |
+
## Analyze the Results
|
203 |
+
|
204 |
+
# Example analysis: Count the number of apologies
|
205 |
+
apology_count = simulation_output.lower().count("sorry") + simulation_output.lower().count("apologi")
|
206 |
+
print(f"Number of apologies: {apology_count}")
|
207 |
+
|
208 |
+
# Example analysis: Average length of AI responses
|
209 |
+
ai_responses = [line.split(": ", 1)[1] for line in simulation_output.split("\n") if line.startswith("Assistant: ")]
|
210 |
+
avg_response_length = sum(len(response.split()) for response in ai_responses) / len(ai_responses)
|
211 |
+
print(f"Average length of AI responses: {avg_response_length:.2f} words")
|
212 |
+
|
213 |
+
## Conclusion
|
214 |
+
|
215 |
+
# This notebook demonstrates how to use the Conversation Simulator from the isopro package.
|
216 |
+
# You can modify the personas, adjust the number of turns, or add your own analysis to
|
217 |
+
# further explore the capabilities of the AI models in customer service scenarios.
|
218 |
+
```
|
219 |
+
|
220 |
+
## Output and Logs
|
221 |
+
|
222 |
+
- Simulation outputs are saved in the `output` directory within your current working directory.
|
223 |
+
- Logs are saved in the `logs` directory within your current working directory.
|
224 |
+
|
225 |
+
## Customizing the Simulation
|
226 |
+
|
227 |
+
You can customize the simulation by modifying the `main.py` file or the Jupyter notebook:
|
228 |
+
|
229 |
+
- To change the predefined personas, modify the `personas` list.
|
230 |
+
- To adjust the custom persona, modify the `custom_persona_name`, `custom_characteristics`, and `custom_message_templates` variables.
|
231 |
+
- To change the number of turns in each conversation, modify the `num_turns` parameter in the `run_simulation` and `run_custom_simulation` method calls.
|
232 |
+
|
233 |
+
In the Jupyter notebook, you can also add new cells for additional analysis or visualization of the results.
|
234 |
+
|
235 |
+
## Troubleshooting
|
236 |
+
|
237 |
+
If you encounter any issues:
|
238 |
+
|
239 |
+
1. Make sure your API keys are correctly set in the `.env` file or environment variables.
|
240 |
+
2. Check the logs in the `logs` directory for detailed error messages.
|
241 |
+
3. Ensure you have the latest version of the `isopro` package installed.
|
242 |
+
4. For Jupyter notebook issues, make sure you have Jupyter installed and are running the notebook from the correct directory.
|
243 |
+
|
244 |
+
If problems persist, please open an issue in the project repository.
|
245 |
+
|
246 |
+
## Contributing
|
247 |
+
|
248 |
+
Contributions to the Conversation Simulator are welcome. Please feel free to submit a Pull Request to the `isopro` repository.
|
249 |
+
|
250 |
+
## License
|
251 |
+
|
252 |
+
This project is licensed under the MIT License - see the LICENSE file in the `isopro` package for details.
|
isopro/conversation_simulation/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Conversation Simulation Module
|
3 |
+
|
4 |
+
This module provides tools for simulating conversations with AI agents.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .conversation_environment import ConversationEnvironment
|
8 |
+
from .conversation_agent import ConversationAgent
|
9 |
+
from .user_personas import UserPersona
|
10 |
+
from .custom_persona import create_custom_persona
|
11 |
+
from .conversation_simulator import ConversationSimulator
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"ConversationEnvironment",
|
15 |
+
"ConversationAgent",
|
16 |
+
"UserPersona",
|
17 |
+
"create_custom_persona",
|
18 |
+
"ConversationSimulator",
|
19 |
+
]
|
isopro/conversation_simulation/conversation_agent.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Conversation Agent
|
3 |
+
|
4 |
+
This module defines the AI agent used in the conversation simulation, using Anthropic's Claude API.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import anthropic
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
from ..agents.ai_agent import AI_Agent
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
class ConversationAgent(AI_Agent):
|
18 |
+
def __init__(self, name, prompt, model="claude-3-opus-20240229"):
|
19 |
+
super().__init__(name)
|
20 |
+
self.prompt = prompt
|
21 |
+
self.model = model
|
22 |
+
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
23 |
+
logger.info(f"Initialized ConversationAgent '{name}' with Claude model {model}")
|
24 |
+
|
25 |
+
def generate_response(self, conversation_history):
|
26 |
+
try:
|
27 |
+
messages = [{"role": "user" if msg["role"] != "assistant" else "assistant", "content": msg["content"]}
|
28 |
+
for msg in conversation_history]
|
29 |
+
|
30 |
+
response = self.client.messages.create(
|
31 |
+
model=self.model,
|
32 |
+
max_tokens=1000,
|
33 |
+
system=self.prompt,
|
34 |
+
messages=messages
|
35 |
+
)
|
36 |
+
ai_message = response.content[0].text.strip()
|
37 |
+
logger.debug(f"Generated response: {ai_message}")
|
38 |
+
return ai_message
|
39 |
+
except Exception as e:
|
40 |
+
logger.error(f"Error generating response: {e}")
|
41 |
+
return "I apologize, but I'm having trouble responding at the moment."
|
isopro/conversation_simulation/conversation_environment.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Conversation Environment
|
3 |
+
|
4 |
+
This module defines the environment for simulating conversations between a Claude-based AI agent and users with various personas.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from ..environments.simulation_environment import SimulationEnvironment
|
9 |
+
from .conversation_agent import ConversationAgent
|
10 |
+
from .user_personas import UserPersona
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class ConversationEnvironment(SimulationEnvironment):
|
15 |
+
"""
|
16 |
+
ConversationEnvironment
|
17 |
+
|
18 |
+
This class provides an environment for simulating conversations between Claude-based AI agents and users with various personas.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, ai_prompt="You are a helpful customer service agent. Respond politely and professionally."):
|
22 |
+
"""
|
23 |
+
Initialize the ConversationEnvironment.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
ai_prompt (str): The prompt to guide the AI agent's behavior.
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
self.ai_prompt = ai_prompt
|
30 |
+
self.ai_agent = None
|
31 |
+
self.user_persona = None
|
32 |
+
logger.info("Initialized ConversationEnvironment")
|
33 |
+
|
34 |
+
def set_ai_agent(self, model="claude-3-opus-20240229"):
|
35 |
+
"""
|
36 |
+
Set up the Claude-based AI agent for the conversation.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
model (str): The name of the Claude model to use.
|
40 |
+
"""
|
41 |
+
self.ai_agent = ConversationAgent("Customer Service AI", self.ai_prompt, model)
|
42 |
+
logger.info(f"Set AI agent with Claude model: {model}")
|
43 |
+
def set_user_persona(self, persona_type, **kwargs):
|
44 |
+
"""
|
45 |
+
Set the user persona for the conversation.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
persona_type (str): The type of user persona to use.
|
49 |
+
**kwargs: Additional arguments for the user persona.
|
50 |
+
"""
|
51 |
+
self.user_persona = UserPersona.create(persona_type, **kwargs)
|
52 |
+
logger.info(f"Set user persona: {persona_type}")
|
53 |
+
|
54 |
+
def run_conversation(self, num_turns=5):
|
55 |
+
"""
|
56 |
+
Run a conversation between the AI agent and the user persona.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
num_turns (int): The number of conversation turns to simulate.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
list: A list of dictionaries containing the conversation history.
|
63 |
+
"""
|
64 |
+
if not self.ai_agent or not self.user_persona:
|
65 |
+
raise ValueError("Both AI agent and user persona must be set before running a conversation.")
|
66 |
+
|
67 |
+
conversation_history = []
|
68 |
+
for _ in range(num_turns):
|
69 |
+
user_message = self.user_persona.generate_message(conversation_history)
|
70 |
+
conversation_history.append({"role": "user", "content": user_message})
|
71 |
+
logger.debug(f"User: {user_message}")
|
72 |
+
|
73 |
+
ai_response = self.ai_agent.generate_response(conversation_history)
|
74 |
+
conversation_history.append({"role": "assistant", "content": ai_response})
|
75 |
+
logger.debug(f"AI: {ai_response}")
|
76 |
+
|
77 |
+
logger.info("Completed conversation simulation")
|
78 |
+
return conversation_history
|
isopro/conversation_simulation/conversation_simulator.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Conversation Simulator
|
3 |
+
|
4 |
+
This module provides a high-level interface for running conversation simulations
|
5 |
+
with different personas and analyzing the results using Anthropic's Claude API.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from .conversation_environment import ConversationEnvironment
|
10 |
+
from .custom_persona import create_custom_persona
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class ConversationSimulator:
|
15 |
+
"""
|
16 |
+
ConversationSimulator orchestrates conversation simulations with various personas using Claude.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, ai_prompt="You are a helpful customer service agent. Respond politely and professionally."):
|
20 |
+
"""
|
21 |
+
Initialize the ConversationSimulator.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
ai_prompt (str): The prompt to guide the Claude-based AI agent's behavior.
|
25 |
+
"""
|
26 |
+
self.environment = ConversationEnvironment(ai_prompt)
|
27 |
+
logger.info("Initialized ConversationSimulator with Claude")
|
28 |
+
|
29 |
+
def run_simulation(self, persona_type, num_turns=5, claude_model="claude-3-opus-20240229", **persona_kwargs):
|
30 |
+
"""
|
31 |
+
Run a conversation simulation with a specified persona using Claude.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
persona_type (str): The type of persona to use in the simulation.
|
35 |
+
num_turns (int): The number of conversation turns to simulate.
|
36 |
+
claude_model (str): The specific Claude model to use for the simulation.
|
37 |
+
**persona_kwargs: Additional arguments for creating the persona.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
list: A list of dictionaries containing the conversation history.
|
41 |
+
"""
|
42 |
+
self.environment.set_ai_agent(model=claude_model)
|
43 |
+
self.environment.set_user_persona(persona_type, **persona_kwargs)
|
44 |
+
conversation_history = self.environment.run_conversation(num_turns)
|
45 |
+
logger.info(f"Completed simulation with {persona_type} persona using Claude model {claude_model}")
|
46 |
+
return conversation_history
|
47 |
+
|
48 |
+
def run_custom_simulation(self, name, characteristics, message_templates, num_turns=5, claude_model="claude-3-opus-20240229"):
|
49 |
+
"""
|
50 |
+
Run a conversation simulation with a custom persona using Claude.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
name (str): The name of the custom persona.
|
54 |
+
characteristics (list): A list of characteristics that define the persona.
|
55 |
+
message_templates (list): A list of message templates the persona can use.
|
56 |
+
num_turns (int): The number of conversation turns to simulate.
|
57 |
+
claude_model (str): The specific Claude model to use for the simulation.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
list: A list of dictionaries containing the conversation history.
|
61 |
+
"""
|
62 |
+
custom_persona = create_custom_persona(name, characteristics, message_templates)
|
63 |
+
self.environment.set_ai_agent(model=claude_model)
|
64 |
+
self.environment.user_persona = custom_persona
|
65 |
+
conversation_history = self.environment.run_conversation(num_turns)
|
66 |
+
logger.info(f"Completed simulation with custom persona: {name} using Claude model {claude_model}")
|
67 |
+
return conversation_history
|
isopro/conversation_simulation/custom_persona.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Custom Persona
|
3 |
+
|
4 |
+
This module allows users to create custom personas for the conversation simulation.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from .user_personas import UserPersona
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class CustomPersona(UserPersona):
|
13 |
+
"""
|
14 |
+
CustomPersona allows users to create their own persona with specific characteristics.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, name, characteristics, message_templates):
|
18 |
+
"""
|
19 |
+
Initialize the CustomPersona.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
name (str): The name of the custom persona.
|
23 |
+
characteristics (list): A list of characteristics that define the persona.
|
24 |
+
message_templates (list): A list of message templates the persona can use.
|
25 |
+
"""
|
26 |
+
super().__init__(name)
|
27 |
+
self.characteristics = characteristics
|
28 |
+
self.message_templates = message_templates
|
29 |
+
logger.info(f"Created CustomPersona: {name}")
|
30 |
+
|
31 |
+
def generate_message(self, conversation_history):
|
32 |
+
"""
|
33 |
+
Generate a message based on the custom persona's characteristics and templates.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
conversation_history (list): A list of dictionaries containing the conversation history.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: The generated message.
|
40 |
+
"""
|
41 |
+
import random
|
42 |
+
message = random.choice(self.message_templates)
|
43 |
+
logger.debug(f"CustomPersona '{self.name}' generated message: {message}")
|
44 |
+
return message
|
45 |
+
|
46 |
+
def create_custom_persona(name, characteristics, message_templates):
|
47 |
+
"""
|
48 |
+
Create a custom persona with the given characteristics and message templates.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
name (str): The name of the custom persona.
|
52 |
+
characteristics (list): A list of characteristics that define the persona.
|
53 |
+
message_templates (list): A list of message templates the persona can use.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
CustomPersona: An instance of the custom persona.
|
57 |
+
"""
|
58 |
+
return CustomPersona(name, characteristics, message_templates)
|
isopro/conversation_simulation/main.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from logging.handlers import RotatingFileHandler
|
3 |
+
import os
|
4 |
+
from datetime import datetime
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from .conversation_simulator import ConversationSimulator
|
7 |
+
from .custom_persona import create_custom_persona
|
8 |
+
|
9 |
+
# Load environment variables
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# Set up logging
|
13 |
+
log_directory = "logs"
|
14 |
+
os.makedirs(log_directory, exist_ok=True)
|
15 |
+
log_file = os.path.join(log_directory, "conversation_simulator.log")
|
16 |
+
|
17 |
+
# Create a rotating file handler
|
18 |
+
file_handler = RotatingFileHandler(log_file, maxBytes=1024*1024, backupCount=5)
|
19 |
+
file_handler.setLevel(logging.DEBUG)
|
20 |
+
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
21 |
+
file_handler.setFormatter(file_formatter)
|
22 |
+
|
23 |
+
# Create a console handler
|
24 |
+
console_handler = logging.StreamHandler()
|
25 |
+
console_handler.setLevel(logging.INFO)
|
26 |
+
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
27 |
+
console_handler.setFormatter(console_formatter)
|
28 |
+
|
29 |
+
# Set up the logger
|
30 |
+
logger = logging.getLogger()
|
31 |
+
logger.setLevel(logging.DEBUG)
|
32 |
+
logger.addHandler(file_handler)
|
33 |
+
logger.addHandler(console_handler)
|
34 |
+
|
35 |
+
def save_output(content, filename):
|
36 |
+
"""Save the output content to a file."""
|
37 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
38 |
+
f.write(content)
|
39 |
+
|
40 |
+
def get_user_choice():
|
41 |
+
"""Get user's choice of AI model."""
|
42 |
+
while True:
|
43 |
+
choice = input("Choose AI model (claude/openai): ").lower()
|
44 |
+
if choice in ['claude', 'openai']:
|
45 |
+
return choice
|
46 |
+
print("Invalid choice. Please enter 'claude' or 'openai'.")
|
47 |
+
|
48 |
+
def main():
|
49 |
+
# Get user's choice of AI model
|
50 |
+
ai_choice = get_user_choice()
|
51 |
+
|
52 |
+
# Set up the appropriate model and API key
|
53 |
+
if ai_choice == 'claude':
|
54 |
+
model = "claude-3-opus-20240229"
|
55 |
+
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
|
56 |
+
ai_name = "Claude"
|
57 |
+
else: # openai
|
58 |
+
model = "gpt-4-1106-preview"
|
59 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
60 |
+
ai_name = "GPT-4 Turbo"
|
61 |
+
|
62 |
+
# Initialize the ConversationSimulator
|
63 |
+
simulator = ConversationSimulator(
|
64 |
+
ai_prompt=f"You are {ai_name}, an AI assistant created to be helpful, harmless, and honest. You are a customer service agent for a tech company. Respond politely and professionally."
|
65 |
+
)
|
66 |
+
|
67 |
+
output_content = f"Conversation Simulator using {ai_name} model: {model}\n\n"
|
68 |
+
|
69 |
+
# Run simulations with different personas
|
70 |
+
personas = ["upset", "human_request", "inappropriate", "incomplete_info"]
|
71 |
+
|
72 |
+
for persona in personas:
|
73 |
+
logger.info(f"Running simulation with {persona} persona using {ai_name}")
|
74 |
+
conversation_history = simulator.run_simulation(persona, num_turns=3)
|
75 |
+
|
76 |
+
output_content += f"\nConversation with {persona} persona:\n"
|
77 |
+
for message in conversation_history:
|
78 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
79 |
+
output_content += output_line
|
80 |
+
logger.debug(output_line.strip())
|
81 |
+
output_content += "\n" + "-"*50 + "\n"
|
82 |
+
|
83 |
+
# Create and run a simulation with a custom persona
|
84 |
+
custom_persona_name = "Techie Customer"
|
85 |
+
custom_characteristics = ["tech-savvy", "impatient", "detail-oriented"]
|
86 |
+
custom_message_templates = [
|
87 |
+
"I've tried rebooting my device, but the error persists. Can you help?",
|
88 |
+
"What's the latest update on the cloud service outage?",
|
89 |
+
"I need specifics on the API rate limits for the enterprise plan.",
|
90 |
+
"The latency on your servers is unacceptable. What's being done about it?",
|
91 |
+
"Can you explain the technical details of your encryption method?"
|
92 |
+
]
|
93 |
+
|
94 |
+
logger.info(f"Running simulation with custom persona: {custom_persona_name} using {ai_name}")
|
95 |
+
custom_conversation = simulator.run_custom_simulation(
|
96 |
+
custom_persona_name,
|
97 |
+
custom_characteristics,
|
98 |
+
custom_message_templates,
|
99 |
+
num_turns=3
|
100 |
+
)
|
101 |
+
|
102 |
+
output_content += f"\nConversation with {custom_persona_name}:\n"
|
103 |
+
for message in custom_conversation:
|
104 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
105 |
+
output_content += output_line
|
106 |
+
logger.debug(output_line.strip())
|
107 |
+
|
108 |
+
# Save the output to a file
|
109 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
110 |
+
output_directory = "output"
|
111 |
+
os.makedirs(output_directory, exist_ok=True)
|
112 |
+
output_file = os.path.join(output_directory, f"{ai_name.lower()}_conversation_output_{timestamp}.txt")
|
113 |
+
save_output(output_content, output_file)
|
114 |
+
logger.info(f"Output saved to {output_file}")
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
main()
|
isopro/conversation_simulation/user_personas.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
User Personas
|
3 |
+
|
4 |
+
This module defines various user personas for the conversation simulation.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import random
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class UserPersona:
|
13 |
+
"""
|
14 |
+
Base class for user personas in the conversation simulation.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, name):
|
18 |
+
self.name = name
|
19 |
+
|
20 |
+
def generate_message(self, conversation_history):
|
21 |
+
"""
|
22 |
+
Generate a message based on the persona and conversation history.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
conversation_history (list): A list of dictionaries containing the conversation history.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
str: The generated message.
|
29 |
+
"""
|
30 |
+
raise NotImplementedError("Subclasses must implement generate_message method")
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def create(persona_type, **kwargs):
|
34 |
+
"""
|
35 |
+
Factory method to create user personas.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
persona_type (str): The type of user persona to create.
|
39 |
+
**kwargs: Additional arguments for the user persona.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
UserPersona: An instance of the specified user persona.
|
43 |
+
"""
|
44 |
+
persona_classes = {
|
45 |
+
"upset": UpsetCustomer,
|
46 |
+
"human_request": HumanRequestCustomer,
|
47 |
+
"inappropriate": InappropriateCustomer,
|
48 |
+
"incomplete_info": IncompleteInfoCustomer,
|
49 |
+
}
|
50 |
+
|
51 |
+
if persona_type not in persona_classes:
|
52 |
+
raise ValueError(f"Unknown persona type: {persona_type}")
|
53 |
+
|
54 |
+
return persona_classes[persona_type](**kwargs)
|
55 |
+
|
56 |
+
class UpsetCustomer(UserPersona):
|
57 |
+
def __init__(self):
|
58 |
+
super().__init__("Upset Customer")
|
59 |
+
self.complaints = [
|
60 |
+
"This is unacceptable!",
|
61 |
+
"I've been waiting for hours!",
|
62 |
+
"I want to speak to your manager!",
|
63 |
+
"This is the worst service I've ever experienced!",
|
64 |
+
"I'm extremely disappointed with your company!",
|
65 |
+
]
|
66 |
+
|
67 |
+
def generate_message(self, conversation_history):
|
68 |
+
message = random.choice(self.complaints)
|
69 |
+
logger.debug(f"UpsetCustomer generated message: {message}")
|
70 |
+
return message
|
71 |
+
|
72 |
+
class HumanRequestCustomer(UserPersona):
|
73 |
+
def __init__(self):
|
74 |
+
super().__init__("Human Request Customer")
|
75 |
+
self.requests = [
|
76 |
+
"Can I speak to a human representative?",
|
77 |
+
"I don't want to talk to a bot. Get me a real person.",
|
78 |
+
"Is there a way to talk to an actual employee?",
|
79 |
+
"I need to speak with a human agent, not an AI.",
|
80 |
+
"Please transfer me to a live representative.",
|
81 |
+
]
|
82 |
+
|
83 |
+
def generate_message(self, conversation_history):
|
84 |
+
message = random.choice(self.requests)
|
85 |
+
logger.debug(f"HumanRequestCustomer generated message: {message}")
|
86 |
+
return message
|
87 |
+
|
88 |
+
class InappropriateCustomer(UserPersona):
|
89 |
+
def __init__(self):
|
90 |
+
super().__init__("Inappropriate Customer")
|
91 |
+
self.inappropriate_words = ["[INAPPROPRIATE1]", "[INAPPROPRIATE2]", "[INAPPROPRIATE3]"]
|
92 |
+
|
93 |
+
def generate_message(self, conversation_history):
|
94 |
+
message = f"You're a {random.choice(self.inappropriate_words)} and this service is {random.choice(self.inappropriate_words)}!"
|
95 |
+
logger.debug(f"InappropriateCustomer generated message: {message}")
|
96 |
+
return message
|
97 |
+
|
98 |
+
class IncompleteInfoCustomer(UserPersona):
|
99 |
+
def __init__(self):
|
100 |
+
super().__init__("Incomplete Info Customer")
|
101 |
+
self.vague_requests = [
|
102 |
+
"I need help with my account.",
|
103 |
+
"There's a problem with my order.",
|
104 |
+
"Something's not working right.",
|
105 |
+
"I have a question about your service.",
|
106 |
+
"Can you check on the status of my thing?",
|
107 |
+
]
|
108 |
+
|
109 |
+
def generate_message(self, conversation_history):
|
110 |
+
message = random.choice(self.vague_requests)
|
111 |
+
logger.debug(f"IncompleteInfoCustomer generated message: {message}")
|
112 |
+
return message
|
isopro/environments/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Environment classes for the isopro package.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .simulation_environment import SimulationEnvironment
|
6 |
+
from .custom_environment import CustomEnvironment
|
7 |
+
from .llm_orchestrator import LLMOrchestrator
|
8 |
+
|
9 |
+
__all__ = ["SimulationEnvironment", "CustomEnvironment", "LLMOrchestrator"]
|
isopro/environments/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (432 Bytes). View file
|
|
isopro/environments/__pycache__/custom_environment.cpython-38.pyc
ADDED
Binary file (4.18 kB). View file
|
|
isopro/environments/__pycache__/llm_orchestrator.cpython-38.pyc
ADDED
Binary file (7.06 kB). View file
|
|
isopro/environments/__pycache__/simulation_environment.cpython-38.pyc
ADDED
Binary file (2.04 kB). View file
|
|
isopro/environments/custom_environment.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom Environment for creating user-defined simulation environments."""
|
2 |
+
from ..environments.simulation_environment import SimulationEnvironment
|
3 |
+
from ..agents.ai_agent import AI_Agent
|
4 |
+
from ..base.base_component import BaseComponent, agent_component
|
5 |
+
|
6 |
+
class CustomAgent(AI_Agent):
|
7 |
+
"""
|
8 |
+
CustomAgent
|
9 |
+
|
10 |
+
This class defines a custom agent. Users can extend this class to implement their own agents.
|
11 |
+
"""
|
12 |
+
def __init__(self, name, custom_param):
|
13 |
+
"""
|
14 |
+
Initialize the CustomAgent.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
name (str): The name of the agent.
|
18 |
+
custom_param: A custom parameter for the agent.
|
19 |
+
"""
|
20 |
+
super().__init__(name)
|
21 |
+
self.custom_param = custom_param
|
22 |
+
|
23 |
+
def run(self, input_data):
|
24 |
+
"""
|
25 |
+
Run the custom agent.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
input_data (dict): The input data for the agent.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
dict: The processed output data.
|
32 |
+
"""
|
33 |
+
self.logger.info(f"Running custom agent: {self.name} with parameter: {self.custom_param}")
|
34 |
+
# Implement custom behavior here
|
35 |
+
return super().run(input_data)
|
36 |
+
|
37 |
+
@agent_component
|
38 |
+
class CustomComponent(BaseComponent):
|
39 |
+
"""
|
40 |
+
CustomComponent
|
41 |
+
|
42 |
+
This class defines a custom component. Users can extend this class to implement their own components.
|
43 |
+
"""
|
44 |
+
def __init__(self, name, custom_param):
|
45 |
+
"""
|
46 |
+
Initialize the CustomComponent.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
name (str): The name of the component.
|
50 |
+
custom_param: A custom parameter for the component.
|
51 |
+
"""
|
52 |
+
super().__init__(name)
|
53 |
+
self.custom_param = custom_param
|
54 |
+
|
55 |
+
def run(self, input_data):
|
56 |
+
"""
|
57 |
+
Run the custom component.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
input_data (dict): The input data for the component.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
dict: The processed output data.
|
64 |
+
"""
|
65 |
+
self.logger.info(f"Running custom component: {self.name} with parameter: {self.custom_param}")
|
66 |
+
# Implement custom behavior here
|
67 |
+
return input_data
|
68 |
+
|
69 |
+
class CustomEnvironment(SimulationEnvironment):
|
70 |
+
"""
|
71 |
+
CustomEnvironment
|
72 |
+
|
73 |
+
This class provides a template for creating a custom training environment.
|
74 |
+
Users can define their own agents and components, and integrate them into the simulation environment.
|
75 |
+
"""
|
76 |
+
def __init__(self, num_agents=1, custom_param=None):
|
77 |
+
"""
|
78 |
+
Initialize the CustomEnvironment.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
num_agents (int): The number of agents to create.
|
82 |
+
custom_param: A custom parameter for the environment.
|
83 |
+
"""
|
84 |
+
super().__init__()
|
85 |
+
self.num_agents = num_agents
|
86 |
+
self.custom_param = custom_param
|
87 |
+
self._create_custom_agents()
|
88 |
+
|
89 |
+
def _create_custom_agents(self):
|
90 |
+
"""Create custom agents and add them to the environment."""
|
91 |
+
for i in range(self.num_agents):
|
92 |
+
agent = CustomAgent(name=f"Custom Agent {i+1}", custom_param=self.custom_param)
|
93 |
+
component = CustomComponent(name=f"Custom Component {i+1}", custom_param=self.custom_param)
|
94 |
+
agent.add_component(component)
|
95 |
+
self.add_agent(agent)
|
96 |
+
|
97 |
+
def add_custom_agent(self, agent_name, custom_param):
|
98 |
+
"""
|
99 |
+
Add a custom agent to the environment.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
agent_name (str): The name of the agent.
|
103 |
+
custom_param: A custom parameter for the agent.
|
104 |
+
"""
|
105 |
+
agent = CustomAgent(name=agent_name, custom_param=custom_param)
|
106 |
+
component = CustomComponent(name=f"Component for {agent_name}", custom_param=custom_param)
|
107 |
+
agent.add_component(component)
|
108 |
+
self.add_agent(agent)
|
isopro/environments/llm_orchestrator.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM Orchestrator for managing and executing LLM components in various modes.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import heapq
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from typing import List, Any, Optional, Callable
|
9 |
+
from ..base.base_component import BaseComponent
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class ComponentException(Exception):
|
14 |
+
"""Custom exception for component-related errors."""
|
15 |
+
pass
|
16 |
+
|
17 |
+
class LLMOrchestrator:
|
18 |
+
"""
|
19 |
+
LLMOrchestrator manages and executes LLM components in various modes:
|
20 |
+
sequential, parallel, or priority-based node execution.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
"""Initialize the LLMOrchestrator with an empty list of components."""
|
25 |
+
self.components: List[BaseComponent] = []
|
26 |
+
self.priority_function: Optional[Callable[[BaseComponent, Any], int]] = None
|
27 |
+
|
28 |
+
def add_component(self, component: BaseComponent) -> None:
|
29 |
+
"""
|
30 |
+
Add a component to the orchestrator.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
component (BaseComponent): The component to be added.
|
34 |
+
|
35 |
+
Raises:
|
36 |
+
ValueError: If the component is None or not an instance of BaseComponent.
|
37 |
+
"""
|
38 |
+
if component is None:
|
39 |
+
raise ValueError("Cannot add None as a component")
|
40 |
+
if not isinstance(component, BaseComponent):
|
41 |
+
raise ValueError(f"Only BaseComponent instances can be added, got {type(component)}")
|
42 |
+
self.components.append(component)
|
43 |
+
|
44 |
+
def set_priority_function(self, priority_func: Callable[[BaseComponent, Any], int]) -> None:
|
45 |
+
"""
|
46 |
+
Set the priority function for node-based execution.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
priority_func (Callable[[BaseComponent, Any], int]): A function that takes a component
|
50 |
+
and input data, and returns an integer priority value.
|
51 |
+
"""
|
52 |
+
self.priority_function = priority_func
|
53 |
+
|
54 |
+
def run_orchestration(self, mode: str = 'sequence', input_data: Optional[Any] = None) -> List[Any]:
|
55 |
+
"""
|
56 |
+
Run the orchestration in the specified mode.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
mode (str): The execution mode ('sequence', 'parallel', or 'node').
|
60 |
+
input_data (Any, optional): The initial input data for the components.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
List[Any]: The results from all components.
|
64 |
+
|
65 |
+
Raises:
|
66 |
+
ValueError: If an invalid execution mode is specified.
|
67 |
+
"""
|
68 |
+
if not self.components:
|
69 |
+
logger.warning("No components to run")
|
70 |
+
return []
|
71 |
+
|
72 |
+
if mode == 'sequence':
|
73 |
+
return self._run_in_sequence(input_data)
|
74 |
+
elif mode == 'parallel':
|
75 |
+
return self._run_in_parallel(input_data)
|
76 |
+
elif mode == 'node':
|
77 |
+
return self._run_as_node(input_data)
|
78 |
+
else:
|
79 |
+
raise ValueError("Invalid execution mode")
|
80 |
+
|
81 |
+
def _run_in_sequence(self, input_data: Any) -> List[Any]:
|
82 |
+
"""
|
83 |
+
Run components sequentially, passing the output of each as input to the next.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
input_data (Any): The initial input data for the first component.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
List[Any]: The results from all components.
|
90 |
+
"""
|
91 |
+
logger.info("Running in sequence mode")
|
92 |
+
results = []
|
93 |
+
current_input = input_data
|
94 |
+
|
95 |
+
for component in self.components:
|
96 |
+
try:
|
97 |
+
result = self._run_component(component, current_input)
|
98 |
+
results.append(result)
|
99 |
+
current_input = result # Use the output as input for the next component
|
100 |
+
except ComponentException as e:
|
101 |
+
logger.error(f"Error: {e}")
|
102 |
+
results.append(str(e))
|
103 |
+
|
104 |
+
return results
|
105 |
+
|
106 |
+
def _run_in_parallel(self, input_data: Any) -> List[Any]:
|
107 |
+
"""
|
108 |
+
Run components in parallel, providing the same input to all components.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
input_data (Any): The input data for all components.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
List[Any]: The results from all components.
|
115 |
+
"""
|
116 |
+
logger.info("Running in parallel mode")
|
117 |
+
results = []
|
118 |
+
|
119 |
+
with ThreadPoolExecutor() as executor:
|
120 |
+
futures = [executor.submit(self._run_component, component, input_data)
|
121 |
+
for component in self.components]
|
122 |
+
|
123 |
+
for future in futures:
|
124 |
+
try:
|
125 |
+
result = future.result()
|
126 |
+
results.append(result)
|
127 |
+
except ComponentException as e:
|
128 |
+
logger.error(f"Error: {e}")
|
129 |
+
results.append(str(e))
|
130 |
+
|
131 |
+
return results
|
132 |
+
|
133 |
+
def _run_as_node(self, input_data: Any) -> List[Any]:
|
134 |
+
"""
|
135 |
+
Run components in priority-based node mode.
|
136 |
+
|
137 |
+
The priority is defined either by the LLM using reasoning on the best path
|
138 |
+
of solving the problem or designated by the user through the priority_function.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
input_data (Any): The input data for all components.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
List[Any]: The results from all components, ordered by priority.
|
145 |
+
"""
|
146 |
+
logger.info("Running in node mode (priority-based)")
|
147 |
+
results = []
|
148 |
+
|
149 |
+
if self.priority_function is None:
|
150 |
+
logger.warning("No priority function set. Using default priority (0) for all components.")
|
151 |
+
priority_queue = [(0, i, component) for i, component in enumerate(self.components)]
|
152 |
+
else:
|
153 |
+
priority_queue = [(self.priority_function(component, input_data), i, component)
|
154 |
+
for i, component in enumerate(self.components)]
|
155 |
+
|
156 |
+
heapq.heapify(priority_queue)
|
157 |
+
|
158 |
+
while priority_queue:
|
159 |
+
priority, _, component = heapq.heappop(priority_queue)
|
160 |
+
logger.info(f"Running component {component} with priority {priority}")
|
161 |
+
try:
|
162 |
+
result = self._run_component(component, input_data)
|
163 |
+
results.append(result)
|
164 |
+
|
165 |
+
# If the component changes the priority, we need to update the queue
|
166 |
+
if self.priority_function:
|
167 |
+
new_priority = self.priority_function(component, result)
|
168 |
+
if new_priority != priority:
|
169 |
+
heapq.heappush(priority_queue, (new_priority, len(results), component))
|
170 |
+
logger.info(f"Updated priority for component {component}: {priority} -> {new_priority}")
|
171 |
+
|
172 |
+
except ComponentException as e:
|
173 |
+
logger.error(f"Error: {e}")
|
174 |
+
results.append(str(e))
|
175 |
+
|
176 |
+
return results
|
177 |
+
|
178 |
+
def _run_component(self, component: BaseComponent, input_data: Any) -> Any:
|
179 |
+
"""
|
180 |
+
Run a single component with the given input data.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
component (BaseComponent): The component to run.
|
184 |
+
input_data (Any): The input data for the component.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
Any: The result of running the component.
|
188 |
+
|
189 |
+
Raises:
|
190 |
+
ComponentException: If the component doesn't have a callable 'run' method.
|
191 |
+
"""
|
192 |
+
if not hasattr(component, 'run') or not callable(component.run):
|
193 |
+
raise ComponentException(f"Component {component} does not have a callable 'run' method")
|
194 |
+
return component.run(input_data)
|