Delete isopro/examples
Browse files
isopro/examples/__init__.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
ISOPRO Examples
|
3 |
-
|
4 |
-
This package contains example Jupyter notebooks demonstrating various features and use cases of the ISOPRO package.
|
5 |
-
|
6 |
-
Available examples:
|
7 |
-
- custom_environment_example: Demonstrates how to create custom environments using Claude or Hugging Face models.
|
8 |
-
- conversation_simulation_example: Shows how to use the conversation simulation with a Claude agent for customer service.
|
9 |
-
- adversarial_simulation_example: Illustrates how to use the adversarial simulation and analyze its results.
|
10 |
-
|
11 |
-
To run these examples, open the respective .ipynb files in a Jupyter notebook environment.
|
12 |
-
"""
|
13 |
-
|
14 |
-
# Import any shared utilities or constants used across notebooks here
|
15 |
-
# For example:
|
16 |
-
# from .utils import plot_results, load_sample_data
|
17 |
-
|
18 |
-
# List available example notebooks
|
19 |
-
AVAILABLE_EXAMPLES = [
|
20 |
-
"custom_environment_example",
|
21 |
-
"conversation_simulation_example",
|
22 |
-
"adversarial_simulation_example"
|
23 |
-
]
|
24 |
-
|
25 |
-
def list_examples():
|
26 |
-
"""
|
27 |
-
Print a list of available example notebooks.
|
28 |
-
"""
|
29 |
-
print("Available ISOPRO example notebooks:")
|
30 |
-
for example in AVAILABLE_EXAMPLES:
|
31 |
-
print(f"- {example}")
|
32 |
-
print("\nTo run an example, open the corresponding .ipynb file in a Jupyter notebook environment.")
|
33 |
-
|
34 |
-
# You can add any other shared functions or variables here that might be useful across multiple notebooks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isopro/examples/adversarial_example.ipynb
DELETED
@@ -1,242 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {},
|
6 |
-
"source": [
|
7 |
-
"# Adversarial Simulation Notebook\n",
|
8 |
-
"\n",
|
9 |
-
"This notebook demonstrates how to run an adversarial simulation against a language model (in this case, Claude) and analyze the results.\n",
|
10 |
-
"\n",
|
11 |
-
"## Setup\n",
|
12 |
-
"\n",
|
13 |
-
"First, we'll import the necessary libraries and set up our environment."
|
14 |
-
]
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"cell_type": "code",
|
18 |
-
"execution_count": null,
|
19 |
-
"metadata": {},
|
20 |
-
"outputs": [],
|
21 |
-
"source": [
|
22 |
-
"import logging\n",
|
23 |
-
"from typing import List\n",
|
24 |
-
"from isopro.adversarial_simulation import AdversarialSimulator, AdversarialEnvironment\n",
|
25 |
-
"from isopro.utils.analyze_adversarial_sim import analyze_adversarial_results, summarize_adversarial_impact\n",
|
26 |
-
"from isopro.agents.ai_agent import AI_Agent\n",
|
27 |
-
"import anthropic\n",
|
28 |
-
"import os\n",
|
29 |
-
"from dotenv import load_dotenv\n",
|
30 |
-
"import json\n",
|
31 |
-
"from datetime import datetime\n",
|
32 |
-
"import numpy as np\n",
|
33 |
-
"import torch\n",
|
34 |
-
"import matplotlib.pyplot as plt\n",
|
35 |
-
"import seaborn as sns\n",
|
36 |
-
"\n",
|
37 |
-
"# Load environment variables\n",
|
38 |
-
"load_dotenv()\n",
|
39 |
-
"\n",
|
40 |
-
"# Set up logging\n",
|
41 |
-
"logging.basicConfig(level=logging.INFO)\n",
|
42 |
-
"logger = logging.getLogger(__name__)"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "markdown",
|
47 |
-
"metadata": {},
|
48 |
-
"source": [
|
49 |
-
"## Define Helper Classes and Functions\n",
|
50 |
-
"\n",
|
51 |
-
"Now, we'll define our ClaudeAgent class and some helper functions."
|
52 |
-
]
|
53 |
-
},
|
54 |
-
{
|
55 |
-
"cell_type": "code",
|
56 |
-
"execution_count": null,
|
57 |
-
"metadata": {},
|
58 |
-
"outputs": [],
|
59 |
-
"source": [
|
60 |
-
"class ClaudeAgent(AI_Agent):\n",
|
61 |
-
" def __init__(self, name):\n",
|
62 |
-
" super().__init__(name)\n",
|
63 |
-
" self.client = anthropic.Anthropic(api_key=os.getenv(\"ANTHROPIC_API_KEY\"))\n",
|
64 |
-
"\n",
|
65 |
-
" def run(self, input_data):\n",
|
66 |
-
" response = self.client.messages.create(\n",
|
67 |
-
" model=\"claude-3-opus-20240229\",\n",
|
68 |
-
" max_tokens=100,\n",
|
69 |
-
" messages=[{\"role\": \"user\", \"content\": input_data['text']}]\n",
|
70 |
-
" )\n",
|
71 |
-
" return response.content[0].text\n",
|
72 |
-
"\n",
|
73 |
-
" def step(self, sim_state):\n",
|
74 |
-
" sim_state['output'] = self.run(sim_state)\n",
|
75 |
-
" return sim_state\n",
|
76 |
-
"\n",
|
77 |
-
"class NumpyEncoder(json.JSONEncoder):\n",
|
78 |
-
" def default(self, obj):\n",
|
79 |
-
" if isinstance(obj, np.floating):\n",
|
80 |
-
" return float(obj)\n",
|
81 |
-
" if isinstance(obj, torch.Tensor):\n",
|
82 |
-
" return obj.tolist()\n",
|
83 |
-
" return super(NumpyEncoder, self).default(obj)\n",
|
84 |
-
"\n",
|
85 |
-
"def save_scores(output_dir: str, run_id: str, analysis_results: dict) -> None:\n",
|
86 |
-
" os.makedirs(output_dir, exist_ok=True)\n",
|
87 |
-
" output_file = os.path.join(output_dir, f\"adv-{run_id}.json\")\n",
|
88 |
-
" \n",
|
89 |
-
" try:\n",
|
90 |
-
" with open(output_file, 'w') as f:\n",
|
91 |
-
" json.dump(analysis_results, f, indent=2, cls=NumpyEncoder)\n",
|
92 |
-
" \n",
|
93 |
-
" logging.info(f\"Saved analysis results to {output_file}\")\n",
|
94 |
-
" except Exception as e:\n",
|
95 |
-
" logging.error(f\"Error saving analysis results: {str(e)}\")\n",
|
96 |
-
"\n",
|
97 |
-
"def get_sample_inputs() -> List[str]:\n",
|
98 |
-
" return [\n",
|
99 |
-
" \"What is the capital of France?\",\n",
|
100 |
-
" \"How does photosynthesis work?\",\n",
|
101 |
-
" \"Explain the theory of relativity.\"\n",
|
102 |
-
" ]"
|
103 |
-
]
|
104 |
-
},
|
105 |
-
{
|
106 |
-
"cell_type": "markdown",
|
107 |
-
"metadata": {},
|
108 |
-
"source": [
|
109 |
-
"## Run the Adversarial Simulation\n",
|
110 |
-
"\n",
|
111 |
-
"Now we'll set up and run our adversarial simulation."
|
112 |
-
]
|
113 |
-
},
|
114 |
-
{
|
115 |
-
"cell_type": "code",
|
116 |
-
"execution_count": null,
|
117 |
-
"metadata": {},
|
118 |
-
"outputs": [],
|
119 |
-
"source": [
|
120 |
-
"def run_simulation():\n",
|
121 |
-
" run_id = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
|
122 |
-
" logger.info(f\"Starting adversarial simulation run {run_id}\")\n",
|
123 |
-
"\n",
|
124 |
-
" claude_agent = ClaudeAgent(\"Claude Agent\")\n",
|
125 |
-
"\n",
|
126 |
-
" # Create the AdversarialEnvironment\n",
|
127 |
-
" adv_env = AdversarialEnvironment(\n",
|
128 |
-
" agent_wrapper=claude_agent,\n",
|
129 |
-
" num_adversarial_agents=2,\n",
|
130 |
-
" attack_types=[\"textbugger\", \"deepwordbug\"],\n",
|
131 |
-
" attack_targets=[\"input\", \"output\"]\n",
|
132 |
-
" )\n",
|
133 |
-
"\n",
|
134 |
-
" # Set up the adversarial simulator with the environment\n",
|
135 |
-
" simulator = AdversarialSimulator(adv_env)\n",
|
136 |
-
"\n",
|
137 |
-
" input_data = get_sample_inputs()\n",
|
138 |
-
"\n",
|
139 |
-
" logger.info(\"Starting adversarial simulation...\")\n",
|
140 |
-
" simulation_results = simulator.run_simulation(input_data, num_steps=1)\n",
|
141 |
-
"\n",
|
142 |
-
" logger.info(\"Analyzing simulation results...\")\n",
|
143 |
-
" analysis_results = analyze_adversarial_results(simulation_results)\n",
|
144 |
-
"\n",
|
145 |
-
" summary = summarize_adversarial_impact(analysis_results)\n",
|
146 |
-
"\n",
|
147 |
-
" print(\"\\nAdversarial Simulation Summary:\")\n",
|
148 |
-
" print(summary)\n",
|
149 |
-
"\n",
|
150 |
-
" output_dir = \"output\"\n",
|
151 |
-
" save_scores(output_dir, run_id, analysis_results)\n",
|
152 |
-
"\n",
|
153 |
-
" logger.info(\"Simulation complete.\")\n",
|
154 |
-
" \n",
|
155 |
-
" return simulation_results, analysis_results\n",
|
156 |
-
"\n",
|
157 |
-
"# Run the simulation\n",
|
158 |
-
"simulation_results, analysis_results = run_simulation()"
|
159 |
-
]
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"cell_type": "markdown",
|
163 |
-
"metadata": {},
|
164 |
-
"source": [
|
165 |
-
"## Analyze and Visualize Results\n",
|
166 |
-
"\n",
|
167 |
-
"Now that we have our results, let's analyze and visualize them."
|
168 |
-
]
|
169 |
-
},
|
170 |
-
{
|
171 |
-
"cell_type": "code",
|
172 |
-
"execution_count": null,
|
173 |
-
"metadata": {},
|
174 |
-
"outputs": [],
|
175 |
-
"source": [
|
176 |
-
"def plot_metric_changes(analysis_results):\n",
|
177 |
-
" metrics = ['bleu', 'rouge-1', 'rouge-2', 'rouge-l', 'perplexity', 'coherence']\n",
|
178 |
-
" changes = [analysis_results[f'{metric}_change'] for metric in metrics]\n",
|
179 |
-
" \n",
|
180 |
-
" plt.figure(figsize=(12, 6))\n",
|
181 |
-
" sns.barplot(x=metrics, y=changes)\n",
|
182 |
-
" plt.title('Changes in Metrics After Adversarial Attacks')\n",
|
183 |
-
" plt.xlabel('Metrics')\n",
|
184 |
-
" plt.ylabel('Percentage Change')\n",
|
185 |
-
" plt.xticks(rotation=45)\n",
|
186 |
-
" plt.show()\n",
|
187 |
-
"\n",
|
188 |
-
"plot_metric_changes(analysis_results)\n",
|
189 |
-
"\n",
|
190 |
-
"# Display original and perturbed inputs and outputs\n",
|
191 |
-
"for i, result in enumerate(simulation_results):\n",
|
192 |
-
" print(f\"\\nExample {i+1}:\")\n",
|
193 |
-
" print(f\"Original Input: {result['original_input']}\")\n",
|
194 |
-
" print(f\"Perturbed Input: {result['perturbed_input']}\")\n",
|
195 |
-
" print(f\"Original Output: {result['original_output']}\")\n",
|
196 |
-
" print(f\"Perturbed Output: {result['perturbed_output']}\")\n",
|
197 |
-
" print(\"-\" * 50)"
|
198 |
-
]
|
199 |
-
},
|
200 |
-
{
|
201 |
-
"cell_type": "markdown",
|
202 |
-
"metadata": {},
|
203 |
-
"source": [
|
204 |
-
"## Conclusion\n",
|
205 |
-
"\n",
|
206 |
-
"This notebook demonstrates how to run an adversarial simulation against a language model and analyze the results. The simulation applies various adversarial attacks to the input or output of the model and measures the impact on different metrics.\n",
|
207 |
-
"\n",
|
208 |
-
"Key observations:\n",
|
209 |
-
"1. The changes in different metrics (BLEU, ROUGE, perplexity, coherence) show how the adversarial attacks affect the model's performance.\n",
|
210 |
-
"2. By comparing the original and perturbed inputs and outputs, we can see how the attacks modify the text and how the model's responses change as a result.\n",
|
211 |
-
"\n",
|
212 |
-
"This information can be used to assess the robustness of the language model against adversarial attacks and identify areas for improvement in the model's defenses."
|
213 |
-
]
|
214 |
-
}
|
215 |
-
],
|
216 |
-
"metadata": {
|
217 |
-
"kernelspec": {
|
218 |
-
"display_name": "smooth_env",
|
219 |
-
"language": "python",
|
220 |
-
"name": "python3"
|
221 |
-
},
|
222 |
-
"language_info": {
|
223 |
-
"codemirror_mode": {
|
224 |
-
"name": "ipython",
|
225 |
-
"version": 3
|
226 |
-
},
|
227 |
-
"file_extension": ".py",
|
228 |
-
"mimetype": "text/x-python",
|
229 |
-
"name": "python",
|
230 |
-
"nbconvert_exporter": "python",
|
231 |
-
"pygments_lexer": "ipython3",
|
232 |
-
"version": "3.9.18"
|
233 |
-
},
|
234 |
-
"vscode": {
|
235 |
-
"interpreter": {
|
236 |
-
"hash": "e35b4d35af899f01dc238e082b97509c22792197b4b3ae814b774a24a240ad24"
|
237 |
-
}
|
238 |
-
}
|
239 |
-
},
|
240 |
-
"nbformat": 4,
|
241 |
-
"nbformat_minor": 4
|
242 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isopro/examples/conversation_simulation_example.ipynb
DELETED
@@ -1,258 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {},
|
6 |
-
"source": [
|
7 |
-
"# Conversation Simulator\n",
|
8 |
-
"\n",
|
9 |
-
"This notebook demonstrates the usage of the Conversation Simulator from the isopro package. It simulates conversations between an AI assistant (either Claude or GPT-4) and various user personas."
|
10 |
-
]
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"cell_type": "markdown",
|
14 |
-
"metadata": {},
|
15 |
-
"source": [
|
16 |
-
"## Setup\n",
|
17 |
-
"\n",
|
18 |
-
"First, let's import the necessary modules and set up our environment."
|
19 |
-
]
|
20 |
-
},
|
21 |
-
{
|
22 |
-
"cell_type": "code",
|
23 |
-
"execution_count": null,
|
24 |
-
"metadata": {},
|
25 |
-
"outputs": [],
|
26 |
-
"source": [
|
27 |
-
"import logging\n",
|
28 |
-
"from logging.handlers import RotatingFileHandler\n",
|
29 |
-
"import os\n",
|
30 |
-
"from datetime import datetime\n",
|
31 |
-
"from dotenv import load_dotenv\n",
|
32 |
-
"from isopro.conversation_simulation.conversation_simulator import ConversationSimulator\n",
|
33 |
-
"from isopro.conversation_simulation.custom_persona import create_custom_persona\n",
|
34 |
-
"\n",
|
35 |
-
"# Load environment variables\n",
|
36 |
-
"load_dotenv()\n",
|
37 |
-
"\n",
|
38 |
-
"# Set up logging\n",
|
39 |
-
"log_directory = \"logs\"\n",
|
40 |
-
"os.makedirs(log_directory, exist_ok=True)\n",
|
41 |
-
"log_file = os.path.join(log_directory, \"conversation_simulator.log\")\n",
|
42 |
-
"\n",
|
43 |
-
"# Create a rotating file handler\n",
|
44 |
-
"file_handler = RotatingFileHandler(log_file, maxBytes=1024*1024, backupCount=5)\n",
|
45 |
-
"file_handler.setLevel(logging.DEBUG)\n",
|
46 |
-
"file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')\n",
|
47 |
-
"file_handler.setFormatter(file_formatter)\n",
|
48 |
-
"\n",
|
49 |
-
"# Create a console handler\n",
|
50 |
-
"console_handler = logging.StreamHandler()\n",
|
51 |
-
"console_handler.setLevel(logging.INFO)\n",
|
52 |
-
"console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')\n",
|
53 |
-
"console_handler.setFormatter(console_formatter)\n",
|
54 |
-
"\n",
|
55 |
-
"# Set up the logger\n",
|
56 |
-
"logger = logging.getLogger()\n",
|
57 |
-
"logger.setLevel(logging.DEBUG)\n",
|
58 |
-
"logger.addHandler(file_handler)\n",
|
59 |
-
"logger.addHandler(console_handler)\n",
|
60 |
-
"\n",
|
61 |
-
"print(\"Setup complete.\")"
|
62 |
-
]
|
63 |
-
},
|
64 |
-
{
|
65 |
-
"cell_type": "markdown",
|
66 |
-
"metadata": {},
|
67 |
-
"source": [
|
68 |
-
"## Helper Functions\n",
|
69 |
-
"\n",
|
70 |
-
"Next, let's define some helper functions."
|
71 |
-
]
|
72 |
-
},
|
73 |
-
{
|
74 |
-
"cell_type": "code",
|
75 |
-
"execution_count": null,
|
76 |
-
"metadata": {},
|
77 |
-
"outputs": [],
|
78 |
-
"source": [
|
79 |
-
"def save_output(content, filename):\n",
|
80 |
-
" \"\"\"Save the output content to a file.\"\"\"\n",
|
81 |
-
" with open(filename, 'w', encoding='utf-8') as f:\n",
|
82 |
-
" f.write(content)\n",
|
83 |
-
"\n",
|
84 |
-
"def get_user_choice():\n",
|
85 |
-
" \"\"\"Get user's choice of AI model.\"\"\"\n",
|
86 |
-
" while True:\n",
|
87 |
-
" choice = input(\"Choose AI model (claude/openai): \").lower()\n",
|
88 |
-
" if choice in ['claude', 'openai']:\n",
|
89 |
-
" return choice\n",
|
90 |
-
" print(\"Invalid choice. Please enter 'claude' or 'openai'.\")\n",
|
91 |
-
"\n",
|
92 |
-
"print(\"Helper functions defined.\")"
|
93 |
-
]
|
94 |
-
},
|
95 |
-
{
|
96 |
-
"cell_type": "markdown",
|
97 |
-
"metadata": {},
|
98 |
-
"source": [
|
99 |
-
"## Main Simulation Function\n",
|
100 |
-
"\n",
|
101 |
-
"Now, let's define our main simulation function."
|
102 |
-
]
|
103 |
-
},
|
104 |
-
{
|
105 |
-
"cell_type": "code",
|
106 |
-
"execution_count": null,
|
107 |
-
"metadata": {},
|
108 |
-
"outputs": [],
|
109 |
-
"source": [
|
110 |
-
"def run_simulation():\n",
|
111 |
-
" # Get user's choice of AI model\n",
|
112 |
-
" ai_choice = get_user_choice()\n",
|
113 |
-
"\n",
|
114 |
-
" # Set up the appropriate model and API key\n",
|
115 |
-
" if ai_choice == 'claude':\n",
|
116 |
-
" model = \"claude-3-opus-20240229\"\n",
|
117 |
-
" os.environ[\"ANTHROPIC_API_KEY\"] = os.getenv(\"ANTHROPIC_API_KEY\")\n",
|
118 |
-
" ai_name = \"Claude\"\n",
|
119 |
-
" else: # openai\n",
|
120 |
-
" model = \"gpt-4-1106-preview\"\n",
|
121 |
-
" os.environ[\"OPENAI_API_KEY\"] = os.getenv(\"OPENAI_API_KEY\")\n",
|
122 |
-
" ai_name = \"GPT-4 Turbo\"\n",
|
123 |
-
"\n",
|
124 |
-
" # Initialize the ConversationSimulator\n",
|
125 |
-
" simulator = ConversationSimulator(\n",
|
126 |
-
" ai_prompt=f\"You are {ai_name}, an AI assistant created to be helpful, harmless, and honest. You are a customer service agent for a tech company. Respond politely and professionally.\"\n",
|
127 |
-
" )\n",
|
128 |
-
"\n",
|
129 |
-
" output_content = f\"Conversation Simulator using {ai_name} model: {model}\\n\\n\"\n",
|
130 |
-
"\n",
|
131 |
-
" # Run simulations with different personas\n",
|
132 |
-
" personas = [\"upset\", \"human_request\", \"inappropriate\", \"incomplete_info\"]\n",
|
133 |
-
" \n",
|
134 |
-
" for persona in personas:\n",
|
135 |
-
" logger.info(f\"Running simulation with {persona} persona using {ai_name}\")\n",
|
136 |
-
" conversation_history = simulator.run_simulation(persona, num_turns=3)\n",
|
137 |
-
" \n",
|
138 |
-
" output_content += f\"\\nConversation with {persona} persona:\\n\"\n",
|
139 |
-
" for message in conversation_history:\n",
|
140 |
-
" output_line = f\"{message['role'].capitalize()}: {message['content']}\\n\"\n",
|
141 |
-
" output_content += output_line\n",
|
142 |
-
" logger.debug(output_line.strip())\n",
|
143 |
-
" output_content += \"\\n\" + \"-\"*50 + \"\\n\"\n",
|
144 |
-
"\n",
|
145 |
-
" # Create and run a simulation with a custom persona\n",
|
146 |
-
" custom_persona_name = \"Techie Customer\"\n",
|
147 |
-
" custom_characteristics = [\"tech-savvy\", \"impatient\", \"detail-oriented\"]\n",
|
148 |
-
" custom_message_templates = [\n",
|
149 |
-
" \"I've tried rebooting my device, but the error persists. Can you help?\",\n",
|
150 |
-
" \"What's the latest update on the cloud service outage?\",\n",
|
151 |
-
" \"I need specifics on the API rate limits for the enterprise plan.\",\n",
|
152 |
-
" \"The latency on your servers is unacceptable. What's being done about it?\",\n",
|
153 |
-
" \"Can you explain the technical details of your encryption method?\"\n",
|
154 |
-
" ]\n",
|
155 |
-
"\n",
|
156 |
-
" logger.info(f\"Running simulation with custom persona: {custom_persona_name} using {ai_name}\")\n",
|
157 |
-
" custom_conversation = simulator.run_custom_simulation(\n",
|
158 |
-
" custom_persona_name,\n",
|
159 |
-
" custom_characteristics,\n",
|
160 |
-
" custom_message_templates,\n",
|
161 |
-
" num_turns=3\n",
|
162 |
-
" )\n",
|
163 |
-
"\n",
|
164 |
-
" output_content += f\"\\nConversation with {custom_persona_name}:\\n\"\n",
|
165 |
-
" for message in custom_conversation:\n",
|
166 |
-
" output_line = f\"{message['role'].capitalize()}: {message['content']}\\n\"\n",
|
167 |
-
" output_content += output_line\n",
|
168 |
-
" logger.debug(output_line.strip())\n",
|
169 |
-
"\n",
|
170 |
-
" # Save the output to a file\n",
|
171 |
-
" timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
|
172 |
-
" output_directory = \"output\"\n",
|
173 |
-
" os.makedirs(output_directory, exist_ok=True)\n",
|
174 |
-
" output_file = os.path.join(output_directory, f\"{ai_name.lower()}_conversation_output_{timestamp}.txt\")\n",
|
175 |
-
" save_output(output_content, output_file)\n",
|
176 |
-
" logger.info(f\"Output saved to {output_file}\")\n",
|
177 |
-
"\n",
|
178 |
-
" return output_content\n",
|
179 |
-
"\n",
|
180 |
-
"print(\"Main simulation function defined.\")"
|
181 |
-
]
|
182 |
-
},
|
183 |
-
{
|
184 |
-
"cell_type": "markdown",
|
185 |
-
"metadata": {},
|
186 |
-
"source": [
|
187 |
-
"## Run the Simulation\n",
|
188 |
-
"\n",
|
189 |
-
"Now we're ready to run the simulation. This cell will prompt you to choose between Claude and GPT-4, then run the simulation and display the results."
|
190 |
-
]
|
191 |
-
},
|
192 |
-
{
|
193 |
-
"cell_type": "code",
|
194 |
-
"execution_count": null,
|
195 |
-
"metadata": {},
|
196 |
-
"outputs": [],
|
197 |
-
"source": [
|
198 |
-
"simulation_output = run_simulation()\n",
|
199 |
-
"print(simulation_output)"
|
200 |
-
]
|
201 |
-
},
|
202 |
-
{
|
203 |
-
"cell_type": "markdown",
|
204 |
-
"metadata": {},
|
205 |
-
"source": [
|
206 |
-
"## Analyze the Results\n",
|
207 |
-
"\n",
|
208 |
-
"After running the simulation, you can analyze the results here. For example, you might want to count the number of times certain phrases or words were used, or calculate the average length of responses."
|
209 |
-
]
|
210 |
-
},
|
211 |
-
{
|
212 |
-
"cell_type": "code",
|
213 |
-
"execution_count": null,
|
214 |
-
"metadata": {},
|
215 |
-
"outputs": [],
|
216 |
-
"source": [
|
217 |
-
"# Example analysis: Count the number of apologies\n",
|
218 |
-
"apology_count = simulation_output.lower().count(\"sorry\") + simulation_output.lower().count(\"apologi\")\n",
|
219 |
-
"print(f\"Number of apologies: {apology_count}\")\n",
|
220 |
-
"\n",
|
221 |
-
"# Example analysis: Average length of AI responses\n",
|
222 |
-
"ai_responses = [line.split(\": \", 1)[1] for line in simulation_output.split(\"\\n\") if line.startswith(\"Assistant: \")]\n",
|
223 |
-
"avg_response_length = sum(len(response.split()) for response in ai_responses) / len(ai_responses)\n",
|
224 |
-
"print(f\"Average length of AI responses: {avg_response_length:.2f} words\")"
|
225 |
-
]
|
226 |
-
},
|
227 |
-
{
|
228 |
-
"cell_type": "markdown",
|
229 |
-
"metadata": {},
|
230 |
-
"source": [
|
231 |
-
"## Conclusion\n",
|
232 |
-
"\n",
|
233 |
-
"This notebook demonstrates how to use the Conversation Simulator from the isopro package. You can modify the personas, adjust the number of turns, or add your own analysis to further explore the capabilities of the AI models in customer service scenarios."
|
234 |
-
]
|
235 |
-
}
|
236 |
-
],
|
237 |
-
"metadata": {
|
238 |
-
"kernelspec": {
|
239 |
-
"display_name": "Python 3",
|
240 |
-
"language": "python",
|
241 |
-
"name": "python3"
|
242 |
-
},
|
243 |
-
"language_info": {
|
244 |
-
"codemirror_mode": {
|
245 |
-
"name": "ipython",
|
246 |
-
"version": 3
|
247 |
-
},
|
248 |
-
"file_extension": ".py",
|
249 |
-
"mimetype": "text/x-python",
|
250 |
-
"name": "python",
|
251 |
-
"nbconvert_exporter": "python",
|
252 |
-
"pygments_lexer": "ipython3",
|
253 |
-
"version": "3.8.5"
|
254 |
-
}
|
255 |
-
},
|
256 |
-
"nbformat": 4,
|
257 |
-
"nbformat_minor": 4
|
258 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isopro/examples/orchestrator_example.ipynb
DELETED
@@ -1,245 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {},
|
6 |
-
"source": [
|
7 |
-
"# isopro Tutorial: Orchestrator, Evaluator, and Evaluation Modules\n",
|
8 |
-
"\n",
|
9 |
-
"This notebook will guide you through using the `isopro` package, focusing on the orchestrator, evaluator, and evaluation modules. We'll cover installation, setup, and usage examples."
|
10 |
-
]
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"cell_type": "markdown",
|
14 |
-
"metadata": {},
|
15 |
-
"source": [
|
16 |
-
"## 1. Installation\n",
|
17 |
-
"\n",
|
18 |
-
"First, let's install the `isopro` package. Run the following cell to install it using pip:"
|
19 |
-
]
|
20 |
-
},
|
21 |
-
{
|
22 |
-
"cell_type": "code",
|
23 |
-
"execution_count": null,
|
24 |
-
"metadata": {},
|
25 |
-
"outputs": [],
|
26 |
-
"source": [
|
27 |
-
"!pip install isopro"
|
28 |
-
]
|
29 |
-
},
|
30 |
-
{
|
31 |
-
"cell_type": "markdown",
|
32 |
-
"metadata": {},
|
33 |
-
"source": [
|
34 |
-
"## 2. Setup\n",
|
35 |
-
"\n",
|
36 |
-
"Now, let's import the necessary modules and set up our environment. We'll need to set our API keys for OpenAI and Anthropic. In a production environment, you should use environment variables for these keys. For this notebook, we'll set them directly (but remember not to share your notebook with these keys included)."
|
37 |
-
]
|
38 |
-
},
|
39 |
-
{
|
40 |
-
"cell_type": "code",
|
41 |
-
"execution_count": null,
|
42 |
-
"metadata": {},
|
43 |
-
"outputs": [
|
44 |
-
{
|
45 |
-
"ename": "",
|
46 |
-
"evalue": "",
|
47 |
-
"output_type": "error",
|
48 |
-
"traceback": [
|
49 |
-
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
50 |
-
]
|
51 |
-
},
|
52 |
-
{
|
53 |
-
"ename": "",
|
54 |
-
"evalue": "",
|
55 |
-
"output_type": "error",
|
56 |
-
"traceback": [
|
57 |
-
"\u001b[1;31mCanceled future for execute_request message before replies were done"
|
58 |
-
]
|
59 |
-
}
|
60 |
-
],
|
61 |
-
"source": [
|
62 |
-
"import os\n",
|
63 |
-
"from isopro.orchestration_simulation import OrchestrationEnv\n",
|
64 |
-
"from isopro.orchestration_simulation.components import LLaMAAgent, AnalysisAgent, WritingAgent\n",
|
65 |
-
"from isopro.orchestration_simulation.evaluator import Evaluator\n",
|
66 |
-
"from dotenv import load_dotenv\n",
|
67 |
-
"\n",
|
68 |
-
"# Load environment variables from .env file\n",
|
69 |
-
"load_dotenv()\n",
|
70 |
-
"\n",
|
71 |
-
"# Access API keys from environment variables\n",
|
72 |
-
"openai_api_key = os.getenv(\"OPENAI_API_KEY\")"
|
73 |
-
]
|
74 |
-
},
|
75 |
-
{
|
76 |
-
"cell_type": "markdown",
|
77 |
-
"metadata": {},
|
78 |
-
"source": [
|
79 |
-
"## 3. Creating the Orchestration Environment\n",
|
80 |
-
"\n",
|
81 |
-
"Let's create our orchestration environment and add our agents to it."
|
82 |
-
]
|
83 |
-
},
|
84 |
-
{
|
85 |
-
"cell_type": "code",
|
86 |
-
"execution_count": null,
|
87 |
-
"metadata": {},
|
88 |
-
"outputs": [],
|
89 |
-
"source": [
|
90 |
-
"# Create the orchestration environment\n",
|
91 |
-
"env = OrchestrationEnv()\n",
|
92 |
-
"\n",
|
93 |
-
"# Add agents to the environment\n",
|
94 |
-
"env.add_component(LLaMAAgent(\"Research\", \"conduct thorough research on the impact of artificial intelligence on job markets in the next decade\"))\n",
|
95 |
-
"env.add_component(AnalysisAgent(\"Analysis\"))\n",
|
96 |
-
"env.add_component(WritingAgent(\"Writing\"))\n",
|
97 |
-
"\n",
|
98 |
-
"print(\"Orchestration environment created with agents added!\")"
|
99 |
-
]
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"cell_type": "markdown",
|
103 |
-
"metadata": {},
|
104 |
-
"source": [
|
105 |
-
"## 4. Defining the Task\n",
|
106 |
-
"\n",
|
107 |
-
"Now, let's define the task that our agents will work on."
|
108 |
-
]
|
109 |
-
},
|
110 |
-
{
|
111 |
-
"cell_type": "code",
|
112 |
-
"execution_count": null,
|
113 |
-
"metadata": {},
|
114 |
-
"outputs": [],
|
115 |
-
"source": [
|
116 |
-
"task = \"Prepare a comprehensive report on the impact of artificial intelligence on job markets in the next decade.\"\n",
|
117 |
-
"print(f\"Task defined: {task}\")"
|
118 |
-
]
|
119 |
-
},
|
120 |
-
{
|
121 |
-
"cell_type": "markdown",
|
122 |
-
"metadata": {},
|
123 |
-
"source": [
|
124 |
-
"## 5. Running Simulations in Different Modes\n",
|
125 |
-
"\n",
|
126 |
-
"We'll now run our simulation in different modes: parallel, sequence, and node."
|
127 |
-
]
|
128 |
-
},
|
129 |
-
{
|
130 |
-
"cell_type": "code",
|
131 |
-
"execution_count": null,
|
132 |
-
"metadata": {},
|
133 |
-
"outputs": [],
|
134 |
-
"source": [
|
135 |
-
"modes = ['parallel', 'sequence', 'node']\n",
|
136 |
-
"results = {}\n",
|
137 |
-
"\n",
|
138 |
-
"for mode in modes:\n",
|
139 |
-
" print(f\"\\nRunning simulation in {mode} mode...\")\n",
|
140 |
-
" result = env.run_simulation(mode=mode, input_data={'task': task, 'run_order': 'first'})\n",
|
141 |
-
" results[mode] = result\n",
|
142 |
-
" print(f\"Simulation in {mode} mode completed.\")\n",
|
143 |
-
"\n",
|
144 |
-
"print(\"\\nAll simulations completed!\")"
|
145 |
-
]
|
146 |
-
},
|
147 |
-
{
|
148 |
-
"cell_type": "markdown",
|
149 |
-
"metadata": {},
|
150 |
-
"source": [
|
151 |
-
"## 6. Evaluating the Results\n",
|
152 |
-
"\n",
|
153 |
-
"Now that we have our results, let's use the Evaluator to determine which mode performed best."
|
154 |
-
]
|
155 |
-
},
|
156 |
-
{
|
157 |
-
"cell_type": "code",
|
158 |
-
"execution_count": 1,
|
159 |
-
"metadata": {},
|
160 |
-
"outputs": [
|
161 |
-
{
|
162 |
-
"ename": "NameError",
|
163 |
-
"evalue": "name 'Evaluator' is not defined",
|
164 |
-
"output_type": "error",
|
165 |
-
"traceback": [
|
166 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
167 |
-
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
168 |
-
"\u001b[0;32m<ipython-input-1-a86bfe25b9d1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mevaluator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEvaluator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mbest_mode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"\\nEvaluation complete. The best execution mode for this task was: {best_mode}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
169 |
-
"\u001b[0;31mNameError\u001b[0m: name 'Evaluator' is not defined"
|
170 |
-
]
|
171 |
-
}
|
172 |
-
],
|
173 |
-
"source": [
|
174 |
-
"evaluator = Evaluator()\n",
|
175 |
-
"best_mode = evaluator.evaluate(results)\n",
|
176 |
-
"\n",
|
177 |
-
"print(f\"\\nEvaluation complete. The best execution mode for this task was: {best_mode}\")"
|
178 |
-
]
|
179 |
-
},
|
180 |
-
{
|
181 |
-
"cell_type": "markdown",
|
182 |
-
"metadata": {},
|
183 |
-
"source": [
|
184 |
-
"## 7. Examining the Results\n",
|
185 |
-
"\n",
|
186 |
-
"Let's take a closer look at the results from each mode."
|
187 |
-
]
|
188 |
-
},
|
189 |
-
{
|
190 |
-
"cell_type": "code",
|
191 |
-
"execution_count": null,
|
192 |
-
"metadata": {},
|
193 |
-
"outputs": [],
|
194 |
-
"source": [
|
195 |
-
"for mode, result in results.items():\n",
|
196 |
-
" print(f\"\\nResults for {mode} mode:\")\n",
|
197 |
-
" print(f\"Execution Time: {result.get('execution_time', 'N/A')} seconds\")\n",
|
198 |
-
" print(f\"Memory Usage: {result.get('memory_usage', 'N/A')} MB\")\n",
|
199 |
-
" print(f\"Output Sample: {result.get('output', 'N/A')[:200]}...\")"
|
200 |
-
]
|
201 |
-
},
|
202 |
-
{
|
203 |
-
"cell_type": "markdown",
|
204 |
-
"metadata": {},
|
205 |
-
"source": [
|
206 |
-
"## 8. Conclusion\n",
|
207 |
-
"\n",
|
208 |
-
"In this tutorial, we've learned how to:\n",
|
209 |
-
"1. Set up the isopro package\n",
|
210 |
-
"2. Create an orchestration environment and add agents\n",
|
211 |
-
"3. Run simulations in different modes\n",
|
212 |
-
"4. Use the Evaluator to determine the best execution mode\n",
|
213 |
-
"5. Examine the results of our simulations\n",
|
214 |
-
"\n",
|
215 |
-
"This demonstrates the power and flexibility of the isopro package for orchestrating AI agents and evaluating their performance in different execution modes."
|
216 |
-
]
|
217 |
-
}
|
218 |
-
],
|
219 |
-
"metadata": {
|
220 |
-
"kernelspec": {
|
221 |
-
"display_name": "smooth_env",
|
222 |
-
"language": "python",
|
223 |
-
"name": "python3"
|
224 |
-
},
|
225 |
-
"language_info": {
|
226 |
-
"codemirror_mode": {
|
227 |
-
"name": "ipython",
|
228 |
-
"version": 3
|
229 |
-
},
|
230 |
-
"file_extension": ".py",
|
231 |
-
"mimetype": "text/x-python",
|
232 |
-
"name": "python",
|
233 |
-
"nbconvert_exporter": "python",
|
234 |
-
"pygments_lexer": "ipython3",
|
235 |
-
"version": "3.9.18"
|
236 |
-
},
|
237 |
-
"vscode": {
|
238 |
-
"interpreter": {
|
239 |
-
"hash": "e35b4d35af899f01dc238e082b97509c22792197b4b3ae814b774a24a240ad24"
|
240 |
-
}
|
241 |
-
}
|
242 |
-
},
|
243 |
-
"nbformat": 4,
|
244 |
-
"nbformat_minor": 4
|
245 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isopro/examples/run_cartpole_example.ipynb
DELETED
@@ -1,403 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {},
|
6 |
-
"source": [
|
7 |
-
"# LLM-based CartPole Reinforcement Learning Agent\n",
|
8 |
-
"\n",
|
9 |
-
"This notebook demonstrates how to create and train a Reinforcement Learning agent that uses a Large Language Model (LLM) to make decisions in the CartPole environment.\n",
|
10 |
-
"\n",
|
11 |
-
"## Setup\n",
|
12 |
-
"\n",
|
13 |
-
"First, let's import the necessary libraries and set up our environment."
|
14 |
-
]
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"cell_type": "code",
|
18 |
-
"execution_count": null,
|
19 |
-
"metadata": {},
|
20 |
-
"outputs": [
|
21 |
-
{
|
22 |
-
"name": "stdout",
|
23 |
-
"output_type": "stream",
|
24 |
-
"text": [
|
25 |
-
"Collecting isopro\n",
|
26 |
-
" Downloading isopro-0.1.2-py3-none-any.whl (60 kB)\n",
|
27 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.4/60.4 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
28 |
-
"\u001b[?25hCollecting tqdm\n",
|
29 |
-
" Using cached tqdm-4.66.5-py3-none-any.whl (78 kB)\n",
|
30 |
-
"Collecting gymnasium\n",
|
31 |
-
" Using cached gymnasium-0.29.1-py3-none-any.whl (953 kB)\n",
|
32 |
-
"Collecting transformers\n",
|
33 |
-
" Downloading transformers-4.45.0-py3-none-any.whl (9.9 MB)\n",
|
34 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.9/9.9 MB\u001b[0m \u001b[31m36.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
35 |
-
"\u001b[?25hCollecting seaborn\n",
|
36 |
-
" Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n",
|
37 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m294.9/294.9 kB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
38 |
-
"\u001b[?25hCollecting rouge\n",
|
39 |
-
" Using cached rouge-1.0.1-py3-none-any.whl (13 kB)\n",
|
40 |
-
"Collecting langchain-openai\n",
|
41 |
-
" Downloading langchain_openai-0.1.25-py3-none-any.whl (51 kB)\n",
|
42 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m51.5/51.5 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
43 |
-
"\u001b[?25hCollecting nltk\n",
|
44 |
-
" Using cached nltk-3.9.1-py3-none-any.whl (1.5 MB)\n",
|
45 |
-
"Collecting scikit-learn\n",
|
46 |
-
" Downloading scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl (10.1 MB)\n",
|
47 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.1/10.1 MB\u001b[0m \u001b[31m55.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m0:01\u001b[0m\n",
|
48 |
-
"\u001b[?25hCollecting matplotlib\n",
|
49 |
-
" Downloading matplotlib-3.7.5-cp38-cp38-macosx_10_12_x86_64.whl (7.4 MB)\n",
|
50 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m53.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
51 |
-
"\u001b[?25hCollecting openai\n",
|
52 |
-
" Downloading openai-1.48.0-py3-none-any.whl (376 kB)\n",
|
53 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m376.1/376.1 kB\u001b[0m \u001b[31m38.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
54 |
-
"\u001b[?25hCollecting stable-baselines3\n",
|
55 |
-
" Using cached stable_baselines3-2.3.2-py3-none-any.whl (182 kB)\n",
|
56 |
-
"Collecting torch\n",
|
57 |
-
" Using cached torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl (150.6 MB)\n",
|
58 |
-
"Collecting anthropic\n",
|
59 |
-
" Downloading anthropic-0.34.2-py3-none-any.whl (891 kB)\n",
|
60 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m891.9/891.9 kB\u001b[0m \u001b[31m46.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
61 |
-
"\u001b[?25hCollecting iso-adverse\n",
|
62 |
-
" Using cached iso_adverse-0.2.0-py3-none-any.whl (12 kB)\n",
|
63 |
-
"Requirement already satisfied: numpy in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from isopro) (1.24.2)\n",
|
64 |
-
"Collecting langchain\n",
|
65 |
-
" Downloading langchain-0.2.16-py3-none-any.whl (1.0 MB)\n",
|
66 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m37.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
67 |
-
"\u001b[?25hCollecting sentence-transformers\n",
|
68 |
-
" Downloading sentence_transformers-3.1.1-py3-none-any.whl (245 kB)\n",
|
69 |
-
"\u001b[2K \u001b[90m━━━���━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m245.3/245.3 kB\u001b[0m \u001b[31m23.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
70 |
-
"\u001b[?25hCollecting python-dotenv\n",
|
71 |
-
" Using cached python_dotenv-1.0.1-py3-none-any.whl (19 kB)\n",
|
72 |
-
"Collecting distro<2,>=1.7.0\n",
|
73 |
-
" Using cached distro-1.9.0-py3-none-any.whl (20 kB)\n",
|
74 |
-
"Collecting tokenizers>=0.13.0\n",
|
75 |
-
" Downloading tokenizers-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl (2.6 MB)\n",
|
76 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.6/2.6 MB\u001b[0m \u001b[31m51.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
|
77 |
-
"\u001b[?25hRequirement already satisfied: anyio<5,>=3.5.0 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from anthropic->isopro) (3.6.2)\n",
|
78 |
-
"Collecting jiter<1,>=0.4.0\n",
|
79 |
-
" Using cached jiter-0.5.0-cp38-cp38-macosx_10_12_x86_64.whl (284 kB)\n",
|
80 |
-
"Collecting typing-extensions<5,>=4.7\n",
|
81 |
-
" Using cached typing_extensions-4.12.2-py3-none-any.whl (37 kB)\n",
|
82 |
-
"Requirement already satisfied: sniffio in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from anthropic->isopro) (1.3.0)\n",
|
83 |
-
"Collecting httpx<1,>=0.23.0\n",
|
84 |
-
" Downloading httpx-0.27.2-py3-none-any.whl (76 kB)\n",
|
85 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.4/76.4 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
86 |
-
"\u001b[?25hCollecting pydantic<3,>=1.9.0\n",
|
87 |
-
" Downloading pydantic-2.9.2-py3-none-any.whl (434 kB)\n",
|
88 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.9/434.9 kB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
89 |
-
"\u001b[?25hCollecting cloudpickle>=1.2.0\n",
|
90 |
-
" Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\n",
|
91 |
-
"Collecting farama-notifications>=0.0.1\n",
|
92 |
-
" Using cached Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n",
|
93 |
-
"Requirement already satisfied: importlib-metadata>=4.8.0 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from gymnasium->isopro) (6.0.0)\n",
|
94 |
-
"Collecting filelock\n",
|
95 |
-
" Downloading filelock-3.16.1-py3-none-any.whl (16 kB)\n",
|
96 |
-
"Requirement already satisfied: jinja2 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from torch->isopro) (3.1.2)\n",
|
97 |
-
"Collecting sympy\n",
|
98 |
-
" Downloading sympy-1.13.3-py3-none-any.whl (6.2 MB)\n",
|
99 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.2/6.2 MB\u001b[0m \u001b[31m51.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
100 |
-
"\u001b[?25hCollecting fsspec\n",
|
101 |
-
" Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n",
|
102 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m22.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
103 |
-
"\u001b[?25hCollecting networkx\n",
|
104 |
-
" Using cached networkx-3.1-py3-none-any.whl (2.1 MB)\n",
|
105 |
-
"Requirement already satisfied: packaging>=20.0 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from transformers->isopro) (23.0)\n",
|
106 |
-
"Collecting safetensors>=0.4.1\n",
|
107 |
-
" Downloading safetensors-0.4.5-cp38-cp38-macosx_10_12_x86_64.whl (392 kB)\n",
|
108 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m392.9/392.9 kB\u001b[0m \u001b[31m21.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
109 |
-
"\u001b[?25hCollecting huggingface-hub<1.0,>=0.23.2\n",
|
110 |
-
" Downloading huggingface_hub-0.25.1-py3-none-any.whl (436 kB)\n",
|
111 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m436.4/436.4 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
112 |
-
"\u001b[?25hCollecting regex!=2019.12.17\n",
|
113 |
-
" Downloading regex-2024.9.11-cp38-cp38-macosx_10_9_x86_64.whl (287 kB)\n",
|
114 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m287.5/287.5 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
115 |
-
"\u001b[?25hRequirement already satisfied: requests in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from transformers->isopro) (2.28.2)\n",
|
116 |
-
"Requirement already satisfied: pyyaml>=5.1 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from transformers->isopro) (6.0)\n",
|
117 |
-
"Collecting langchain-core<0.3.0,>=0.2.38\n",
|
118 |
-
" Downloading langchain_core-0.2.41-py3-none-any.whl (397 kB)\n",
|
119 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m397.0/397.0 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
120 |
-
"\u001b[?25hCollecting SQLAlchemy<3,>=1.4\n",
|
121 |
-
" Downloading SQLAlchemy-2.0.35-cp38-cp38-macosx_10_9_x86_64.whl (2.1 MB)\n",
|
122 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m60.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
123 |
-
"\u001b[?25hCollecting async-timeout<5.0.0,>=4.0.0\n",
|
124 |
-
" Using cached async_timeout-4.0.3-py3-none-any.whl (5.7 kB)\n",
|
125 |
-
"Collecting aiohttp<4.0.0,>=3.8.3\n",
|
126 |
-
" Downloading aiohttp-3.10.6-cp38-cp38-macosx_10_9_x86_64.whl (401 kB)\n",
|
127 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m401.0/401.0 kB\u001b[0m \u001b[31m28.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
128 |
-
"\u001b[?25hCollecting tenacity!=8.4.0,<9.0.0,>=8.1.0\n",
|
129 |
-
" Using cached tenacity-8.5.0-py3-none-any.whl (28 kB)\n",
|
130 |
-
"Collecting langsmith<0.2.0,>=0.1.17\n",
|
131 |
-
" Downloading langsmith-0.1.128-py3-none-any.whl (292 kB)\n",
|
132 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m292.1/292.1 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
133 |
-
"\u001b[?25hCollecting langchain-text-splitters<0.3.0,>=0.2.0\n",
|
134 |
-
" Downloading langchain_text_splitters-0.2.4-py3-none-any.whl (25 kB)\n",
|
135 |
-
"Collecting tiktoken<1,>=0.7\n",
|
136 |
-
" Using cached tiktoken-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl (961 kB)\n",
|
137 |
-
"Collecting pyparsing>=2.3.1\n",
|
138 |
-
" Using cached pyparsing-3.1.4-py3-none-any.whl (104 kB)\n",
|
139 |
-
"Requirement already satisfied: importlib-resources>=3.2.0 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from matplotlib->isopro) (5.10.2)\n",
|
140 |
-
"Collecting fonttools>=4.22.0\n",
|
141 |
-
" Downloading fonttools-4.54.1-cp38-cp38-macosx_10_9_universal2.whl (2.8 MB)\n",
|
142 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.8/2.8 MB\u001b[0m \u001b[31m58.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
|
143 |
-
"\u001b[?25hCollecting contourpy>=1.0.1\n",
|
144 |
-
" Downloading contourpy-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl (247 kB)\n",
|
145 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m247.0/247.0 kB\u001b[0m \u001b[31m30.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
146 |
-
"\u001b[?25hCollecting pillow>=6.2.0\n",
|
147 |
-
" Using cached pillow-10.4.0-cp38-cp38-macosx_10_10_x86_64.whl (3.5 MB)\n",
|
148 |
-
"Requirement already satisfied: python-dateutil>=2.7 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from matplotlib->isopro) (2.8.2)\n",
|
149 |
-
"Collecting kiwisolver>=1.0.1\n",
|
150 |
-
" Downloading kiwisolver-1.4.7-cp38-cp38-macosx_10_9_x86_64.whl (65 kB)\n",
|
151 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.7/65.7 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
152 |
-
"\u001b[?25hCollecting cycler>=0.10\n",
|
153 |
-
" Using cached cycler-0.12.1-py3-none-any.whl (8.3 kB)\n",
|
154 |
-
"Collecting click\n",
|
155 |
-
" Using cached click-8.1.7-py3-none-any.whl (97 kB)\n",
|
156 |
-
"Collecting joblib\n",
|
157 |
-
" Using cached joblib-1.4.2-py3-none-any.whl (301 kB)\n",
|
158 |
-
"Requirement already satisfied: six in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from rouge->isopro) (1.16.0)\n",
|
159 |
-
"Collecting scipy>=1.5.0\n",
|
160 |
-
" Downloading scipy-1.10.1-cp38-cp38-macosx_10_9_x86_64.whl (35.0 MB)\n",
|
161 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.0/35.0 MB\u001b[0m \u001b[31m43.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
162 |
-
"\u001b[?25hCollecting threadpoolctl>=2.0.0\n",
|
163 |
-
" Using cached threadpoolctl-3.5.0-py3-none-any.whl (18 kB)\n",
|
164 |
-
"Collecting pandas>=1.2\n",
|
165 |
-
" Downloading pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl (11.7 MB)\n",
|
166 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.7/11.7 MB\u001b[0m \u001b[31m54.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
167 |
-
"\u001b[?25hCollecting multidict<7.0,>=4.5\n",
|
168 |
-
" Downloading multidict-6.1.0-cp38-cp38-macosx_10_9_x86_64.whl (29 kB)\n",
|
169 |
-
"Collecting yarl<2.0,>=1.12.0\n",
|
170 |
-
" Downloading yarl-1.12.1-cp38-cp38-macosx_10_9_x86_64.whl (116 kB)\n",
|
171 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.6/116.6 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
172 |
-
"\u001b[?25hCollecting aiosignal>=1.1.2\n",
|
173 |
-
" Using cached aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
|
174 |
-
"Requirement already satisfied: attrs>=17.3.0 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain->isopro) (22.2.0)\n",
|
175 |
-
"Collecting aiohappyeyeballs>=2.3.0\n",
|
176 |
-
" Using cached aiohappyeyeballs-2.4.0-py3-none-any.whl (12 kB)\n",
|
177 |
-
"Collecting frozenlist>=1.1.1\n",
|
178 |
-
" Using cached frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl (55 kB)\n",
|
179 |
-
"Requirement already satisfied: idna>=2.8 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from anyio<5,>=3.5.0->anthropic->isopro) (3.4)\n",
|
180 |
-
"Collecting httpcore==1.*\n",
|
181 |
-
" Using cached httpcore-1.0.5-py3-none-any.whl (77 kB)\n",
|
182 |
-
"Requirement already satisfied: certifi in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from httpx<1,>=0.23.0->anthropic->isopro) (2022.12.7)\n",
|
183 |
-
"Collecting h11<0.15,>=0.13\n",
|
184 |
-
" Using cached h11-0.14.0-py3-none-any.whl (58 kB)\n",
|
185 |
-
"Requirement already satisfied: zipp>=0.5 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from importlib-metadata>=4.8.0->gymnasium->isopro) (3.13.0)\n",
|
186 |
-
"Collecting packaging>=20.0\n",
|
187 |
-
" Using cached packaging-24.1-py3-none-any.whl (53 kB)\n",
|
188 |
-
"Collecting jsonpatch<2.0,>=1.33\n",
|
189 |
-
" Using cached jsonpatch-1.33-py2.py3-none-any.whl (12 kB)\n",
|
190 |
-
"Collecting orjson<4.0.0,>=3.9.14\n",
|
191 |
-
" Downloading orjson-3.10.7-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl (251 kB)\n",
|
192 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.1/251.1 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
193 |
-
"\u001b[?25hCollecting tzdata>=2022.1\n",
|
194 |
-
" Downloading tzdata-2024.2-py2.py3-none-any.whl (346 kB)\n",
|
195 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m346.6/346.6 kB\u001b[0m \u001b[31m27.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
196 |
-
"\u001b[?25hCollecting pytz>=2020.1\n",
|
197 |
-
" Using cached pytz-2024.2-py2.py3-none-any.whl (508 kB)\n",
|
198 |
-
"Collecting annotated-types>=0.6.0\n",
|
199 |
-
" Using cached annotated_types-0.7.0-py3-none-any.whl (13 kB)\n",
|
200 |
-
"Collecting pydantic-core==2.23.4\n",
|
201 |
-
" Downloading pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl (1.9 MB)\n",
|
202 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.9/1.9 MB\u001b[0m \u001b[31m39.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
|
203 |
-
"\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from requests->transformers->isopro) (3.0.1)\n",
|
204 |
-
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from requests->transformers->isopro) (1.26.14)\n",
|
205 |
-
"Collecting greenlet!=0.4.17\n",
|
206 |
-
" Downloading greenlet-3.1.1.tar.gz (186 kB)\n",
|
207 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m186.0/186.0 kB\u001b[0m \u001b[31m22.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
208 |
-
"\u001b[?25h Installing build dependencies ... \u001b[?25ldone\n",
|
209 |
-
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
|
210 |
-
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
211 |
-
"\u001b[?25hRequirement already satisfied: MarkupSafe>=2.0 in /Users/jazmiahenry/toy_genai/env/lib/python3.8/site-packages (from jinja2->torch->isopro) (2.1.2)\n",
|
212 |
-
"Collecting mpmath<1.4,>=1.1.0\n",
|
213 |
-
" Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)\n",
|
214 |
-
"Collecting jsonpointer>=1.9\n",
|
215 |
-
" Using cached jsonpointer-3.0.0-py2.py3-none-any.whl (7.6 kB)\n",
|
216 |
-
"Building wheels for collected packages: greenlet\n",
|
217 |
-
" Building wheel for greenlet (pyproject.toml) ... \u001b[?25ldone\n",
|
218 |
-
"\u001b[?25h Created wheel for greenlet: filename=greenlet-3.1.1-cp38-cp38-macosx_10_9_x86_64.whl size=228270 sha256=61660bb35fa5416d14ab65bd473051c7c9f723b524837a5fd0c58d21fb4818bd\n",
|
219 |
-
" Stored in directory: /Users/jazmiahenry/Library/Caches/pip/wheels/ba/f9/e2/f8e444bf385c014fea09ef24bde9b85486657505f51396875f\n",
|
220 |
-
"Successfully built greenlet\n",
|
221 |
-
"Installing collected packages: pytz, mpmath, farama-notifications, tzdata, typing-extensions, tqdm, threadpoolctl, tenacity, sympy, scipy, safetensors, rouge, regex, python-dotenv, pyparsing, pillow, packaging, orjson, networkx, kiwisolver, jsonpointer, joblib, jiter, h11, greenlet, fsspec, frozenlist, fonttools, filelock, distro, cycler, contourpy, cloudpickle, click, async-timeout, aiohappyeyeballs, torch, tiktoken, SQLAlchemy, scikit-learn, pydantic-core, pandas, nltk, multidict, matplotlib, jsonpatch, huggingface-hub, httpcore, gymnasium, annotated-types, aiosignal, yarl, tokenizers, stable-baselines3, seaborn, pydantic, httpx, transformers, openai, langsmith, anthropic, aiohttp, sentence-transformers, langchain-core, iso-adverse, langchain-text-splitters, langchain-openai, langchain, isopro\n",
|
222 |
-
" Attempting uninstall: typing-extensions\n",
|
223 |
-
" Found existing installation: typing_extensions 4.4.0\n",
|
224 |
-
" Uninstalling typing_extensions-4.4.0:\n",
|
225 |
-
" Successfully uninstalled typing_extensions-4.4.0\n",
|
226 |
-
" Attempting uninstall: packaging\n",
|
227 |
-
" Found existing installation: packaging 23.0\n",
|
228 |
-
" Uninstalling packaging-23.0:\n",
|
229 |
-
" Successfully uninstalled packaging-23.0\n",
|
230 |
-
"Successfully installed SQLAlchemy-2.0.35 aiohappyeyeballs-2.4.0 aiohttp-3.10.6 aiosignal-1.3.1 annotated-types-0.7.0 anthropic-0.34.2 async-timeout-4.0.3 click-8.1.7 cloudpickle-3.0.0 contourpy-1.1.1 cycler-0.12.1 distro-1.9.0 farama-notifications-0.0.4 filelock-3.16.1 fonttools-4.54.1 frozenlist-1.4.1 fsspec-2024.9.0 greenlet-3.1.1 gymnasium-0.29.1 h11-0.14.0 httpcore-1.0.5 httpx-0.27.2 huggingface-hub-0.25.1 iso-adverse-0.2.0 isopro-0.1.2 jiter-0.5.0 joblib-1.4.2 jsonpatch-1.33 jsonpointer-3.0.0 kiwisolver-1.4.7 langchain-0.2.16 langchain-core-0.2.41 langchain-openai-0.1.25 langchain-text-splitters-0.2.4 langsmith-0.1.128 matplotlib-3.7.5 mpmath-1.3.0 multidict-6.1.0 networkx-3.1 nltk-3.9.1 openai-1.48.0 orjson-3.10.7 packaging-24.1 pandas-2.0.3 pillow-10.4.0 pydantic-2.9.2 pydantic-core-2.23.4 pyparsing-3.1.4 python-dotenv-1.0.1 pytz-2024.2 regex-2024.9.11 rouge-1.0.1 safetensors-0.4.5 scikit-learn-1.3.2 scipy-1.10.1 seaborn-0.13.2 sentence-transformers-3.1.1 stable-baselines3-2.3.2 sympy-1.13.3 tenacity-8.5.0 threadpoolctl-3.5.0 tiktoken-0.7.0 tokenizers-0.20.0 torch-2.2.2 tqdm-4.66.5 transformers-4.45.0 typing-extensions-4.12.2 tzdata-2024.2 yarl-1.12.1\n",
|
231 |
-
"\n",
|
232 |
-
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n",
|
233 |
-
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
234 |
-
]
|
235 |
-
}
|
236 |
-
],
|
237 |
-
"source": [
|
238 |
-
"!pip install isopro"
|
239 |
-
]
|
240 |
-
},
|
241 |
-
{
|
242 |
-
"cell_type": "code",
|
243 |
-
"execution_count": 1,
|
244 |
-
"metadata": {},
|
245 |
-
"outputs": [
|
246 |
-
{
|
247 |
-
"ename": "",
|
248 |
-
"evalue": "",
|
249 |
-
"output_type": "error",
|
250 |
-
"traceback": [
|
251 |
-
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
252 |
-
]
|
253 |
-
},
|
254 |
-
{
|
255 |
-
"ename": "",
|
256 |
-
"evalue": "",
|
257 |
-
"output_type": "error",
|
258 |
-
"traceback": [
|
259 |
-
"\u001b[1;31mCanceled future for execute_request message before replies were done"
|
260 |
-
]
|
261 |
-
}
|
262 |
-
],
|
263 |
-
"source": [
|
264 |
-
"import gymnasium as gym\n",
|
265 |
-
"from isopro.rl.rl_agent import RLAgent\n",
|
266 |
-
"from isopro.rl.rl_environment import LLMRLEnvironment\n",
|
267 |
-
"from stable_baselines3 import PPO\n",
|
268 |
-
"import numpy as np\n",
|
269 |
-
"import anthropic\n",
|
270 |
-
"import os\n",
|
271 |
-
"import logging\n",
|
272 |
-
"from typing import Optional, Dict, Any\n",
|
273 |
-
"from tqdm import tqdm\n",
|
274 |
-
"import json\n",
|
275 |
-
"from datetime import datetime\n",
|
276 |
-
"from llm_cartpole_wrapper import LLMCartPoleWrapper\n",
|
277 |
-
"\n",
|
278 |
-
"# Set up logging\n",
|
279 |
-
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
|
280 |
-
"logger = logging.getLogger(__name__)"
|
281 |
-
]
|
282 |
-
},
|
283 |
-
{
|
284 |
-
"cell_type": "markdown",
|
285 |
-
"metadata": {},
|
286 |
-
"source": [
|
287 |
-
"## Create and Train the RL Agent\n",
|
288 |
-
"\n",
|
289 |
-
"Now, let's create our RL agent and train it using the LLM-based CartPole environment."
|
290 |
-
]
|
291 |
-
},
|
292 |
-
{
|
293 |
-
"cell_type": "code",
|
294 |
-
"execution_count": null,
|
295 |
-
"metadata": {},
|
296 |
-
"outputs": [],
|
297 |
-
"source": [
|
298 |
-
"agent_prompt = \"\"\"You are an AI trained to play the CartPole game. \n",
|
299 |
-
"Your goal is to balance a pole on a moving cart for as long as possible. \n",
|
300 |
-
"You will receive observations about the cart's position, velocity, pole angle, and angular velocity. \n",
|
301 |
-
"Based on these, you should decide whether to move the cart left or right. \n",
|
302 |
-
"Respond with 'Move left' or 'Move right' for each decision.\"\"\"\n",
|
303 |
-
"\n",
|
304 |
-
"env = LLMCartPoleWrapper(agent_prompt)\n",
|
305 |
-
"model = PPO(\"MlpPolicy\", env, verbose=1)\n",
|
306 |
-
"\n",
|
307 |
-
"logger.info(\"Starting training\")\n",
|
308 |
-
"model.learn(total_timesteps=10000)\n",
|
309 |
-
"logger.info(\"Training completed\")"
|
310 |
-
]
|
311 |
-
},
|
312 |
-
{
|
313 |
-
"cell_type": "markdown",
|
314 |
-
"metadata": {},
|
315 |
-
"source": [
|
316 |
-
"## Test the Trained Agent\n",
|
317 |
-
"\n",
|
318 |
-
"Now that we've trained our agent, let's test it for 2 episodes and see how it performs."
|
319 |
-
]
|
320 |
-
},
|
321 |
-
{
|
322 |
-
"cell_type": "code",
|
323 |
-
"execution_count": null,
|
324 |
-
"metadata": {},
|
325 |
-
"outputs": [],
|
326 |
-
"source": [
|
327 |
-
"test_episodes = 2\n",
|
328 |
-
"results = []\n",
|
329 |
-
"\n",
|
330 |
-
"logger.info(\"Starting test episodes\")\n",
|
331 |
-
"for episode in tqdm(range(test_episodes), desc=\"Test Episodes\"):\n",
|
332 |
-
" obs, _ = env.reset()\n",
|
333 |
-
" done = False\n",
|
334 |
-
" total_reward = 0\n",
|
335 |
-
" episode_length = 0\n",
|
336 |
-
" while not done:\n",
|
337 |
-
" action, _ = model.predict(obs, deterministic=True)\n",
|
338 |
-
" obs, reward, terminated, truncated, _ = env.step(action)\n",
|
339 |
-
" total_reward += reward\n",
|
340 |
-
" episode_length += 1\n",
|
341 |
-
" done = terminated or truncated\n",
|
342 |
-
" \n",
|
343 |
-
" logger.info(f\"Episode {episode + 1} completed. Total reward: {total_reward}, Length: {episode_length}\")\n",
|
344 |
-
" results.append({\"episode\": episode + 1, \"total_reward\": total_reward, \"length\": episode_length})\n",
|
345 |
-
"\n",
|
346 |
-
"# Save results to file\n",
|
347 |
-
"timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
|
348 |
-
"output_file = os.path.join(output_folder, f\"cartpole_results_{timestamp}.json\")\n",
|
349 |
-
"with open(output_file, 'w') as f:\n",
|
350 |
-
" json.dump(results, f, indent=2)\n",
|
351 |
-
"logger.info(f\"Results saved to {output_file}\")\n",
|
352 |
-
"\n",
|
353 |
-
"# Print summary\n",
|
354 |
-
"average_reward = sum(r['total_reward'] for r in results) / len(results)\n",
|
355 |
-
"average_length = sum(r['length'] for r in results) / len(results)\n",
|
356 |
-
"logger.info(f\"Test completed. Average reward: {average_reward:.2f}, Average length: {average_length:.2f}\")"
|
357 |
-
]
|
358 |
-
},
|
359 |
-
{
|
360 |
-
"cell_type": "markdown",
|
361 |
-
"metadata": {},
|
362 |
-
"source": [
|
363 |
-
"## Conclusion\n",
|
364 |
-
"\n",
|
365 |
-
"In this notebook, we've demonstrated how to:\n",
|
366 |
-
"\n",
|
367 |
-
"1. Set up an LLM-based wrapper for the CartPole environment\n",
|
368 |
-
"2. Train a reinforcement learning agent using this environment\n",
|
369 |
-
"3. Test the trained agent and collect performance metrics\n",
|
370 |
-
"\n",
|
371 |
-
"This approach combines the decision-making capabilities of a large language model with the learning process of reinforcement learning, potentially leading to interesting and novel solutions to the CartPole problem.\n",
|
372 |
-
"\n",
|
373 |
-
"Feel free to experiment with different prompts, training parameters, or even different environments to see how this approach can be applied in various scenarios!"
|
374 |
-
]
|
375 |
-
}
|
376 |
-
],
|
377 |
-
"metadata": {
|
378 |
-
"kernelspec": {
|
379 |
-
"display_name": "Python 3",
|
380 |
-
"language": "python",
|
381 |
-
"name": "python3"
|
382 |
-
},
|
383 |
-
"language_info": {
|
384 |
-
"codemirror_mode": {
|
385 |
-
"name": "ipython",
|
386 |
-
"version": 3
|
387 |
-
},
|
388 |
-
"file_extension": ".py",
|
389 |
-
"mimetype": "text/x-python",
|
390 |
-
"name": "python",
|
391 |
-
"nbconvert_exporter": "python",
|
392 |
-
"pygments_lexer": "ipython3",
|
393 |
-
"version": "3.9.6"
|
394 |
-
},
|
395 |
-
"vscode": {
|
396 |
-
"interpreter": {
|
397 |
-
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
398 |
-
}
|
399 |
-
}
|
400 |
-
},
|
401 |
-
"nbformat": 4,
|
402 |
-
"nbformat_minor": 4
|
403 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isopro/examples/workflow_example.ipynb
DELETED
@@ -1,316 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"attachments": {},
|
5 |
-
"cell_type": "markdown",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# Automating a Meme Generator with Workflow Simulation\n",
|
9 |
-
"This notebook demonstrates how to use isopro.workflow_simulation to automate a meme generation workflow. We'll train an agent to:\n",
|
10 |
-
"\n",
|
11 |
-
"1. Navigate a meme generator website\n",
|
12 |
-
"2. Upload images\n",
|
13 |
-
"3. Add captions\n",
|
14 |
-
"4. Generate and download memes\n",
|
15 |
-
"\n",
|
16 |
-
"And do it all automatically!"
|
17 |
-
]
|
18 |
-
},
|
19 |
-
{
|
20 |
-
"attachments": {},
|
21 |
-
"cell_type": "markdown",
|
22 |
-
"metadata": {},
|
23 |
-
"source": [
|
24 |
-
"## Setup\n",
|
25 |
-
"First, let's import our required libraries and set up our environment:"
|
26 |
-
]
|
27 |
-
},
|
28 |
-
{
|
29 |
-
"cell_type": "code",
|
30 |
-
"execution_count": null,
|
31 |
-
"metadata": {},
|
32 |
-
"outputs": [],
|
33 |
-
"source": [
|
34 |
-
"import os\n",
|
35 |
-
"from pathlib import Path\n",
|
36 |
-
"from isopro.workflow_simulation import (\n",
|
37 |
-
" WorkflowSimulator,\n",
|
38 |
-
" AgentConfig,\n",
|
39 |
-
" VisualizationConfig,\n",
|
40 |
-
" ValidationConfig\n",
|
41 |
-
")\n",
|
42 |
-
"import matplotlib.pyplot as plt\n",
|
43 |
-
"from IPython.display import Image, HTML"
|
44 |
-
]
|
45 |
-
},
|
46 |
-
{
|
47 |
-
"attachments": {},
|
48 |
-
"cell_type": "markdown",
|
49 |
-
"metadata": {},
|
50 |
-
"source": [
|
51 |
-
"## Configuration\n",
|
52 |
-
"Let's create a fun configuration for our meme generator automation:"
|
53 |
-
]
|
54 |
-
},
|
55 |
-
{
|
56 |
-
"cell_type": "code",
|
57 |
-
"execution_count": null,
|
58 |
-
"metadata": {},
|
59 |
-
"outputs": [],
|
60 |
-
"source": [
|
61 |
-
"# Create output directory for our memes\n",
|
62 |
-
"output_dir = Path(\"meme_generator.mp4\")\n",
|
63 |
-
"output_dir.mkdir(exist_ok=True)\n",
|
64 |
-
"\n",
|
65 |
-
"# Configure our agent with some fun parameters\n",
|
66 |
-
"agent_config = AgentConfig(\n",
|
67 |
-
" learning_rate=3e-4, # Not too fast, not too slow - just right for meme making\n",
|
68 |
-
" pretrain_epochs=10, # Give it some time to learn the art of memes\n",
|
69 |
-
" use_demonstration=True, # Learn from the meme masters\n",
|
70 |
-
" use_reasoning=True, # Think before you meme\n",
|
71 |
-
" reward_threshold=0.8 # High standards for our memes!\n",
|
72 |
-
")\n",
|
73 |
-
"\n",
|
74 |
-
"# Set up visualization so we can watch the magic happen\n",
|
75 |
-
"viz_config = VisualizationConfig(\n",
|
76 |
-
" show_ui_elements=True, # See what the agent sees\n",
|
77 |
-
" show_cursor=True, # Watch the cursor dance\n",
|
78 |
-
" show_actions=True, # Understand what's happening\n",
|
79 |
-
" save_frames=True, # Save the best moments\n",
|
80 |
-
" real_time_display=True # Watch it live!\n",
|
81 |
-
")\n",
|
82 |
-
"\n",
|
83 |
-
"# Define what makes a successful meme\n",
|
84 |
-
"validation_config = ValidationConfig.from_dict({\n",
|
85 |
-
" \"success_criteria\": [\n",
|
86 |
-
" \"image_uploaded\",\n",
|
87 |
-
" \"captions_added\",\n",
|
88 |
-
" \"meme_generated\",\n",
|
89 |
-
" \"meme_downloaded\"\n",
|
90 |
-
" ],\n",
|
91 |
-
" \"error_tolerance\": 0.1 # Some memes are meant to be a little off...\n",
|
92 |
-
"})"
|
93 |
-
]
|
94 |
-
},
|
95 |
-
{
|
96 |
-
"attachments": {},
|
97 |
-
"cell_type": "markdown",
|
98 |
-
"metadata": {},
|
99 |
-
"source": [
|
100 |
-
"## Recording a Demonstration\n",
|
101 |
-
"Before we can train our agent, we need to show it how to make memes. Here's how we record a demonstration:"
|
102 |
-
]
|
103 |
-
},
|
104 |
-
{
|
105 |
-
"cell_type": "code",
|
106 |
-
"execution_count": null,
|
107 |
-
"metadata": {},
|
108 |
-
"outputs": [],
|
109 |
-
"source": [
|
110 |
-
"# Initialize our simulator\n",
|
111 |
-
"simulator = WorkflowSimulator(\n",
|
112 |
-
" video_path=\"meme_tutorial.mp4\", # Your recorded workflow video\n",
|
113 |
-
" agent_config=agent_config,\n",
|
114 |
-
" viz_config=viz_config,\n",
|
115 |
-
" validation_config=validation_config,\n",
|
116 |
-
" output_dir=str(output_dir)\n",
|
117 |
-
")\n",
|
118 |
-
"\n",
|
119 |
-
"# Let's see what our demonstration video looks like\n",
|
120 |
-
"display(HTML(f\"\"\"\n",
|
121 |
-
"<video width=\"640\" height=\"480\" controls>\n",
|
122 |
-
" <source src=\"meme_tutorial.mp4\" type=\"video/mp4\">\n",
|
123 |
-
" Your browser does not support the video tag.\n",
|
124 |
-
"</video>\n",
|
125 |
-
"\"\"\"))"
|
126 |
-
]
|
127 |
-
},
|
128 |
-
{
|
129 |
-
"attachments": {},
|
130 |
-
"cell_type": "markdown",
|
131 |
-
"metadata": {},
|
132 |
-
"source": [
|
133 |
-
"## Training Our Meme Master\n",
|
134 |
-
"Now that we have our demonstration, let's train our agent to become a meme master:"
|
135 |
-
]
|
136 |
-
},
|
137 |
-
{
|
138 |
-
"cell_type": "code",
|
139 |
-
"execution_count": null,
|
140 |
-
"metadata": {},
|
141 |
-
"outputs": [],
|
142 |
-
"source": [
|
143 |
-
"# Time to learn!\n",
|
144 |
-
"print(\"🎓 Training our agent to become a meme master...\")\n",
|
145 |
-
"training_results = simulator.train_agents()\n",
|
146 |
-
"\n",
|
147 |
-
"# Show the learning progress\n",
|
148 |
-
"plt.figure(figsize=(10, 5))\n",
|
149 |
-
"plt.plot(training_results['episode_rewards'])\n",
|
150 |
-
"plt.title(\"Learning to Meme\")\n",
|
151 |
-
"plt.xlabel(\"Episode\")\n",
|
152 |
-
"plt.ylabel(\"Reward\")\n",
|
153 |
-
"plt.show()\n",
|
154 |
-
"\n",
|
155 |
-
"# Print some fun stats\n",
|
156 |
-
"print(\"\\n🎯 Training Results:\")\n",
|
157 |
-
"print(f\"Average Reward: {training_results['mean_reward']:.2f}\")\n",
|
158 |
-
"print(f\"Success Rate: {training_results['success_rate']*100:.1f}%\")\n",
|
159 |
-
"print(f\"Best Episode Reward: {max(training_results['episode_rewards']):.2f}\")"
|
160 |
-
]
|
161 |
-
},
|
162 |
-
{
|
163 |
-
"attachments": {},
|
164 |
-
"cell_type": "markdown",
|
165 |
-
"metadata": {},
|
166 |
-
"source": [
|
167 |
-
"## Unleashing the Meme Generator\n",
|
168 |
-
"Let's use our trained agent to generate some memes!"
|
169 |
-
]
|
170 |
-
},
|
171 |
-
{
|
172 |
-
"cell_type": "code",
|
173 |
-
"execution_count": null,
|
174 |
-
"metadata": {},
|
175 |
-
"outputs": [],
|
176 |
-
"source": [
|
177 |
-
"# Prepare some fun meme templates and captions\n",
|
178 |
-
"meme_tasks = [\n",
|
179 |
-
" {\n",
|
180 |
-
" \"template\": \"distracted_boyfriend.jpg\",\n",
|
181 |
-
" \"captions\": [\n",
|
182 |
-
" \"Python\",\n",
|
183 |
-
" \"Me\",\n",
|
184 |
-
" \"JavaScript\"\n",
|
185 |
-
" ]\n",
|
186 |
-
" },\n",
|
187 |
-
" {\n",
|
188 |
-
" \"template\": \"drake.jpg\",\n",
|
189 |
-
" \"captions\": [\n",
|
190 |
-
" \"Writing code without comments\",\n",
|
191 |
-
" \"Writing comments without code\"\n",
|
192 |
-
" ]\n",
|
193 |
-
" },\n",
|
194 |
-
" {\n",
|
195 |
-
" \"template\": \"expanding_brain.jpg\",\n",
|
196 |
-
" \"captions\": [\n",
|
197 |
-
" \"print('debug')\",\n",
|
198 |
-
" \"console.log('debug')\",\n",
|
199 |
-
" \"Using a debugger\",\n",
|
200 |
-
" \"Adding random print statements and hoping for the best\"\n",
|
201 |
-
" ]\n",
|
202 |
-
" }\n",
|
203 |
-
"]\n",
|
204 |
-
"\n",
|
205 |
-
"# Generate memes!\n",
|
206 |
-
"print(\"🎨 Generating memes...\")\n",
|
207 |
-
"for i, task in enumerate(meme_tasks):\n",
|
208 |
-
" print(f\"\\n✨ Creating meme {i+1}/{len(meme_tasks)}\")\n",
|
209 |
-
" \n",
|
210 |
-
" # Let our agent work its magic\n",
|
211 |
-
" observation = simulator.reset()\n",
|
212 |
-
" done = False\n",
|
213 |
-
" \n",
|
214 |
-
" while not done:\n",
|
215 |
-
" action, _ = simulator.predict(observation)\n",
|
216 |
-
" observation, reward, done, info = simulator.step(action)\n",
|
217 |
-
" \n",
|
218 |
-
" if info.get('meme_generated'):\n",
|
219 |
-
" print(\"🎉 Meme created successfully!\")\n",
|
220 |
-
" \n",
|
221 |
-
" # Display the generated meme\n",
|
222 |
-
" meme_path = output_dir / f\"meme_{i+1}.png\"\n",
|
223 |
-
" display(Image(filename=str(meme_path)))"
|
224 |
-
]
|
225 |
-
},
|
226 |
-
{
|
227 |
-
"attachments": {},
|
228 |
-
"cell_type": "markdown",
|
229 |
-
"metadata": {},
|
230 |
-
"source": [
|
231 |
-
"## Analyzing Our Meme Factory\n",
|
232 |
-
"Let's look at some fun statistics about our meme generation:"
|
233 |
-
]
|
234 |
-
},
|
235 |
-
{
|
236 |
-
"cell_type": "code",
|
237 |
-
"execution_count": null,
|
238 |
-
"metadata": {},
|
239 |
-
"outputs": [],
|
240 |
-
"source": [
|
241 |
-
"# Get evaluation results\n",
|
242 |
-
"eval_results = simulator.evaluate_agents()\n",
|
243 |
-
"\n",
|
244 |
-
"# Create a fun visualization of our meme factory stats\n",
|
245 |
-
"stats = {\n",
|
246 |
-
" \"Memes Generated\": len(meme_tasks),\n",
|
247 |
-
" \"Success Rate\": f\"{eval_results['success_rate']*100:.1f}%\",\n",
|
248 |
-
" \"Average Generation Time\": f\"{eval_results['mean_length']:.1f}s\",\n",
|
249 |
-
" \"Quality Score\": f\"{eval_results['mean_reward']:.2f}/1.0\"\n",
|
250 |
-
"}\n",
|
251 |
-
"\n",
|
252 |
-
"print(\"📊 Meme Factory Statistics:\")\n",
|
253 |
-
"for stat, value in stats.items():\n",
|
254 |
-
" print(f\"{stat}: {value}\")\n",
|
255 |
-
"\n",
|
256 |
-
"# Plot a fun pie chart of time spent on each step\n",
|
257 |
-
"steps = [\n",
|
258 |
-
" \"Finding Templates\",\n",
|
259 |
-
" \"Adding Captions\",\n",
|
260 |
-
" \"Adjusting Layout\",\n",
|
261 |
-
" \"Generating Meme\",\n",
|
262 |
-
" \"Saving Masterpiece\"\n",
|
263 |
-
"]\n",
|
264 |
-
"times = [15, 30, 25, 20, 10] # Example percentages\n",
|
265 |
-
"\n",
|
266 |
-
"plt.figure(figsize=(10, 8))\n",
|
267 |
-
"plt.pie(times, labels=steps, autopct='%1.1f%%', \n",
|
268 |
-
" colors=['#FF9999', '#66B2FF', '#99FF99', '#FFCC99', '#FF99CC'])\n",
|
269 |
-
"plt.title(\"Time Spent Making Memes\")\n",
|
270 |
-
"plt.show()"
|
271 |
-
]
|
272 |
-
},
|
273 |
-
{
|
274 |
-
"attachments": {},
|
275 |
-
"cell_type": "markdown",
|
276 |
-
"metadata": {},
|
277 |
-
"source": [
|
278 |
-
"# Conclusion\n",
|
279 |
-
"Congratulations! You've successfully created an automated meme factory! 🎉\n",
|
280 |
-
"Some fun things we learned:\n",
|
281 |
-
"\n",
|
282 |
-
"- Our agent can learn to navigate UI elements and create memes\n",
|
283 |
-
"- The power of combining computer vision with reinforcement learning\n",
|
284 |
-
"- How to make our code more entertaining with emojis 😄\n",
|
285 |
-
"\n",
|
286 |
-
"## Next Steps\n",
|
287 |
-
"Want to make your meme factory even better? Here are some fun ideas:\n",
|
288 |
-
"\n",
|
289 |
-
"- Train on different meme templates\n",
|
290 |
-
"- Add text effects and styling\n",
|
291 |
-
"- Create a meme recommendation system\n",
|
292 |
-
"- Build a Discord bot using this automation\n",
|
293 |
-
"- Generate captions using Claude"
|
294 |
-
]
|
295 |
-
}
|
296 |
-
],
|
297 |
-
"metadata": {
|
298 |
-
"kernelspec": {
|
299 |
-
"display_name": "Python 3",
|
300 |
-
"language": "python",
|
301 |
-
"name": "python3"
|
302 |
-
},
|
303 |
-
"language_info": {
|
304 |
-
"name": "python",
|
305 |
-
"version": "3.12.7"
|
306 |
-
},
|
307 |
-
"orig_nbformat": 4,
|
308 |
-
"vscode": {
|
309 |
-
"interpreter": {
|
310 |
-
"hash": "7500c3e1c7c786e4ba1e4b4eb7588219b4e35d5153674f92eb3a82672b534f6e"
|
311 |
-
}
|
312 |
-
}
|
313 |
-
},
|
314 |
-
"nbformat": 4,
|
315 |
-
"nbformat_minor": 2
|
316 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|