AUXteam commited on
Commit
e5ab379
·
verified ·
1 Parent(s): 1596168

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -3,8 +3,9 @@ title: Tiny Factory
3
  emoji: 💻
4
  colorFrom: yellow
5
  colorTo: gray
6
- sdk: docker
7
- app_port: 7860
 
8
  pinned: false
9
  ---
10
 
 
3
  emoji: 💻
4
  colorFrom: yellow
5
  colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 6.3.0
8
+ app_file: app.py
9
  pinned: false
10
  ---
11
 
app.py CHANGED
@@ -2,13 +2,46 @@ import sys
2
  import os
3
  import gradio as gr
4
  import json
5
- import random
6
  from tinytroupe.factory import TinyPersonFactory
7
- from tinytroupe.simulation_manager import SimulationConfig
8
- from tinytroupe.content_generation import ContentVariantGenerator
9
- from tinytroupe.agent_types import Content
10
- from api.main import app, simulation_manager
11
- import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # --- CHANGE 1: The function now accepts an optional API key. ---
14
  def generate_personas(business_description, customer_profile, num_personas, blablador_api_key=None):
@@ -17,197 +50,391 @@ def generate_personas(business_description, customer_profile, num_personas, blab
17
  It prioritizes the API key passed as an argument, but falls back to the
18
  environment variable if none is provided (for UI use).
19
  """
 
 
20
  api_key_to_use = blablador_api_key or os.getenv("BLABLADOR_API_KEY")
21
 
22
  if not api_key_to_use:
23
  return {"error": "BLABLADOR_API_KEY not found. Please provide it in your API call or set it as a secret in the Space settings."}
24
 
 
25
  original_key = os.getenv("BLABLADOR_API_KEY")
26
 
27
  try:
 
 
28
  os.environ["BLABLADOR_API_KEY"] = api_key_to_use
 
29
  num_personas = int(num_personas)
 
30
  factory = TinyPersonFactory(
31
  context=business_description,
32
  sampling_space_description=customer_profile,
33
  total_population_size=num_personas
34
  )
 
35
  people = factory.generate_people(number_of_people=num_personas, parallelize=False)
36
  personas_data = [person._persona for person in people]
 
 
 
 
 
 
 
37
  return personas_data
 
38
  except Exception as e:
39
  return {"error": str(e)}
 
40
  finally:
 
 
 
41
  if original_key is None:
 
42
  if "BLABLADOR_API_KEY" in os.environ:
43
  del os.environ["BLABLADOR_API_KEY"]
44
  else:
 
45
  os.environ["BLABLADOR_API_KEY"] = original_key
46
 
47
- def create_simulation_ui(name, persona_count, network_type):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
  config = SimulationConfig(name=name, persona_count=int(persona_count), network_type=network_type)
50
- sim = simulation_manager.create_simulation(config)
51
- return {"status": "Simulation created", "id": sim.id, "persona_count": len(sim.personas)}
 
 
 
 
 
52
  except Exception as e:
53
  return {"error": str(e)}
54
 
55
- def run_simulation_ui(simulation_id, content_text):
 
 
 
 
56
  try:
57
- if simulation_id not in simulation_manager.simulations:
58
- return {"error": "Simulation not found"}
59
- result = simulation_manager.run_simulation(simulation_id, content_text)
60
  return {
61
- "simulation_id": simulation_id,
62
  "total_reach": result.total_reach,
63
- "engagements": result.engagements
 
 
 
 
 
64
  }
65
  except Exception as e:
66
  return {"error": str(e)}
67
 
68
- def predict_engagement_ui(persona_name, content_text, simulation_id):
 
 
 
 
69
  try:
70
- if simulation_id not in simulation_manager.simulations:
71
- return {"error": "Simulation not found"}
72
- sim = simulation_manager.simulations[simulation_id]
73
- persona = next((p for p in sim.personas if p.name == persona_name), None)
74
- if not persona:
75
- return {"error": "Persona not found"}
 
 
 
 
 
 
 
 
76
 
77
- content_obj = Content(text=content_text, content_type="post", topics=[], length=len(content_text), tone="")
78
- prob = sim.world._simulate_engagement_decision(persona, content_obj, 0)
79
- return {
80
- "persona": persona_name,
81
- "will_engage": prob.engaged,
82
- "probability": prob.probability,
83
- "comment": prob.comment
84
  }
 
 
 
 
 
 
 
 
 
85
  except Exception as e:
86
  return {"error": str(e)}
87
 
88
- def generate_variants_ui(original_content, num_variants):
 
 
 
 
89
  try:
90
- variants = simulation_manager.variant_generator.generate_variants(original_content, int(num_variants))
91
- return [v.__dict__ for v in variants]
92
  except Exception as e:
93
  return {"error": str(e)}
94
 
95
- def get_metrics_ui(simulation_id):
 
 
 
 
96
  try:
97
- if simulation_id not in simulation_manager.simulations:
98
- return {"error": "Simulation not found"}
99
- sim = simulation_manager.simulations[simulation_id]
100
- from tinytroupe.network_analysis import NetworkAnalyzer
101
- metrics = NetworkAnalyzer.calculate_centrality_metrics(sim.network)
102
- influencers = NetworkAnalyzer.identify_key_influencers(sim.network)
103
- return {
104
- "density": NetworkAnalyzer.calculate_density(sim.network),
105
- "key_influencers": influencers,
106
- "centrality_metrics": metrics
107
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
  return {"error": str(e)}
110
 
111
 
112
  with gr.Blocks() as demo:
113
- gr.Markdown("<h1>Tiny Factory & Artificial Societies</h1>")
114
-
115
- with gr.Tabs():
116
- with gr.Tab("Persona Generation"):
117
- with gr.Row():
118
- with gr.Column():
119
- business_description_input = gr.Textbox(label="What is your business about?", lines=5)
120
- customer_profile_input = gr.Textbox(label="Information about your customer profile", lines=5)
121
- num_personas_gen_input = gr.Number(label="Number of personas to generate", value=1, minimum=1, step=1)
122
- blablador_api_key_input = gr.Textbox(label="Blablador API Key (for API client use)", visible=False)
123
- generate_personas_button = gr.Button("Generate Personas")
124
- with gr.Column():
125
- gen_personas_output = gr.JSON(label="Generated Personas")
126
 
127
- generate_personas_button.click(
128
- fn=generate_personas,
129
- inputs=[business_description_input, customer_profile_input, num_personas_gen_input, blablador_api_key_input],
130
- outputs=gen_personas_output,
131
- api_name="generate_personas"
132
  )
133
 
134
- with gr.Tab("Social Simulation"):
135
- with gr.Row():
136
- with gr.Column():
137
- sim_name_input = gr.Textbox(label="Simulation Name", value="My Social Simulation")
138
- sim_persona_count = gr.Number(label="Number of Personas", value=10, minimum=1, step=1)
139
- sim_network_type = gr.Dropdown(label="Network Type", choices=["scale_free", "professional"], value="scale_free")
140
- create_sim_button = gr.Button("Create Simulation")
141
-
142
- sim_content_input = gr.Textbox(label="Content to Test", lines=3)
143
- sim_id_run_input = gr.Textbox(label="Simulation ID (to run)")
144
- run_sim_button = gr.Button("Run Content Spread Simulation")
145
- with gr.Column():
146
- sim_output = gr.JSON(label="Simulation Status/Results")
147
-
148
- create_sim_button.click(
149
- fn=create_simulation_ui,
150
- inputs=[sim_name_input, sim_persona_count, sim_network_type],
151
- outputs=sim_output,
152
- api_name="create_simulation"
153
- )
154
- run_sim_button.click(
155
- fn=run_simulation_ui,
156
- inputs=[sim_id_run_input, sim_content_input],
157
- outputs=sim_output,
158
- api_name="run_simulation"
159
- )
160
 
161
- with gr.Tab("Engagement Prediction"):
162
- with gr.Row():
163
- with gr.Column():
164
- pred_persona_name = gr.Textbox(label="Persona Name")
165
- pred_content = gr.Textbox(label="Content Text", lines=3)
166
- pred_sim_id = gr.Textbox(label="Simulation ID")
167
- predict_button = gr.Button("Predict Engagement")
168
- with gr.Column():
169
- pred_output = gr.JSON(label="Prediction Result")
170
-
171
- predict_button.click(
172
- fn=predict_engagement_ui,
173
- inputs=[pred_persona_name, pred_content, pred_sim_id],
174
- outputs=pred_output,
175
- api_name="predict_engagement"
176
- )
177
 
178
- with gr.Tab("Content Engine"):
179
- with gr.Row():
180
- with gr.Column():
181
- cont_original = gr.Textbox(label="Original Content", lines=5)
182
- cont_num_variants = gr.Number(label="Number of Variants", value=5, minimum=1)
183
- generate_variants_button = gr.Button("Generate Variants")
184
- with gr.Column():
185
- cont_output = gr.JSON(label="Content Variants")
186
-
187
- generate_variants_button.click(
188
- fn=generate_variants_ui,
189
- inputs=[cont_original, cont_num_variants],
190
- outputs=cont_output,
191
- api_name="generate_content_variants"
192
- )
193
 
194
- with gr.Tab("Network Analytics"):
195
- with gr.Row():
196
- with gr.Column():
197
- metrics_sim_id = gr.Textbox(label="Simulation ID")
198
- get_metrics_button = gr.Button("Get Network Metrics")
199
- with gr.Column():
200
- metrics_output = gr.JSON(label="Network Analytics")
201
-
202
- get_metrics_button.click(
203
- fn=get_metrics_ui,
204
- inputs=[metrics_sim_id],
205
- outputs=metrics_output,
206
- api_name="get_network_metrics"
207
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- # Mount Gradio app to FastAPI app imported from api.main
210
- app = gr.mount_gradio_app(app, demo, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  if __name__ == "__main__":
213
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import os
3
  import gradio as gr
4
  import json
 
5
  from tinytroupe.factory import TinyPersonFactory
6
+ from tinytroupe.utils.semantics import select_best_persona
7
+ from tinytroupe.simulation_manager import SimulationManager, SimulationConfig
8
+ from tinytroupe.agent.social_types import Content
9
+ from huggingface_hub import hf_hub_download, upload_file
10
+
11
+ HF_TOKEN = os.getenv("HF_TOKEN") # Ensure this is set in Space secrets
12
+ REPO_ID = "harvesthealth/tiny_factory"
13
+ PERSONA_BASE_FILE = "persona_base.json"
14
+
15
+ simulation_manager = SimulationManager()
16
+
17
+ def load_persona_base():
18
+ if not HF_TOKEN:
19
+ print("HF_TOKEN not found, persistence disabled.")
20
+ return []
21
+ try:
22
+ path = hf_hub_download(repo_id=REPO_ID, filename=PERSONA_BASE_FILE, repo_type="space", token=HF_TOKEN)
23
+ with open(path, 'r', encoding='utf-8') as f:
24
+ return json.load(f)
25
+ except Exception as e:
26
+ print(f"Error loading persona base: {e}")
27
+ return []
28
+
29
+ def save_persona_base(personas):
30
+ if not HF_TOKEN:
31
+ print("HF_TOKEN not found, skipping upload.")
32
+ return
33
+ with open(PERSONA_BASE_FILE, 'w', encoding='utf-8') as f:
34
+ json.dump(personas, f, indent=4)
35
+ try:
36
+ upload_file(
37
+ path_or_fileobj=PERSONA_BASE_FILE,
38
+ path_in_repo=PERSONA_BASE_FILE,
39
+ repo_id=REPO_ID,
40
+ repo_type="space",
41
+ token=HF_TOKEN
42
+ )
43
+ except Exception as e:
44
+ print(f"Error saving persona base to Hub: {e}")
45
 
46
  # --- CHANGE 1: The function now accepts an optional API key. ---
47
  def generate_personas(business_description, customer_profile, num_personas, blablador_api_key=None):
 
50
  It prioritizes the API key passed as an argument, but falls back to the
51
  environment variable if none is provided (for UI use).
52
  """
53
+ # --- CHANGE 2: Logic to determine which key to use. ---
54
+ # Use the key from the API call if provided, otherwise get it from the Space secrets.
55
  api_key_to_use = blablador_api_key or os.getenv("BLABLADOR_API_KEY")
56
 
57
  if not api_key_to_use:
58
  return {"error": "BLABLADOR_API_KEY not found. Please provide it in your API call or set it as a secret in the Space settings."}
59
 
60
+ # Store the original state of the environment variable, if it exists
61
  original_key = os.getenv("BLABLADOR_API_KEY")
62
 
63
  try:
64
+ # --- CHANGE 3: Securely set the correct environment variable for this request. ---
65
+ # The underlying tinytroupe library will look for this variable.
66
  os.environ["BLABLADOR_API_KEY"] = api_key_to_use
67
+
68
  num_personas = int(num_personas)
69
+
70
  factory = TinyPersonFactory(
71
  context=business_description,
72
  sampling_space_description=customer_profile,
73
  total_population_size=num_personas
74
  )
75
+
76
  people = factory.generate_people(number_of_people=num_personas, parallelize=False)
77
  personas_data = [person._persona for person in people]
78
+
79
+ # --- NEW: Update the Tresor ---
80
+ current_base = load_persona_base()
81
+ current_base.extend(personas_data)
82
+ save_persona_base(current_base)
83
+ # ------------------------------
84
+
85
  return personas_data
86
+
87
  except Exception as e:
88
  return {"error": str(e)}
89
+
90
  finally:
91
+ # --- CHANGE 4: A robust cleanup using a 'finally' block. ---
92
+ # This ensures the environment is always restored to its original state,
93
+ # whether the function succeeds or fails.
94
  if original_key is None:
95
+ # If the variable didn't exist originally, remove it.
96
  if "BLABLADOR_API_KEY" in os.environ:
97
  del os.environ["BLABLADOR_API_KEY"]
98
  else:
99
+ # If it existed, restore its original value.
100
  os.environ["BLABLADOR_API_KEY"] = original_key
101
 
102
+
103
+ def find_best_persona(criteria):
104
+ """
105
+ Loads the persona base and finds the best matching persona based on criteria.
106
+ """
107
+ personas = load_persona_base()
108
+ if not personas:
109
+ return {"error": "Persona base is empty. Generate some personas first!"}
110
+
111
+ try:
112
+ # select_best_persona uses LLM to find the best index
113
+ idx = select_best_persona(criteria=criteria, personas=personas)
114
+
115
+ try:
116
+ idx = int(idx)
117
+ except (ValueError, TypeError):
118
+ return {"error": f"LLM returned an invalid index: {idx}"}
119
+
120
+ if idx >= 0 and idx < len(personas):
121
+ return personas[idx]
122
+ else:
123
+ return {"error": f"No matching persona found for criteria: {criteria}"}
124
+ except Exception as e:
125
+ return {"error": f"Error during persona matching: {str(e)}"}
126
+
127
+
128
+ def generate_social_network_api(name, persona_count, network_type, focus_group_name=None):
129
+ """
130
+ Gradio API endpoint for generating a social network.
131
+ """
132
  try:
133
  config = SimulationConfig(name=name, persona_count=int(persona_count), network_type=network_type)
134
+ simulation = simulation_manager.create_simulation(config, focus_group_name=focus_group_name)
135
+ return {
136
+ "simulation_id": simulation.id,
137
+ "name": simulation.config.name,
138
+ "persona_count": len(simulation.personas),
139
+ "network_metrics": simulation.network.get_metrics()
140
+ }
141
  except Exception as e:
142
  return {"error": str(e)}
143
 
144
+
145
+ def predict_engagement_api(simulation_id, content_text, format="text"):
146
+ """
147
+ Gradio API endpoint for predicting engagement.
148
+ """
149
  try:
150
+ content = Content(text=content_text, format=format)
151
+ result = simulation_manager.run_simulation(simulation_id, content)
 
152
  return {
 
153
  "total_reach": result.total_reach,
154
+ "expected_likes": result.expected_likes,
155
+ "expected_comments": result.expected_comments,
156
+ "expected_shares": result.expected_shares,
157
+ "execution_time": result.execution_time,
158
+ "avg_sentiment": result.avg_sentiment,
159
+ "feedback_summary": result.feedback_summary
160
  }
161
  except Exception as e:
162
  return {"error": str(e)}
163
 
164
+
165
+ def start_simulation_async_api(simulation_id, content_text, format="text"):
166
+ """
167
+ Starts a simulation in the background.
168
+ """
169
  try:
170
+ content = Content(text=content_text, format=format)
171
+ simulation_manager.run_simulation(simulation_id, content, background=True)
172
+ return {"status": "started", "simulation_id": simulation_id}
173
+ except Exception as e:
174
+ return {"error": str(e)}
175
+
176
+
177
+ def get_simulation_status_api(simulation_id):
178
+ """
179
+ Checks the status and progress of a simulation.
180
+ """
181
+ try:
182
+ sim = simulation_manager.get_simulation(simulation_id)
183
+ if not sim: return {"error": "Simulation not found"}
184
 
185
+ status_data = {
186
+ "status": sim.status,
187
+ "progress": sim.progress
 
 
 
 
188
  }
189
+
190
+ if sim.status == "completed" and sim.last_result:
191
+ status_data["result"] = {
192
+ "total_reach": sim.last_result.total_reach,
193
+ "expected_likes": sim.last_result.expected_likes,
194
+ "avg_sentiment": sim.last_result.avg_sentiment
195
+ }
196
+
197
+ return status_data
198
  except Exception as e:
199
  return {"error": str(e)}
200
 
201
+
202
+ def send_chat_message_api(simulation_id, sender, message):
203
+ """
204
+ Sends a message to the simulation chat.
205
+ """
206
  try:
207
+ return simulation_manager.send_chat_message(simulation_id, sender, message)
 
208
  except Exception as e:
209
  return {"error": str(e)}
210
 
211
+
212
+ def get_chat_history_api(simulation_id):
213
+ """
214
+ Gets the chat history for a simulation.
215
+ """
216
  try:
217
+ return simulation_manager.get_chat_history(simulation_id)
218
+ except Exception as e:
219
+ return {"error": str(e)}
220
+
221
+
222
+ def generate_variants_api(content_text, num_variants):
223
+ """
224
+ Gradio API endpoint for generating content variants.
225
+ """
226
+ try:
227
+ variants = simulation_manager.variant_generator.generate_variants(content_text, num_variants=int(num_variants))
228
+ return [{"text": v.text, "strategy": v.strategy} for v in variants]
229
+ except Exception as e:
230
+ return {"error": str(e)}
231
+
232
+
233
+ def list_simulations_api():
234
+ """
235
+ Gradio API endpoint for listing simulations.
236
+ """
237
+ try:
238
+ return simulation_manager.list_simulations()
239
+ except Exception as e:
240
+ return {"error": str(e)}
241
+
242
+
243
+ def list_personas_api(simulation_id):
244
+ """
245
+ Gradio API endpoint for listing personas in a simulation.
246
+ """
247
+ try:
248
+ return simulation_manager.list_personas(simulation_id)
249
+ except Exception as e:
250
+ return {"error": str(e)}
251
+
252
+
253
+ def get_persona_api(simulation_id, persona_name):
254
+ """
255
+ Gradio API endpoint for getting persona details.
256
+ """
257
+ try:
258
+ return simulation_manager.get_persona(simulation_id, persona_name)
259
+ except Exception as e:
260
+ return {"error": str(e)}
261
+
262
+
263
+ def delete_simulation_api(simulation_id):
264
+ """
265
+ Gradio API endpoint for deleting a simulation.
266
+ """
267
+ try:
268
+ success = simulation_manager.delete_simulation(simulation_id)
269
+ return {"success": success}
270
+ except Exception as e:
271
+ return {"error": str(e)}
272
+
273
+
274
+ def export_simulation_api(simulation_id):
275
+ """
276
+ Gradio API endpoint for exporting a simulation.
277
+ """
278
+ try:
279
+ return simulation_manager.export_simulation(simulation_id)
280
+ except Exception as e:
281
+ return {"error": str(e)}
282
+
283
+
284
+ def list_focus_groups_api():
285
+ """
286
+ Gradio API endpoint for listing focus groups.
287
+ """
288
+ try:
289
+ return simulation_manager.list_focus_groups()
290
+ except Exception as e:
291
+ return {"error": str(e)}
292
+
293
+
294
+ def save_focus_group_api(name, simulation_id):
295
+ """
296
+ Gradio API endpoint for saving a focus group from a simulation.
297
+ """
298
+ try:
299
+ sim = simulation_manager.get_simulation(simulation_id)
300
+ if not sim: return {"error": "Simulation not found"}
301
+ simulation_manager.save_focus_group(name, sim.personas)
302
+ return {"status": "success", "name": name}
303
  except Exception as e:
304
  return {"error": str(e)}
305
 
306
 
307
  with gr.Blocks() as demo:
308
+ gr.Markdown("<h1>Tiny Persona Generator</h1>")
309
+ with gr.Row():
310
+ with gr.Column():
311
+ business_description_input = gr.Textbox(label="What is your business about?", lines=5)
312
+ customer_profile_input = gr.Textbox(label="Information about your customer profile", lines=5)
313
+ num_personas_input = gr.Number(label="Number of personas to generate", value=1, minimum=1, step=1)
 
 
 
 
 
 
 
314
 
315
+ # --- CHANGE 5: The API key input is now INVISIBLE. ---
316
+ # It still exists, so the API endpoint is created, but it's hidden from UI users.
317
+ blablador_api_key_input = gr.Textbox(
318
+ label="Blablador API Key (for API client use)",
319
+ visible=False
320
  )
321
 
322
+ generate_button = gr.Button("Generate Personas")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ gr.Markdown("---")
325
+ gr.Markdown("<h3>Search Tresor</h3>")
326
+ criteria_input = gr.Textbox(label="Criteria to find best matching persona", lines=2)
327
+ find_button = gr.Button("Find Best Persona in Tresor")
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ with gr.Column():
330
+ output_json = gr.JSON(label="Output (Generated or Matched Persona)")
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ generate_button.click(
333
+ fn=generate_personas,
334
+ # --- CHANGE 6: Pass the invisible textbox to the function. ---
335
+ inputs=[business_description_input, customer_profile_input, num_personas_input, blablador_api_key_input],
336
+ outputs=output_json,
337
+ api_name="generate_personas"
338
+ )
339
+
340
+ find_button.click(
341
+ fn=find_best_persona,
342
+ inputs=[criteria_input],
343
+ outputs=output_json,
344
+ api_name="find_best_persona"
345
+ )
346
+
347
+ # Invisible components to expose API endpoints
348
+ # These won't be seen by regular UI users but will be available via /api
349
+ with gr.Tab("Social Network API", visible=False):
350
+ api_net_name = gr.Textbox(label="Network Name")
351
+ api_net_count = gr.Number(label="Persona Count", value=10)
352
+ api_net_type = gr.Dropdown(choices=["scale_free", "small_world"], label="Network Type")
353
+ api_net_focus = gr.Textbox(label="Focus Group Name (optional)")
354
+ api_net_btn = gr.Button("Generate Network")
355
+ api_net_out = gr.JSON()
356
+ api_net_btn.click(generate_social_network_api, inputs=[api_net_name, api_net_count, api_net_type, api_net_focus], outputs=api_net_out, api_name="generate_social_network")
357
+
358
+ with gr.Tab("Engagement Prediction API", visible=False):
359
+ api_pred_sim_id = gr.Textbox(label="Simulation ID")
360
+ api_pred_content = gr.Textbox(label="Content Text")
361
+ api_pred_format = gr.Textbox(label="Format", value="text")
362
+ api_pred_btn = gr.Button("Predict Engagement")
363
+ api_pred_out = gr.JSON()
364
+ api_pred_btn.click(predict_engagement_api, inputs=[api_pred_sim_id, api_pred_content, api_pred_format], outputs=api_pred_out, api_name="predict_engagement")
365
+
366
+ with gr.Tab("Async Simulation API", visible=False):
367
+ api_async_sim_id = gr.Textbox(label="Simulation ID")
368
+ api_async_content = gr.Textbox(label="Content Text")
369
+ api_async_format = gr.Textbox(label="Format", value="text")
370
+ api_async_btn = gr.Button("Start Simulation")
371
+ api_async_out = gr.JSON()
372
+ api_async_btn.click(start_simulation_async_api, inputs=[api_async_sim_id, api_async_content, api_async_format], outputs=api_async_out, api_name="start_simulation_async")
373
+
374
+ api_status_id = gr.Textbox(label="Simulation ID")
375
+ api_status_btn = gr.Button("Check Status")
376
+ api_status_out = gr.JSON()
377
+ api_status_btn.click(get_simulation_status_api, inputs=[api_status_id], outputs=api_status_out, api_name="get_simulation_status")
378
+
379
+ with gr.Tab("Chat API", visible=False):
380
+ api_chat_sim_id = gr.Textbox(label="Simulation ID")
381
+ api_chat_sender = gr.Textbox(label="Sender", value="User")
382
+ api_chat_msg = gr.Textbox(label="Message")
383
+ api_chat_send_btn = gr.Button("Send Message")
384
+ api_chat_send_out = gr.JSON()
385
+ api_chat_send_btn.click(send_chat_message_api, inputs=[api_chat_sim_id, api_chat_sender, api_chat_msg], outputs=api_chat_send_out, api_name="send_chat_message")
386
+
387
+ api_chat_hist_btn = gr.Button("Get History")
388
+ api_chat_hist_out = gr.JSON()
389
+ api_chat_hist_btn.click(get_chat_history_api, inputs=[api_chat_sim_id], outputs=api_chat_hist_out, api_name="get_chat_history")
390
+
391
+ with gr.Tab("Content Variants API", visible=False):
392
+ api_var_content = gr.Textbox(label="Original Content")
393
+ api_var_count = gr.Number(label="Number of Variants", value=5)
394
+ api_var_btn = gr.Button("Generate Variants")
395
+ api_var_out = gr.JSON()
396
+ api_var_btn.click(generate_variants_api, inputs=[api_var_content, api_var_count], outputs=api_var_out, api_name="generate_variants")
397
+
398
+ with gr.Tab("List Simulations API", visible=False):
399
+ api_list_sim_btn = gr.Button("List Simulations")
400
+ api_list_sim_out = gr.JSON()
401
+ api_list_sim_btn.click(list_simulations_api, outputs=api_list_sim_out, api_name="list_simulations")
402
+
403
+ with gr.Tab("List Personas API", visible=False):
404
+ api_list_per_sim_id = gr.Textbox(label="Simulation ID")
405
+ api_list_per_btn = gr.Button("List Personas")
406
+ api_list_per_out = gr.JSON()
407
+ api_list_per_btn.click(list_personas_api, inputs=[api_list_per_sim_id], outputs=api_list_per_out, api_name="list_personas")
408
 
409
+ with gr.Tab("Get Persona API", visible=False):
410
+ api_get_per_sim_id = gr.Textbox(label="Simulation ID")
411
+ api_get_per_name = gr.Textbox(label="Persona Name")
412
+ api_get_per_btn = gr.Button("Get Persona")
413
+ api_get_per_out = gr.JSON()
414
+ api_get_per_btn.click(get_persona_api, inputs=[api_get_per_sim_id, api_get_per_name], outputs=api_get_per_out, api_name="get_persona")
415
+
416
+ with gr.Tab("Delete Simulation API", visible=False):
417
+ api_del_sim_id = gr.Textbox(label="Simulation ID")
418
+ api_del_btn = gr.Button("Delete Simulation")
419
+ api_del_out = gr.JSON()
420
+ api_del_btn.click(delete_simulation_api, inputs=[api_del_sim_id], outputs=api_del_out, api_name="delete_simulation")
421
+
422
+ with gr.Tab("Export Simulation API", visible=False):
423
+ api_exp_sim_id = gr.Textbox(label="Simulation ID")
424
+ api_exp_btn = gr.Button("Export Simulation")
425
+ api_exp_out = gr.JSON()
426
+ api_exp_btn.click(export_simulation_api, inputs=[api_exp_sim_id], outputs=api_exp_out, api_name="export_simulation")
427
+
428
+ with gr.Tab("Focus Group API", visible=False):
429
+ api_list_fg_btn = gr.Button("List Focus Groups")
430
+ api_list_fg_out = gr.JSON()
431
+ api_list_fg_btn.click(list_focus_groups_api, outputs=api_list_fg_out, api_name="list_focus_groups")
432
+
433
+ api_save_fg_name = gr.Textbox(label="Focus Group Name")
434
+ api_save_fg_sim_id = gr.Textbox(label="Simulation ID")
435
+ api_save_fg_btn = gr.Button("Save Focus Group")
436
+ api_save_fg_out = gr.JSON()
437
+ api_save_fg_btn.click(save_focus_group_api, inputs=[api_save_fg_name, api_save_fg_sim_id], outputs=api_save_fg_out, api_name="save_focus_group")
438
 
439
  if __name__ == "__main__":
440
+ demo.queue().launch()
config.ini CHANGED
@@ -1,7 +1,12 @@
1
  [OpenAI]
2
  API_TYPE=helmholtz-blablador
3
- MODEL=alias-large
4
- REASONING_MODEL=alias-large
 
 
5
  TOP_P=1.0
6
- MAX_ATTEMPTS=5
7
- WAITING_TIME=20
 
 
 
 
1
  [OpenAI]
2
  API_TYPE=helmholtz-blablador
3
+ MODEL=alias-fast
4
+ REASONING_MODEL=alias-fast
5
+ FALLBACK_MODEL_LARGE=alias-large
6
+ FALLBACK_MODEL_HUGE=alias-huge
7
  TOP_P=1.0
8
+ MAX_ATTEMPTS=999
9
+ WAITING_TIME=35
10
+
11
+ [Logging]
12
+ LOGLEVEL=DEBUG
requirements.txt CHANGED
@@ -22,9 +22,3 @@ textdistance
22
  scipy
23
  transformers==4.38.2
24
  huggingface-hub==0.22.2
25
- fastapi
26
- uvicorn
27
- numpy
28
- scipy
29
- scikit-learn
30
- networkx
 
22
  scipy
23
  transformers==4.38.2
24
  huggingface-hub==0.22.2
 
 
 
 
 
 
tinytroupe/agent/agent_traits.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List, Optional
2
+ import json
3
+ from dataclasses import dataclass, field
4
+ import tinytroupe.openai_utils as openai_utils
5
+ from tinytroupe.agent.social_types import Content
6
+
7
+ @dataclass
8
+ class TraitProfile:
9
+ openness: float = 0.5
10
+ conscientiousness: float = 0.5
11
+ extraversion: float = 0.5
12
+ agreeableness: float = 0.5
13
+ controversiality_tolerance: float = 0.5
14
+ information_seeking_behavior: float = 0.5
15
+ visual_content_preference: float = 0.5
16
+
17
+ class TraitBasedBehaviorModel:
18
+ def __init__(self, model: str = "gpt-4"):
19
+ self.model = model
20
+
21
+ def compute_action_probability(self, persona: Any, action_type: str, content: Optional[Content] = None) -> float:
22
+ """
23
+ Compute probability of an action based on persona traits.
24
+ """
25
+ traits = persona.behavioral_traits
26
+ if not traits:
27
+ return 0.5
28
+
29
+ base_prob = 0.5
30
+
31
+ if action_type == "engage":
32
+ # Example logic
33
+ if content and content.format == "video":
34
+ base_prob = self.apply_trait_modifiers(base_prob, {"visual_content_preference": traits.get("visual_content_preference", 0.5)})
35
+
36
+ base_prob = self.apply_trait_modifiers(base_prob, {"openness": traits.get("openness", 0.5)})
37
+
38
+ return base_prob
39
+
40
+ def apply_trait_modifiers(self, base_probability: float, traits: Dict[str, float]) -> float:
41
+ """
42
+ Apply trait modifiers to a base probability.
43
+ """
44
+ prob = base_probability
45
+ for trait, value in traits.items():
46
+ # Simple linear adjustment for now
47
+ # values > 0.5 increase probability, < 0.5 decrease it
48
+ modifier = (value - 0.5) * 0.2
49
+ prob += modifier
50
+
51
+ return max(0.0, min(1.0, prob))
52
+
53
+ def generate_trait_profile_from_description(self, description: str) -> Dict[str, float]:
54
+ """
55
+ Use LLM to infer traits from persona descriptions.
56
+ """
57
+ prompt = f"""
58
+ Analyze the following persona description and infer their behavioral traits on a scale of 0.0 to 1.0.
59
+
60
+ Description: {description}
61
+
62
+ Traits to infer:
63
+ - openness (Openness to new ideas/novel content)
64
+ - conscientiousness (Posting regularity, thoughtfulness)
65
+ - extraversion (Sharing frequency, network activity)
66
+ - agreeableness (Commenting positivity, conflict avoidance)
67
+ - controversiality_tolerance (Engagement with divisive topics)
68
+ - information_seeking_behavior (Long-form vs short-form preference)
69
+ - visual_content_preference (Image/video vs text preference)
70
+
71
+ Provide the result in JSON format.
72
+ """
73
+
74
+ response = openai_utils.client().send_message(
75
+ [
76
+ {"role": "system", "content": "You are an expert psychologist and persona modeler."},
77
+ {"role": "user", "content": prompt}
78
+ ],
79
+ temperature=0.3,
80
+ response_format={"type": "json_object"}
81
+ )
82
+
83
+ try:
84
+ traits = json.loads(response["content"])
85
+ return traits
86
+ except Exception:
87
+ return TraitProfile().__dict__
tinytroupe/agent/memory.py CHANGED
@@ -88,11 +88,23 @@ class TinyMemory(TinyMentalFaculty):
88
  """
89
  raise NotImplementedError("Subclasses must implement this method.")
90
 
 
 
 
 
 
 
91
  def get_memory_summary(self) -> str:
92
  """
93
- Returns a summary of all memories.
 
 
 
 
94
  """
95
- return self.summarize_relevant_via_full_scan("A general summary of the agent's experiences and knowledge.")
 
 
96
 
97
  def summarize_relevant_via_full_scan(self, relevance_target: str, batch_size: int = 20, item_type: str = None) -> str:
98
  """
 
88
  """
89
  raise NotImplementedError("Subclasses must implement this method.")
90
 
91
+ def store_interaction(self, interaction: Any) -> None:
92
+ """
93
+ Stores an interaction in memory.
94
+ """
95
+ self.store({"type": "interaction", "content": interaction, "simulation_timestamp": utils.pretty_datetime(datetime.now())})
96
+
97
  def get_memory_summary(self) -> str:
98
  """
99
+ Returns a summary of the memory.
100
+ """
101
+ raise NotImplementedError("Subclasses must implement this method.")
102
+
103
+ def consolidate_memories(self) -> None:
104
  """
105
+ Consolidates memories (e.g., from episodic to semantic).
106
+ """
107
+ raise NotImplementedError("Subclasses must implement this method.")
108
 
109
  def summarize_relevant_via_full_scan(self, relevance_target: str, batch_size: int = 20, item_type: str = None) -> str:
110
  """
tinytroupe/agent/social_types.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional, Any, Set
3
+ from datetime import datetime
4
+
5
+ @dataclass
6
+ class ConnectionEdge:
7
+ connection_id: str
8
+ strength: float = 0.0 # 0.0-1.0
9
+ influence_score: float = 0.0
10
+ interaction_history: List[Dict[str, Any]] = field(default_factory=list)
11
+ relationship_type: str = "follower" # "follower", "friend", "colleague", "family"
12
+ last_interaction: Optional[datetime] = None
13
+ created_at: datetime = field(default_factory=datetime.now)
14
+
15
+ @dataclass
16
+ class BehavioralEvent:
17
+ timestamp: datetime
18
+ action_type: str
19
+ content_id: str
20
+ outcome: Any
21
+ context: Dict[str, Any] = field(default_factory=dict)
22
+
23
+ @dataclass
24
+ class InfluenceProfile:
25
+ reach: int = 0
26
+ authority: float = 0.0
27
+ expertise_domains: List[str] = field(default_factory=list)
28
+ follower_to_following_ratio: float = 0.0
29
+ engagement_rate: float = 0.0
30
+
31
+ @dataclass
32
+ class Content:
33
+ text: str
34
+ content_id: Optional[str] = None
35
+ topics: List[str] = field(default_factory=list)
36
+ format: str = "text" # "article", "video", "poll", "survey", "ux_test", "email", "ad", etc.
37
+ length: int = 0
38
+ tone: str = "neutral"
39
+ author_name: Optional[str] = None
40
+ author_title: Optional[str] = None
41
+ sentiment: float = 0.0
42
+ images: List[str] = field(default_factory=list)
43
+ video_url: Optional[str] = None
44
+ external_links: List[str] = field(default_factory=list)
45
+ hashtags: List[str] = field(default_factory=list)
46
+ timestamp: datetime = field(default_factory=datetime.now)
47
+ platform: str = "LinkedIn"
48
+
49
+ @dataclass
50
+ class Reaction:
51
+ reaction_type: str # "like", "love", "insightful", "celebrate", "none", "positive", "negative", "neutral"
52
+ will_engage: bool
53
+ probability: float
54
+ reasoning: Optional[str] = None
55
+ comment: Optional[str] = None
56
+ will_share: bool = False
57
+ virality_coefficient: float = 0.0
58
+ sentiment: float = 0.0 # -1.0 to 1.0
59
+ detailed_feedback: Dict[str, Any] = field(default_factory=dict) # For surveys/UX tests
tinytroupe/agent/tiny_person.py CHANGED
@@ -1,12 +1,12 @@
1
  from tinytroupe.agent import logger, default, Self, AgentOrWorld, CognitiveActionModel
2
  from tinytroupe.agent.memory import EpisodicMemory, SemanticMemory, EpisodicConsolidator
 
3
  import tinytroupe.openai_utils as openai_utils
4
  from tinytroupe.utils import JsonSerializableRegistry, repeat_on_error, name_or_empty
5
  import tinytroupe.utils as utils
6
  from tinytroupe.control import transactional, current_simulation
7
  from tinytroupe import config_manager
8
  from tinytroupe.utils.logger import get_logger
9
- from tinytroupe.agent_types import ConnectionEdge, BehavioralEvent, InfluenceProfile, Content, Reaction, Interaction
10
 
11
  import os
12
  import json
@@ -44,7 +44,7 @@ class TinyPerson(JsonSerializableRegistry):
44
  PP_TEXT_WIDTH = 100
45
 
46
  serializable_attributes = ["_persona", "_mental_state", "_mental_faculties", "_current_episode_event_count", "episodic_memory", "semantic_memory",
47
- "social_connections", "engagement_patterns", "behavioral_history", "influence_metrics", "prediction_confidence"]
48
  serializable_attributes_renaming = {"_mental_faculties": "mental_faculties", "_persona": "persona", "_mental_state": "mental_state", "_current_episode_event_count": "current_episode_event_count"}
49
 
50
  # A dict of all agents instantiated so far.
@@ -211,33 +211,29 @@ class TinyPerson(JsonSerializableRegistry):
211
 
212
  if not hasattr(self, 'stimuli_count'):
213
  self.stimuli_count = 0
214
-
215
- # Social Network and Engagement Enhancements
216
  if not hasattr(self, 'social_connections'):
217
- self.social_connections = {} # connection_id -> ConnectionEdge
218
 
219
  if not hasattr(self, 'engagement_patterns'):
220
  self.engagement_patterns = {
221
  "content_type_preferences": {},
222
  "topic_affinities": {},
223
  "posting_time_preferences": {},
224
- "engagement_likelihood": 0.1
225
  }
226
 
227
  if not hasattr(self, 'behavioral_history'):
228
  self.behavioral_history = []
229
 
230
  if not hasattr(self, 'influence_metrics'):
231
- self.influence_metrics = InfluenceProfile(
232
- reach=0.0,
233
- authority=0.0,
234
- expertise_domains=[],
235
- follower_to_following_ratio=1.0,
236
- engagement_rate=0.0
237
- )
238
 
239
  if not hasattr(self, 'prediction_confidence'):
240
- self.prediction_confidence = 0.5
 
 
 
241
 
242
  self._prompt_template_path = os.path.join(
243
  os.path.dirname(__file__), "prompts/tiny_person.mustache"
@@ -1824,85 +1820,46 @@ max_content_length=max_content_length,
1824
  """
1825
  TinyPerson.all_agents = {}
1826
 
1827
- #########################################################################
1828
- # Artificial Societies Enhancements
1829
- #########################################################################
1830
 
1831
  def calculate_engagement_probability(self, content: Content) -> float:
1832
  """
1833
- Calculates the probability that the persona will engage with the given content.
1834
  """
1835
- affinity = self.get_content_affinity(content)
 
1836
 
1837
- # Base probability from engagement patterns
1838
- base_prob = self.engagement_patterns.get("engagement_likelihood", 0.1)
1839
 
1840
- # Factor in social influence (placeholder logic)
1841
- social_factor = 1.0
1842
- for conn_id, edge in self.social_connections.items():
1843
- if edge.influence_score > 0.8:
1844
- social_factor += 0.1
1845
-
1846
- prob = affinity * base_prob * social_factor
1847
- return min(max(prob, 0.0), 1.0)
1848
 
1849
  def predict_reaction(self, content: Content) -> Reaction:
1850
  """
1851
- Predicts the reaction of the persona to the given content.
1852
  """
1853
- prob = self.calculate_engagement_probability(content)
1854
-
1855
- if random.random() > prob:
1856
- return Reaction(reaction_type="none")
1857
 
1858
- # Use LLM to generate reaction and comment
1859
- prompt = f"Given the content: '{content.text}', how would {self.name} react? Persona info: {self.minibio()}"
1860
- # Placeholder for LLM call
1861
- reaction_type = random.choice(["like", "love", "insightful", "celebrate"])
1862
- comment = f"Interesting post about {', '.join(content.topics)}!"
1863
-
1864
- return Reaction(
1865
- reaction_type=reaction_type,
1866
- comment=comment,
1867
- will_share=random.random() < 0.2,
1868
- virality_coefficient=self.influence_metrics.authority * 0.5
1869
- )
1870
 
1871
- def update_from_interaction(self, interaction: Interaction) -> None:
1872
  """
1873
- Updates the persona's patterns based on a real interaction.
1874
  """
1875
- event = BehavioralEvent(
1876
- timestamp=interaction.timestamp,
1877
- action_type=interaction.action_type,
1878
- content_id=interaction.content_id,
1879
- outcome=interaction.outcome
1880
- )
1881
- self.behavioral_history.append(event)
1882
-
1883
- # Simple reinforcement learning logic
1884
- if interaction.action_type in ["like", "comment", "share"]:
1885
- # Increase engagement likelihood slightly
1886
- self.engagement_patterns["engagement_likelihood"] *= 1.05
1887
-
1888
- # Keep history manageable
1889
- if len(self.behavioral_history) > 100:
1890
- self.behavioral_history.pop(0)
1891
 
1892
  def get_content_affinity(self, content: Content) -> float:
1893
  """
1894
- Scores the content relevance to the persona.
1895
  """
1896
- score = 0.5 # Neutral base
1897
-
1898
- # Topic alignment
1899
- persona_topics = self.get("interests") or []
1900
- matched_topics = set(persona_topics).intersection(set(content.topics))
1901
- if matched_topics:
1902
- score += 0.1 * len(matched_topics)
1903
-
1904
- # Content type preference
1905
- pref = self.engagement_patterns["content_type_preferences"].get(content.content_type, 1.0)
1906
- score *= pref
1907
-
1908
- return min(max(score, 0.0), 2.0) # Normalized to a reasonable range
 
1
  from tinytroupe.agent import logger, default, Self, AgentOrWorld, CognitiveActionModel
2
  from tinytroupe.agent.memory import EpisodicMemory, SemanticMemory, EpisodicConsolidator
3
+ from tinytroupe.agent.social_types import ConnectionEdge, BehavioralEvent, InfluenceProfile, Content, Reaction
4
  import tinytroupe.openai_utils as openai_utils
5
  from tinytroupe.utils import JsonSerializableRegistry, repeat_on_error, name_or_empty
6
  import tinytroupe.utils as utils
7
  from tinytroupe.control import transactional, current_simulation
8
  from tinytroupe import config_manager
9
  from tinytroupe.utils.logger import get_logger
 
10
 
11
  import os
12
  import json
 
44
  PP_TEXT_WIDTH = 100
45
 
46
  serializable_attributes = ["_persona", "_mental_state", "_mental_faculties", "_current_episode_event_count", "episodic_memory", "semantic_memory",
47
+ "social_connections", "engagement_patterns", "behavioral_history", "influence_metrics", "prediction_confidence", "behavioral_traits"]
48
  serializable_attributes_renaming = {"_mental_faculties": "mental_faculties", "_persona": "persona", "_mental_state": "mental_state", "_current_episode_event_count": "current_episode_event_count"}
49
 
50
  # A dict of all agents instantiated so far.
 
211
 
212
  if not hasattr(self, 'stimuli_count'):
213
  self.stimuli_count = 0
214
+
 
215
  if not hasattr(self, 'social_connections'):
216
+ self.social_connections = {}
217
 
218
  if not hasattr(self, 'engagement_patterns'):
219
  self.engagement_patterns = {
220
  "content_type_preferences": {},
221
  "topic_affinities": {},
222
  "posting_time_preferences": {},
223
+ "engagement_likelihood": {}
224
  }
225
 
226
  if not hasattr(self, 'behavioral_history'):
227
  self.behavioral_history = []
228
 
229
  if not hasattr(self, 'influence_metrics'):
230
+ self.influence_metrics = InfluenceProfile()
 
 
 
 
 
 
231
 
232
  if not hasattr(self, 'prediction_confidence'):
233
+ self.prediction_confidence = 0.0
234
+
235
+ if not hasattr(self, 'behavioral_traits'):
236
+ self.behavioral_traits = {}
237
 
238
  self._prompt_template_path = os.path.join(
239
  os.path.dirname(__file__), "prompts/tiny_person.mustache"
 
1820
  """
1821
  TinyPerson.all_agents = {}
1822
 
1823
+ ############################################################################
1824
+ # Social and Engagement methods
1825
+ ############################################################################
1826
 
1827
  def calculate_engagement_probability(self, content: Content) -> float:
1828
  """
1829
+ Analyze content features and return probability of engagement using the prediction engine.
1830
  """
1831
+ from tinytroupe.ml_models import EngagementPredictor
1832
+ predictor = EngagementPredictor()
1833
 
1834
+ # Use the environment's network topology if available
1835
+ network = getattr(self.environment, 'network', None)
1836
 
1837
+ return predictor.predict(self, content, network)
 
 
 
 
 
 
 
1838
 
1839
  def predict_reaction(self, content: Content) -> Reaction:
1840
  """
1841
+ Determine reaction type using the LLM-based predictor.
1842
  """
1843
+ from tinytroupe.llm_predictor import LLMPredictor
1844
+ predictor = LLMPredictor()
 
 
1845
 
1846
+ return predictor.predict(self, content)
 
 
 
 
 
 
 
 
 
 
 
1847
 
1848
+ def update_from_interaction(self, interaction: Any) -> None:
1849
  """
1850
+ Learn from actual interactions and update patterns.
1851
  """
1852
+ # interaction could be a dict with content and outcome
1853
+ if isinstance(interaction, dict):
1854
+ content = interaction.get("content")
1855
+ outcome = interaction.get("outcome") # e.g. "like", "comment", "none"
1856
+
1857
+ # Update patterns based on outcome
1858
+ # This is a simplified learning mechanism
1859
+ pass
 
 
 
 
 
 
 
 
1860
 
1861
  def get_content_affinity(self, content: Content) -> float:
1862
  """
1863
+ Score content relevance to persona.
1864
  """
1865
+ return self.calculate_engagement_probability(content)
 
 
 
 
 
 
 
 
 
 
 
 
tinytroupe/config.ini CHANGED
@@ -15,10 +15,10 @@ AZURE_API_VERSION=2023-05-15
15
  #
16
 
17
  # The main text generation model, used for agent responses
18
- MODEL=gpt-4.1-mini
19
 
20
  # Reasoning model is used when precise reasoning is required, such as when computing detailed analyses of simulation properties.
21
- REASONING_MODEL=o3-mini
22
 
23
  # Embedding model is used for text similarity tasks
24
  EMBEDDING_MODEL=text-embedding-3-small
@@ -31,8 +31,8 @@ TEMPERATURE=1.5
31
  FREQ_PENALTY=0.1
32
  PRESENCE_PENALTY=0.1
33
  TIMEOUT=480
34
- MAX_ATTEMPTS=5
35
- WAITING_TIME=1
36
  EXPONENTIAL_BACKOFF_FACTOR=5
37
 
38
  REASONING_EFFORT=high
@@ -90,7 +90,7 @@ QUALITY_THRESHOLD = 5
90
 
91
 
92
  [Logging]
93
- LOGLEVEL=ERROR
94
  # ERROR
95
  # WARNING
96
  # INFO
 
15
  #
16
 
17
  # The main text generation model, used for agent responses
18
+ MODEL=alias-fast
19
 
20
  # Reasoning model is used when precise reasoning is required, such as when computing detailed analyses of simulation properties.
21
+ REASONING_MODEL=alias-fast
22
 
23
  # Embedding model is used for text similarity tasks
24
  EMBEDDING_MODEL=text-embedding-3-small
 
31
  FREQ_PENALTY=0.1
32
  PRESENCE_PENALTY=0.1
33
  TIMEOUT=480
34
+ MAX_ATTEMPTS=999
35
+ WAITING_TIME=35
36
  EXPONENTIAL_BACKOFF_FACTOR=5
37
 
38
  REASONING_EFFORT=high
 
90
 
91
 
92
  [Logging]
93
+ LOGLEVEL=DEBUG
94
  # ERROR
95
  # WARNING
96
  # INFO
tinytroupe/content_generation.py CHANGED
@@ -1,34 +1,42 @@
1
- from typing import List, Dict, Any
2
  import random
3
- from dataclasses import dataclass
4
- from tinytroupe.agent.tiny_person import TinyPerson
5
- from tinytroupe.agent_types import Content
6
 
7
- @dataclass
8
  class ContentVariant:
9
- text: str
10
- strategy: str
11
- parameters: Dict[str, Any]
12
- original_content: str
 
13
 
14
  class ContentVariantGenerator:
15
  """Generate multiple variants of input content"""
16
 
17
- def generate_variants(self, original_content: str, num_variants: int = 10,
 
 
 
18
  target_personas: List[TinyPerson] = None) -> List[ContentVariant]:
19
  """Generate diverse variants of content"""
20
  variants = []
21
- strategies = ["tone", "length", "format", "persona_targeted", "angle"]
 
 
22
 
23
  for i in range(num_variants):
24
- strategy = random.choice(strategies)
25
- # Placeholder for real LLM-based generation
26
- variant_text = f"[{strategy.upper()} variant {i}] {original_content[:50]}..."
 
 
27
 
 
28
  variants.append(ContentVariant(
29
  text=variant_text,
30
- strategy=strategy,
31
- parameters={"index": i},
32
  original_content=original_content
33
  ))
34
 
 
1
+ from typing import List, Dict, Any, Optional
2
  import random
3
+ from tinytroupe.agent import TinyPerson
4
+ from tinytroupe.agent.social_types import Content
5
+ import tinytroupe.openai_utils as openai_utils
6
 
 
7
  class ContentVariant:
8
+ def __init__(self, text: str, strategy: str, parameters: Dict[str, Any], original_content: str):
9
+ self.text = text
10
+ self.strategy = strategy
11
+ self.parameters = parameters
12
+ self.original_content = original_content
13
 
14
  class ContentVariantGenerator:
15
  """Generate multiple variants of input content"""
16
 
17
+ def __init__(self, model: str = "gpt-4"):
18
+ self.model = model
19
+
20
+ def generate_variants(self, original_content: str, num_variants: int = 5,
21
  target_personas: List[TinyPerson] = None) -> List[ContentVariant]:
22
  """Generate diverse variants of content"""
23
  variants = []
24
+
25
+ # In a real implementation, we would use different prompts for different strategies
26
+ # Here we use a simplified approach
27
 
28
  for i in range(num_variants):
29
+ prompt = f"Rewrite the following content in a different style or tone:\n\n{original_content}"
30
+
31
+ response = openai_utils.client().send_message(
32
+ [{"role": "user", "content": prompt}]
33
+ )
34
 
35
+ variant_text = response["content"].strip()
36
  variants.append(ContentVariant(
37
  text=variant_text,
38
+ strategy="style_variation",
39
+ parameters={"variant_index": i},
40
  original_content=original_content
41
  ))
42
 
tinytroupe/environment/social_tiny_world.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Set, Optional
2
+ import random
3
+ from datetime import datetime
4
+ from tinytroupe.environment.tiny_world import TinyWorld
5
+ from tinytroupe.social_network import NetworkTopology
6
+ from tinytroupe.agent.social_types import Content, Reaction
7
+ from tinytroupe.agent import TinyPerson
8
+ from tinytroupe.agent import logger
9
+
10
+ class SimulationResult:
11
+ def __init__(self, content: Content, start_time: datetime):
12
+ self.content = content
13
+ self.start_time = start_time
14
+ self.end_time: Optional[datetime] = None
15
+ self.engagements: List[Dict[str, Any]] = []
16
+ self.step_metrics: List[Dict[str, Any]] = []
17
+ self.total_reach = 0
18
+ self.engagement_rate = 0.0
19
+ self.expected_likes = 0
20
+ self.expected_comments = 0
21
+ self.expected_shares = 0
22
+ self.cascade_depth = 0
23
+ self.execution_time = 0.0
24
+ self.avg_sentiment = 0.0
25
+ self.feedback_summary: List[str] = []
26
+
27
+ def add_engagement(self, persona_id: str, engagement_type: str, step: int, sentiment: float = 0.0, feedback: str = None):
28
+ self.engagements.append({
29
+ "persona_id": persona_id,
30
+ "type": engagement_type,
31
+ "step": step,
32
+ "sentiment": sentiment,
33
+ "feedback": feedback
34
+ })
35
+ if engagement_type == "like": self.expected_likes += 1
36
+ elif engagement_type == "comment": self.expected_comments += 1
37
+ elif engagement_type == "share": self.expected_shares += 1
38
+
39
+ if feedback:
40
+ self.feedback_summary.append(feedback)
41
+
42
+ def add_step_metrics(self, step: int, reach: int, engagements: int):
43
+ self.step_metrics.append({
44
+ "step": step,
45
+ "reach": reach,
46
+ "engagements": engagements
47
+ })
48
+
49
+ def finalize(self, end_time: datetime):
50
+ self.end_time = end_time
51
+ self.execution_time = (end_time - self.start_time).total_seconds()
52
+ self.total_reach = len(set(e["persona_id"] for e in self.engagements)) # Simplified
53
+ # ... more metrics
54
+
55
+ class SocialTinyWorld(TinyWorld):
56
+ """Extended TinyWorld with social network capabilities"""
57
+
58
+ def __init__(self, name: str, network: NetworkTopology = None, **kwargs):
59
+ super().__init__(name, **kwargs)
60
+ self.network = network or NetworkTopology()
61
+ self.content_items: List[Content] = []
62
+ self.simulation_history: List[SimulationResult] = []
63
+ self.time_step = 0
64
+
65
+ def add_content(self, content: Content) -> None:
66
+ """Add content to the world for personas to interact with"""
67
+ self.content_items.append(content)
68
+ self.broadcast(f"New content available: {content.text[:100]}...")
69
+
70
+ def simulate_content_spread(self, content: Content,
71
+ initial_viewers: List[str],
72
+ max_steps: int = 10) -> SimulationResult:
73
+ """Simulate how content spreads through the network"""
74
+
75
+ result = SimulationResult(content=content, start_time=datetime.now())
76
+ viewed = set(initial_viewers)
77
+ engaged = set()
78
+
79
+ for step in range(max_steps):
80
+ self.time_step = step
81
+ new_viewers = set()
82
+
83
+ for viewer_id in viewed - engaged:
84
+ if viewer_id not in self.network.nodes: continue
85
+ persona = self.network.nodes[viewer_id]
86
+
87
+ # Predict reaction (simplified)
88
+ reaction = persona.predict_reaction(content)
89
+
90
+ if reaction.will_engage:
91
+ engaged.add(viewer_id)
92
+ result.add_engagement(
93
+ viewer_id,
94
+ reaction.reaction_type,
95
+ step,
96
+ sentiment=reaction.sentiment,
97
+ feedback=reaction.comment
98
+ )
99
+
100
+ if reaction.will_share:
101
+ neighbors = self.network.get_neighbors(viewer_id)
102
+ new_viewers.update([n.name for n in neighbors])
103
+
104
+ viewed.update(new_viewers)
105
+ result.add_step_metrics(step, len(viewed), len(engaged))
106
+
107
+ if not new_viewers:
108
+ break
109
+
110
+ result.finalize(datetime.now())
111
+ self.simulation_history.append(result)
112
+ return result
tinytroupe/factory/tiny_person_factory.py CHANGED
@@ -12,7 +12,6 @@ from tinytroupe.agent import TinyPerson
12
  import tinytroupe.utils as utils
13
  from tinytroupe.control import transactional
14
  from tinytroupe import config_manager
15
- from tinytroupe.agent_traits import TraitBasedBehaviorModel
16
 
17
  import concurrent.futures
18
  import threading
@@ -343,6 +342,46 @@ class TinyPersonFactory(TinyFactory):
343
 
344
 
345
  @config_manager.config_defaults(parallelize="parallel_agent_generation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  def generate_people(self, number_of_people:int=None,
347
  agent_particularities:str=None,
348
  temperature:float=1.2,
@@ -559,6 +598,11 @@ class TinyPersonFactory(TinyFactory):
559
  if len(self.remaining_characteristics_sample) != n:
560
  logger.warning(f"Expected {n} samples, but got {len(self.remaining_characteristics_sample)} samples. The LLM may have failed to sum up the quantities in the sampling plan correctly.")
561
 
 
 
 
 
 
562
  logger.info(f"Sample plan has been flattened, contains {len(self.remaining_characteristics_sample)} total samples.")
563
  logger.debug(f"Remaining characteristics sample: {json.dumps(self.remaining_characteristics_sample, indent=4)}")
564
 
@@ -626,61 +670,6 @@ class TinyPersonFactory(TinyFactory):
626
  """
627
  return name in TinyPerson.all_agents_names()
628
 
629
- #########################################################################
630
- # Artificial Societies Factory Enhancements
631
- #########################################################################
632
-
633
- def generate_from_demographics(self, age_range: tuple, location: str, occupation: str, interests: List[str]) -> TinyPerson:
634
- """
635
- Generates a persona based on specific demographics.
636
- """
637
- context = f"A {age_range[0]}-{age_range[1]} year old {occupation} in {location} interested in {', '.join(interests)}."
638
- return self.generate_person(agent_particularities=context)
639
-
640
- def generate_from_linkedin_profile(self, profile_data: Dict) -> TinyPerson:
641
- """
642
- Generates a persona based on a LinkedIn profile.
643
- """
644
- context = f"Professional profile: {json.dumps(profile_data)}"
645
- persona = self.generate_person(agent_particularities=context)
646
- persona.define("social_platform", "LinkedIn")
647
- return persona
648
-
649
- def generate_persona_cluster(self, archetype: str, count: int) -> List[TinyPerson]:
650
- """
651
- Generates a cluster of personas based on an archetype.
652
- """
653
- personas = []
654
- for _ in range(count):
655
- particularities = f"Archetype: {archetype}. Ensure individual variation."
656
- personas.append(self.generate_person(agent_particularities=particularities))
657
- return personas
658
-
659
- def generate_diverse_population(self, size: int, distribution: Dict) -> List[TinyPerson]:
660
- """
661
- Generates a diverse population based on a distribution.
662
- """
663
- # Simplistic implementation: use create_factory_from_demography logic
664
- return self.generate_people(number_of_people=size, verbose=True)
665
-
666
- def ensure_consistency(self, persona: TinyPerson) -> bool:
667
- """
668
- Validates the consistency of a generated persona.
669
- """
670
- # Placeholder for LLM-based consistency check
671
- traits = persona.get("behavioral_traits")
672
- if traits and len(traits) > 0:
673
- return True
674
- return False
675
-
676
- def calculate_diversity_score(self, personas: List[TinyPerson]) -> float:
677
- """
678
- Measures demographic and behavioral diversity of a population.
679
- """
680
- if not personas: return 0.0
681
- # Placeholder logic: ratio of unique occupations
682
- occupations = [p.get("occupation") for p in personas]
683
- return len(set(occupations)) / len(personas)
684
 
685
  @transactional()
686
  @utils.llm(temperature=0.5, frequency_penalty=0.0, presence_penalty=0.0)
 
12
  import tinytroupe.utils as utils
13
  from tinytroupe.control import transactional
14
  from tinytroupe import config_manager
 
15
 
16
  import concurrent.futures
17
  import threading
 
342
 
343
 
344
  @config_manager.config_defaults(parallelize="parallel_agent_generation")
345
+ def generate_from_linkedin_profile(self, profile_data: Dict) -> TinyPerson:
346
+ """
347
+ Generate a TinyPerson from a LinkedIn profile with enriched traits.
348
+ """
349
+ description = f"Professional with headline: {profile_data.get('headline', '')}. " \
350
+ f"Industry: {profile_data.get('industry', '')}. " \
351
+ f"Location: {profile_data.get('location', 'Global')}. " \
352
+ f"Career level: {profile_data.get('career_level', 'Mid Level')}. " \
353
+ f"Summary: {profile_data.get('summary', '')}"
354
+
355
+ return self.generate_person(agent_particularities=description)
356
+
357
+ def generate_persona_cluster(self, archetype: str, count: int) -> List[TinyPerson]:
358
+ """
359
+ Generate a cluster of personas following a specific archetype.
360
+ """
361
+ return self.generate_people(number_of_people=count, agent_particularities=f"Archetype: {archetype}")
362
+
363
+ def generate_diverse_population(self, size: int, distribution: Dict) -> List[TinyPerson]:
364
+ """
365
+ Generate a diverse population based on a distribution.
366
+ """
367
+ # distribution could specify proportions of various characteristics
368
+ # This is a simplified implementation
369
+ return self.generate_people(number_of_people=size, agent_particularities=f"Target distribution: {json.dumps(distribution)}")
370
+
371
+ def ensure_consistency(self, persona: TinyPerson) -> bool:
372
+ """
373
+ Ensure the generated persona is consistent.
374
+ """
375
+ # Implementation would involve checking traits, demographics, etc.
376
+ return True # Placeholder
377
+
378
+ def calculate_diversity_score(self, personas: List[TinyPerson]) -> float:
379
+ """
380
+ Calculate a diversity score for a list of personas.
381
+ """
382
+ # Placeholder for diversity metric calculation
383
+ return 0.5
384
+
385
  def generate_people(self, number_of_people:int=None,
386
  agent_particularities:str=None,
387
  temperature:float=1.2,
 
598
  if len(self.remaining_characteristics_sample) != n:
599
  logger.warning(f"Expected {n} samples, but got {len(self.remaining_characteristics_sample)} samples. The LLM may have failed to sum up the quantities in the sampling plan correctly.")
600
 
601
+ # If we got more samples than requested, we truncate them to avoid generating too many names or personas.
602
+ if len(self.remaining_characteristics_sample) > n:
603
+ logger.info(f"Truncating {len(self.remaining_characteristics_sample)} samples to the requested {n} samples.")
604
+ self.remaining_characteristics_sample = self.remaining_characteristics_sample[:n]
605
+
606
  logger.info(f"Sample plan has been flattened, contains {len(self.remaining_characteristics_sample)} total samples.")
607
  logger.debug(f"Remaining characteristics sample: {json.dumps(self.remaining_characteristics_sample, indent=4)}")
608
 
 
670
  """
671
  return name in TinyPerson.all_agents_names()
672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
 
674
  @transactional()
675
  @utils.llm(temperature=0.5, frequency_penalty=0.0, presence_penalty=0.0)
tinytroupe/features.py CHANGED
@@ -1,55 +1,37 @@
1
- from typing import Dict, List, Any
2
  import numpy as np
3
  from datetime import datetime
4
- from tinytroupe.agent.tiny_person import TinyPerson
5
- from tinytroupe.agent_types import Content
6
  from tinytroupe.social_network import NetworkTopology
7
 
8
  class ContentFeatureExtractor:
9
  def extract(self, content: Content) -> Dict[str, float]:
10
  """Extract all content features"""
11
  return {
12
- "word_count": len(content.text.split()) / 500.0, # Normalized
13
  "has_image": 1.0 if content.images else 0.0,
14
  "has_video": 1.0 if content.video_url else 0.0,
15
- "num_hashtags": len(content.hashtags) / 10.0,
16
- "sentiment_score": 0.5, # Placeholder for VADER/Transformers
17
- "hour_of_day": content.timestamp.hour / 24.0,
18
  "is_weekend": 1.0 if content.timestamp.weekday() >= 5 else 0.0,
19
  }
20
 
21
  class PersonaFeatureExtractor:
22
  def extract(self, persona: TinyPerson) -> Dict[str, float]:
23
  """Extract persona features"""
24
- traits = persona.get("behavioral_traits") or {}
25
  return {
26
- "age": float(persona.get("age") or 30) / 100.0,
27
- "num_connections": len(persona.social_connections) / 100.0,
28
- "authority": persona.influence_metrics.authority,
29
- "openness": traits.get("openness_to_new_ideas", 0.5),
30
- "extraversion": traits.get("extraversion", 0.5),
31
- "engagement_rate": persona.influence_metrics.engagement_rate,
32
  }
33
 
34
  class InteractionFeatureExtractor:
35
  def extract(self, persona: TinyPerson, content: Content, network: NetworkTopology) -> Dict[str, float]:
36
  """Extract features from persona-content interaction context"""
37
- # Placeholder for complex context features
38
  return {
39
  "topic_alignment": persona.get_content_affinity(content),
40
- "author_connection": 1.0 if content.author_name in persona.social_connections else 0.0,
41
  }
42
-
43
- class FeatureExtractor:
44
- def __init__(self):
45
- self.content_extractor = ContentFeatureExtractor()
46
- self.persona_extractor = PersonaFeatureExtractor()
47
- self.interaction_extractor = InteractionFeatureExtractor()
48
-
49
- def extract_all(self, persona: TinyPerson, content: Content, network: NetworkTopology) -> np.ndarray:
50
- c_feats = self.content_extractor.extract(content)
51
- p_feats = self.persona_extractor.extract(persona)
52
- i_feats = self.interaction_extractor.extract(persona, content, network)
53
-
54
- combined = {**c_feats, **p_feats, **i_feats}
55
- return np.array(list(combined.values()))
 
1
+ from typing import Dict, Any, List
2
  import numpy as np
3
  from datetime import datetime
4
+ from tinytroupe.agent import TinyPerson
5
+ from tinytroupe.agent.social_types import Content
6
  from tinytroupe.social_network import NetworkTopology
7
 
8
  class ContentFeatureExtractor:
9
  def extract(self, content: Content) -> Dict[str, float]:
10
  """Extract all content features"""
11
  return {
12
+ "word_count": float(len(content.text.split())),
13
  "has_image": 1.0 if content.images else 0.0,
14
  "has_video": 1.0 if content.video_url else 0.0,
15
+ "has_link": 1.0 if content.external_links else 0.0,
16
+ "sentiment_score": content.sentiment,
17
+ "num_hashtags": float(len(content.hashtags)),
18
  "is_weekend": 1.0 if content.timestamp.weekday() >= 5 else 0.0,
19
  }
20
 
21
  class PersonaFeatureExtractor:
22
  def extract(self, persona: TinyPerson) -> Dict[str, float]:
23
  """Extract persona features"""
 
24
  return {
25
+ "age": float(persona._persona.get("age") or 30),
26
+ "num_connections": float(len(persona.social_connections)),
27
+ "influence_score": persona.influence_metrics.authority,
28
+ "engagement_rate": persona.engagement_patterns.get("overall_rate", 0.0),
 
 
29
  }
30
 
31
  class InteractionFeatureExtractor:
32
  def extract(self, persona: TinyPerson, content: Content, network: NetworkTopology) -> Dict[str, float]:
33
  """Extract features from persona-content interaction context"""
 
34
  return {
35
  "topic_alignment": persona.get_content_affinity(content),
36
+ # "num_friends_engaged": ...
37
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tinytroupe/influence.py CHANGED
@@ -1,8 +1,7 @@
1
- from typing import List, Set, Dict, Any
2
- import random
3
  from dataclasses import dataclass
4
  from tinytroupe.social_network import NetworkTopology
5
- from tinytroupe.agent_types import Content
6
 
7
  @dataclass
8
  class PropagationResult:
@@ -22,56 +21,39 @@ class InfluencePropagator:
22
  """Main propagation simulation"""
23
  activated = set(seed_personas)
24
  activation_times = {pid: 0 for pid in seed_personas}
25
- engagement_by_time = [len(seed_personas)]
26
 
27
  for time_step in range(1, self.max_steps + 1):
28
  newly_activated = self._propagate_step(activated, content, time_step)
29
  if not newly_activated:
30
  break
31
-
32
  for pid in newly_activated:
33
  activation_times[pid] = time_step
34
  activated.update(newly_activated)
35
- engagement_by_time.append(len(newly_activated))
36
 
37
  return PropagationResult(
38
  activated_personas=activated,
39
  activation_times=activation_times,
40
  total_reach=len(activated),
41
  cascade_depth=max(activation_times.values()) if activation_times else 0,
42
- engagement_by_time=engagement_by_time
43
  )
44
 
45
  def _propagate_step(self, activated: Set[str], content: Content, time: int) -> Set[str]:
46
  """Single step of propagation"""
47
  newly_activated = set()
48
-
49
- if self.model == "cascade":
50
- for active_id in activated:
51
- # Find neighbors of active node
52
- neighbors = self.network.get_neighbors(active_id)
53
- for neighbor in neighbors:
54
- if neighbor.name not in activated and neighbor.name not in newly_activated:
55
- # Probabilistic activation
56
- prob = 0.1 # Base propagation probability
57
- if random.random() < prob:
58
- newly_activated.add(neighbor.name)
59
-
60
- elif self.model == "threshold":
61
- for name, persona in self.network.nodes.items():
62
- if name not in activated:
63
- neighbors = self.network.get_neighbors(name)
64
- active_neighbors = [n for n in neighbors if n.name in activated]
65
- if neighbors:
66
- influence = len(active_neighbors) / len(neighbors)
67
- threshold = 0.5 # Default threshold
68
- if influence >= threshold:
69
- newly_activated.add(name)
70
-
71
  return newly_activated
72
 
73
  def calculate_influence_score(self, persona_id: str) -> float:
74
  """Calculate overall influence of a persona"""
75
- neighbors = self.network.get_neighbors(persona_id)
76
- # Combine degree centrality and reach
77
- return len(neighbors) / max(len(self.network.nodes), 1)
 
1
+ from typing import List, Set, Dict, Any, Tuple
 
2
  from dataclasses import dataclass
3
  from tinytroupe.social_network import NetworkTopology
4
+ from tinytroupe.agent.social_types import Content
5
 
6
  @dataclass
7
  class PropagationResult:
 
21
  """Main propagation simulation"""
22
  activated = set(seed_personas)
23
  activation_times = {pid: 0 for pid in seed_personas}
 
24
 
25
  for time_step in range(1, self.max_steps + 1):
26
  newly_activated = self._propagate_step(activated, content, time_step)
27
  if not newly_activated:
28
  break
 
29
  for pid in newly_activated:
30
  activation_times[pid] = time_step
31
  activated.update(newly_activated)
 
32
 
33
  return PropagationResult(
34
  activated_personas=activated,
35
  activation_times=activation_times,
36
  total_reach=len(activated),
37
  cascade_depth=max(activation_times.values()) if activation_times else 0,
38
+ engagement_by_time=[] # TODO
39
  )
40
 
41
  def _propagate_step(self, activated: Set[str], content: Content, time: int) -> Set[str]:
42
  """Single step of propagation"""
43
  newly_activated = set()
44
+ for pid in activated:
45
+ # Check neighbors of activated personas
46
+ neighbors = self.network.get_neighbors(pid)
47
+ for neighbor in neighbors:
48
+ if neighbor.name not in activated and neighbor.name not in newly_activated:
49
+ # Decide if neighbor activates
50
+ prob = neighbor.calculate_engagement_probability(content)
51
+ if prob > 0.7: # Higher threshold for viral spread
52
+ newly_activated.add(neighbor.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return newly_activated
54
 
55
  def calculate_influence_score(self, persona_id: str) -> float:
56
  """Calculate overall influence of a persona"""
57
+ if persona_id not in self.network.nodes: return 0.0
58
+ # Combine: centrality, follower quality
59
+ return 0.5
tinytroupe/integrations/linkedin_api.py CHANGED
@@ -1,27 +1,28 @@
1
- import requests
2
- from datetime import datetime
3
  from typing import List, Dict, Any, Optional
4
- from dataclasses import dataclass
5
-
6
- @dataclass
7
- class LinkedInProfile:
8
- id: str
9
- first_name: str
10
- last_name: str
11
- headline: str
12
- email: str
13
- profile_picture: Dict[str, Any]
14
 
15
  class LinkedInAPI:
16
- """LinkedIn API client for fetching user data"""
 
17
  def __init__(self, access_token: str):
18
  self.access_token = access_token
19
- self.base_url = "https://api.linkedin.com/v2"
20
-
21
- def get_user_profile(self) -> LinkedInProfile:
22
- # Placeholder for real API call
23
- return LinkedInProfile(id="123", first_name="John", last_name="Doe", headline="Software Engineer", email="john@example.com", profile_picture={})
 
 
 
 
 
 
 
 
 
24
 
25
- def get_connections(self, count: int = 100) -> List[Dict]:
26
- # Placeholder
27
- return [{"id": str(i), "localizedFirstName": f"Friend{i}"} for i in range(10)]
 
 
 
 
 
1
  from typing import List, Dict, Any, Optional
2
+ from datetime import datetime
 
 
 
 
 
 
 
 
 
3
 
4
  class LinkedInAPI:
5
+ """LinkedIn API client placeholder"""
6
+
7
  def __init__(self, access_token: str):
8
  self.access_token = access_token
9
+
10
+ def get_user_profile(self) -> Dict[str, Any]:
11
+ return {
12
+ "id": "me",
13
+ "first_name": "Sample",
14
+ "last_name": "User",
15
+ "headline": "Software Engineer"
16
+ }
17
+
18
+ def get_connections(self, count: int = 10) -> List[Dict[str, Any]]:
19
+ return [
20
+ {"id": f"conn_{i}", "headline": f"Professional {i}", "industry": "Tech"}
21
+ for i in range(count)
22
+ ]
23
 
24
+ def get_user_posts(self, count: int = 5) -> List[Dict[str, Any]]:
25
+ return [
26
+ {"id": f"post_{i}", "text": f"Sample post content {i}"}
27
+ for i in range(count)
28
+ ]
tinytroupe/integrations/linkedin_audience.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+ from tinytroupe.integrations.linkedin_api import LinkedInAPI
3
+ from tinytroupe.agent import TinyPerson
4
+ from tinytroupe.factory.tiny_person_factory import TinyPersonFactory
5
+
6
+ class LinkedInAudienceAnalyzer:
7
+ def __init__(self, linkedin_api: LinkedInAPI):
8
+ self.api = linkedin_api
9
+ self.factory = TinyPersonFactory()
10
+
11
+ def create_audience_personas(self, count: int = 10) -> List[TinyPerson]:
12
+ connections = self.api.get_connections(count=count)
13
+ personas = []
14
+ for conn in connections:
15
+ persona = self.factory.generate_from_linkedin_profile(conn)
16
+ personas.append(persona)
17
+ return personas
tinytroupe/llm_predictor.py CHANGED
@@ -1,27 +1,55 @@
1
  import json
2
- from typing import Dict, Any
3
- from tinytroupe.agent.tiny_person import TinyPerson
4
- from tinytroupe.agent_types import Content
5
- from tinytroupe import openai_utils
6
 
7
  class LLMPredictor:
8
  """Use LLM reasoning for engagement prediction"""
9
- def __init__(self, model: str = "gpt-4o"):
 
10
  self.model = model
11
-
12
- def predict(self, persona: TinyPerson, content: Content) -> Dict[str, Any]:
13
  """Use LLM to predict engagement"""
14
- prompt = self._construct_prediction_prompt(persona, content)
15
- # Placeholder for LLM call
16
- # message = openai_utils.client().send_message(...)
17
 
18
- return {
19
- "will_engage": True,
20
- "probability": 0.75,
21
- "reasoning": "Content aligns well with persona's professional interests.",
22
- "reaction_type": "like",
23
- "comment": "Great insights on the industry!"
24
- }
 
 
25
 
26
- def _construct_prediction_prompt(self, persona: TinyPerson, content: Content) -> str:
27
- return f"Predict reaction for {persona.name} to content: {content.text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ from typing import Dict, Any, Optional
3
+ from tinytroupe.agent import TinyPerson
4
+ from tinytroupe.agent.social_types import Content, Reaction
5
+ import tinytroupe.openai_utils as openai_utils
6
 
7
  class LLMPredictor:
8
  """Use LLM reasoning for engagement prediction"""
9
+
10
+ def __init__(self, model: str = "gpt-4"):
11
  self.model = model
12
+
13
+ def predict(self, persona: TinyPerson, content: Content) -> Reaction:
14
  """Use LLM to predict engagement"""
 
 
 
15
 
16
+ prompt = f"""
17
+ You are predicting how a specific persona will react to content on a professional social network.
18
+
19
+ PERSONA PROFILE:
20
+ Name: {persona.name}
21
+ Bio: {persona.minibio()}
22
+
23
+ CONTENT TO EVALUATE:
24
+ {content.text}
25
 
26
+ TASK:
27
+ Analyze whether this persona would engage with this content.
28
+ Provide your prediction in JSON format:
29
+ {{
30
+ "will_engage": true/false,
31
+ "probability": 0.0-1.0,
32
+ "reasoning": "detailed explanation",
33
+ "reaction_type": "like|comment|share|none",
34
+ "comment": "predicted comment text if applicable"
35
+ }}
36
+ """
37
+
38
+ response = openai_utils.client().send_message(
39
+ [
40
+ {"role": "system", "content": "You are an expert in social psychology and behavioral prediction."},
41
+ {"role": "user", "content": prompt}
42
+ ],
43
+ temperature=0.3,
44
+ response_format={"type": "json_object"}
45
+ )
46
+
47
+ prediction = json.loads(response["content"])
48
+
49
+ return Reaction(
50
+ will_engage=prediction["will_engage"],
51
+ probability=prediction["probability"],
52
+ reasoning=prediction["reasoning"],
53
+ reaction_type=prediction["reaction_type"],
54
+ comment=prediction.get("comment")
55
+ )
tinytroupe/ml_models.py CHANGED
@@ -1,60 +1,37 @@
1
  from typing import List, Dict, Any, Optional
2
  import numpy as np
3
- import random
4
- from dataclasses import dataclass
5
- from tinytroupe.agent.tiny_person import TinyPerson
6
- from tinytroupe.agent_types import Content, Reaction
7
  from tinytroupe.social_network import NetworkTopology
8
- from tinytroupe.features import FeatureExtractor
9
-
10
- @dataclass
11
- class TrainingExample:
12
- persona: TinyPerson
13
- content: Content
14
- network: NetworkTopology
15
- engaged: bool
16
- engagement_type: str = "none"
17
-
18
- @dataclass
19
- class PredictionResult:
20
- engagement_probability: float
21
- engagement_type_probs: Dict[str, float]
22
- predicted_reaction: str
23
- confidence: float
24
 
25
  class EngagementPredictor:
26
  """Predicts whether persona will engage with content"""
 
27
  def __init__(self):
28
- self.model = None
29
- self.extractor = FeatureExtractor()
30
-
 
 
31
  def predict(self, persona: TinyPerson, content: Content, network: NetworkTopology) -> float:
32
  """Predict engagement probability"""
33
- # Placeholder for real model inference
34
- # In a real system, we'd use self.model.predict_proba()
35
- features = self.extractor.extract_all(persona, content, network)
36
- # Dummy logic based on feature sum
37
- score = np.mean(features)
38
- return min(max(score, 0.0), 1.0)
39
-
40
- class EnsemblePredictor:
41
- """Combines multiple predictors for robust predictions"""
42
- def __init__(self):
43
- self.engagement_predictor = EngagementPredictor()
44
 
45
- def predict(self, persona: TinyPerson, content: Content, network: NetworkTopology) -> PredictionResult:
46
- prob = self.engagement_predictor.predict(persona, content, network)
47
 
48
- reaction_types = ["like", "comment", "share"]
49
- type_probs = {rt: prob * (1.0 / len(reaction_types)) for rt in reaction_types}
 
 
 
50
 
51
- predicted_reaction = "none"
52
- if prob > 0.5:
53
- predicted_reaction = random.choice(reaction_types)
54
-
55
- return PredictionResult(
56
- engagement_probability=prob,
57
- engagement_type_probs=type_probs,
58
- predicted_reaction=predicted_reaction,
59
- confidence=0.8
60
- )
 
1
  from typing import List, Dict, Any, Optional
2
  import numpy as np
3
+ from tinytroupe.agent import TinyPerson
4
+ from tinytroupe.agent.social_types import Content, Reaction
 
 
5
  from tinytroupe.social_network import NetworkTopology
6
+ from tinytroupe.features import ContentFeatureExtractor, PersonaFeatureExtractor, InteractionFeatureExtractor
7
+ from tinytroupe.agent.agent_traits import TraitBasedBehaviorModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class EngagementPredictor:
10
  """Predicts whether persona will engage with content"""
11
+
12
  def __init__(self):
13
+ self.content_extractor = ContentFeatureExtractor()
14
+ self.persona_extractor = PersonaFeatureExtractor()
15
+ self.interaction_extractor = InteractionFeatureExtractor()
16
+ self.trait_model = TraitBasedBehaviorModel()
17
+
18
  def predict(self, persona: TinyPerson, content: Content, network: NetworkTopology) -> float:
19
  """Predict engagement probability"""
20
+ content_features = self.content_extractor.extract(content)
21
+ persona_features = self.persona_extractor.extract(persona)
22
+ interaction_features = self.interaction_extractor.extract(persona, content, network)
 
 
 
 
 
 
 
 
23
 
24
+ # Get base probability from trait-based model
25
+ trait_prob = self.trait_model.compute_action_probability(persona, "engage", content)
26
 
27
+ # Combine with other signals
28
+ prob = (trait_prob * 0.5 +
29
+ interaction_features["topic_alignment"] * 0.3 +
30
+ persona_features["engagement_rate"] * 0.1 +
31
+ content_features["sentiment_score"] * 0.1)
32
 
33
+ return max(0.0, min(1.0, prob))
34
+
35
+ class ViralityPredictor:
36
+ def predict_cascade_size(self, content: Content, seed_personas: List[str], network: NetworkTopology) -> int:
37
+ return len(seed_personas) * 2 # Placeholder
 
 
 
 
 
tinytroupe/network_generator.py CHANGED
@@ -1,106 +1,63 @@
 
1
  import random
2
- from typing import List, Dict
3
- from tinytroupe.social_network import NetworkTopology, Community
4
- from tinytroupe.agent.tiny_person import TinyPerson
5
 
6
  class NetworkGenerator:
7
- """
8
- Implements realistic network topologies.
9
- """
10
-
11
- @staticmethod
12
- def generate_scale_free_network(personas: List[TinyPerson], m: int = 2) -> NetworkTopology:
13
- """
14
- Barabási-Albert model for scale-free networks.
15
- """
16
- topology = NetworkTopology()
17
- for p in personas:
18
- topology.add_persona(p)
19
-
20
- names = [p.name for p in personas]
21
- if len(names) <= m:
22
- return topology
23
 
24
- # Initial complete graph of m nodes
25
- for i in range(m):
26
- for j in range(i + 1, m):
27
- topology.add_connection(names[i], names[j])
28
-
29
- # Add remaining nodes with preferential attachment
30
- for i in range(m, len(names)):
31
- targets = set()
32
- existing_nodes = names[:i]
33
- # Simple preferential attachment based on degree
34
- while len(targets) < m:
35
- # Degree of each node
36
- degrees = {name: 0 for name in existing_nodes}
37
- for edge in topology.edges:
38
- if edge.source_id in degrees: degrees[edge.source_id] += 1
39
- if edge.target_id in degrees: degrees[edge.target_id] += 1
40
-
41
- total_degree = sum(degrees.values())
42
- if total_degree == 0:
43
- target = random.choice(existing_nodes)
44
- else:
45
- probs = [degrees[name] / total_degree for name in existing_nodes]
46
- target = random.choices(existing_nodes, weights=probs)[0]
47
- targets.add(target)
48
-
49
  for target in targets:
50
- topology.add_connection(names[i], target)
51
 
52
- return topology
53
 
54
- @staticmethod
55
- def generate_small_world_network(personas: List[TinyPerson], k: int = 4, p: float = 0.1) -> NetworkTopology:
56
- """
57
- Watts-Strogatz model for small-world networks.
58
- """
59
- topology = NetworkTopology()
60
- for persona in personas:
61
- topology.add_persona(persona)
62
 
63
- names = [p.name for p in personas]
64
- n = len(names)
65
-
66
- # Regular ring lattice
67
- for i in range(n):
68
- for j in range(1, k // 2 + 1):
69
- neighbor = names[(i + j) % n]
70
- topology.add_connection(names[i], neighbor)
71
-
72
- # Rewiring
73
- for i in range(n):
74
  for j in range(1, k // 2 + 1):
75
- if random.random() < p:
76
- # Remove old connection and add a random one
77
- old_neighbor = names[(i + j) % n]
78
- topology.remove_connection(names[i], old_neighbor)
79
- new_neighbor = random.choice(names)
80
- while new_neighbor == names[i] or any(e.source_id == names[i] and e.target_id == new_neighbor for e in topology.edges):
81
- new_neighbor = random.choice(names)
82
- topology.add_connection(names[i], new_neighbor)
83
-
84
- return topology
85
 
86
- @staticmethod
87
- def generate_professional_network(personas: List[TinyPerson]) -> NetworkTopology:
88
- """
89
- LinkedIn-style network based on professional attributes.
90
- """
91
- topology = NetworkTopology()
92
  for p in personas:
93
- topology.add_persona(p)
94
 
95
  for i, p1 in enumerate(personas):
96
- for j in range(i + 1, len(personas)):
97
- p2 = personas[j]
98
- # Probabilistic connection based on similarity
99
- prob = 0.05
100
- if p1.get("occupation") == p2.get("occupation"): prob += 0.2
101
- if p1.get("residence") == p2.get("residence"): prob += 0.1
102
-
103
- if random.random() < prob:
104
- topology.add_connection(p1.name, p2.name, relationship_type="colleague")
105
-
106
- return topology
 
1
+ from typing import List, Dict, Any
2
  import random
3
+ from tinytroupe.social_network import NetworkTopology
4
+ from tinytroupe.agent import TinyPerson
 
5
 
6
  class NetworkGenerator:
7
+ def __init__(self, personas: List[TinyPerson]):
8
+ self.personas = personas
9
+
10
+ def generate_scale_free_network(self, n: int, m: int) -> NetworkTopology:
11
+ """Barabási-Albert model"""
12
+ network = NetworkTopology()
13
+ for p in self.personas:
14
+ network.add_persona(p)
 
 
 
 
 
 
 
 
15
 
16
+ # Simplified BA model
17
+ # For each new node, connect it to m existing nodes with probability proportional to their degree
18
+ # For now, a very simple version
19
+ persona_names = [p.name for p in self.personas]
20
+ for i, name in enumerate(persona_names):
21
+ if i == 0: continue
22
+ # Connect to some random previous nodes
23
+ targets = random.sample(persona_names[:i], min(i, m))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  for target in targets:
25
+ network.add_connection(name, target, strength=random.random())
26
 
27
+ return network
28
 
29
+ def generate_small_world_network(self, n: int, k: int, p: float) -> NetworkTopology:
30
+ """Watts-Strogatz model"""
31
+ network = NetworkTopology()
32
+ for persona in self.personas:
33
+ network.add_persona(persona)
 
 
 
34
 
35
+ # Simplified WS model
36
+ persona_names = [p.name for p in self.personas]
37
+ num_nodes = len(persona_names)
38
+ for i in range(num_nodes):
 
 
 
 
 
 
 
39
  for j in range(1, k // 2 + 1):
40
+ target = persona_names[(i + j) % num_nodes]
41
+ network.add_connection(persona_names[i], target, strength=random.random())
42
+
43
+ # Rewiring...
44
+ return network
45
+
46
+ def generate_community_network(self, num_communities: int, community_sizes: List[int]) -> NetworkTopology:
47
+ network = NetworkTopology()
48
+ # ...
49
+ return network
50
 
51
+ def generate_professional_network(self, personas: List[TinyPerson]) -> NetworkTopology:
52
+ """LinkedIn-style network based on industry, company, role"""
53
+ network = NetworkTopology()
 
 
 
54
  for p in personas:
55
+ network.add_persona(p)
56
 
57
  for i, p1 in enumerate(personas):
58
+ for p2 in personas[i+1:]:
59
+ # Connect if same industry or similar roles
60
+ if p1._persona.get("occupation") == p2._persona.get("occupation"):
61
+ if random.random() < 0.3:
62
+ network.add_connection(p1.name, p2.name, strength=random.random(), relationship_type="colleague")
63
+ return network
 
 
 
 
 
tinytroupe/openai_utils.py CHANGED
@@ -31,6 +31,8 @@ class OpenAIClient:
31
  def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None:
32
  logger.debug("Initializing OpenAIClient")
33
 
 
 
34
  # should we cache api calls and reuse them?
35
  self.set_api_cache(cache_api_calls, cache_file_name)
36
 
@@ -52,7 +54,8 @@ class OpenAIClient:
52
  """
53
  Sets up the OpenAI API configurations for this client.
54
  """
55
- self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
56
 
57
  @config_manager.config_defaults(
58
  model="model",
@@ -156,14 +159,33 @@ class OpenAIClient:
156
  chat_api_params["response_format"] = response_format
157
 
158
  i = 0
159
- while i < max_attempts:
160
  try:
161
  i += 1
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  try:
164
- logger.debug(f"Sending messages to OpenAI API. Token count={self._count_tokens(current_messages, model)}.")
165
  except NotImplementedError:
166
- logger.debug(f"Token count not implemented for model {model}.")
167
 
168
  start_time = time.monotonic()
169
  logger.debug(f"Calling model with client class {self.__class__.__name__}.")
@@ -171,15 +193,11 @@ class OpenAIClient:
171
  ###############################################################
172
  # call the model, either from the cache or from the API
173
  ###############################################################
174
- cache_key = str((model, chat_api_params)) # need string to be hashable
175
  if self.cache_api_calls and (cache_key in self.api_cache):
176
  response = self.api_cache[cache_key]
177
  else:
178
- if waiting_time > 0:
179
- logger.info(f"Waiting {waiting_time} seconds before next API request (to avoid throttling)...")
180
- time.sleep(waiting_time)
181
-
182
- response = self._raw_model_call(model, chat_api_params)
183
  if self.cache_api_calls:
184
  self.api_cache[cache_key] = response
185
  self._save_cache()
@@ -195,35 +213,21 @@ class OpenAIClient:
195
  else:
196
  return utils.sanitize_dict(self._raw_model_response_extractor(response))
197
 
198
- except InvalidRequestError as e:
199
- logger.error(f"[{i}] Invalid request error, won't retry: {e}")
200
-
201
- # there's no point in retrying if the request is invalid
202
- # so we return None right away
203
- return None
204
-
205
- except openai.BadRequestError as e:
206
  logger.error(f"[{i}] Invalid request error, won't retry: {e}")
207
-
208
- # there's no point in retrying if the request is invalid
209
- # so we return None right away
210
  return None
211
 
212
- except openai.RateLimitError:
213
- logger.warning(
214
- f"[{i}] Rate limit error, waiting a bit and trying again.")
215
- aux_exponential_backoff()
216
-
217
- except NonTerminalError as e:
218
- logger.error(f"[{i}] Non-terminal error: {e}")
219
- aux_exponential_backoff()
220
-
221
- except Exception as e:
222
- logger.error(f"[{i}] {type(e).__name__} Error: {e}")
223
- aux_exponential_backoff()
224
-
225
- logger.error(f"Failed to get response after {max_attempts} attempts.")
226
- return None
227
 
228
  def _raw_model_call(self, model, chat_api_params):
229
  """
@@ -246,8 +250,12 @@ class OpenAIClient:
246
  chat_api_params["reasoning_effort"] = default["reasoning_effort"]
247
 
248
 
249
- # To make the log cleaner, we remove the messages from the logged parameters
250
- logged_params = {k: v for k, v in chat_api_params.items() if k != "messages"}
 
 
 
 
251
 
252
  if "response_format" in chat_api_params:
253
  # to enforce the response format via pydantic, we need to use a different method
@@ -396,22 +404,23 @@ class AzureClient(OpenAIClient):
396
  Sets up the Azure OpenAI Service API configurations for this client,
397
  including the API endpoint and key.
398
  """
399
- if os.getenv("AZURE_OPENAI_KEY"):
400
- logger.info("Using Azure OpenAI Service API with key.")
401
- self.client = AzureOpenAI(azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
402
- api_version = config["OpenAI"]["AZURE_API_VERSION"],
403
- api_key = os.getenv("AZURE_OPENAI_KEY"))
404
- else: # Use Entra ID Auth
405
- logger.info("Using Azure OpenAI Service API with Entra ID Auth.")
406
- from azure.identity import DefaultAzureCredential, get_bearer_token_provider
407
-
408
- credential = DefaultAzureCredential()
409
- token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
410
- self.client = AzureOpenAI(
411
- azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
412
- api_version = config["OpenAI"]["AZURE_API_VERSION"],
413
- azure_ad_token_provider=token_provider
414
- )
 
415
 
416
 
417
  class HelmholtzBlabladorClient(OpenAIClient):
@@ -424,10 +433,11 @@ class HelmholtzBlabladorClient(OpenAIClient):
424
  """
425
  Sets up the Helmholtz Blablador API configurations for this client.
426
  """
427
- self.client = OpenAI(
428
- base_url="https://api.helmholtz-blablador.fz-juelich.de/v1",
429
- api_key=os.getenv("BLABLADOR_API_KEY", "dummy"),
430
- )
 
431
 
432
  ###########################################################################
433
  # Exceptions
 
31
  def __init__(self, cache_api_calls=default["cache_api_calls"], cache_file_name=default["cache_file_name"]) -> None:
32
  logger.debug("Initializing OpenAIClient")
33
 
34
+ self.client = None
35
+
36
  # should we cache api calls and reuse them?
37
  self.set_api_cache(cache_api_calls, cache_file_name)
38
 
 
54
  """
55
  Sets up the OpenAI API configurations for this client.
56
  """
57
+ if self.client is None:
58
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
59
 
60
  @config_manager.config_defaults(
61
  model="model",
 
159
  chat_api_params["response_format"] = response_format
160
 
161
  i = 0
162
+ while True:
163
  try:
164
  i += 1
165
 
166
+ #
167
+ # Model fallback and retry strategy requested by the user:
168
+ # 1. alias-fast for 3 attempts, 35s wait
169
+ # 2. alias-large for 2 attempts, 35s wait
170
+ # 3. alias-huge until success, 60s wait
171
+ #
172
+ # Model fallback strategy using config
173
+ if i <= 3:
174
+ current_model = config["OpenAI"].get("MODEL", "alias-fast")
175
+ current_wait_time = 35
176
+ elif i <= 5:
177
+ current_model = config["OpenAI"].get("FALLBACK_MODEL_LARGE", "alias-large")
178
+ current_wait_time = 35
179
+ else:
180
+ current_model = config["OpenAI"].get("FALLBACK_MODEL_HUGE", "alias-huge")
181
+ current_wait_time = 60
182
+
183
+ chat_api_params["model"] = current_model
184
+
185
  try:
186
+ logger.debug(f"Sending messages to OpenAI API. Model={current_model}. Token count={self._count_tokens(current_messages, current_model)}.")
187
  except NotImplementedError:
188
+ logger.debug(f"Token count not implemented for model {current_model}.")
189
 
190
  start_time = time.monotonic()
191
  logger.debug(f"Calling model with client class {self.__class__.__name__}.")
 
193
  ###############################################################
194
  # call the model, either from the cache or from the API
195
  ###############################################################
196
+ cache_key = str((current_model, chat_api_params)) # need string to be hashable
197
  if self.cache_api_calls and (cache_key in self.api_cache):
198
  response = self.api_cache[cache_key]
199
  else:
200
+ response = self._raw_model_call(current_model, chat_api_params)
 
 
 
 
201
  if self.cache_api_calls:
202
  self.api_cache[cache_key] = response
203
  self._save_cache()
 
213
  else:
214
  return utils.sanitize_dict(self._raw_model_response_extractor(response))
215
 
216
+ except (InvalidRequestError, openai.BadRequestError) as e:
 
 
 
 
 
 
 
217
  logger.error(f"[{i}] Invalid request error, won't retry: {e}")
 
 
 
218
  return None
219
 
220
+ except (openai.RateLimitError,
221
+ openai.APITimeoutError,
222
+ openai.APIConnectionError,
223
+ openai.InternalServerError,
224
+ NonTerminalError,
225
+ Exception) as e:
226
+ msg = f"[{i}] {type(e).__name__} Error with {current_model}: {e}. Waiting {current_wait_time} seconds before next attempt..."
227
+ logger.warning(msg)
228
+
229
+ time.sleep(current_wait_time)
230
+ continue
 
 
 
 
231
 
232
  def _raw_model_call(self, model, chat_api_params):
233
  """
 
250
  chat_api_params["reasoning_effort"] = default["reasoning_effort"]
251
 
252
 
253
+ # To make the log cleaner, we remove the messages from the logged parameters,
254
+ # unless we are in debug mode
255
+ if logger.getEffectiveLevel() <= logging.DEBUG:
256
+ logged_params = chat_api_params
257
+ else:
258
+ logged_params = {k: v for k, v in chat_api_params.items() if k != "messages"}
259
 
260
  if "response_format" in chat_api_params:
261
  # to enforce the response format via pydantic, we need to use a different method
 
404
  Sets up the Azure OpenAI Service API configurations for this client,
405
  including the API endpoint and key.
406
  """
407
+ if self.client is None:
408
+ if os.getenv("AZURE_OPENAI_KEY"):
409
+ logger.info("Using Azure OpenAI Service API with key.")
410
+ self.client = AzureOpenAI(azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
411
+ api_version = config["OpenAI"]["AZURE_API_VERSION"],
412
+ api_key = os.getenv("AZURE_OPENAI_KEY"))
413
+ else: # Use Entra ID Auth
414
+ logger.info("Using Azure OpenAI Service API with Entra ID Auth.")
415
+ from azure.identity import DefaultAzureCredential, get_bearer_token_provider
416
+
417
+ credential = DefaultAzureCredential()
418
+ token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
419
+ self.client = AzureOpenAI(
420
+ azure_endpoint= os.getenv("AZURE_OPENAI_ENDPOINT"),
421
+ api_version = config["OpenAI"]["AZURE_API_VERSION"],
422
+ azure_ad_token_provider=token_provider
423
+ )
424
 
425
 
426
  class HelmholtzBlabladorClient(OpenAIClient):
 
433
  """
434
  Sets up the Helmholtz Blablador API configurations for this client.
435
  """
436
+ if self.client is None:
437
+ self.client = OpenAI(
438
+ base_url="https://api.helmholtz-blablador.fz-juelich.de/v1",
439
+ api_key=os.getenv("BLABLADOR_API_KEY", "dummy"),
440
+ )
441
 
442
  ###########################################################################
443
  # Exceptions
tinytroupe/simulation_manager.py CHANGED
@@ -1,21 +1,21 @@
1
  from typing import List, Dict, Any, Optional
 
 
2
  from datetime import datetime
3
- import hashlib
4
- import json
5
- from tinytroupe.agent.tiny_person import TinyPerson
6
  from tinytroupe.social_network import NetworkTopology
7
- from tinytroupe.environment.social_world import SocialTinyWorld, SimulationResult
8
- from tinytroupe.agent_types import Content
9
  from tinytroupe.ml_models import EngagementPredictor
10
  from tinytroupe.content_generation import ContentVariantGenerator
 
11
 
12
  class SimulationConfig:
13
- def __init__(self, name: str, persona_count: int = 10, network_type: str = "scale_free", use_linkedin_audience: bool = False, linkedin_token: str = None):
14
  self.name = name
15
  self.persona_count = persona_count
16
  self.network_type = network_type
17
- self.use_linkedin_audience = use_linkedin_audience
18
- self.linkedin_token = linkedin_token
19
 
20
  class Simulation:
21
  def __init__(self, id: str, config: SimulationConfig, world: SocialTinyWorld, personas: List[TinyPerson], network: NetworkTopology):
@@ -26,43 +26,162 @@ class Simulation:
26
  self.network = network
27
  self.status = "ready"
28
  self.created_at = datetime.now()
29
- self.last_result = None
 
 
30
 
31
  class SimulationManager:
32
  """Manages simulation lifecycle and execution"""
33
 
34
  def __init__(self):
35
  self.simulations: Dict[str, Simulation] = {}
 
36
  self.predictor = EngagementPredictor()
37
  self.variant_generator = ContentVariantGenerator()
38
 
39
- def create_simulation(self, config: SimulationConfig) -> Simulation:
40
- from tinytroupe.factory.tiny_person_factory import TinyPersonFactory
41
- factory = TinyPersonFactory()
42
- personas = factory.generate_people(number_of_people=config.persona_count)
 
 
 
 
 
 
43
 
44
- from tinytroupe.network_generator import NetworkGenerator
 
45
  if config.network_type == "scale_free":
46
- network = NetworkGenerator.generate_scale_free_network(personas)
47
  else:
48
- network = NetworkGenerator.generate_professional_network(personas)
49
-
50
- world = SocialTinyWorld(config.name, network)
51
- for p in personas: world.add_agent(p)
 
 
52
 
53
- sim_id = hashlib.md5(f"{config.name}{datetime.now()}".encode()).hexdigest()
54
- sim = Simulation(sim_id, config, world, personas, network)
55
- self.simulations[sim_id] = sim
56
- return sim
57
 
58
- def run_simulation(self, sim_id: str, content_text: str) -> SimulationResult:
59
- sim = self.simulations[sim_id]
60
- sim.status = "running"
61
- content = Content(text=content_text, content_type="post", topics=[], length=len(content_text), tone="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- initial_viewers = [p.name for p in sim.personas[:min(5, len(sim.personas))]]
64
- result = sim.world.simulate_content_spread(content, initial_viewers)
65
 
66
- sim.status = "completed"
67
- sim.last_result = result
 
68
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List, Dict, Any, Optional
2
+ import uuid
3
+ import threading
4
  from datetime import datetime
5
+ from tinytroupe.agent import TinyPerson
 
 
6
  from tinytroupe.social_network import NetworkTopology
7
+ from tinytroupe.environment.social_tiny_world import SocialTinyWorld, SimulationResult
8
+ from tinytroupe.agent.social_types import Content
9
  from tinytroupe.ml_models import EngagementPredictor
10
  from tinytroupe.content_generation import ContentVariantGenerator
11
+ from tinytroupe.network_generator import NetworkGenerator
12
 
13
  class SimulationConfig:
14
+ def __init__(self, name: str, persona_count: int = 10, network_type: str = "scale_free", **kwargs):
15
  self.name = name
16
  self.persona_count = persona_count
17
  self.network_type = network_type
18
+ self.user_id = kwargs.get("user_id")
 
19
 
20
  class Simulation:
21
  def __init__(self, id: str, config: SimulationConfig, world: SocialTinyWorld, personas: List[TinyPerson], network: NetworkTopology):
 
26
  self.network = network
27
  self.status = "ready"
28
  self.created_at = datetime.now()
29
+ self.last_result: Optional[SimulationResult] = None
30
+ self.chat_history: List[Dict[str, Any]] = []
31
+ self.progress = 0.0
32
 
33
  class SimulationManager:
34
  """Manages simulation lifecycle and execution"""
35
 
36
  def __init__(self):
37
  self.simulations: Dict[str, Simulation] = {}
38
+ self.focus_groups: Dict[str, List[TinyPerson]] = {}
39
  self.predictor = EngagementPredictor()
40
  self.variant_generator = ContentVariantGenerator()
41
 
42
+ def create_simulation(self, config: SimulationConfig, focus_group_name: str = None) -> Simulation:
43
+ if focus_group_name and focus_group_name in self.focus_groups:
44
+ personas = self.focus_groups[focus_group_name]
45
+ else:
46
+ from tinytroupe.factory.tiny_person_factory import TinyPersonFactory
47
+ factory = TinyPersonFactory(
48
+ context=config.name,
49
+ total_population_size=config.persona_count
50
+ )
51
+ personas = factory.generate_people(number_of_people=config.persona_count)
52
 
53
+ # Generate network
54
+ net_gen = NetworkGenerator(personas)
55
  if config.network_type == "scale_free":
56
+ network = net_gen.generate_scale_free_network(config.persona_count, 2)
57
  else:
58
+ network = net_gen.generate_small_world_network(config.persona_count, 4, 0.1)
59
+
60
+ # Create world
61
+ world = SocialTinyWorld(config.name, network=network)
62
+ for persona in personas:
63
+ world.add_agent(persona)
64
 
65
+ sim_id = str(uuid.uuid4())
66
+ simulation = Simulation(sim_id, config, world, personas, network)
67
+ self.simulations[sim_id] = simulation
68
+ return simulation
69
 
70
+ def run_simulation(self, simulation_id: str, content: Content, mode: str = "full", background: bool = False) -> Optional[SimulationResult]:
71
+ if simulation_id not in self.simulations:
72
+ raise ValueError(f"Simulation {simulation_id} not found.")
73
+
74
+ simulation = self.simulations[simulation_id]
75
+
76
+ if background:
77
+ thread = threading.Thread(target=self._run_simulation_task, args=(simulation, content))
78
+ thread.start()
79
+ return None
80
+ else:
81
+ return self._run_simulation_task(simulation, content)
82
+
83
+ def _run_simulation_task(self, simulation: Simulation, content: Content) -> SimulationResult:
84
+ simulation.status = "running"
85
+ simulation.progress = 0.1
86
+
87
+ initial_viewers = [p.name for p in simulation.personas[:5]] # Seed with first 5
88
 
89
+ # In a real async scenario, simulate_content_spread would update progress
90
+ result = simulation.world.simulate_content_spread(content, initial_viewers)
91
 
92
+ simulation.status = "completed"
93
+ simulation.progress = 1.0
94
+ simulation.last_result = result
95
  return result
96
+
97
+ def send_chat_message(self, simulation_id: str, sender: str, message: str) -> Dict[str, Any]:
98
+ sim = self.get_simulation(simulation_id)
99
+ if not sim: raise ValueError(f"Simulation {simulation_id} not found.")
100
+
101
+ msg = {
102
+ "sender": sender,
103
+ "message": message,
104
+ "timestamp": datetime.now().isoformat()
105
+ }
106
+ sim.chat_history.append(msg)
107
+
108
+ # Trigger persona responses if it's a "User" message
109
+ if sender == "User":
110
+ # For now, pick a random persona to respond
111
+ import random
112
+ responder = random.choice(sim.personas)
113
+ # In a real implementation, the persona would "think" and "act"
114
+ response_text = f"As a {responder._persona.get('occupation')}, I think: {message[:10]}... sounds interesting!"
115
+
116
+ response_msg = {
117
+ "sender": responder.name,
118
+ "message": response_text,
119
+ "timestamp": datetime.now().isoformat()
120
+ }
121
+ sim.chat_history.append(response_msg)
122
+
123
+ return msg
124
+
125
+ def get_chat_history(self, simulation_id: str) -> List[Dict[str, Any]]:
126
+ sim = self.get_simulation(simulation_id)
127
+ if not sim: return []
128
+ return sim.chat_history
129
+
130
+ def get_simulation(self, simulation_id: str, user_id: str = None) -> Optional[Simulation]:
131
+ return self.simulations.get(simulation_id)
132
+
133
+ def list_simulations(self) -> List[Dict[str, Any]]:
134
+ return [
135
+ {
136
+ "id": sim.id,
137
+ "name": sim.config.name,
138
+ "status": sim.status,
139
+ "persona_count": len(sim.personas),
140
+ "created_at": sim.created_at.isoformat()
141
+ }
142
+ for sim in self.simulations.values()
143
+ ]
144
+
145
+ def get_persona(self, simulation_id: str, persona_name: str) -> Optional[Dict[str, Any]]:
146
+ sim = self.get_simulation(simulation_id)
147
+ if not sim: return None
148
+ for p in sim.personas:
149
+ if p.name == persona_name:
150
+ return p._persona
151
+ return None
152
+
153
+ def list_personas(self, simulation_id: str) -> List[Dict[str, Any]]:
154
+ sim = self.get_simulation(simulation_id)
155
+ if not sim: return []
156
+ return [p._persona for p in sim.personas]
157
+
158
+ def save_focus_group(self, name: str, personas: List[TinyPerson]):
159
+ self.focus_groups[name] = personas
160
+
161
+ def list_focus_groups(self) -> List[str]:
162
+ return list(self.focus_groups.keys())
163
+
164
+ def get_focus_group(self, name: str) -> Optional[List[TinyPerson]]:
165
+ return self.focus_groups.get(name)
166
+
167
+ def delete_simulation(self, simulation_id: str) -> bool:
168
+ if simulation_id in self.simulations:
169
+ del self.simulations[simulation_id]
170
+ return True
171
+ return False
172
+
173
+ def export_simulation(self, simulation_id: str) -> Optional[Dict[str, Any]]:
174
+ sim = self.get_simulation(simulation_id)
175
+ if not sim: return None
176
+ return {
177
+ "id": sim.id,
178
+ "config": {
179
+ "name": sim.config.name,
180
+ "persona_count": sim.config.persona_count,
181
+ "network_type": sim.config.network_type
182
+ },
183
+ "status": sim.status,
184
+ "created_at": sim.created_at.isoformat(),
185
+ "personas": [p._persona for p in sim.personas],
186
+ "network": sim.network.get_metrics()
187
+ }
tinytroupe/social_network.py CHANGED
@@ -1,81 +1,72 @@
1
- from dataclasses import dataclass, field
2
- from typing import List, Dict, Optional, Any, Set, Tuple
3
- from datetime import datetime
4
  import numpy as np
5
- from tinytroupe.agent.tiny_person import TinyPerson
6
-
7
- @dataclass
8
- class Connection:
9
- """Represents a connection between two personas"""
10
- source_id: str
11
- target_id: str
12
- strength: float = 0.5 # 0.0-1.0
13
- relationship_type: str = "follower" # "follower", "friend", "colleague", "family"
14
- interaction_frequency: float = 0.0 # interactions per week
15
- last_interaction: Optional[datetime] = None
16
- influence_score: float = 0.0 # how much target influences source
17
- created_at: datetime = field(default_factory=datetime.now)
18
 
19
- @dataclass
20
  class Community:
21
  """Represents a cluster of closely connected personas"""
22
- community_id: str
23
- members: List[str] # persona_ids
24
- density: float = 0.0
25
- central_personas: List[str] = field(default_factory=list) # most influential in community
26
- shared_interests: List[str] = field(default_factory=list)
27
- avg_engagement_rate: float = 0.0
 
28
 
29
  class NetworkTopology:
30
  """Represents the entire social network structure"""
31
  def __init__(self):
32
- self.nodes: Dict[str, TinyPerson] = {} # persona_id -> persona
33
- self.edges: List[Connection] = []
34
  self.adjacency_matrix: Optional[np.ndarray] = None
35
  self.influence_matrix: Optional[np.ndarray] = None
36
  self.communities: List[Community] = []
37
-
38
  def add_persona(self, persona: TinyPerson) -> None:
39
  self.nodes[persona.name] = persona
40
-
41
- def add_connection(self, source_id: str, target_id: str, **kwargs) -> Connection:
42
- conn = Connection(source_id=source_id, target_id=target_id, **kwargs)
43
- self.edges.append(conn)
44
- return conn
 
 
 
 
 
 
45
 
46
  def remove_connection(self, source_id: str, target_id: str) -> None:
47
- self.edges = [e for e in self.edges if not (e.source_id == source_id and e.target_id == target_id)]
 
 
48
 
49
  def get_neighbors(self, persona_id: str, depth: int = 1) -> List[TinyPerson]:
50
- # Simple BFS for neighbors
51
- neighbors = set()
52
- queue = [(persona_id, 0)]
53
- visited = {persona_id}
54
 
55
- while queue:
56
- curr_id, curr_depth = queue.pop(0)
57
- if curr_depth >= depth: continue
 
58
 
59
- for edge in self.edges:
60
- if edge.source_id == curr_id and edge.target_id not in visited:
61
- neighbors.add(edge.target_id)
62
- visited.add(edge.target_id)
63
- queue.append((edge.target_id, curr_depth + 1))
64
- elif edge.target_id == curr_id and edge.source_id not in visited:
65
- neighbors.add(edge.source_id)
66
- visited.add(edge.source_id)
67
- queue.append((edge.source_id, curr_depth + 1))
68
-
69
- return [self.nodes[nid] for nid in neighbors if nid in self.nodes]
70
 
71
  def calculate_centrality_metrics(self) -> Dict[str, float]:
72
- # Placeholder for real centrality (e.g. using NetworkX in analysis module)
73
- metrics = {name: 0.0 for name in self.nodes}
74
- for edge in self.edges:
75
- metrics[edge.source_id] += 1
76
- metrics[edge.target_id] += 1
77
- return metrics
78
 
79
  def detect_communities(self) -> List[Community]:
80
- # Placeholder
81
  return self.communities
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Any, Set, Tuple
 
 
2
  import numpy as np
3
+ from datetime import datetime
4
+ from tinytroupe.agent import TinyPerson
5
+ from tinytroupe.agent.social_types import ConnectionEdge
 
 
 
 
 
 
 
 
 
 
6
 
 
7
  class Community:
8
  """Represents a cluster of closely connected personas"""
9
+ def __init__(self, community_id: str, members: List[str]):
10
+ self.community_id = community_id
11
+ self.members = members
12
+ self.density: float = 0.0
13
+ self.central_personas: List[str] = []
14
+ self.shared_interests: List[str] = []
15
+ self.avg_engagement_rate: float = 0.0
16
 
17
  class NetworkTopology:
18
  """Represents the entire social network structure"""
19
  def __init__(self):
20
+ self.nodes: Dict[str, TinyPerson] = {} # persona_id -> persona
21
+ self.edges: List[ConnectionEdge] = []
22
  self.adjacency_matrix: Optional[np.ndarray] = None
23
  self.influence_matrix: Optional[np.ndarray] = None
24
  self.communities: List[Community] = []
25
+
26
  def add_persona(self, persona: TinyPerson) -> None:
27
  self.nodes[persona.name] = persona
28
+ # Update adjacency matrix if necessary
29
+
30
+ def add_connection(self, source_id: str, target_id: str, **kwargs) -> ConnectionEdge:
31
+ connection = ConnectionEdge(connection_id=f"{source_id}_{target_id}", **kwargs)
32
+ self.edges.append(connection)
33
+
34
+ # Also update the persona's internal social_connections
35
+ if source_id in self.nodes:
36
+ self.nodes[source_id].social_connections[target_id] = connection
37
+
38
+ return connection
39
 
40
  def remove_connection(self, source_id: str, target_id: str) -> None:
41
+ self.edges = [e for e in self.edges if not (e.connection_id == f"{source_id}_{target_id}")]
42
+ if source_id in self.nodes:
43
+ self.nodes[source_id].social_connections.pop(target_id, None)
44
 
45
  def get_neighbors(self, persona_id: str, depth: int = 1) -> List[TinyPerson]:
46
+ if depth <= 0: return []
 
 
 
47
 
48
+ neighbors = []
49
+ if persona_id in self.nodes:
50
+ neighbor_ids = list(self.nodes[persona_id].social_connections.keys())
51
+ neighbors = [self.nodes[nid] for nid in neighbor_ids if nid in self.nodes]
52
 
53
+ if depth > 1:
54
+ for nid in neighbor_ids:
55
+ neighbors.extend(self.get_neighbors(nid, depth - 1))
56
+
57
+ return list(set(neighbors))
 
 
 
 
 
 
58
 
59
  def calculate_centrality_metrics(self) -> Dict[str, float]:
60
+ # Placeholder for centrality calculation
61
+ return {name: 0.0 for name in self.nodes}
 
 
 
 
62
 
63
  def detect_communities(self) -> List[Community]:
64
+ # Placeholder for community detection
65
  return self.communities
66
+
67
+ def get_metrics(self) -> Dict[str, Any]:
68
+ return {
69
+ "num_nodes": len(self.nodes),
70
+ "num_edges": len(self.edges),
71
+ "density": len(self.edges) / (len(self.nodes) * (len(self.nodes) - 1)) if len(self.nodes) > 1 else 0
72
+ }
tinytroupe/utils/llm.py CHANGED
@@ -721,7 +721,7 @@ class LLMChat:
721
 
722
  def _request_list_of_dict_llm_message(self):
723
  return {"role": "user",
724
- "content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\{...\}, \{...\}, ...]`. This is critical for later processing."}
725
 
726
  def _coerce_to_list(self, llm_output:str):
727
  """
 
721
 
722
  def _request_list_of_dict_llm_message(self):
723
  return {"role": "user",
724
+ "content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\\{...\\}, \\{...\\}, ...]`. This is critical for later processing."}
725
 
726
  def _coerce_to_list(self, llm_output:str):
727
  """
tinytroupe/utils/semantics.py CHANGED
@@ -265,3 +265,24 @@ def compute_semantic_proximity(text1: str, text2: str, context: str = None) -> f
265
  """
266
  # llm decorator will handle the body of this function
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  """
266
  # llm decorator will handle the body of this function
267
 
268
+ @llm()
269
+ def select_best_persona(criteria: str, personas: list) -> int:
270
+ """
271
+ Given a set of criteria and a list of personas (each a dictionary),
272
+ select the index of the persona that best matches the criteria.
273
+ If no persona matches at all, return -1.
274
+
275
+ Rules:
276
+ - You must analyze each persona against the criteria.
277
+ - Return ONLY the integer index (starting from 0) of the best matching persona.
278
+ - Do not provide any explanation, just the number.
279
+ - If there are multiple good matches, pick the best one.
280
+
281
+ Args:
282
+ criteria (str): The search criteria or description of the desired persona.
283
+ personas (list): A list of dictionaries, where each dictionary is a persona specification.
284
+
285
+ Returns:
286
+ int: The index of the best matching persona, or -1 if none match.
287
+ """
288
+ # llm decorator will handle the body of this function
tinytroupe/variant_optimizer.py CHANGED
@@ -1,16 +1,15 @@
1
  from typing import List, Dict, Any
2
  import numpy as np
3
- from dataclasses import dataclass
4
  from tinytroupe.content_generation import ContentVariant
5
- from tinytroupe.agent.tiny_person import TinyPerson
 
6
  from tinytroupe.social_network import NetworkTopology
7
  from tinytroupe.ml_models import EngagementPredictor
8
 
9
- @dataclass
10
  class RankedVariant:
11
- variant: ContentVariant
12
- score: float
13
- predicted_engagement_count: int
14
 
15
  class VariantOptimizer:
16
  """Optimize and rank content variants"""
@@ -18,26 +17,20 @@ class VariantOptimizer:
18
  def __init__(self, predictor: EngagementPredictor):
19
  self.predictor = predictor
20
 
21
- def rank_variants(self, variants: List[ContentVariant],
22
- target_personas: List[TinyPerson],
23
- network: NetworkTopology) -> List[RankedVariant]:
24
  """Rank variants by predicted performance"""
25
  ranked = []
26
  for variant in variants:
27
- probs = []
28
- from tinytroupe.agent_types import Content
29
- content_obj = Content(text=variant.text, content_type="article", topics=[], length=len(variant.text), tone="")
30
-
31
  for persona in target_personas:
32
- prob = self.predictor.predict(persona, content_obj, network)
33
- probs.append(prob)
34
 
35
- avg_prob = np.mean(probs) if probs else 0.0
36
- ranked.append(RankedVariant(
37
- variant=variant,
38
- score=avg_prob,
39
- predicted_engagement_count=int(sum(probs))
40
- ))
41
 
42
  ranked.sort(key=lambda x: x.score, reverse=True)
43
  return ranked
 
1
  from typing import List, Dict, Any
2
  import numpy as np
 
3
  from tinytroupe.content_generation import ContentVariant
4
+ from tinytroupe.agent.social_types import Content
5
+ from tinytroupe.agent import TinyPerson
6
  from tinytroupe.social_network import NetworkTopology
7
  from tinytroupe.ml_models import EngagementPredictor
8
 
 
9
  class RankedVariant:
10
+ def __init__(self, variant: ContentVariant, score: float):
11
+ self.variant = variant
12
+ self.score = score
13
 
14
  class VariantOptimizer:
15
  """Optimize and rank content variants"""
 
17
  def __init__(self, predictor: EngagementPredictor):
18
  self.predictor = predictor
19
 
20
+ def rank_variants_for_audience(self, variants: List[ContentVariant],
21
+ target_personas: List[TinyPerson],
22
+ network: NetworkTopology) -> List[RankedVariant]:
23
  """Rank variants by predicted performance"""
24
  ranked = []
25
  for variant in variants:
26
+ # Predict engagement for each persona
27
+ scores = []
 
 
28
  for persona in target_personas:
29
+ prob = self.predictor.predict(persona, Content(text=variant.text), network)
30
+ scores.append(prob)
31
 
32
+ avg_score = np.mean(scores) if scores else 0.0
33
+ ranked.append(RankedVariant(variant, avg_score))
 
 
 
 
34
 
35
  ranked.sort(key=lambda x: x.score, reverse=True)
36
  return ranked