Commit
·
484e3bc
0
Parent(s):
Initial GeoBot Forecasting Framework commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +20 -0
- GETTING_STARTED.md +242 -0
- README.md +542 -0
- examples/01_basic_usage.py +131 -0
- examples/02_data_ingestion.py +202 -0
- examples/03_intervention_simulation.py +237 -0
- examples/04_advanced_features.py +348 -0
- examples/05_complete_framework.py +542 -0
- examples/06_geobot2_analytical_framework.py +318 -0
- examples/EXAMPLES_STATUS.md +1 -0
- examples/README.md +132 -0
- examples/taiwan_situation_room.py +610 -0
- geobot/__init__.py +43 -0
- geobot/analysis/__init__.py +30 -0
- geobot/analysis/engine.py +393 -0
- geobot/analysis/formatter.py +318 -0
- geobot/analysis/framework.py +155 -0
- geobot/analysis/lenses.py +477 -0
- geobot/bayes/__init__.py +21 -0
- geobot/bayes/forecasting.py +659 -0
- geobot/causal/__init__.py +23 -0
- geobot/causal/structural_model.py +664 -0
- geobot/cli.py +6 -0
- geobot/config/__init__.py +11 -0
- geobot/config/settings.py +183 -0
- geobot/core/__init__.py +25 -0
- geobot/core/advanced_optimal_transport.py +653 -0
- geobot/core/optimal_transport.py +360 -0
- geobot/core/scenario.py +285 -0
- geobot/data_ingestion/__init__.py +36 -0
- geobot/data_ingestion/event_database.py +574 -0
- geobot/data_ingestion/event_extraction.py +539 -0
- geobot/data_ingestion/pdf_reader.py +493 -0
- geobot/data_ingestion/web_scraper.py +365 -0
- geobot/inference/__init__.py +21 -0
- geobot/inference/bayesian_engine.py +606 -0
- geobot/inference/do_calculus.py +510 -0
- geobot/inference/particle_filter.py +557 -0
- geobot/inference/variational_inference.py +531 -0
- geobot/ml/__init__.py +39 -0
- geobot/ml/embedding.py +76 -0
- geobot/ml/feature_discovery.py +81 -0
- geobot/ml/graph_neural_networks.py +628 -0
- geobot/ml/risk_scoring.py +159 -0
- geobot/models/__init__.py +32 -0
- geobot/models/causal_discovery.py +363 -0
- geobot/models/causal_graph.py +588 -0
- geobot/models/quasi_experimental.py +715 -0
- geobot/simulation/__init__.py +31 -0
- geobot/simulation/agent_based.py +349 -0
.gitignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore Python virtual environments
|
| 2 |
+
.venv/
|
| 3 |
+
venv/
|
| 4 |
+
.venv311/
|
| 5 |
+
|
| 6 |
+
# Ignore compiled Python files
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.pyc
|
| 9 |
+
|
| 10 |
+
# Ignore OS metadata
|
| 11 |
+
.DS_Store
|
| 12 |
+
|
| 13 |
+
# Ignore local data
|
| 14 |
+
*.sqlite
|
| 15 |
+
*.db
|
| 16 |
+
*.log
|
| 17 |
+
|
| 18 |
+
# Ignore IDE settings
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
GETTING_STARTED.md
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Getting Started with GeoBotv1
|
| 2 |
+
|
| 3 |
+
This guide will help you get up and running with GeoBotv1.
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
### 1. Clone the repository
|
| 8 |
+
```bash
|
| 9 |
+
git clone https://github.com/yourusername/AIGEOPOLITICAL.git
|
| 10 |
+
cd AIGEOPOLITICAL
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
### 2. Create a virtual environment (recommended)
|
| 14 |
+
```bash
|
| 15 |
+
python -m venv venv
|
| 16 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
### 3. Install dependencies
|
| 20 |
+
```bash
|
| 21 |
+
pip install -r requirements.txt
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Or install in development mode:
|
| 25 |
+
```bash
|
| 26 |
+
pip install -e .
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 4. Install optional dependencies
|
| 30 |
+
For full data ingestion capabilities:
|
| 31 |
+
```bash
|
| 32 |
+
pip install pypdf pdfplumber beautifulsoup4 newspaper3k trafilatura feedparser
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Quick Start
|
| 36 |
+
|
| 37 |
+
### Example 1: Basic Scenario Analysis
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
from geobot.core.scenario import Scenario
|
| 41 |
+
from geobot.simulation.monte_carlo import MonteCarloEngine, SimulationConfig
|
| 42 |
+
import numpy as np
|
| 43 |
+
|
| 44 |
+
# Create a scenario
|
| 45 |
+
scenario = Scenario(
|
| 46 |
+
name="tension_scenario",
|
| 47 |
+
features={
|
| 48 |
+
'military_tension': np.array([0.7]),
|
| 49 |
+
'diplomatic_relations': np.array([0.3]),
|
| 50 |
+
}
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Run Monte Carlo simulation
|
| 54 |
+
config = SimulationConfig(n_simulations=1000, time_horizon=50)
|
| 55 |
+
engine = MonteCarloEngine(config)
|
| 56 |
+
|
| 57 |
+
# Define dynamics
|
| 58 |
+
def transition_fn(state, t, noise):
|
| 59 |
+
new_state = {}
|
| 60 |
+
new_state['tension'] = state.get('tension', 0.5) + noise.get('tension', 0)
|
| 61 |
+
return new_state
|
| 62 |
+
|
| 63 |
+
def noise_fn(t):
|
| 64 |
+
return {'tension': np.random.normal(0, 0.05)}
|
| 65 |
+
|
| 66 |
+
initial_state = {'tension': 0.3}
|
| 67 |
+
trajectories = engine.run_basic_simulation(initial_state, transition_fn, noise_fn)
|
| 68 |
+
|
| 69 |
+
# Analyze results
|
| 70 |
+
stats = engine.compute_statistics(trajectories)
|
| 71 |
+
print(f"Mean tension at end: {stats['tension']['mean'][-1]:.3f}")
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### Example 2: Causal Inference
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
from geobot.models.causal_graph import CausalGraph
|
| 78 |
+
|
| 79 |
+
# Build causal graph
|
| 80 |
+
graph = CausalGraph(name="conflict_model")
|
| 81 |
+
|
| 82 |
+
# Add variables
|
| 83 |
+
graph.add_node('sanctions', node_type='policy')
|
| 84 |
+
graph.add_node('tension', node_type='state')
|
| 85 |
+
graph.add_node('conflict', node_type='outcome')
|
| 86 |
+
|
| 87 |
+
# Define causal relationships
|
| 88 |
+
graph.add_edge('sanctions', 'tension',
|
| 89 |
+
strength=0.7,
|
| 90 |
+
mechanism="Sanctions increase tension")
|
| 91 |
+
graph.add_edge('tension', 'conflict',
|
| 92 |
+
strength=0.8,
|
| 93 |
+
mechanism="Tension leads to conflict")
|
| 94 |
+
|
| 95 |
+
# Visualize
|
| 96 |
+
graph.visualize('causal_graph.png')
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### Example 3: Intervention Simulation
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
from geobot.inference.do_calculus import InterventionSimulator
|
| 103 |
+
from geobot.models.causal_graph import StructuralCausalModel
|
| 104 |
+
|
| 105 |
+
# Create SCM with your causal graph
|
| 106 |
+
scm = StructuralCausalModel(graph)
|
| 107 |
+
|
| 108 |
+
# Define structural equations
|
| 109 |
+
# (See examples/03_intervention_simulation.py for full details)
|
| 110 |
+
|
| 111 |
+
# Create simulator
|
| 112 |
+
simulator = InterventionSimulator(scm)
|
| 113 |
+
|
| 114 |
+
# Simulate intervention
|
| 115 |
+
result = simulator.simulate_intervention(
|
| 116 |
+
intervention={'sanctions': 0.8},
|
| 117 |
+
n_samples=1000,
|
| 118 |
+
outcomes=['conflict']
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
print(f"Expected conflict under sanctions: {result['conflict'].mean():.3f}")
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### Example 4: Bayesian Belief Updating
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
from geobot.inference.bayesian_engine import BeliefUpdater
|
| 128 |
+
|
| 129 |
+
# Create updater
|
| 130 |
+
updater = BeliefUpdater()
|
| 131 |
+
|
| 132 |
+
# Initialize belief
|
| 133 |
+
updater.initialize_belief(
|
| 134 |
+
name='conflict_risk',
|
| 135 |
+
prior_mean=0.3,
|
| 136 |
+
prior_std=0.1,
|
| 137 |
+
belief_type='probability'
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Update with new intelligence
|
| 141 |
+
posterior = updater.update_from_intelligence(
|
| 142 |
+
belief='conflict_risk',
|
| 143 |
+
observation=0.6,
|
| 144 |
+
reliability=0.8
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
print(f"Updated risk: {posterior['mean']:.3f} ± {posterior['std']:.3f}")
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Example 5: PDF Processing
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
from geobot.data_ingestion.pdf_reader import PDFProcessor
|
| 154 |
+
|
| 155 |
+
# Create processor
|
| 156 |
+
processor = PDFProcessor()
|
| 157 |
+
|
| 158 |
+
# Process document
|
| 159 |
+
result = processor.extract_intelligence('report.pdf')
|
| 160 |
+
|
| 161 |
+
print(f"Risk Level: {result['intelligence']['risk_level']}")
|
| 162 |
+
print(f"Countries: {result['intelligence']['mentioned_countries']}")
|
| 163 |
+
print(f"Conflict Indicators: {result['intelligence']['conflict_indicators']}")
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### Example 6: Web Scraping
|
| 167 |
+
|
| 168 |
+
```python
|
| 169 |
+
from geobot.data_ingestion.web_scraper import ArticleExtractor
|
| 170 |
+
|
| 171 |
+
# Create extractor
|
| 172 |
+
extractor = ArticleExtractor()
|
| 173 |
+
|
| 174 |
+
# Extract article
|
| 175 |
+
article = extractor.extract_article('https://example.com/article')
|
| 176 |
+
|
| 177 |
+
print(f"Title: {article['title']}")
|
| 178 |
+
print(f"Summary: {article['text'][:200]}...")
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
## Running Examples
|
| 182 |
+
|
| 183 |
+
The `examples/` directory contains comprehensive demonstrations:
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
cd examples
|
| 187 |
+
|
| 188 |
+
# Basic usage
|
| 189 |
+
python 01_basic_usage.py
|
| 190 |
+
|
| 191 |
+
# Data ingestion
|
| 192 |
+
python 02_data_ingestion.py
|
| 193 |
+
|
| 194 |
+
# Intervention simulation
|
| 195 |
+
python 03_intervention_simulation.py
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
## Core Concepts
|
| 199 |
+
|
| 200 |
+
### 1. Scenarios
|
| 201 |
+
Scenarios represent geopolitical states with features and probabilities.
|
| 202 |
+
|
| 203 |
+
### 2. Causal Graphs
|
| 204 |
+
DAGs that model causal relationships between variables.
|
| 205 |
+
|
| 206 |
+
### 3. Structural Causal Models
|
| 207 |
+
Mathematical models with functional equations for each variable.
|
| 208 |
+
|
| 209 |
+
### 4. Monte Carlo Simulation
|
| 210 |
+
Stochastic simulation for uncertainty quantification.
|
| 211 |
+
|
| 212 |
+
### 5. Bayesian Inference
|
| 213 |
+
Principled belief updating as new evidence arrives.
|
| 214 |
+
|
| 215 |
+
### 6. Do-Calculus
|
| 216 |
+
Intervention reasoning for "what if" questions.
|
| 217 |
+
|
| 218 |
+
### 7. Optimal Transport
|
| 219 |
+
Measuring distances between probability distributions.
|
| 220 |
+
|
| 221 |
+
## Next Steps
|
| 222 |
+
|
| 223 |
+
1. Read the full README.md
|
| 224 |
+
2. Explore the examples directory
|
| 225 |
+
3. Check out the module documentation
|
| 226 |
+
4. Build your own models!
|
| 227 |
+
|
| 228 |
+
## Need Help?
|
| 229 |
+
|
| 230 |
+
- Check the examples directory for detailed code
|
| 231 |
+
- Review module docstrings for API documentation
|
| 232 |
+
- Open an issue on GitHub
|
| 233 |
+
|
| 234 |
+
## Tips
|
| 235 |
+
|
| 236 |
+
1. Start with simple models and gradually add complexity
|
| 237 |
+
2. Always validate your causal assumptions
|
| 238 |
+
3. Use Monte Carlo for uncertainty quantification
|
| 239 |
+
4. Combine multiple methods for robust forecasting
|
| 240 |
+
5. Document your assumptions and data sources
|
| 241 |
+
|
| 242 |
+
Happy forecasting!
|
README.md
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GeoBotv1: Research-Grade Geopolitical Forecasting Framework
|
| 2 |
+
|
| 3 |
+
[]() []() []()
|
| 4 |
+
|
| 5 |
+
**GeoBotv1** is a complete, research-grade framework for geopolitical forecasting, conflict prediction, and causal policy analysis. Built on rigorous mathematical foundations, it combines optimal transport theory, structural causal inference, Bayesian reasoning, stochastic processes, econometric methods, and machine learning to provide actionable intelligence on regime shifts, conflict escalation, and intervention outcomes.
|
| 6 |
+
|
| 7 |
+
**Status: ✅ 100% Complete** - All core mathematical frameworks implemented and production-ready.
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## 🎯 Key Capabilities
|
| 12 |
+
|
| 13 |
+
- **Causal Policy Analysis**: Simulate interventions (sanctions, deployments, regime changes) and estimate counterfactual outcomes
|
| 14 |
+
- **Conflict Contagion Modeling**: Model self-exciting escalation dynamics and cross-country spillovers using Hawkes processes
|
| 15 |
+
- **Multi-Country Forecasting**: Capture interdependencies and shock propagation with Vector Autoregression (VAR/SVAR)
|
| 16 |
+
- **Quasi-Experimental Inference**: Estimate treatment effects from observational data (Synthetic Control, DiD, RDD, IV)
|
| 17 |
+
- **Regime Detection**: Identify structural breaks and state transitions in real-time
|
| 18 |
+
- **Scenario Comparison**: Measure distances between geopolitical futures using optimal transport geometry
|
| 19 |
+
- **Intelligence Integration**: Bayesian belief updating from text, events, and structured data
|
| 20 |
+
- **Nowcasting**: High-dimensional factor models for real-time situational awareness
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## 📊 Complete Mathematical Framework
|
| 25 |
+
|
| 26 |
+
### ✅ Implemented Core Components
|
| 27 |
+
|
| 28 |
+
<table>
|
| 29 |
+
<tr>
|
| 30 |
+
<td width="50%">
|
| 31 |
+
|
| 32 |
+
**1. Optimal Transport Theory**
|
| 33 |
+
- Wasserstein distances (W1, W2, W∞)
|
| 34 |
+
- Kantorovich duality (primal/dual formulations)
|
| 35 |
+
- Sinkhorn algorithm (entropic regularization)
|
| 36 |
+
- Unbalanced optimal transport
|
| 37 |
+
- Gromov-Wasserstein for network comparison
|
| 38 |
+
- Gradient-based OT optimization
|
| 39 |
+
|
| 40 |
+
**2. Causal Inference**
|
| 41 |
+
- Directed Acyclic Graphs (DAGs)
|
| 42 |
+
- Structural Causal Models (SCMs)
|
| 43 |
+
- Pearl's Do-Calculus for interventions
|
| 44 |
+
- Backdoor/frontdoor adjustment
|
| 45 |
+
- Counterfactual computation
|
| 46 |
+
- Causal discovery algorithms
|
| 47 |
+
|
| 48 |
+
**3. Bayesian Inference**
|
| 49 |
+
- Markov Chain Monte Carlo (MCMC)
|
| 50 |
+
- Sequential Monte Carlo (Particle Filters)
|
| 51 |
+
- Bootstrap, Auxiliary, Rao-Blackwellized
|
| 52 |
+
- Variational Inference (ELBO, CAVI, ADVI)
|
| 53 |
+
- Belief updating from intelligence
|
| 54 |
+
- Posterior predictive distributions
|
| 55 |
+
|
| 56 |
+
**4. Stochastic Processes**
|
| 57 |
+
- Stochastic Differential Equations (SDEs)
|
| 58 |
+
- Euler-Maruyama, Milstein, Stochastic RK
|
| 59 |
+
- Jump-diffusion processes (Merton model)
|
| 60 |
+
- Ornstein-Uhlenbeck processes
|
| 61 |
+
- GeopoliticalSDE framework
|
| 62 |
+
- Continuous-time dynamics
|
| 63 |
+
|
| 64 |
+
</td>
|
| 65 |
+
<td width="50%">
|
| 66 |
+
|
| 67 |
+
**5. Time-Series & Econometrics**
|
| 68 |
+
- Kalman Filters (Linear & Extended)
|
| 69 |
+
- Hidden Markov Models (HMM)
|
| 70 |
+
- Regime-Switching Models
|
| 71 |
+
- **Vector Autoregression (VAR)**
|
| 72 |
+
- **Structural VAR (SVAR)**
|
| 73 |
+
- **Dynamic Factor Models (DFM)**
|
| 74 |
+
- **Granger Causality Testing**
|
| 75 |
+
- Impulse Response Functions (IRF)
|
| 76 |
+
- Forecast Error Variance Decomposition
|
| 77 |
+
|
| 78 |
+
**6. Point Processes**
|
| 79 |
+
- **Univariate Hawkes Processes**
|
| 80 |
+
- **Multivariate Hawkes Processes**
|
| 81 |
+
- **Conflict Contagion Models**
|
| 82 |
+
- Branching ratio estimation
|
| 83 |
+
- Self-excitation dynamics
|
| 84 |
+
- Cross-country excitation matrices
|
| 85 |
+
|
| 86 |
+
**7. Quasi-Experimental Methods**
|
| 87 |
+
- **Synthetic Control Method (SCM)**
|
| 88 |
+
- **Difference-in-Differences (DiD)**
|
| 89 |
+
- **Regression Discontinuity Design (RDD)**
|
| 90 |
+
- **Instrumental Variables (2SLS)**
|
| 91 |
+
- Placebo tests and robustness checks
|
| 92 |
+
- Treatment effect bounds
|
| 93 |
+
|
| 94 |
+
**8. Machine Learning**
|
| 95 |
+
- Graph Neural Networks (GNN, GAT)
|
| 96 |
+
- CausalGNN for directed graphs
|
| 97 |
+
- Risk scoring and classification
|
| 98 |
+
- Feature discovery and embeddings
|
| 99 |
+
- Transformer-based text encoding
|
| 100 |
+
|
| 101 |
+
</td>
|
| 102 |
+
</tr>
|
| 103 |
+
</table>
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
## 🚀 Quick Start
|
| 108 |
+
|
| 109 |
+
### Installation
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
# Clone the repository
|
| 113 |
+
git clone https://github.com/your-org/AIGEOPOLITICAL.git
|
| 114 |
+
cd AIGEOPOLITICAL
|
| 115 |
+
|
| 116 |
+
# Install dependencies
|
| 117 |
+
pip install -r requirements.txt
|
| 118 |
+
|
| 119 |
+
# Optional: Install Graph Neural Network support
|
| 120 |
+
# (Requires matching torch version)
|
| 121 |
+
pip install torch-geometric torch-scatter torch-sparse
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### Basic Usage Example
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
from geobot.core import Scenario, ScenarioComparator
|
| 128 |
+
from geobot.models import CausalGraph
|
| 129 |
+
from geobot.inference import DoCalculus
|
| 130 |
+
from geobot.timeseries import VARModel
|
| 131 |
+
import numpy as np
|
| 132 |
+
|
| 133 |
+
# 1. Create geopolitical scenarios
|
| 134 |
+
scenario_baseline = Scenario(
|
| 135 |
+
name="Baseline",
|
| 136 |
+
features={
|
| 137 |
+
"military_readiness": np.array([0.6, 0.4, 0.5]),
|
| 138 |
+
"economic_stability": np.array([0.7, 0.6, 0.5])
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
scenario_intervention = Scenario(
|
| 143 |
+
name="Post-Sanctions",
|
| 144 |
+
features={
|
| 145 |
+
"military_readiness": np.array([0.8, 0.4, 0.5]),
|
| 146 |
+
"economic_stability": np.array([0.3, 0.6, 0.5])
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# 2. Compare scenarios using optimal transport
|
| 151 |
+
comparator = ScenarioComparator()
|
| 152 |
+
distance = comparator.compare(scenario_baseline, scenario_intervention)
|
| 153 |
+
print(f"Wasserstein distance: {distance:.4f}")
|
| 154 |
+
|
| 155 |
+
# 3. Build causal model
|
| 156 |
+
causal_graph = CausalGraph()
|
| 157 |
+
causal_graph.add_edge("sanctions", "economy", strength=0.8)
|
| 158 |
+
causal_graph.add_edge("economy", "stability", strength=0.6)
|
| 159 |
+
causal_graph.add_edge("stability", "conflict_risk", strength=-0.7)
|
| 160 |
+
|
| 161 |
+
# 4. Simulate intervention
|
| 162 |
+
do_calc = DoCalculus(causal_graph)
|
| 163 |
+
intervention_effect = do_calc.compute_intervention_effect(
|
| 164 |
+
intervention={"sanctions": 1.0},
|
| 165 |
+
outcome="conflict_risk"
|
| 166 |
+
)
|
| 167 |
+
print(f"Estimated effect on conflict risk: {intervention_effect:.3f}")
|
| 168 |
+
|
| 169 |
+
# 5. Multi-country VAR forecasting
|
| 170 |
+
# (Simulated data for demonstration)
|
| 171 |
+
data = np.random.randn(100, 3) # 100 time periods, 3 countries
|
| 172 |
+
var = VARModel(n_lags=2)
|
| 173 |
+
results = var.fit(data, variable_names=['Country_A', 'Country_B', 'Country_C'])
|
| 174 |
+
forecast = var.forecast(results, steps=10)
|
| 175 |
+
print(f"10-step forecast:\n{forecast}")
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### Conflict Contagion Example
|
| 179 |
+
|
| 180 |
+
```python
|
| 181 |
+
from geobot.timeseries import ConflictContagionModel
|
| 182 |
+
|
| 183 |
+
# Model conflict spread between countries
|
| 184 |
+
countries = ['Syria', 'Iraq', 'Lebanon', 'Turkey']
|
| 185 |
+
model = ConflictContagionModel(countries=countries)
|
| 186 |
+
|
| 187 |
+
# Historical conflict events (times when conflicts occurred)
|
| 188 |
+
events = {
|
| 189 |
+
'Syria': [1.2, 5.3, 10.1, 15.2, 22.3],
|
| 190 |
+
'Iraq': [3.4, 8.9, 12.1, 18.5],
|
| 191 |
+
'Lebanon': [12.3, 19.8, 25.1],
|
| 192 |
+
'Turkey': [28.2, 30.5]
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# Fit contagion model
|
| 196 |
+
result = model.fit(events, T=365.0) # 1 year of data
|
| 197 |
+
|
| 198 |
+
print(f"Most contagious source: {result['most_contagious_source']}")
|
| 199 |
+
print(f"Most vulnerable target: {result['most_vulnerable_target']}")
|
| 200 |
+
|
| 201 |
+
# Assess future risk
|
| 202 |
+
risks = model.contagion_risk(events, result, t=370.0, horizon=30.0)
|
| 203 |
+
for country, risk in risks.items():
|
| 204 |
+
print(f"{country} conflict risk (next 30 days): {risk:.1%}")
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
### Synthetic Control for Policy Evaluation
|
| 208 |
+
|
| 209 |
+
```python
|
| 210 |
+
from geobot.models import SyntheticControlMethod
|
| 211 |
+
import numpy as np
|
| 212 |
+
|
| 213 |
+
# Evaluate impact of sanctions on target country
|
| 214 |
+
scm = SyntheticControlMethod()
|
| 215 |
+
|
| 216 |
+
# Data: treated country vs. control countries
|
| 217 |
+
treated_outcome = gdp_growth_target_country # Shape: (T,)
|
| 218 |
+
control_outcomes = gdp_growth_other_countries # Shape: (T, J)
|
| 219 |
+
|
| 220 |
+
result = scm.fit(
|
| 221 |
+
treated_outcome=treated_outcome,
|
| 222 |
+
control_outcomes=control_outcomes,
|
| 223 |
+
treatment_time=50, # Sanctions imposed at t=50
|
| 224 |
+
control_names=['Country_1', 'Country_2', ..., 'Country_J']
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Estimated treatment effect
|
| 228 |
+
avg_effect = np.mean(result.treatment_effect[50:])
|
| 229 |
+
print(f"Average treatment effect: {avg_effect:.3f}")
|
| 230 |
+
|
| 231 |
+
# Statistical significance via placebo test
|
| 232 |
+
p_value = scm.placebo_test(treated_outcome, control_outcomes,
|
| 233 |
+
treatment_time=50, n_permutations=100)
|
| 234 |
+
print(f"Placebo test p-value: {p_value:.3f}")
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## 📁 Project Structure
|
| 240 |
+
|
| 241 |
+
```
|
| 242 |
+
AIGEOPOLITICAL/
|
| 243 |
+
├── geobot/ # Main package
|
| 244 |
+
│ ├── core/ # Core mathematical primitives
|
| 245 |
+
│ │ ├── optimal_transport.py # Wasserstein distances
|
| 246 |
+
│ │ ├── advanced_optimal_transport.py # Kantorovich, Sinkhorn, Gromov-W
|
| 247 |
+
│ │ └── scenario.py # Scenario representations
|
| 248 |
+
│ │
|
| 249 |
+
│ ├── models/ # Causal models
|
| 250 |
+
│ │ ├── causal_graph.py # DAGs and SCMs
|
| 251 |
+
│ │ ├── causal_discovery.py # Causal structure learning
|
| 252 |
+
│ │ └── quasi_experimental.py # SCM, DiD, RDD, IV methods
|
| 253 |
+
│ │
|
| 254 |
+
│ ├── inference/ # Probabilistic inference
|
| 255 |
+
│ │ ├── do_calculus.py # Interventions and counterfactuals
|
| 256 |
+
│ │ ├── bayesian_engine.py # Bayesian belief updating
|
| 257 |
+
│ │ ├── particle_filter.py # Sequential Monte Carlo
|
| 258 |
+
│ │ └── variational_inference.py # VI, ELBO, ADVI
|
| 259 |
+
│ │
|
| 260 |
+
│ ├── simulation/ # Stochastic simulation
|
| 261 |
+
│ │ ├── monte_carlo.py # Basic Monte Carlo
|
| 262 |
+
│ │ ├── sde_solver.py # Stochastic differential equations
|
| 263 |
+
│ │ └── agent_based.py # Agent-based models
|
| 264 |
+
│ │
|
| 265 |
+
│ ├── timeseries/ # Time-series models
|
| 266 |
+
│ │ ├── kalman_filter.py # Kalman filters
|
| 267 |
+
│ │ ├── hmm.py # Hidden Markov Models
|
| 268 |
+
│ │ ├── regime_switching.py # Regime-switching models
|
| 269 |
+
│ │ ├── var_models.py # VAR, SVAR, DFM, Granger causality
|
| 270 |
+
│ │ └── point_processes.py # Hawkes processes, conflict contagion
|
| 271 |
+
│ │
|
| 272 |
+
│ ├── ml/ # Machine learning
|
| 273 |
+
│ │ ├── risk_models.py # Risk scoring
|
| 274 |
+
│ │ ├── graph_neural_networks.py # GNNs for causal/geopolitical networks
|
| 275 |
+
│ │ └── embeddings.py # Feature embeddings
|
| 276 |
+
│ │
|
| 277 |
+
│ ├── data_ingestion/ # Data processing
|
| 278 |
+
│ │ ├── pdf_reader.py # PDF intelligence extraction
|
| 279 |
+
│ │ ├── web_scraper.py # Web scraping
|
| 280 |
+
│ │ ├── event_extraction.py # NLP-based event structuring
|
| 281 |
+
│ │ └── event_database.py # Event storage and querying
|
| 282 |
+
│ │
|
| 283 |
+
│ ├── utils/ # Utilities
|
| 284 |
+
│ │ ├── data_processing.py # Data preprocessing
|
| 285 |
+
│ │ └── visualization.py # Plotting and visualization
|
| 286 |
+
│ │
|
| 287 |
+
│ └── config/ # Configuration
|
| 288 |
+
│ └── settings.py # System configuration
|
| 289 |
+
│
|
| 290 |
+
├── examples/ # Comprehensive examples
|
| 291 |
+
│ ├── 01_basic_usage.py # Scenarios, causal graphs, Monte Carlo
|
| 292 |
+
│ ├── 02_data_ingestion.py # PDF/web scraping pipeline
|
| 293 |
+
│ ├── 03_intervention_simulation.py # Do-calculus and counterfactuals
|
| 294 |
+
│ ├── 04_advanced_features.py # Particle filters, VI, SDEs, GNNs
|
| 295 |
+
│ └── 05_complete_framework.py # VAR, Hawkes, quasi-experimental
|
| 296 |
+
│
|
| 297 |
+
├── requirements.txt # Python dependencies
|
| 298 |
+
└── README.md # This file
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
---
|
| 302 |
+
|
| 303 |
+
## 🔬 Mathematical Foundations
|
| 304 |
+
|
| 305 |
+
### Causality-First Principle
|
| 306 |
+
|
| 307 |
+
GeoBotv1 grounds all forecasting in **explicit causal structure**. This prevents spurious correlations and enables simulation of interventions that have never been observed.
|
| 308 |
+
|
| 309 |
+
**Structural Causal Model (SCM):**
|
| 310 |
+
```
|
| 311 |
+
X := f_X(Pa_X, U_X)
|
| 312 |
+
Y := f_Y(Pa_Y, U_Y)
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
**Pearl's Do-Calculus** enables reasoning about interventions `do(X = x)` even when only observational data is available.
|
| 316 |
+
|
| 317 |
+
### Optimal Transport for Scenario Comparison
|
| 318 |
+
|
| 319 |
+
The Wasserstein distance measures the "cost" of transforming one probability distribution into another:
|
| 320 |
+
|
| 321 |
+
```
|
| 322 |
+
W_2(μ, ν) = inf_{π ∈ Π(μ,ν)} √(∫∫ ||x - y||² dπ(x,y))
|
| 323 |
+
```
|
| 324 |
+
|
| 325 |
+
**Kantorovich Duality** provides computational efficiency and geometric interpretation.
|
| 326 |
+
|
| 327 |
+
### Hawkes Processes for Conflict Dynamics
|
| 328 |
+
|
| 329 |
+
Self-exciting point processes capture escalation and contagion:
|
| 330 |
+
|
| 331 |
+
```
|
| 332 |
+
λ_k(t) = μ_k + ∑_{j=1}^K α_{kj} ∑_{t_i^j < t} exp(-β_{kj}(t - t_i^j))
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
**Branching ratio** `n = α/β` determines stability:
|
| 336 |
+
- `n < 1`: Process is stable (subcritical)
|
| 337 |
+
- `n ≥ 1`: Process is explosive (supercritical)
|
| 338 |
+
|
| 339 |
+
### Stochastic Differential Equations
|
| 340 |
+
|
| 341 |
+
Continuous-time geopolitical dynamics:
|
| 342 |
+
|
| 343 |
+
```
|
| 344 |
+
dX_t = μ(X_t, t) dt + σ(X_t, t) dW_t + dJ_t
|
| 345 |
+
```
|
| 346 |
+
|
| 347 |
+
Where:
|
| 348 |
+
- `μ`: Drift (deterministic component)
|
| 349 |
+
- `σ dW_t`: Diffusion (continuous shocks)
|
| 350 |
+
- `dJ_t`: Jumps (discrete events)
|
| 351 |
+
|
| 352 |
+
---
|
| 353 |
+
|
| 354 |
+
## 💡 Use Cases
|
| 355 |
+
|
| 356 |
+
### 1. Strategic Intelligence & Risk Assessment
|
| 357 |
+
- **Multi-country risk dashboards** with VAR-based interdependency analysis
|
| 358 |
+
- **Real-time belief updating** as intelligence arrives (Bayesian inference)
|
| 359 |
+
- **Regime detection** using HMM and particle filters
|
| 360 |
+
|
| 361 |
+
### 2. Conflict Prediction & Early Warning
|
| 362 |
+
- **Hawkes process models** for escalation dynamics and contagion
|
| 363 |
+
- **Structural break detection** in military/economic indicators
|
| 364 |
+
- **Cross-country spillover analysis** with SVAR impulse responses
|
| 365 |
+
|
| 366 |
+
### 3. Policy Impact Analysis
|
| 367 |
+
- **Synthetic control** for sanctions/intervention evaluation
|
| 368 |
+
- **Difference-in-differences** for regime change effects
|
| 369 |
+
- **Regression discontinuity** for election outcome impacts
|
| 370 |
+
- **Do-calculus** for counterfactual policy simulation
|
| 371 |
+
|
| 372 |
+
### 4. Resource Allocation & Logistics
|
| 373 |
+
- **Optimal transport** for supply chain optimization under uncertainty
|
| 374 |
+
- **Monte Carlo simulation** for contingency planning
|
| 375 |
+
- **Stochastic optimization** with SDE-based constraints
|
| 376 |
+
|
| 377 |
+
### 5. Diplomatic Forecasting
|
| 378 |
+
- **Game-theoretic extensions** (can be added via causal graphs)
|
| 379 |
+
- **Network centrality** analysis with GNNs
|
| 380 |
+
- **Alliance dynamics** using multivariate Hawkes processes
|
| 381 |
+
|
| 382 |
+
---
|
| 383 |
+
|
| 384 |
+
## 🛠️ Technical Requirements
|
| 385 |
+
|
| 386 |
+
### Core Dependencies
|
| 387 |
+
- **Python**: 3.9+
|
| 388 |
+
- **NumPy**: 1.24+
|
| 389 |
+
- **SciPy**: 1.10+
|
| 390 |
+
- **Pandas**: 2.0+
|
| 391 |
+
- **NetworkX**: 3.0+ (causal graphs)
|
| 392 |
+
- **POT**: 0.9+ (optimal transport)
|
| 393 |
+
- **PyMC**: 5.0+ (Bayesian inference)
|
| 394 |
+
- **Statsmodels**: 0.14+ (econometrics)
|
| 395 |
+
|
| 396 |
+
### Optional but Recommended
|
| 397 |
+
- **PyTorch**: 2.0+ (for GNNs, deep learning)
|
| 398 |
+
- **PyTorch Geometric**: 2.3+ (graph neural networks)
|
| 399 |
+
- **spaCy**: 3.6+ (NLP for event extraction)
|
| 400 |
+
- **Transformers**: 4.30+ (text embeddings)
|
| 401 |
+
|
| 402 |
+
### Data Ingestion
|
| 403 |
+
- **Beautiful Soup**: 4.12+ (web scraping)
|
| 404 |
+
- **PDFPlumber**: 0.10+ (PDF extraction)
|
| 405 |
+
- **Newspaper3k**: 0.2.8+ (article extraction)
|
| 406 |
+
- **Feedparser**: 6.0+ (RSS feeds)
|
| 407 |
+
|
| 408 |
+
See `requirements.txt` for complete dependency list.
|
| 409 |
+
|
| 410 |
+
---
|
| 411 |
+
|
| 412 |
+
## 📖 Documentation & Examples
|
| 413 |
+
|
| 414 |
+
### Example Scripts
|
| 415 |
+
All examples are fully functional and demonstrate end-to-end workflows:
|
| 416 |
+
|
| 417 |
+
1. **`01_basic_usage.py`**: Scenarios, causal graphs, Monte Carlo basics, Bayesian updating
|
| 418 |
+
2. **`02_data_ingestion.py`**: PDF extraction, web scraping, event databases
|
| 419 |
+
3. **`03_intervention_simulation.py`**: Do-calculus, counterfactuals, policy simulation
|
| 420 |
+
4. **`04_advanced_features.py`**: Particle filters, VI, SDEs, GNNs, event extraction
|
| 421 |
+
5. **`05_complete_framework.py`**: VAR/SVAR/DFM, Hawkes processes, quasi-experimental methods
|
| 422 |
+
|
| 423 |
+
### Running Examples
|
| 424 |
+
|
| 425 |
+
```bash
|
| 426 |
+
cd examples
|
| 427 |
+
|
| 428 |
+
# Basic usage
|
| 429 |
+
python 01_basic_usage.py
|
| 430 |
+
|
| 431 |
+
# Data ingestion pipeline
|
| 432 |
+
python 02_data_ingestion.py
|
| 433 |
+
|
| 434 |
+
# Intervention simulation
|
| 435 |
+
python 03_intervention_simulation.py
|
| 436 |
+
|
| 437 |
+
# Advanced mathematical features
|
| 438 |
+
python 04_advanced_features.py
|
| 439 |
+
|
| 440 |
+
# Complete framework demonstration (VAR, Hawkes, quasi-experimental)
|
| 441 |
+
python 05_complete_framework.py
|
| 442 |
+
```
|
| 443 |
+
|
| 444 |
+
---
|
| 445 |
+
|
| 446 |
+
## 🎓 Theoretical Background
|
| 447 |
+
|
| 448 |
+
GeoBotv1 synthesizes methods from multiple fields:
|
| 449 |
+
|
| 450 |
+
### Economics & Econometrics
|
| 451 |
+
- Vector Autoregression (Sims, 1980)
|
| 452 |
+
- Structural VAR identification (Blanchard & Quah, 1989)
|
| 453 |
+
- Synthetic Control Method (Abadie et al., 2010, 2015)
|
| 454 |
+
- Difference-in-Differences (Card & Krueger, 1994)
|
| 455 |
+
- Regression Discontinuity (Thistlethwaite & Campbell, 1960)
|
| 456 |
+
|
| 457 |
+
### Statistics & Probability
|
| 458 |
+
- Optimal Transport (Villani, 2003, 2009)
|
| 459 |
+
- Hawkes Processes (Hawkes, 1971; Hawkes & Oakes, 1974)
|
| 460 |
+
- Sequential Monte Carlo (Doucet et al., 2001)
|
| 461 |
+
- Variational Inference (Jordan et al., 1999; Blei et al., 2017)
|
| 462 |
+
|
| 463 |
+
### Computer Science & AI
|
| 464 |
+
- Causal Inference (Pearl, 2000, 2009)
|
| 465 |
+
- Do-Calculus (Pearl, 1995)
|
| 466 |
+
- Graph Neural Networks (Kipf & Welling, 2017; Veličković et al., 2018)
|
| 467 |
+
- Structural Causal Models (Pearl & Mackenzie, 2018)
|
| 468 |
+
|
| 469 |
+
### Geopolitics & Conflict Studies
|
| 470 |
+
- Conflict contagion (Gleditsch, 2007; Braithwaite, 2010)
|
| 471 |
+
- Regime change dynamics (Acemoglu & Robinson, 2006)
|
| 472 |
+
- Economic sanctions effects (Hufbauer et al., 2007)
|
| 473 |
+
|
| 474 |
+
---
|
| 475 |
+
|
| 476 |
+
## 🧪 Testing & Validation
|
| 477 |
+
|
| 478 |
+
### Mathematical Correctness
|
| 479 |
+
- Kantorovich duality: Verify primal-dual gap ≈ 0
|
| 480 |
+
- Hawkes stability: Check branching ratio < 1 for subcritical processes
|
| 481 |
+
- Causal identification: Validate backdoor/frontdoor criteria
|
| 482 |
+
- Particle filter: Monitor Effective Sample Size (ESS)
|
| 483 |
+
|
| 484 |
+
### Statistical Properties
|
| 485 |
+
- VAR stationarity: Check eigenvalues of companion matrix
|
| 486 |
+
- SVAR identification: Verify order conditions satisfied
|
| 487 |
+
- Synthetic control: Pre-treatment fit quality (RMSPE)
|
| 488 |
+
- RDD: Bandwidth sensitivity analysis
|
| 489 |
+
|
| 490 |
+
### Reproducibility
|
| 491 |
+
All examples include `np.random.seed()` for deterministic results.
|
| 492 |
+
|
| 493 |
+
---
|
| 494 |
+
|
| 495 |
+
## 🚧 Extensions & Future Work
|
| 496 |
+
|
| 497 |
+
While GeoBotv1 is complete for core forecasting tasks, potential extensions include:
|
| 498 |
+
|
| 499 |
+
- **Game-Theoretic Modules**: Strategic interaction between actors
|
| 500 |
+
- **Spatial Statistics**: Geographic contagion with distance decay
|
| 501 |
+
- **Network Effects**: Centrality measures, community detection
|
| 502 |
+
- **Deep Learning**: Attention mechanisms for text-to-risk pipelines
|
| 503 |
+
- **Real-Time Data Feeds**: API integrations for live intelligence
|
| 504 |
+
- **Uncertainty Quantification**: Conformal prediction, calibration
|
| 505 |
+
- **Ensemble Methods**: Model averaging across multiple frameworks
|
| 506 |
+
|
| 507 |
+
---
|
| 508 |
+
|
| 509 |
+
## 📄 License
|
| 510 |
+
|
| 511 |
+
MIT License - See LICENSE file for details
|
| 512 |
+
|
| 513 |
+
---
|
| 514 |
+
|
| 515 |
+
## 🙏 Acknowledgments
|
| 516 |
+
|
| 517 |
+
GeoBotv1 builds on decades of research in:
|
| 518 |
+
- Causal inference (Judea Pearl, Donald Rubin)
|
| 519 |
+
- Optimal transport theory (Cédric Villani, Leonid Kantorovich)
|
| 520 |
+
- Econometrics (Christopher Sims, Alberto Abadie)
|
| 521 |
+
- Point processes (Alan Hawkes, Daryl Daley)
|
| 522 |
+
- Bayesian statistics (Andrew Gelman, Michael Jordan)
|
| 523 |
+
- Stochastic calculus (Kiyosi Itô, Bernt Øksendal)
|
| 524 |
+
|
| 525 |
+
---
|
| 526 |
+
|
| 527 |
+
## 📧 Contact & Support
|
| 528 |
+
|
| 529 |
+
For questions, issues, or collaboration:
|
| 530 |
+
- **Issues**: [GitHub Issues](https://github.com/your-org/AIGEOPOLITICAL/issues)
|
| 531 |
+
- **Documentation**: [Full documentation](https://your-org.github.io/AIGEOPOLITICAL)
|
| 532 |
+
- **Citation**: If you use GeoBotv1 in research, please cite this repository
|
| 533 |
+
|
| 534 |
+
---
|
| 535 |
+
|
| 536 |
+
<div align="center">
|
| 537 |
+
|
| 538 |
+
**GeoBotv1** - Where rigorous mathematics meets geopolitical forecasting
|
| 539 |
+
|
| 540 |
+
*Built with causality, powered by probability, validated by theory*
|
| 541 |
+
|
| 542 |
+
</div>
|
examples/01_basic_usage.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example 1: Basic Usage of GeoBotv1
|
| 3 |
+
|
| 4 |
+
This example demonstrates the core components of the framework:
|
| 5 |
+
- Creating scenarios
|
| 6 |
+
- Building causal graphs
|
| 7 |
+
- Running Monte Carlo simulations
|
| 8 |
+
- Bayesian belief updating
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.append('..')
|
| 14 |
+
|
| 15 |
+
from geobot.core.scenario import Scenario, ScenarioDistribution
|
| 16 |
+
from geobot.models.causal_graph import CausalGraph, StructuralCausalModel
|
| 17 |
+
from geobot.simulation.monte_carlo import MonteCarloEngine, SimulationConfig
|
| 18 |
+
from geobot.inference.bayesian_engine import BayesianEngine, Prior, Evidence, BeliefUpdater
|
| 19 |
+
from scipy import stats
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
print("=" * 80)
|
| 24 |
+
print("GeoBotv1 - Basic Usage Example")
|
| 25 |
+
print("=" * 80)
|
| 26 |
+
|
| 27 |
+
# 1. Create a simple scenario
|
| 28 |
+
print("\n1. Creating a geopolitical scenario...")
|
| 29 |
+
scenario = Scenario(
|
| 30 |
+
name="baseline_scenario",
|
| 31 |
+
features={
|
| 32 |
+
'military_tension': np.array([0.5]),
|
| 33 |
+
'economic_sanctions': np.array([0.3]),
|
| 34 |
+
'diplomatic_relations': np.array([0.6]),
|
| 35 |
+
},
|
| 36 |
+
probability=1.0
|
| 37 |
+
)
|
| 38 |
+
print(f" Created scenario: {scenario.name}")
|
| 39 |
+
print(f" Features: {list(scenario.features.keys())}")
|
| 40 |
+
|
| 41 |
+
# 2. Build a causal graph
|
| 42 |
+
print("\n2. Building causal graph...")
|
| 43 |
+
causal_graph = CausalGraph(name="geopolitical_dag")
|
| 44 |
+
|
| 45 |
+
# Add nodes
|
| 46 |
+
causal_graph.add_node('sanctions', node_type='policy')
|
| 47 |
+
causal_graph.add_node('tension', node_type='state')
|
| 48 |
+
causal_graph.add_node('conflict_risk', node_type='outcome')
|
| 49 |
+
|
| 50 |
+
# Add causal edges
|
| 51 |
+
causal_graph.add_edge('sanctions', 'tension',
|
| 52 |
+
strength=0.7,
|
| 53 |
+
mechanism="Sanctions increase military tension")
|
| 54 |
+
causal_graph.add_edge('tension', 'conflict_risk',
|
| 55 |
+
strength=0.8,
|
| 56 |
+
mechanism="Tension increases conflict probability")
|
| 57 |
+
|
| 58 |
+
print(f" Created graph with {len(causal_graph.graph.nodes)} nodes")
|
| 59 |
+
print(f" Causal relationships: sanctions -> tension -> conflict_risk")
|
| 60 |
+
|
| 61 |
+
# 3. Run Monte Carlo simulation
|
| 62 |
+
print("\n3. Running Monte Carlo simulation...")
|
| 63 |
+
config = SimulationConfig(n_simulations=100, time_horizon=50)
|
| 64 |
+
mc_engine = MonteCarloEngine(config)
|
| 65 |
+
|
| 66 |
+
def transition_fn(state, t, noise):
|
| 67 |
+
# Simple dynamics
|
| 68 |
+
new_state = {}
|
| 69 |
+
new_state['tension'] = state.get('tension', 0.5) + \
|
| 70 |
+
0.1 * state.get('sanctions', 0) + \
|
| 71 |
+
noise.get('tension', 0)
|
| 72 |
+
new_state['conflict_risk'] = 0.5 * new_state['tension'] + \
|
| 73 |
+
noise.get('conflict_risk', 0)
|
| 74 |
+
# Clip values
|
| 75 |
+
new_state['tension'] = np.clip(new_state['tension'], 0, 1)
|
| 76 |
+
new_state['conflict_risk'] = np.clip(new_state['conflict_risk'], 0, 1)
|
| 77 |
+
return new_state
|
| 78 |
+
|
| 79 |
+
def noise_fn(t):
|
| 80 |
+
return {
|
| 81 |
+
'tension': np.random.normal(0, 0.05),
|
| 82 |
+
'conflict_risk': np.random.normal(0, 0.05)
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
initial_state = {'tension': 0.3, 'sanctions': 0.2, 'conflict_risk': 0.1}
|
| 86 |
+
trajectories = mc_engine.run_basic_simulation(initial_state, transition_fn, noise_fn)
|
| 87 |
+
|
| 88 |
+
# Compute statistics
|
| 89 |
+
stats = mc_engine.compute_statistics(trajectories)
|
| 90 |
+
print(f" Ran {config.n_simulations} simulations")
|
| 91 |
+
print(f" Final conflict risk (mean): {stats['conflict_risk']['mean'][-1]:.3f}")
|
| 92 |
+
print(f" Final conflict risk (95% CI): [{stats['conflict_risk']['q5'][-1]:.3f}, {stats['conflict_risk']['q95'][-1]:.3f}]")
|
| 93 |
+
|
| 94 |
+
# 4. Bayesian belief updating
|
| 95 |
+
print("\n4. Bayesian belief updating...")
|
| 96 |
+
updater = BeliefUpdater()
|
| 97 |
+
|
| 98 |
+
# Initialize belief about conflict risk
|
| 99 |
+
updater.initialize_belief(
|
| 100 |
+
name='conflict_risk',
|
| 101 |
+
prior_mean=0.3,
|
| 102 |
+
prior_std=0.1,
|
| 103 |
+
belief_type='probability'
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Receive intelligence report suggesting higher risk
|
| 107 |
+
print(" Received intelligence: conflict risk = 0.6 (reliability: 0.7)")
|
| 108 |
+
posterior = updater.update_from_intelligence(
|
| 109 |
+
belief='conflict_risk',
|
| 110 |
+
observation=0.6,
|
| 111 |
+
reliability=0.7
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
print(f" Updated belief - Mean: {posterior['mean']:.3f}, Std: {posterior['std']:.3f}")
|
| 115 |
+
print(f" 95% Credible Interval: [{posterior['q5']:.3f}, {posterior['q95']:.3f}]")
|
| 116 |
+
|
| 117 |
+
# 5. Probability of high risk
|
| 118 |
+
prob_high_risk = updater.get_belief_probability(
|
| 119 |
+
'conflict_risk',
|
| 120 |
+
threshold=0.5,
|
| 121 |
+
direction='greater'
|
| 122 |
+
)
|
| 123 |
+
print(f" Probability of high risk (>0.5): {prob_high_risk:.3f}")
|
| 124 |
+
|
| 125 |
+
print("\n" + "=" * 80)
|
| 126 |
+
print("Example completed successfully!")
|
| 127 |
+
print("=" * 80)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
main()
|
examples/02_data_ingestion.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example 2: Data Ingestion - PDF and Web Scraping
|
| 3 |
+
|
| 4 |
+
This example demonstrates:
|
| 5 |
+
- PDF document reading and processing
|
| 6 |
+
- Web article extraction
|
| 7 |
+
- News aggregation
|
| 8 |
+
- Intelligence extraction from documents
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
sys.path.append('..')
|
| 13 |
+
|
| 14 |
+
from geobot.data_ingestion.pdf_reader import PDFReader, PDFProcessor
|
| 15 |
+
from geobot.data_ingestion.web_scraper import WebScraper, ArticleExtractor, NewsAggregator
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def demo_pdf_processing():
|
| 19 |
+
"""Demonstrate PDF processing capabilities."""
|
| 20 |
+
print("\n" + "=" * 80)
|
| 21 |
+
print("PDF Processing Demo")
|
| 22 |
+
print("=" * 80)
|
| 23 |
+
|
| 24 |
+
# Create PDF processor
|
| 25 |
+
processor = PDFProcessor()
|
| 26 |
+
|
| 27 |
+
print("\nPDF processing capabilities:")
|
| 28 |
+
print("- Text extraction from PDFs")
|
| 29 |
+
print("- Table extraction")
|
| 30 |
+
print("- Metadata extraction")
|
| 31 |
+
print("- Entity recognition (countries, organizations)")
|
| 32 |
+
print("- Keyword extraction")
|
| 33 |
+
print("- Risk assessment")
|
| 34 |
+
print("\nTo use: processor.process_document('path/to/document.pdf')")
|
| 35 |
+
|
| 36 |
+
# Example code structure
|
| 37 |
+
example_code = """
|
| 38 |
+
# Process a single PDF
|
| 39 |
+
result = processor.process_document('intelligence_report.pdf')
|
| 40 |
+
|
| 41 |
+
print(f"Title: {result['metadata'].get('title', 'Unknown')}")
|
| 42 |
+
print(f"Pages: {result['num_pages']}")
|
| 43 |
+
print(f"Keywords: {result['keywords']}")
|
| 44 |
+
print(f"Risk Level: {result['intelligence']['risk_level']}")
|
| 45 |
+
|
| 46 |
+
# Process multiple PDFs
|
| 47 |
+
results = processor.batch_process('reports_directory/', '*.pdf')
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
print("\nExample usage:")
|
| 51 |
+
print(example_code)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def demo_web_scraping():
|
| 55 |
+
"""Demonstrate web scraping capabilities."""
|
| 56 |
+
print("\n" + "=" * 80)
|
| 57 |
+
print("Web Scraping Demo")
|
| 58 |
+
print("=" * 80)
|
| 59 |
+
|
| 60 |
+
# Create article extractor
|
| 61 |
+
extractor = ArticleExtractor()
|
| 62 |
+
|
| 63 |
+
print("\nWeb scraping capabilities:")
|
| 64 |
+
print("- Extract articles from URLs")
|
| 65 |
+
print("- Clean HTML content")
|
| 66 |
+
print("- Extract metadata (author, date, etc.)")
|
| 67 |
+
print("- Multiple extraction methods (newspaper3k, trafilatura, BeautifulSoup)")
|
| 68 |
+
|
| 69 |
+
# Example with a well-known news site (without actually fetching)
|
| 70 |
+
example_url = "https://www.example.com/geopolitical-analysis"
|
| 71 |
+
|
| 72 |
+
print(f"\nExample: Extracting article from {example_url}")
|
| 73 |
+
print("(This is a demonstration - no actual web request is made)")
|
| 74 |
+
|
| 75 |
+
example_code = """
|
| 76 |
+
# Extract article
|
| 77 |
+
article = extractor.extract_article(url)
|
| 78 |
+
|
| 79 |
+
print(f"Title: {article['title']}")
|
| 80 |
+
print(f"Author: {article['authors']}")
|
| 81 |
+
print(f"Published: {article['publish_date']}")
|
| 82 |
+
print(f"Content length: {len(article['text'])} characters")
|
| 83 |
+
|
| 84 |
+
# Extract multiple articles
|
| 85 |
+
urls = ['url1', 'url2', 'url3']
|
| 86 |
+
articles = extractor.batch_extract(urls)
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
print("\nExample usage:")
|
| 90 |
+
print(example_code)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def demo_news_aggregation():
|
| 94 |
+
"""Demonstrate news aggregation capabilities."""
|
| 95 |
+
print("\n" + "=" * 80)
|
| 96 |
+
print("News Aggregation Demo")
|
| 97 |
+
print("=" * 80)
|
| 98 |
+
|
| 99 |
+
aggregator = NewsAggregator()
|
| 100 |
+
|
| 101 |
+
print("\nNews aggregation capabilities:")
|
| 102 |
+
print("- Aggregate from multiple sources")
|
| 103 |
+
print("- RSS feed support")
|
| 104 |
+
print("- Keyword filtering")
|
| 105 |
+
print("- Trending topic detection")
|
| 106 |
+
print("- Real-time monitoring")
|
| 107 |
+
|
| 108 |
+
# Example configuration
|
| 109 |
+
print("\nExample: Setting up news aggregation")
|
| 110 |
+
|
| 111 |
+
example_code = """
|
| 112 |
+
# Add news sources
|
| 113 |
+
aggregator.add_source(
|
| 114 |
+
name='Reuters',
|
| 115 |
+
url='https://www.reuters.com/news/world',
|
| 116 |
+
source_type='rss'
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
aggregator.add_source(
|
| 120 |
+
name='Al Jazeera',
|
| 121 |
+
url='https://www.aljazeera.com/xml/rss/all.xml',
|
| 122 |
+
source_type='rss'
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Fetch news with keywords
|
| 126 |
+
keywords = ['sanctions', 'conflict', 'diplomacy', 'military']
|
| 127 |
+
articles = aggregator.fetch_news(keywords)
|
| 128 |
+
|
| 129 |
+
print(f"Found {len(articles)} relevant articles")
|
| 130 |
+
|
| 131 |
+
# Get trending topics
|
| 132 |
+
topics = aggregator.get_trending_topics(articles, n_topics=10)
|
| 133 |
+
print("Trending topics:", topics)
|
| 134 |
+
|
| 135 |
+
# Monitor sources continuously
|
| 136 |
+
def alert_callback(new_articles):
|
| 137 |
+
print(f"ALERT: {len(new_articles)} new relevant articles found")
|
| 138 |
+
for article in new_articles:
|
| 139 |
+
print(f" - {article['title']}")
|
| 140 |
+
|
| 141 |
+
# Monitor every hour
|
| 142 |
+
aggregator.monitor_sources(keywords, callback=alert_callback, interval=3600)
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
print(example_code)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def demo_intelligence_extraction():
|
| 149 |
+
"""Demonstrate intelligence extraction from documents."""
|
| 150 |
+
print("\n" + "=" * 80)
|
| 151 |
+
print("Intelligence Extraction Demo")
|
| 152 |
+
print("=" * 80)
|
| 153 |
+
|
| 154 |
+
print("\nIntelligence extraction capabilities:")
|
| 155 |
+
print("- Country and organization detection")
|
| 156 |
+
print("- Conflict indicator detection")
|
| 157 |
+
print("- Risk level assessment")
|
| 158 |
+
print("- Document classification")
|
| 159 |
+
print("- Key phrase extraction")
|
| 160 |
+
|
| 161 |
+
example_code = """
|
| 162 |
+
processor = PDFProcessor()
|
| 163 |
+
|
| 164 |
+
# Extract intelligence from PDF
|
| 165 |
+
intel = processor.extract_intelligence('report.pdf')
|
| 166 |
+
|
| 167 |
+
print("Intelligence Summary:")
|
| 168 |
+
print(f"Risk Level: {intel['intelligence']['risk_level']}")
|
| 169 |
+
print(f"Countries mentioned: {intel['intelligence']['mentioned_countries']}")
|
| 170 |
+
print(f"Conflict indicators: {intel['intelligence']['conflict_indicators']}")
|
| 171 |
+
print(f"Key topics: {intel['intelligence']['key_topics']}")
|
| 172 |
+
print(f"Document type: {intel['intelligence']['document_type']}")
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
print("\nExample usage:")
|
| 176 |
+
print(example_code)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def main():
|
| 180 |
+
print("=" * 80)
|
| 181 |
+
print("GeoBotv1 - Data Ingestion Examples")
|
| 182 |
+
print("=" * 80)
|
| 183 |
+
print("\nThis module demonstrates the data ingestion capabilities of GeoBotv1:")
|
| 184 |
+
print("1. PDF document processing")
|
| 185 |
+
print("2. Web scraping and article extraction")
|
| 186 |
+
print("3. News aggregation from multiple sources")
|
| 187 |
+
print("4. Intelligence extraction from documents")
|
| 188 |
+
|
| 189 |
+
demo_pdf_processing()
|
| 190 |
+
demo_web_scraping()
|
| 191 |
+
demo_news_aggregation()
|
| 192 |
+
demo_intelligence_extraction()
|
| 193 |
+
|
| 194 |
+
print("\n" + "=" * 80)
|
| 195 |
+
print("Data Ingestion Demo Complete")
|
| 196 |
+
print("=" * 80)
|
| 197 |
+
print("\nNote: Install required packages for full functionality:")
|
| 198 |
+
print(" pip install pypdf pdfplumber beautifulsoup4 newspaper3k trafilatura")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
main()
|
examples/03_intervention_simulation.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example 3: Intervention Simulation and Counterfactual Analysis
|
| 3 |
+
|
| 4 |
+
This example demonstrates:
|
| 5 |
+
- Policy intervention simulation
|
| 6 |
+
- Counterfactual reasoning ("what if" scenarios)
|
| 7 |
+
- Comparing multiple interventions
|
| 8 |
+
- Optimal intervention finding
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.append('..')
|
| 14 |
+
|
| 15 |
+
from geobot.models.causal_graph import CausalGraph, StructuralCausalModel
|
| 16 |
+
from geobot.inference.do_calculus import DoCalculus, InterventionSimulator
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def create_geopolitical_scm():
|
| 20 |
+
"""Create a structural causal model for geopolitical scenarios."""
|
| 21 |
+
print("\n1. Creating Structural Causal Model...")
|
| 22 |
+
|
| 23 |
+
# Create causal graph
|
| 24 |
+
graph = CausalGraph(name="geopolitical_system")
|
| 25 |
+
|
| 26 |
+
# Add nodes
|
| 27 |
+
graph.add_node('economic_sanctions', node_type='policy')
|
| 28 |
+
graph.add_node('diplomatic_pressure', node_type='policy')
|
| 29 |
+
graph.add_node('domestic_stability', node_type='state')
|
| 30 |
+
graph.add_node('military_mobilization', node_type='state')
|
| 31 |
+
graph.add_node('conflict_probability', node_type='outcome')
|
| 32 |
+
|
| 33 |
+
# Add causal edges
|
| 34 |
+
graph.add_edge('economic_sanctions', 'domestic_stability',
|
| 35 |
+
strength=-0.6, mechanism="Sanctions reduce stability")
|
| 36 |
+
graph.add_edge('diplomatic_pressure', 'domestic_stability',
|
| 37 |
+
strength=-0.3, mechanism="Pressure affects stability")
|
| 38 |
+
graph.add_edge('domestic_stability', 'military_mobilization',
|
| 39 |
+
strength=-0.7, mechanism="Instability drives mobilization")
|
| 40 |
+
graph.add_edge('military_mobilization', 'conflict_probability',
|
| 41 |
+
strength=0.8, mechanism="Mobilization increases conflict risk")
|
| 42 |
+
graph.add_edge('economic_sanctions', 'conflict_probability',
|
| 43 |
+
strength=0.4, mechanism="Direct deterrence effect")
|
| 44 |
+
|
| 45 |
+
print(f" Created graph with {len(graph.graph.nodes)} nodes and {len(graph.edges)} edges")
|
| 46 |
+
|
| 47 |
+
# Create SCM
|
| 48 |
+
scm = StructuralCausalModel(graph)
|
| 49 |
+
|
| 50 |
+
# Define structural equations
|
| 51 |
+
def sanctions_fn(parents, noise):
|
| 52 |
+
return 0.5 + noise # Baseline policy level
|
| 53 |
+
|
| 54 |
+
def pressure_fn(parents, noise):
|
| 55 |
+
return 0.3 + noise
|
| 56 |
+
|
| 57 |
+
def stability_fn(parents, noise):
|
| 58 |
+
sanctions = parents.get('economic_sanctions', np.zeros(1))[0]
|
| 59 |
+
pressure = parents.get('diplomatic_pressure', np.zeros(1))[0]
|
| 60 |
+
return np.clip(0.7 - 0.6 * sanctions - 0.3 * pressure + noise, 0, 1)
|
| 61 |
+
|
| 62 |
+
def mobilization_fn(parents, noise):
|
| 63 |
+
stability = parents.get('domestic_stability', np.zeros(1))[0]
|
| 64 |
+
return np.clip(0.3 - 0.7 * stability + noise, 0, 1)
|
| 65 |
+
|
| 66 |
+
def conflict_fn(parents, noise):
|
| 67 |
+
mobilization = parents.get('military_mobilization', np.zeros(1))[0]
|
| 68 |
+
sanctions = parents.get('economic_sanctions', np.zeros(1))[0]
|
| 69 |
+
return np.clip(0.8 * mobilization + 0.4 * sanctions + noise, 0, 1)
|
| 70 |
+
|
| 71 |
+
# Set functions
|
| 72 |
+
from scipy import stats
|
| 73 |
+
scm.set_function('economic_sanctions', sanctions_fn, stats.norm(0, 0.1))
|
| 74 |
+
scm.set_function('diplomatic_pressure', pressure_fn, stats.norm(0, 0.1))
|
| 75 |
+
scm.set_function('domestic_stability', stability_fn, stats.norm(0, 0.05))
|
| 76 |
+
scm.set_function('military_mobilization', mobilization_fn, stats.norm(0, 0.05))
|
| 77 |
+
scm.set_function('conflict_probability', conflict_fn, stats.norm(0, 0.05))
|
| 78 |
+
|
| 79 |
+
print(" Structural equations defined")
|
| 80 |
+
|
| 81 |
+
return scm
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def simulate_baseline(simulator):
|
| 85 |
+
"""Simulate baseline (no intervention) scenario."""
|
| 86 |
+
print("\n2. Simulating Baseline Scenario...")
|
| 87 |
+
|
| 88 |
+
baseline = simulator.simulate_intervention(
|
| 89 |
+
intervention={},
|
| 90 |
+
n_samples=1000,
|
| 91 |
+
outcomes=['conflict_probability']
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
conflict_mean = np.mean(baseline['conflict_probability'])
|
| 95 |
+
conflict_std = np.std(baseline['conflict_probability'])
|
| 96 |
+
|
| 97 |
+
print(f" Baseline conflict probability: {conflict_mean:.3f} ± {conflict_std:.3f}")
|
| 98 |
+
|
| 99 |
+
return baseline
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def simulate_interventions(simulator):
|
| 103 |
+
"""Simulate different policy interventions."""
|
| 104 |
+
print("\n3. Simulating Policy Interventions...")
|
| 105 |
+
|
| 106 |
+
interventions = [
|
| 107 |
+
{'economic_sanctions': 0.8, 'diplomatic_pressure': 0.3}, # Heavy sanctions
|
| 108 |
+
{'economic_sanctions': 0.3, 'diplomatic_pressure': 0.8}, # Heavy diplomacy
|
| 109 |
+
{'economic_sanctions': 0.6, 'diplomatic_pressure': 0.6}, # Balanced approach
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
intervention_names = [
|
| 113 |
+
"Heavy Sanctions",
|
| 114 |
+
"Heavy Diplomacy",
|
| 115 |
+
"Balanced Approach"
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
results = simulator.compare_interventions(
|
| 119 |
+
interventions,
|
| 120 |
+
outcome='conflict_probability',
|
| 121 |
+
n_samples=1000
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
print("\n Intervention Results:")
|
| 125 |
+
print(" " + "-" * 60)
|
| 126 |
+
|
| 127 |
+
for i, name in enumerate(intervention_names):
|
| 128 |
+
result = results[f'intervention_{i}']
|
| 129 |
+
print(f"\n {name}:")
|
| 130 |
+
print(f" Mean conflict probability: {result['mean']:.3f}")
|
| 131 |
+
print(f" Std deviation: {result['std']:.3f}")
|
| 132 |
+
print(f" 95% CI: [{result['q25']:.3f}, {result['q75']:.3f}]")
|
| 133 |
+
|
| 134 |
+
return results
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def find_optimal_intervention(simulator):
|
| 138 |
+
"""Find optimal intervention to minimize conflict."""
|
| 139 |
+
print("\n4. Finding Optimal Intervention...")
|
| 140 |
+
|
| 141 |
+
optimal = simulator.optimal_intervention(
|
| 142 |
+
target_var='conflict_probability',
|
| 143 |
+
intervention_vars=['economic_sanctions', 'diplomatic_pressure'],
|
| 144 |
+
intervention_ranges={
|
| 145 |
+
'economic_sanctions': (0.0, 1.0),
|
| 146 |
+
'diplomatic_pressure': (0.0, 1.0)
|
| 147 |
+
},
|
| 148 |
+
objective='minimize',
|
| 149 |
+
n_trials=50,
|
| 150 |
+
n_samples=1000
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
print(f"\n Optimal intervention found:")
|
| 154 |
+
print(f" Economic Sanctions: {optimal['optimal_intervention']['economic_sanctions']:.3f}")
|
| 155 |
+
print(f" Diplomatic Pressure: {optimal['optimal_intervention']['diplomatic_pressure']:.3f}")
|
| 156 |
+
print(f" Expected conflict probability: {optimal['optimal_value']:.3f}")
|
| 157 |
+
|
| 158 |
+
return optimal
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def counterfactual_analysis(simulator):
|
| 162 |
+
"""Perform counterfactual analysis."""
|
| 163 |
+
print("\n5. Counterfactual Analysis...")
|
| 164 |
+
|
| 165 |
+
# Observed scenario
|
| 166 |
+
observed = {
|
| 167 |
+
'economic_sanctions': 0.7,
|
| 168 |
+
'diplomatic_pressure': 0.2,
|
| 169 |
+
'domestic_stability': 0.4,
|
| 170 |
+
'military_mobilization': 0.6,
|
| 171 |
+
'conflict_probability': 0.65
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
print("\n Observed scenario:")
|
| 175 |
+
print(f" Sanctions: {observed['economic_sanctions']}")
|
| 176 |
+
print(f" Diplomacy: {observed['diplomatic_pressure']}")
|
| 177 |
+
print(f" Conflict: {observed['conflict_probability']}")
|
| 178 |
+
|
| 179 |
+
# Counterfactual: What if we had used more diplomacy?
|
| 180 |
+
counterfactual_intervention = {
|
| 181 |
+
'diplomatic_pressure': 0.8,
|
| 182 |
+
'economic_sanctions': 0.3
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
result = simulator.counterfactual_analysis(
|
| 186 |
+
observed=observed,
|
| 187 |
+
intervention=counterfactual_intervention,
|
| 188 |
+
outcome='conflict_probability'
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
print("\n Counterfactual: 'What if we had emphasized diplomacy?'")
|
| 192 |
+
print(f" Counterfactual conflict: {result['counterfactual_outcome']:.3f}")
|
| 193 |
+
print(f" Effect of intervention: {result['effect']:.3f}")
|
| 194 |
+
|
| 195 |
+
if result['effect'] < 0:
|
| 196 |
+
print(f" Conclusion: Diplomacy would have REDUCED conflict by {abs(result['effect']):.3f}")
|
| 197 |
+
else:
|
| 198 |
+
print(f" Conclusion: Diplomacy would have INCREASED conflict by {result['effect']:.3f}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def main():
|
| 202 |
+
print("=" * 80)
|
| 203 |
+
print("GeoBotv1 - Intervention Simulation & Counterfactual Analysis")
|
| 204 |
+
print("=" * 80)
|
| 205 |
+
print("\nThis example demonstrates answering 'what if' questions:")
|
| 206 |
+
print("- 'What if the U.S. increases sanctions?'")
|
| 207 |
+
print("- 'What if we emphasize diplomacy over sanctions?'")
|
| 208 |
+
print("- 'What is the optimal policy mix?'")
|
| 209 |
+
print("- 'What would have happened if we had acted differently?'")
|
| 210 |
+
|
| 211 |
+
# Create SCM
|
| 212 |
+
scm = create_geopolitical_scm()
|
| 213 |
+
|
| 214 |
+
# Create intervention simulator
|
| 215 |
+
simulator = InterventionSimulator(scm)
|
| 216 |
+
|
| 217 |
+
# Run analyses
|
| 218 |
+
baseline = simulate_baseline(simulator)
|
| 219 |
+
interventions = simulate_interventions(simulator)
|
| 220 |
+
optimal = find_optimal_intervention(simulator)
|
| 221 |
+
counterfactual_analysis(simulator)
|
| 222 |
+
|
| 223 |
+
print("\n" + "=" * 80)
|
| 224 |
+
print("Key Insights:")
|
| 225 |
+
print("=" * 80)
|
| 226 |
+
print("\n1. Different interventions have different effects on conflict probability")
|
| 227 |
+
print("2. Optimal policy can be discovered through systematic search")
|
| 228 |
+
print("3. Counterfactual reasoning enables learning from alternative scenarios")
|
| 229 |
+
print("4. Causal models enable principled 'what if' analysis")
|
| 230 |
+
|
| 231 |
+
print("\n" + "=" * 80)
|
| 232 |
+
print("Example completed successfully!")
|
| 233 |
+
print("=" * 80)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
main()
|
examples/04_advanced_features.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example 4: Advanced Mathematical Features
|
| 3 |
+
|
| 4 |
+
This example demonstrates the research-grade advanced features:
|
| 5 |
+
- Sequential Monte Carlo (particle filtering)
|
| 6 |
+
- Variational Inference
|
| 7 |
+
- Stochastic Differential Equations (SDEs)
|
| 8 |
+
- Gradient-based Optimal Transport
|
| 9 |
+
- Kantorovich Duality
|
| 10 |
+
- Event Extraction and Database
|
| 11 |
+
- Continuous-time dynamics
|
| 12 |
+
|
| 13 |
+
These features enable measure-theoretic, rigorous forecasting.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import sys
|
| 18 |
+
sys.path.append('..')
|
| 19 |
+
|
| 20 |
+
from datetime import datetime, timedelta
|
| 21 |
+
from scipy import stats
|
| 22 |
+
|
| 23 |
+
# Advanced inference
|
| 24 |
+
from geobot.inference.particle_filter import SequentialMonteCarlo
|
| 25 |
+
from geobot.inference.variational_inference import VariationalInference
|
| 26 |
+
|
| 27 |
+
# SDE solvers
|
| 28 |
+
from geobot.simulation.sde_solver import (
|
| 29 |
+
EulerMaruyama,
|
| 30 |
+
Milstein,
|
| 31 |
+
JumpDiffusionProcess,
|
| 32 |
+
GeopoliticalSDE
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Advanced optimal transport
|
| 36 |
+
from geobot.core.advanced_optimal_transport import (
|
| 37 |
+
KantorovichDuality,
|
| 38 |
+
EntropicOT,
|
| 39 |
+
GradientBasedOT
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Event extraction
|
| 43 |
+
from geobot.data_ingestion.event_extraction import EventExtractor, EventType
|
| 44 |
+
from geobot.data_ingestion.event_database import EventDatabase
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def demo_particle_filter():
|
| 48 |
+
"""Demonstrate Sequential Monte Carlo / Particle Filter."""
|
| 49 |
+
print("\n" + "="*80)
|
| 50 |
+
print("1. Sequential Monte Carlo (Particle Filter)")
|
| 51 |
+
print("="*80)
|
| 52 |
+
|
| 53 |
+
# Define nonlinear dynamics
|
| 54 |
+
def dynamics_fn(x, noise):
|
| 55 |
+
# Nonlinear geopolitical dynamics
|
| 56 |
+
# x[0] = tension, x[1] = stability
|
| 57 |
+
tension = x[0]
|
| 58 |
+
stability = x[1]
|
| 59 |
+
|
| 60 |
+
new_tension = tension + 0.1 * (1 - stability) + noise[0]
|
| 61 |
+
new_stability = stability - 0.05 * tension + noise[1]
|
| 62 |
+
|
| 63 |
+
return np.array([
|
| 64 |
+
np.clip(new_tension, 0, 1),
|
| 65 |
+
np.clip(new_stability, 0, 1)
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
def observation_fn(y, x):
|
| 69 |
+
# Log-likelihood of observation given state
|
| 70 |
+
# Observe tension with noise
|
| 71 |
+
predicted = x[0]
|
| 72 |
+
return stats.norm.logpdf(y[0], loc=predicted, scale=0.1)
|
| 73 |
+
|
| 74 |
+
# Create particle filter
|
| 75 |
+
pf = SequentialMonteCarlo(
|
| 76 |
+
n_particles=500,
|
| 77 |
+
state_dim=2,
|
| 78 |
+
dynamics_fn=dynamics_fn,
|
| 79 |
+
observation_fn=observation_fn
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Initialize from prior
|
| 83 |
+
pf.initialize_from_prior(lambda: np.array([0.3, 0.7]))
|
| 84 |
+
|
| 85 |
+
# Generate synthetic observations
|
| 86 |
+
observations = np.array([
|
| 87 |
+
[0.35], [0.40], [0.45], [0.50], [0.55],
|
| 88 |
+
[0.60], [0.65], [0.70], [0.75], [0.80]
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
print(f"\nRunning particle filter with {pf.n_particles} particles...")
|
| 92 |
+
print("Tracking hidden geopolitical state from noisy observations\n")
|
| 93 |
+
|
| 94 |
+
# Filter
|
| 95 |
+
states = pf.filter(observations)
|
| 96 |
+
|
| 97 |
+
# Show results
|
| 98 |
+
for i, state in enumerate(states[-5:]): # Last 5 steps
|
| 99 |
+
mean, cov = pf.get_state_estimate()
|
| 100 |
+
print(f"Step {i+6}: Tension={mean[0]:.3f}±{np.sqrt(cov[0,0]):.3f}, "
|
| 101 |
+
f"Stability={mean[1]:.3f}±{np.sqrt(cov[1,1]):.3f}, "
|
| 102 |
+
f"ESS={state.ess:.1f}")
|
| 103 |
+
|
| 104 |
+
print("\n✓ Particle filter successfully tracked nonlinear hidden states!")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def demo_sde_solver():
|
| 108 |
+
"""Demonstrate Stochastic Differential Equations."""
|
| 109 |
+
print("\n" + "="*80)
|
| 110 |
+
print("2. Stochastic Differential Equations (Continuous-Time Dynamics)")
|
| 111 |
+
print("="*80)
|
| 112 |
+
|
| 113 |
+
# Define SDE: dx = f(x,t)dt + g(x,t)dW
|
| 114 |
+
def drift(x, t):
|
| 115 |
+
# Mean-reverting to 0.5 (long-term stability)
|
| 116 |
+
return 0.2 * (0.5 - x)
|
| 117 |
+
|
| 118 |
+
def diffusion(x, t):
|
| 119 |
+
# Volatility increases with tension
|
| 120 |
+
return 0.1 * (1 + x)
|
| 121 |
+
|
| 122 |
+
# Create SDE solver
|
| 123 |
+
solver = EulerMaruyama(
|
| 124 |
+
drift=drift,
|
| 125 |
+
diffusion=diffusion,
|
| 126 |
+
x0=np.array([0.7]), # Start with high tension
|
| 127 |
+
t0=0.0
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
print("\nSimulating continuous-time geopolitical tension dynamics...")
|
| 131 |
+
print("SDE: dx = 0.2(0.5 - x)dt + 0.1(1 + x)dW\n")
|
| 132 |
+
|
| 133 |
+
# Integrate
|
| 134 |
+
solution = solver.integrate(T=10.0, dt=0.01, n_paths=5)
|
| 135 |
+
|
| 136 |
+
# Show statistics
|
| 137 |
+
final_values = solution.x[:, -1, 0]
|
| 138 |
+
print(f"After T=10.0 time units:")
|
| 139 |
+
print(f" Mean tension: {np.mean(final_values):.3f}")
|
| 140 |
+
print(f" Std deviation: {np.std(final_values):.3f}")
|
| 141 |
+
print(f" Min/Max: [{np.min(final_values):.3f}, {np.max(final_values):.3f}]")
|
| 142 |
+
|
| 143 |
+
print("\n✓ SDE solver successfully simulated continuous-time dynamics!")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def demo_jump_diffusion():
|
| 147 |
+
"""Demonstrate Jump-Diffusion Process."""
|
| 148 |
+
print("\n" + "="*80)
|
| 149 |
+
print("3. Jump-Diffusion Process (Modeling Black Swan Events)")
|
| 150 |
+
print("="*80)
|
| 151 |
+
|
| 152 |
+
# Create jump-diffusion process
|
| 153 |
+
jdp = JumpDiffusionProcess(
|
| 154 |
+
drift=0.05, # Slow drift
|
| 155 |
+
diffusion=0.1, # Normal volatility
|
| 156 |
+
jump_intensity=0.5, # 0.5 jumps per unit time (on average)
|
| 157 |
+
jump_mean=-0.2, # Negative jumps (crises)
|
| 158 |
+
jump_std=0.1,
|
| 159 |
+
x0=np.array([0.5])
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
print("\nSimulating conflict escalation with discrete shock events...")
|
| 163 |
+
print("Model: Continuous diffusion + Poisson jumps (λ=0.5, μ=-0.2)\n")
|
| 164 |
+
|
| 165 |
+
# Simulate
|
| 166 |
+
solution = jdp.simulate(T=20.0, dt=0.1, n_paths=3)
|
| 167 |
+
|
| 168 |
+
# Count jumps (approximately)
|
| 169 |
+
for path in range(3):
|
| 170 |
+
# Detect jumps as large changes
|
| 171 |
+
diffs = np.diff(solution.x[path, :, 0])
|
| 172 |
+
n_jumps = np.sum(np.abs(diffs) > 0.15)
|
| 173 |
+
final_value = solution.x[path, -1, 0]
|
| 174 |
+
print(f"Path {path+1}: {n_jumps} jumps detected, Final value: {final_value:.3f}")
|
| 175 |
+
|
| 176 |
+
print("\n✓ Jump-diffusion successfully modeled rare shock events!")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def demo_kantorovich_duality():
|
| 180 |
+
"""Demonstrate Kantorovich Duality."""
|
| 181 |
+
print("\n" + "="*80)
|
| 182 |
+
print("4. Kantorovich Duality (Optimal Transport Theory)")
|
| 183 |
+
print("="*80)
|
| 184 |
+
|
| 185 |
+
# Create two distributions (scenarios)
|
| 186 |
+
n, m = 10, 10
|
| 187 |
+
mu = np.ones(n) / n # Uniform source
|
| 188 |
+
nu = np.ones(m) / m # Uniform target
|
| 189 |
+
|
| 190 |
+
# Cost matrix (Euclidean distance)
|
| 191 |
+
X_source = np.random.rand(n, 2)
|
| 192 |
+
X_target = np.random.rand(m, 2) + np.array([0.5, 0.5]) # Shifted
|
| 193 |
+
from scipy.spatial.distance import cdist
|
| 194 |
+
C = cdist(X_source, X_target, metric='sqeuclidean')
|
| 195 |
+
|
| 196 |
+
# Solve primal and dual
|
| 197 |
+
kantorovich = KantorovichDuality()
|
| 198 |
+
|
| 199 |
+
print("\nComputing optimal transport between two geopolitical scenarios...")
|
| 200 |
+
print(f"Source: {n} points, Target: {m} points\n")
|
| 201 |
+
|
| 202 |
+
# Primal solution
|
| 203 |
+
coupling, primal_cost = kantorovich.solve_primal(mu, nu, C, method='emd')
|
| 204 |
+
print(f"Primal optimal cost: {primal_cost:.6f}")
|
| 205 |
+
|
| 206 |
+
# Dual solution
|
| 207 |
+
f, g, dual_value = kantorovich.solve_dual(mu, nu, C, max_iter=100)
|
| 208 |
+
print(f"Dual optimal value: {dual_value:.6f}")
|
| 209 |
+
|
| 210 |
+
# Verify duality gap
|
| 211 |
+
gap = kantorovich.verify_duality_gap(mu, nu, C)
|
| 212 |
+
print(f"Duality gap: {gap:.8f} (should be ≈ 0)")
|
| 213 |
+
|
| 214 |
+
if abs(gap) < 1e-4:
|
| 215 |
+
print("\n✓ Strong duality verified! Primal = Dual")
|
| 216 |
+
else:
|
| 217 |
+
print("\n⚠ Duality gap present (numerical approximation)")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def demo_entropic_ot():
|
| 221 |
+
"""Demonstrate Entropic Optimal Transport (Sinkhorn)."""
|
| 222 |
+
print("\n" + "="*80)
|
| 223 |
+
print("5. Entropic Optimal Transport (Sinkhorn Algorithm)")
|
| 224 |
+
print("="*80)
|
| 225 |
+
|
| 226 |
+
# Create distributions
|
| 227 |
+
n, m = 20, 20
|
| 228 |
+
mu = np.random.dirichlet(np.ones(n)) # Random distribution
|
| 229 |
+
nu = np.random.dirichlet(np.ones(m))
|
| 230 |
+
|
| 231 |
+
# Cost matrix
|
| 232 |
+
X = np.random.rand(n, 2)
|
| 233 |
+
Y = np.random.rand(m, 2)
|
| 234 |
+
from scipy.spatial.distance import cdist
|
| 235 |
+
C = cdist(X, Y, metric='euclidean')
|
| 236 |
+
|
| 237 |
+
# Entropic OT with different regularization
|
| 238 |
+
epsilons = [0.01, 0.05, 0.1]
|
| 239 |
+
|
| 240 |
+
print("\nComparing regularization levels for Sinkhorn algorithm...\n")
|
| 241 |
+
|
| 242 |
+
for eps in epsilons:
|
| 243 |
+
eot = EntropicOT(epsilon=eps)
|
| 244 |
+
coupling, cost = eot.sinkhorn(mu, nu, C, max_iter=500)
|
| 245 |
+
|
| 246 |
+
print(f"ε = {eps:0.2f}: Cost = {cost:.6f}, "
|
| 247 |
+
f"Entropy = {-np.sum(coupling * np.log(coupling + 1e-10)):.4f}")
|
| 248 |
+
|
| 249 |
+
print("\n✓ Entropic OT computed with fast Sinkhorn iterations!")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def demo_event_extraction():
|
| 253 |
+
"""Demonstrate Event Extraction Pipeline."""
|
| 254 |
+
print("\n" + "="*80)
|
| 255 |
+
print("6. Structured Event Extraction from Intelligence")
|
| 256 |
+
print("="*80)
|
| 257 |
+
|
| 258 |
+
# Sample intelligence text
|
| 259 |
+
intelligence_text = """
|
| 260 |
+
On March 15, 2024, tensions escalated between the United States and China
|
| 261 |
+
following a major military mobilization in the Taiwan Strait. NATO issued
|
| 262 |
+
a statement expressing concern. Russia announced sanctions on European Union
|
| 263 |
+
member states. India maintained diplomatic neutrality while calling for
|
| 264 |
+
de-escalation talks.
|
| 265 |
+
|
| 266 |
+
The United Nations Security Council convened an emergency session on March 16,
|
| 267 |
+
2024. Economic sanctions were proposed against China by the United States,
|
| 268 |
+
but Russia exercised its veto power.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
# Extract events
|
| 272 |
+
extractor = EventExtractor()
|
| 273 |
+
|
| 274 |
+
print("\nExtracting structured events from intelligence report...\n")
|
| 275 |
+
print("Input text:")
|
| 276 |
+
print("-" * 60)
|
| 277 |
+
print(intelligence_text[:200] + "...")
|
| 278 |
+
print("-" * 60)
|
| 279 |
+
|
| 280 |
+
events = extractor.extract_events(
|
| 281 |
+
intelligence_text,
|
| 282 |
+
source="intel_report_001",
|
| 283 |
+
default_timestamp=datetime(2024, 3, 15)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
print(f"\n✓ Extracted {len(events)} geopolitical events:")
|
| 287 |
+
print()
|
| 288 |
+
|
| 289 |
+
for i, event in enumerate(events):
|
| 290 |
+
print(f"Event {i+1}:")
|
| 291 |
+
print(f" Type: {event.event_type.value}")
|
| 292 |
+
print(f" Actors: {', '.join(event.actors)}")
|
| 293 |
+
print(f" Magnitude: {event.magnitude:.2f}")
|
| 294 |
+
print(f" Timestamp: {event.timestamp.date()}")
|
| 295 |
+
print()
|
| 296 |
+
|
| 297 |
+
# Store in database
|
| 298 |
+
print("Storing events in database...")
|
| 299 |
+
with EventDatabase("demo_events.db") as db:
|
| 300 |
+
db.insert_events(events)
|
| 301 |
+
|
| 302 |
+
# Query back
|
| 303 |
+
conflict_events = db.query_events(
|
| 304 |
+
event_types=[EventType.CONFLICT, EventType.MILITARY_MOBILIZATION]
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
print(f"✓ Database contains {len(conflict_events)} conflict-related events")
|
| 308 |
+
|
| 309 |
+
print("\n✓ Event extraction and storage pipeline operational!")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def main():
|
| 313 |
+
"""Run all advanced feature demonstrations."""
|
| 314 |
+
print("=" * 80)
|
| 315 |
+
print("GeoBotv1 - Advanced Mathematical Features Demonstration")
|
| 316 |
+
print("=" * 80)
|
| 317 |
+
print("\nThis example showcases research-grade capabilities:")
|
| 318 |
+
print("• Sequential Monte Carlo (particle filtering)")
|
| 319 |
+
print("• Stochastic Differential Equations")
|
| 320 |
+
print("• Jump-Diffusion Processes")
|
| 321 |
+
print("• Kantorovich Duality in Optimal Transport")
|
| 322 |
+
print("• Entropic OT with Sinkhorn")
|
| 323 |
+
print("• Structured Event Extraction")
|
| 324 |
+
|
| 325 |
+
# Run demonstrations
|
| 326 |
+
demo_particle_filter()
|
| 327 |
+
demo_sde_solver()
|
| 328 |
+
demo_jump_diffusion()
|
| 329 |
+
demo_kantorovich_duality()
|
| 330 |
+
demo_entropic_ot()
|
| 331 |
+
demo_event_extraction()
|
| 332 |
+
|
| 333 |
+
print("\n" + "=" * 80)
|
| 334 |
+
print("All Advanced Features Demonstrated Successfully!")
|
| 335 |
+
print("=" * 80)
|
| 336 |
+
print("\nKey Insights:")
|
| 337 |
+
print("1. Particle filters handle nonlinear/non-Gaussian state estimation")
|
| 338 |
+
print("2. SDEs model continuous-time geopolitical dynamics rigorously")
|
| 339 |
+
print("3. Jump-diffusion captures both gradual change and sudden shocks")
|
| 340 |
+
print("4. Kantorovich duality provides theoretical foundation for OT")
|
| 341 |
+
print("5. Entropic OT enables fast computation via Sinkhorn")
|
| 342 |
+
print("6. Event extraction creates structured data for causal modeling")
|
| 343 |
+
|
| 344 |
+
print("\n" + "="*80)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
main()
|
examples/05_complete_framework.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example 5: Complete GeoBotv1 Framework - Final Features
|
| 3 |
+
|
| 4 |
+
This example demonstrates the final critical components that complete GeoBotv1
|
| 5 |
+
to 100% research-grade capability:
|
| 6 |
+
|
| 7 |
+
1. Vector Autoregression (VAR/SVAR/DFM) - Econometric time-series analysis
|
| 8 |
+
2. Hawkes Processes - Conflict contagion and self-exciting dynamics
|
| 9 |
+
3. Quasi-Experimental Methods - Causal inference without randomization
|
| 10 |
+
- Synthetic Control Method (SCM)
|
| 11 |
+
- Difference-in-Differences (DiD)
|
| 12 |
+
- Regression Discontinuity Design (RDD)
|
| 13 |
+
- Instrumental Variables (IV)
|
| 14 |
+
|
| 15 |
+
These methods are essential for:
|
| 16 |
+
- Multi-country forecasting with spillovers (VAR)
|
| 17 |
+
- Modeling conflict escalation and contagion (Hawkes)
|
| 18 |
+
- Estimating policy effects and counterfactuals (quasi-experimental)
|
| 19 |
+
|
| 20 |
+
GeoBotv1 is now COMPLETE with all research-grade mathematical components!
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import sys
|
| 25 |
+
sys.path.append('..')
|
| 26 |
+
|
| 27 |
+
from datetime import datetime, timedelta
|
| 28 |
+
|
| 29 |
+
# Time-series models
|
| 30 |
+
from geobot.timeseries import (
|
| 31 |
+
VARModel,
|
| 32 |
+
SVARModel,
|
| 33 |
+
DynamicFactorModel,
|
| 34 |
+
GrangerCausality,
|
| 35 |
+
UnivariateHawkesProcess,
|
| 36 |
+
MultivariateHawkesProcess,
|
| 37 |
+
ConflictContagionModel
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Quasi-experimental methods
|
| 41 |
+
from geobot.models import (
|
| 42 |
+
SyntheticControlMethod,
|
| 43 |
+
DifferenceinDifferences,
|
| 44 |
+
RegressionDiscontinuity,
|
| 45 |
+
InstrumentalVariables
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def demo_var_model():
|
| 50 |
+
"""Demonstrate Vector Autoregression for multi-country forecasting."""
|
| 51 |
+
print("\n" + "="*80)
|
| 52 |
+
print("1. Vector Autoregression (VAR) - Multi-Country Spillovers")
|
| 53 |
+
print("="*80)
|
| 54 |
+
|
| 55 |
+
# Simulate data for 3 countries
|
| 56 |
+
# Country dynamics with interdependencies
|
| 57 |
+
np.random.seed(42)
|
| 58 |
+
T = 100
|
| 59 |
+
n_vars = 3
|
| 60 |
+
|
| 61 |
+
# Generate VAR(2) data
|
| 62 |
+
# Y_t = A_1 Y_{t-1} + A_2 Y_{t-2} + noise
|
| 63 |
+
A1 = np.array([
|
| 64 |
+
[0.5, 0.2, 0.1], # Country 1: affected by all
|
| 65 |
+
[0.1, 0.6, 0.15], # Country 2: strong self-dependence
|
| 66 |
+
[0.05, 0.1, 0.55] # Country 3: weak spillovers
|
| 67 |
+
])
|
| 68 |
+
A2 = np.array([
|
| 69 |
+
[0.2, 0.05, 0.0],
|
| 70 |
+
[0.1, 0.1, 0.05],
|
| 71 |
+
[0.0, 0.05, 0.2]
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
# Simulate
|
| 75 |
+
data = np.zeros((T, n_vars))
|
| 76 |
+
data[0] = np.random.randn(n_vars) * 0.1
|
| 77 |
+
data[1] = np.random.randn(n_vars) * 0.1
|
| 78 |
+
|
| 79 |
+
for t in range(2, T):
|
| 80 |
+
data[t] = (A1 @ data[t-1] + A2 @ data[t-2] +
|
| 81 |
+
np.random.randn(n_vars) * 0.1)
|
| 82 |
+
|
| 83 |
+
print(f"\nSimulated {T} time periods for {n_vars} countries")
|
| 84 |
+
print(f"Variables: GDP growth, Military spending, Stability index\n")
|
| 85 |
+
|
| 86 |
+
# Fit VAR model
|
| 87 |
+
var = VARModel(n_lags=2)
|
| 88 |
+
variable_names = ['GDP_growth', 'Military_spend', 'Stability']
|
| 89 |
+
results = var.fit(data, variable_names)
|
| 90 |
+
|
| 91 |
+
print(f"VAR({results.n_lags}) Estimation Results:")
|
| 92 |
+
print(f" Log-likelihood: {results.log_likelihood:.2f}")
|
| 93 |
+
print(f" AIC: {results.aic:.2f}")
|
| 94 |
+
print(f" BIC: {results.bic:.2f}")
|
| 95 |
+
|
| 96 |
+
# Forecast
|
| 97 |
+
forecast = var.forecast(results, steps=10)
|
| 98 |
+
print(f"\n10-step ahead forecast:")
|
| 99 |
+
print(f" GDP growth: {forecast[-1, 0]:.3f}")
|
| 100 |
+
print(f" Military spending: {forecast[-1, 1]:.3f}")
|
| 101 |
+
print(f" Stability: {forecast[-1, 2]:.3f}")
|
| 102 |
+
|
| 103 |
+
# Granger causality
|
| 104 |
+
print("\nGranger Causality Tests:")
|
| 105 |
+
for i in range(n_vars):
|
| 106 |
+
for j in range(n_vars):
|
| 107 |
+
if i != j:
|
| 108 |
+
gc_result = var.granger_causality(results, i, j)
|
| 109 |
+
if gc_result['p_value'] < 0.05:
|
| 110 |
+
print(f" {variable_names[j]} → {variable_names[i]}: "
|
| 111 |
+
f"F={gc_result['f_statistic']:.2f}, p={gc_result['p_value']:.3f} ✓")
|
| 112 |
+
|
| 113 |
+
# Impulse response functions
|
| 114 |
+
irf_result = var.impulse_response(results, steps=10)
|
| 115 |
+
print("\nImpulse Response Functions computed (10 steps)")
|
| 116 |
+
print(f" Shock to Military spending → GDP growth at t=5: {irf_result.irf[0, 1, 5]:.4f}")
|
| 117 |
+
|
| 118 |
+
# Forecast error variance decomposition
|
| 119 |
+
fevd = var.forecast_error_variance_decomposition(results, steps=10)
|
| 120 |
+
print("\nForecast Error Variance Decomposition (horizon=10):")
|
| 121 |
+
for i, var_name in enumerate(variable_names):
|
| 122 |
+
contributions = fevd[i, :, -1]
|
| 123 |
+
print(f" {var_name} variance explained by:")
|
| 124 |
+
for j, source_name in enumerate(variable_names):
|
| 125 |
+
print(f" {source_name}: {contributions[j]:.1%}")
|
| 126 |
+
|
| 127 |
+
print("\n✓ VAR model demonstrates multi-country interdependencies!")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def demo_hawkes_process():
|
| 131 |
+
"""Demonstrate Hawkes processes for conflict contagion."""
|
| 132 |
+
print("\n" + "="*80)
|
| 133 |
+
print("2. Hawkes Processes - Conflict Escalation and Contagion")
|
| 134 |
+
print("="*80)
|
| 135 |
+
|
| 136 |
+
# Simulate conflict events
|
| 137 |
+
print("\nSimulating conflict events with self-excitation...")
|
| 138 |
+
hawkes = UnivariateHawkesProcess()
|
| 139 |
+
|
| 140 |
+
# Parameters: baseline=0.3, excitation=0.6, decay=1.2
|
| 141 |
+
# Branching ratio = 0.6/1.2 = 0.5 (stable, subcritical)
|
| 142 |
+
events = hawkes.simulate(mu=0.3, alpha=0.6, beta=1.2, T=100.0)
|
| 143 |
+
|
| 144 |
+
print(f"Generated {len(events)} conflict events over 100 time units")
|
| 145 |
+
print(f"Average rate: {len(events) / 100.0:.2f} events/unit\n")
|
| 146 |
+
|
| 147 |
+
# Fit model
|
| 148 |
+
result = hawkes.fit(events, T=100.0)
|
| 149 |
+
|
| 150 |
+
print("Estimated Hawkes Parameters:")
|
| 151 |
+
print(f" Baseline intensity (μ): {result.params.mu:.3f}")
|
| 152 |
+
print(f" Excitation (α): {result.params.alpha:.3f}")
|
| 153 |
+
print(f" Decay rate (β): {result.params.beta:.3f}")
|
| 154 |
+
print(f" Branching ratio: {result.params.branching_ratio:.3f}")
|
| 155 |
+
print(f" Process is {'STABLE' if result.params.is_stable else 'EXPLOSIVE'}")
|
| 156 |
+
|
| 157 |
+
# Predict intensity
|
| 158 |
+
t_future = 105.0
|
| 159 |
+
intensity = hawkes.predict_intensity(events, result.params, t_future)
|
| 160 |
+
print(f"\nPredicted conflict intensity at t={t_future}: {intensity:.3f}")
|
| 161 |
+
|
| 162 |
+
# Multivariate: conflict contagion between countries
|
| 163 |
+
print("\n" + "-"*80)
|
| 164 |
+
print("Multivariate Hawkes: Cross-Country Conflict Contagion")
|
| 165 |
+
print("-"*80)
|
| 166 |
+
|
| 167 |
+
countries = ['Syria', 'Iraq', 'Lebanon']
|
| 168 |
+
contagion_model = ConflictContagionModel(countries=countries)
|
| 169 |
+
|
| 170 |
+
# Simulate with cross-excitation
|
| 171 |
+
mu = np.array([0.5, 0.3, 0.2]) # Different baseline rates
|
| 172 |
+
alpha = np.array([
|
| 173 |
+
[0.3, 0.15, 0.1], # Syria: high self-excitation, moderate contagion
|
| 174 |
+
[0.2, 0.25, 0.1], # Iraq: affected by Syria
|
| 175 |
+
[0.15, 0.1, 0.2] # Lebanon: affected by both
|
| 176 |
+
])
|
| 177 |
+
beta = np.ones((3, 3)) * 1.5
|
| 178 |
+
|
| 179 |
+
multi_hawkes = MultivariateHawkesProcess(n_dimensions=3)
|
| 180 |
+
events_multi = multi_hawkes.simulate(mu=mu, alpha=alpha, beta=beta, T=100.0)
|
| 181 |
+
|
| 182 |
+
print(f"\nSimulated events:")
|
| 183 |
+
for i, country in enumerate(countries):
|
| 184 |
+
print(f" {country}: {len(events_multi[i])} events")
|
| 185 |
+
|
| 186 |
+
# Fit multivariate model
|
| 187 |
+
events_dict = {country: events_multi[i] for i, country in enumerate(countries)}
|
| 188 |
+
fit_result = contagion_model.fit(events_dict, T=100.0)
|
| 189 |
+
|
| 190 |
+
print(f"\nFitted contagion model:")
|
| 191 |
+
print(f" Spectral radius: {fit_result['spectral_radius']:.3f} (< 1 = stable)")
|
| 192 |
+
print(f" Most contagious source: {fit_result['most_contagious_source']}")
|
| 193 |
+
print(f" Most vulnerable target: {fit_result['most_vulnerable_target']}")
|
| 194 |
+
|
| 195 |
+
# Identify contagion pathways
|
| 196 |
+
pathways = contagion_model.identify_contagion_pathways(fit_result, threshold=0.1)
|
| 197 |
+
print("\nSignificant contagion pathways (branching ratio > 0.1):")
|
| 198 |
+
for source, target, strength in pathways[:5]:
|
| 199 |
+
print(f" {source} → {target}: {strength:.3f}")
|
| 200 |
+
|
| 201 |
+
# Risk assessment
|
| 202 |
+
risks = contagion_model.contagion_risk(events_dict, fit_result, t=105.0, horizon=5.0)
|
| 203 |
+
print("\nConflict risk over next 5 time units:")
|
| 204 |
+
for country, risk in risks.items():
|
| 205 |
+
print(f" {country}: {risk:.1%}")
|
| 206 |
+
|
| 207 |
+
print("\n✓ Hawkes processes capture conflict escalation dynamics!")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def demo_synthetic_control():
|
| 211 |
+
"""Demonstrate Synthetic Control Method."""
|
| 212 |
+
print("\n" + "="*80)
|
| 213 |
+
print("3. Synthetic Control Method - Policy Impact Estimation")
|
| 214 |
+
print("="*80)
|
| 215 |
+
|
| 216 |
+
# Scenario: Estimate effect of sanctions on target country's GDP
|
| 217 |
+
print("\nScenario: Economic sanctions imposed on Country A at t=50")
|
| 218 |
+
print("Question: What is the causal effect on GDP growth?\n")
|
| 219 |
+
|
| 220 |
+
# Generate data
|
| 221 |
+
np.random.seed(42)
|
| 222 |
+
T = 100
|
| 223 |
+
J = 10 # 10 control countries
|
| 224 |
+
|
| 225 |
+
# Pre-treatment: all countries follow similar trends
|
| 226 |
+
time = np.arange(T)
|
| 227 |
+
trend = 0.02 * time + np.random.randn(T) * 0.1
|
| 228 |
+
|
| 229 |
+
# Control countries
|
| 230 |
+
control_outcomes = np.zeros((T, J))
|
| 231 |
+
for j in range(J):
|
| 232 |
+
control_outcomes[:, j] = trend + np.random.randn(T) * 0.15 + np.random.randn() * 0.5
|
| 233 |
+
|
| 234 |
+
# Treated country (matches controls pre-treatment)
|
| 235 |
+
treated_outcome = trend + np.random.randn(T) * 0.15
|
| 236 |
+
|
| 237 |
+
# Treatment effect: negative shock starting at t=50
|
| 238 |
+
treatment_time = 50
|
| 239 |
+
true_effect = -0.8
|
| 240 |
+
treated_outcome[treatment_time:] += true_effect + np.random.randn(T - treatment_time) * 0.1
|
| 241 |
+
|
| 242 |
+
# Fit SCM
|
| 243 |
+
scm = SyntheticControlMethod()
|
| 244 |
+
result = scm.fit(
|
| 245 |
+
treated_outcome=treated_outcome,
|
| 246 |
+
control_outcomes=control_outcomes,
|
| 247 |
+
treatment_time=treatment_time,
|
| 248 |
+
control_names=[f"Country_{j+1}" for j in range(J)]
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
print("Synthetic Control Results:")
|
| 252 |
+
print(f" Pre-treatment fit (RMSPE): {result.pre_treatment_fit:.4f}")
|
| 253 |
+
print(f"\nSynthetic Country A is weighted combination of:")
|
| 254 |
+
for j, weight in enumerate(result.weights):
|
| 255 |
+
if weight > 0.01: # Only show significant weights
|
| 256 |
+
print(f" {result.control_units[j]}: {weight:.1%}")
|
| 257 |
+
|
| 258 |
+
# Treatment effects
|
| 259 |
+
avg_effect = np.mean(result.treatment_effect[treatment_time:])
|
| 260 |
+
print(f"\nEstimated treatment effect (post-sanctions):")
|
| 261 |
+
print(f" Average: {avg_effect:.3f} (true effect: {true_effect:.3f})")
|
| 262 |
+
print(f" Final period: {result.treatment_effect[-1]:.3f}")
|
| 263 |
+
|
| 264 |
+
# Placebo test
|
| 265 |
+
p_value = scm.placebo_test(treated_outcome, control_outcomes, treatment_time, n_permutations=J)
|
| 266 |
+
print(f"\nPlacebo test p-value: {p_value:.3f}")
|
| 267 |
+
if p_value < 0.05:
|
| 268 |
+
print(" ✓ Effect is statistically significant (unusual compared to placebos)")
|
| 269 |
+
else:
|
| 270 |
+
print(" ✗ Effect not significant (could be random)")
|
| 271 |
+
|
| 272 |
+
print("\n✓ Synthetic control provides credible counterfactual!")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def demo_difference_in_differences():
|
| 276 |
+
"""Demonstrate Difference-in-Differences."""
|
| 277 |
+
print("\n" + "="*80)
|
| 278 |
+
print("4. Difference-in-Differences (DiD) - Regime Change Analysis")
|
| 279 |
+
print("="*80)
|
| 280 |
+
|
| 281 |
+
# Scenario: Regime change in treated country
|
| 282 |
+
print("\nScenario: Regime change in Country T at t=50")
|
| 283 |
+
print("Compare to similar countries without regime change\n")
|
| 284 |
+
|
| 285 |
+
np.random.seed(42)
|
| 286 |
+
|
| 287 |
+
# Pre-treatment (similar trends)
|
| 288 |
+
treated_pre = 3.0 + np.random.randn(50) * 0.5
|
| 289 |
+
control_pre = 3.2 + np.random.randn(50) * 0.5
|
| 290 |
+
|
| 291 |
+
# Post-treatment (treatment effect = +1.5 on outcome)
|
| 292 |
+
true_effect = 1.5
|
| 293 |
+
treated_post = 3.0 + true_effect + np.random.randn(50) * 0.5
|
| 294 |
+
control_post = 3.2 + np.random.randn(50) * 0.5 # No effect
|
| 295 |
+
|
| 296 |
+
# Estimate DiD
|
| 297 |
+
did = DifferenceinDifferences()
|
| 298 |
+
result = did.estimate(treated_pre, treated_post, control_pre, control_post)
|
| 299 |
+
|
| 300 |
+
print("Difference-in-Differences Results:")
|
| 301 |
+
print(f"\n Pre-treatment difference: {result.pre_treatment_diff:.3f}")
|
| 302 |
+
print(f" Post-treatment difference: {result.post_treatment_diff:.3f}")
|
| 303 |
+
print(f"\n Average Treatment Effect (ATT): {result.att:.3f}")
|
| 304 |
+
print(f" Standard error: {result.se:.3f}")
|
| 305 |
+
print(f" t-statistic: {result.t_stat:.3f}")
|
| 306 |
+
print(f" p-value: {result.p_value:.4f}")
|
| 307 |
+
|
| 308 |
+
if result.p_value < 0.05:
|
| 309 |
+
print(f"\n ✓ Regime change had significant effect (true effect: {true_effect:.3f})")
|
| 310 |
+
else:
|
| 311 |
+
print("\n ✗ Effect not statistically significant")
|
| 312 |
+
|
| 313 |
+
# Assumption check
|
| 314 |
+
if abs(result.pre_treatment_diff) < 0.5:
|
| 315 |
+
print("\n ✓ Parallel trends assumption plausible (small pre-treatment diff)")
|
| 316 |
+
else:
|
| 317 |
+
print("\n ⚠ Parallel trends questionable (large pre-treatment diff)")
|
| 318 |
+
|
| 319 |
+
print("\n✓ DiD isolates causal effect of regime change!")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def demo_regression_discontinuity():
|
| 323 |
+
"""Demonstrate Regression Discontinuity Design."""
|
| 324 |
+
print("\n" + "="*80)
|
| 325 |
+
print("5. Regression Discontinuity Design (RDD) - Election Effects")
|
| 326 |
+
print("="*80)
|
| 327 |
+
|
| 328 |
+
# Scenario: Effect of winning election on military policy
|
| 329 |
+
print("\nScenario: Effect of hawkish candidate winning election")
|
| 330 |
+
print("Running variable: Vote share (cutoff = 50%)")
|
| 331 |
+
print("Outcome: Military spending increase\n")
|
| 332 |
+
|
| 333 |
+
np.random.seed(42)
|
| 334 |
+
n = 500
|
| 335 |
+
|
| 336 |
+
# Vote share (running variable)
|
| 337 |
+
vote_share = np.random.uniform(0.3, 0.7, n)
|
| 338 |
+
|
| 339 |
+
# Outcome: military spending
|
| 340 |
+
# Smooth function of vote share + discontinuity at 50%
|
| 341 |
+
outcome = 2.0 + 1.5 * vote_share + np.random.randn(n) * 0.3
|
| 342 |
+
|
| 343 |
+
# Treatment effect: +0.8 if vote > 50%
|
| 344 |
+
true_effect = 0.8
|
| 345 |
+
outcome[vote_share >= 0.5] += true_effect
|
| 346 |
+
|
| 347 |
+
# Estimate RDD
|
| 348 |
+
rdd = RegressionDiscontinuity(cutoff=0.5)
|
| 349 |
+
result = rdd.estimate_sharp(
|
| 350 |
+
running_var=vote_share,
|
| 351 |
+
outcome=outcome,
|
| 352 |
+
bandwidth=0.15, # 15% bandwidth
|
| 353 |
+
kernel='triangular'
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
print("Regression Discontinuity Results:")
|
| 357 |
+
print(f"\n Bandwidth: {result.bandwidth:.3f}")
|
| 358 |
+
print(f" Observations below cutoff: {result.n_left}")
|
| 359 |
+
print(f" Observations above cutoff: {result.n_right}")
|
| 360 |
+
print(f"\n Treatment effect (LATE): {result.treatment_effect:.3f}")
|
| 361 |
+
print(f" Standard error: {result.se:.3f}")
|
| 362 |
+
print(f" t-statistic: {result.t_stat:.3f}")
|
| 363 |
+
print(f" p-value: {result.p_value:.4f}")
|
| 364 |
+
|
| 365 |
+
if result.p_value < 0.05:
|
| 366 |
+
print(f"\n ✓ Winning election causes increase in military spending")
|
| 367 |
+
print(f" (true effect: {true_effect:.3f})")
|
| 368 |
+
else:
|
| 369 |
+
print("\n ✗ Effect not statistically significant")
|
| 370 |
+
|
| 371 |
+
print("\n✓ RDD exploits threshold-based treatment assignment!")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def demo_instrumental_variables():
|
| 375 |
+
"""Demonstrate Instrumental Variables."""
|
| 376 |
+
print("\n" + "="*80)
|
| 377 |
+
print("6. Instrumental Variables (IV) - Trade and Conflict")
|
| 378 |
+
print("="*80)
|
| 379 |
+
|
| 380 |
+
# Scenario: Effect of trade on conflict (trade is endogenous)
|
| 381 |
+
print("\nScenario: Does trade reduce conflict?")
|
| 382 |
+
print("Problem: Trade is endogenous (reverse causality, omitted variables)")
|
| 383 |
+
print("Instrument: Geographic distance to major trade routes\n")
|
| 384 |
+
|
| 385 |
+
np.random.seed(42)
|
| 386 |
+
n = 300
|
| 387 |
+
|
| 388 |
+
# Instrument: distance (exogenous)
|
| 389 |
+
distance = np.random.uniform(100, 1000, n)
|
| 390 |
+
|
| 391 |
+
# Unobserved confounders
|
| 392 |
+
unobserved = np.random.randn(n)
|
| 393 |
+
|
| 394 |
+
# Trade (endogenous): affected by distance and confounders
|
| 395 |
+
trade = 50 - 0.03 * distance + 2.0 * unobserved + np.random.randn(n) * 5
|
| 396 |
+
|
| 397 |
+
# Conflict: true effect of trade = -0.15, but also affected by confounders
|
| 398 |
+
true_effect = -0.15
|
| 399 |
+
conflict = 10 + true_effect * trade - 1.5 * unobserved + np.random.randn(n) * 2
|
| 400 |
+
|
| 401 |
+
# Estimate with IV
|
| 402 |
+
iv = InstrumentalVariables()
|
| 403 |
+
result = iv.estimate_2sls(
|
| 404 |
+
outcome=conflict,
|
| 405 |
+
endogenous=trade,
|
| 406 |
+
instrument=distance
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
print("Instrumental Variables (2SLS) Results:")
|
| 410 |
+
print(f"\n First stage F-statistic: {result.first_stage_f:.2f}")
|
| 411 |
+
if result.weak_instrument:
|
| 412 |
+
print(" ⚠ Warning: Weak instrument (F < 10)")
|
| 413 |
+
else:
|
| 414 |
+
print(" ✓ Strong instrument (F > 10)")
|
| 415 |
+
|
| 416 |
+
print(f"\n OLS estimate (biased): {result.beta_ols[0]:.4f}")
|
| 417 |
+
print(f" IV estimate (consistent): {result.beta_iv[0]:.4f}")
|
| 418 |
+
print(f" IV standard error: {result.se_iv[0]:.4f}")
|
| 419 |
+
print(f"\n True causal effect: {true_effect:.4f}")
|
| 420 |
+
|
| 421 |
+
# Hausman test (informal)
|
| 422 |
+
if abs(result.beta_ols[0] - result.beta_iv[0]) > 0.05:
|
| 423 |
+
print("\n ✓ OLS and IV differ substantially → endogeneity present")
|
| 424 |
+
print(" IV corrects for bias!")
|
| 425 |
+
else:
|
| 426 |
+
print("\n OLS and IV similar → endogeneity may be small")
|
| 427 |
+
|
| 428 |
+
print("\n✓ IV isolates causal effect using exogenous variation!")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def demo_dynamic_factor_model():
|
| 432 |
+
"""Demonstrate Dynamic Factor Model for nowcasting."""
|
| 433 |
+
print("\n" + "="*80)
|
| 434 |
+
print("7. Dynamic Factor Model (DFM) - High-Dimensional Nowcasting")
|
| 435 |
+
print("="*80)
|
| 436 |
+
|
| 437 |
+
# Scenario: Nowcast geopolitical tension from many indicators
|
| 438 |
+
print("\nScenario: Nowcast regional tension from 50 economic/political indicators")
|
| 439 |
+
print("DFM extracts common latent factors driving all indicators\n")
|
| 440 |
+
|
| 441 |
+
np.random.seed(42)
|
| 442 |
+
T = 200
|
| 443 |
+
n_indicators = 50
|
| 444 |
+
n_factors = 3
|
| 445 |
+
|
| 446 |
+
# True factors (latent tensions)
|
| 447 |
+
true_factors = np.zeros((T, n_factors))
|
| 448 |
+
for k in range(n_factors):
|
| 449 |
+
# AR(1) dynamics
|
| 450 |
+
for t in range(1, T):
|
| 451 |
+
true_factors[t, k] = 0.8 * true_factors[t-1, k] + np.random.randn() * 0.5
|
| 452 |
+
|
| 453 |
+
# Factor loadings (how indicators load on factors)
|
| 454 |
+
true_loadings = np.random.randn(n_indicators, n_factors)
|
| 455 |
+
|
| 456 |
+
# Observed indicators = factors * loadings + idiosyncratic noise
|
| 457 |
+
data = true_factors @ true_loadings.T + np.random.randn(T, n_indicators) * 0.5
|
| 458 |
+
|
| 459 |
+
# Fit DFM
|
| 460 |
+
dfm = DynamicFactorModel(n_factors=3, n_lags=1)
|
| 461 |
+
model = dfm.fit(data)
|
| 462 |
+
|
| 463 |
+
print(f"Dynamic Factor Model Results:")
|
| 464 |
+
print(f"\n Number of indicators: {n_indicators}")
|
| 465 |
+
print(f" Number of factors: {n_factors}")
|
| 466 |
+
print(f" Explained variance: {model['explained_variance_ratio']:.1%}")
|
| 467 |
+
|
| 468 |
+
# Extracted factors
|
| 469 |
+
factors = model['factors']
|
| 470 |
+
print(f"\n Extracted factor dimensions: {factors.shape}")
|
| 471 |
+
print(f" Factor 1 final value: {factors[-1, 0]:.3f}")
|
| 472 |
+
print(f" Factor 2 final value: {factors[-1, 1]:.3f}")
|
| 473 |
+
print(f" Factor 3 final value: {factors[-1, 2]:.3f}")
|
| 474 |
+
|
| 475 |
+
# Forecast
|
| 476 |
+
forecast = dfm.forecast(model, steps=10)
|
| 477 |
+
print(f"\n 10-step ahead forecast dimensions: {forecast.shape}")
|
| 478 |
+
print(f" Average forecasted indicator value: {np.mean(forecast[-1]):.3f}")
|
| 479 |
+
|
| 480 |
+
# Correlation with true factors
|
| 481 |
+
corr_0 = np.corrcoef(true_factors[:, 0], factors[:, 0])[0, 1]
|
| 482 |
+
print(f"\n Factor recovery (correlation with true): {abs(corr_0):.3f}")
|
| 483 |
+
|
| 484 |
+
print("\n✓ DFM reduces dimensionality while preserving information!")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def main():
|
| 488 |
+
"""Run all demonstrations of final features."""
|
| 489 |
+
print("=" * 80)
|
| 490 |
+
print("GeoBotv1 - COMPLETE FRAMEWORK DEMONSTRATION")
|
| 491 |
+
print("=" * 80)
|
| 492 |
+
print("\nThis example showcases the final components that complete GeoBotv1:")
|
| 493 |
+
print("• Vector Autoregression (VAR/SVAR/DFM)")
|
| 494 |
+
print("• Hawkes Processes for conflict contagion")
|
| 495 |
+
print("• Quasi-Experimental Causal Inference")
|
| 496 |
+
print(" - Synthetic Control Method")
|
| 497 |
+
print(" - Difference-in-Differences")
|
| 498 |
+
print(" - Regression Discontinuity Design")
|
| 499 |
+
print(" - Instrumental Variables")
|
| 500 |
+
|
| 501 |
+
# Run all demonstrations
|
| 502 |
+
demo_var_model()
|
| 503 |
+
demo_hawkes_process()
|
| 504 |
+
demo_synthetic_control()
|
| 505 |
+
demo_difference_in_differences()
|
| 506 |
+
demo_regression_discontinuity()
|
| 507 |
+
demo_instrumental_variables()
|
| 508 |
+
demo_dynamic_factor_model()
|
| 509 |
+
|
| 510 |
+
print("\n" + "=" * 80)
|
| 511 |
+
print("GeoBotv1 Framework is NOW 100% COMPLETE!")
|
| 512 |
+
print("=" * 80)
|
| 513 |
+
print("\n🎉 All Research-Grade Mathematical Components Implemented:")
|
| 514 |
+
print("\n📊 CORE FRAMEWORKS:")
|
| 515 |
+
print(" ✓ Optimal Transport (Wasserstein, Kantorovich, Sinkhorn)")
|
| 516 |
+
print(" ✓ Causal Inference (DAGs, SCMs, Do-Calculus)")
|
| 517 |
+
print(" ✓ Bayesian Inference (MCMC, Particle Filters, VI)")
|
| 518 |
+
print(" ✓ Stochastic Processes (SDEs, Jump-Diffusion)")
|
| 519 |
+
print(" ✓ Time-Series Models (Kalman, HMM, VAR, Hawkes)")
|
| 520 |
+
print(" ✓ Quasi-Experimental Methods (SCM, DiD, RDD, IV)")
|
| 521 |
+
print(" ✓ Machine Learning (GNNs, Risk Scoring, Embeddings)")
|
| 522 |
+
print("\n📈 SPECIALIZED CAPABILITIES:")
|
| 523 |
+
print(" ✓ Multi-country interdependency modeling (VAR)")
|
| 524 |
+
print(" ✓ Conflict contagion and escalation (Hawkes)")
|
| 525 |
+
print(" ✓ Policy counterfactuals (Synthetic Control)")
|
| 526 |
+
print(" ✓ Regime change effects (Difference-in-Differences)")
|
| 527 |
+
print(" ✓ Election outcomes impact (Regression Discontinuity)")
|
| 528 |
+
print(" ✓ Trade-conflict nexus (Instrumental Variables)")
|
| 529 |
+
print(" ✓ High-dimensional nowcasting (Dynamic Factor Models)")
|
| 530 |
+
print("\n🔬 MATHEMATICAL RIGOR:")
|
| 531 |
+
print(" ✓ Measure-theoretic probability foundations")
|
| 532 |
+
print(" ✓ Continuous-time dynamics (SDEs)")
|
| 533 |
+
print(" ✓ Causal identification strategies")
|
| 534 |
+
print(" ✓ Structural econometric methods")
|
| 535 |
+
print(" ✓ Point process theory")
|
| 536 |
+
print(" ✓ Optimal transport geometry")
|
| 537 |
+
print("\n💡 GeoBotv1 is ready for production geopolitical forecasting!")
|
| 538 |
+
print("=" * 80 + "\n")
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
if __name__ == "__main__":
|
| 542 |
+
main()
|
examples/06_geobot2_analytical_framework.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeoBot 2.0 Analytical Framework Example
|
| 3 |
+
|
| 4 |
+
Demonstrates the complete GeoBot 2.0 framework for clinical systems analysis
|
| 5 |
+
with geopolitical nuance. Includes:
|
| 6 |
+
|
| 7 |
+
1. Framework overview
|
| 8 |
+
2. Analytical lenses demonstration
|
| 9 |
+
3. China Rocket Force purge analysis (example from specification)
|
| 10 |
+
4. Governance system comparison
|
| 11 |
+
5. Corruption type analysis
|
| 12 |
+
6. Non-Western military assessment
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
sys.path.insert(0, '..')
|
| 17 |
+
|
| 18 |
+
from geobot.analysis import (
|
| 19 |
+
AnalyticalEngine,
|
| 20 |
+
GeoBotFramework,
|
| 21 |
+
AnalyticalLenses,
|
| 22 |
+
GovernanceType,
|
| 23 |
+
CorruptionType
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def print_section(title):
|
| 28 |
+
"""Print formatted section header."""
|
| 29 |
+
print("\n" + "=" * 80)
|
| 30 |
+
print(f" {title}")
|
| 31 |
+
print("=" * 80 + "\n")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def example_1_framework_overview():
|
| 35 |
+
"""Demonstrate GeoBot 2.0 framework overview."""
|
| 36 |
+
print_section("Example 1: GeoBot 2.0 Framework Overview")
|
| 37 |
+
|
| 38 |
+
framework = GeoBotFramework()
|
| 39 |
+
summary = framework.get_framework_summary()
|
| 40 |
+
|
| 41 |
+
print(f"Version: {summary['version']}")
|
| 42 |
+
print(f"Description: {summary['description']}\n")
|
| 43 |
+
|
| 44 |
+
print("Core Identity:")
|
| 45 |
+
print(f" Focus: {summary['identity']['focus']}")
|
| 46 |
+
print(f" Key Shift: {summary['identity']['key_shift']}\n")
|
| 47 |
+
|
| 48 |
+
print("Integration Elements:")
|
| 49 |
+
for element in summary['identity']['integration_elements']:
|
| 50 |
+
print(f" - {element}")
|
| 51 |
+
|
| 52 |
+
print("\nTone:")
|
| 53 |
+
print(summary['tone'])
|
| 54 |
+
|
| 55 |
+
print("\nAnalytical Principles:")
|
| 56 |
+
for i, principle in enumerate(summary['principles'], 1):
|
| 57 |
+
print(f" {i}. {principle}")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def example_2_china_rocket_force_analysis():
|
| 61 |
+
"""
|
| 62 |
+
Demonstrate complete analysis using China Rocket Force purge example
|
| 63 |
+
from the GeoBot 2.0 specification.
|
| 64 |
+
"""
|
| 65 |
+
print_section("Example 2: China Rocket Force Purge Analysis")
|
| 66 |
+
|
| 67 |
+
engine = AnalyticalEngine()
|
| 68 |
+
|
| 69 |
+
query = "China removed several top Rocket Force generals. What does this mean?"
|
| 70 |
+
|
| 71 |
+
context = {
|
| 72 |
+
'governance_type': GovernanceType.AUTHORITARIAN_CENTRALIZED,
|
| 73 |
+
'corruption_type': CorruptionType.MANAGED_BOUNDED,
|
| 74 |
+
'military_system': 'Chinese PLA',
|
| 75 |
+
'scenario_description': 'Leadership purge in strategic forces',
|
| 76 |
+
'operational_context': 'Strategic nuclear forces readiness',
|
| 77 |
+
|
| 78 |
+
'summary': """The purge indicates internal accountability enforcement within strategic forces
|
| 79 |
+
command, with mixed implications for near-term readiness and decision coherence.""",
|
| 80 |
+
|
| 81 |
+
'logistics_assessment': """Rocket Force maintenance, silo integration, and inventory control
|
| 82 |
+
are likely under audit. Purges typically follow discovery of procurement irregularities or
|
| 83 |
+
readiness misreporting. However, unlike Russian corruption patterns, Chinese anti-corruption
|
| 84 |
+
campaigns since 2012 have successfully constrained (though not eliminated) defense sector
|
| 85 |
+
embezzlement. The PLA's civil-military logistics integration provides redundancy that mitigates
|
| 86 |
+
some supply chain risks.""",
|
| 87 |
+
|
| 88 |
+
'scenarios': [
|
| 89 |
+
{
|
| 90 |
+
'name': 'Routine institutional maintenance',
|
| 91 |
+
'probability': 0.50,
|
| 92 |
+
'description': 'Temporary disruption, return to baseline within 6-12 months'
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
'name': 'Deeper procurement crisis',
|
| 96 |
+
'probability': 0.30,
|
| 97 |
+
'description': 'Extended degradation of readiness reporting reliability'
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
'name': 'Factional conflict',
|
| 101 |
+
'probability': 0.15,
|
| 102 |
+
'description': 'Prolonged command instability'
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
'name': 'Major reorganization',
|
| 106 |
+
'probability': 0.05,
|
| 107 |
+
'description': 'Strategic forces restructure'
|
| 108 |
+
}
|
| 109 |
+
],
|
| 110 |
+
|
| 111 |
+
'uncertainty_factors': [
|
| 112 |
+
'Limited visibility into CCP internal dynamics and audit findings'
|
| 113 |
+
],
|
| 114 |
+
|
| 115 |
+
'signals_to_watch': [
|
| 116 |
+
'Promotion patterns (meritocratic vs. factional indicators)',
|
| 117 |
+
'Training tempo changes (satellite observable)',
|
| 118 |
+
'Procurement contract patterns',
|
| 119 |
+
'Whether purges expand beyond Rocket Force'
|
| 120 |
+
],
|
| 121 |
+
|
| 122 |
+
'comparative_notes': """Russia's similar aerospace purges (2015-2017) resulted in sustained
|
| 123 |
+
degradation because underlying corruption was never addressed—only individuals were replaced.
|
| 124 |
+
China's systemic anti-corruption infrastructure suggests different trajectory is possible."""
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
# Add governance context
|
| 128 |
+
context['governance_context'] = {
|
| 129 |
+
'trade_off': """China gains long-term institutional integrity at the cost of
|
| 130 |
+
short-term command continuity.""",
|
| 131 |
+
'context_specific_advantage': """Authoritarian systems can execute rapid leadership
|
| 132 |
+
replacement without legislative constraints, allowing faster course correction than
|
| 133 |
+
consensus-based systems. However, this creates temporary communication disruption and
|
| 134 |
+
institutional memory loss."""
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# Add corruption details
|
| 138 |
+
context['corruption_details'] = {
|
| 139 |
+
'evidence': """Evidence suggests managed corruption model rather than parasitic:
|
| 140 |
+
purges indicate the system detected and acted on problems, rather than tolerating systemic decay.
|
| 141 |
+
This is structurally different from militaries where corruption goes unaddressed.""",
|
| 142 |
+
'risk_assessment': """Purge-induced fear may cause temporary over-reporting conservatism,
|
| 143 |
+
slowing mobilization responses."""
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# Add non-Western context
|
| 147 |
+
context['non_western_context'] = {
|
| 148 |
+
'analysis_framework': """Western analysis often treats purges as pure weakness signals.
|
| 149 |
+
In Chinese institutional context, periodic purges are a maintenance mechanism for regime stability.
|
| 150 |
+
The question is whether this specific purge reflects routine enforcement or deeper structural crisis.""",
|
| 151 |
+
'key_distinction': """Indicators distinguishing routine vs. crisis:
|
| 152 |
+
- Scope: limited to RF or expanding to other services?
|
| 153 |
+
- Timing: related to specific audit cycle or sudden?
|
| 154 |
+
- Replacements: technocratic or factional?"""
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Perform analysis
|
| 158 |
+
analysis = engine.analyze(query, context)
|
| 159 |
+
print(analysis)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def example_3_governance_comparison():
|
| 163 |
+
"""Demonstrate governance system comparison."""
|
| 164 |
+
print_section("Example 3: Governance System Comparison")
|
| 165 |
+
|
| 166 |
+
engine = AnalyticalEngine()
|
| 167 |
+
|
| 168 |
+
scenario = "Rapid military mobilization in response to regional crisis"
|
| 169 |
+
|
| 170 |
+
comparison = engine.compare_governance_systems(
|
| 171 |
+
scenario=scenario,
|
| 172 |
+
authoritarian_context={},
|
| 173 |
+
democratic_context={}
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
print(comparison)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def example_4_corruption_assessment():
|
| 180 |
+
"""Demonstrate corruption type assessment."""
|
| 181 |
+
print_section("Example 4: Corruption Type Assessment")
|
| 182 |
+
|
| 183 |
+
lenses = AnalyticalLenses()
|
| 184 |
+
|
| 185 |
+
print("Corruption Type Analysis:\n")
|
| 186 |
+
|
| 187 |
+
# Assess different corruption types
|
| 188 |
+
corruption_types = [
|
| 189 |
+
(CorruptionType.PARASITIC, "Russia"),
|
| 190 |
+
(CorruptionType.MANAGED_BOUNDED, "China"),
|
| 191 |
+
(CorruptionType.INSTITUTIONALIZED_PATRONAGE, "Iran IRGC"),
|
| 192 |
+
(CorruptionType.LOW_CORRUPTION, "NATO countries")
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
for corr_type, example in corruption_types:
|
| 196 |
+
analysis = lenses.corruption.analyze(
|
| 197 |
+
corr_type,
|
| 198 |
+
operational_context="Sustained high-intensity conventional operations"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
print(f"\n{corr_type.value.upper()} ({example}):")
|
| 202 |
+
print(f" Operational Impact: {analysis['operational_impact']}")
|
| 203 |
+
print(f" Characteristics:")
|
| 204 |
+
for char in analysis['characteristics']:
|
| 205 |
+
print(f" - {char}")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def example_5_non_western_military_assessment():
|
| 209 |
+
"""Demonstrate non-Western military analysis."""
|
| 210 |
+
print_section("Example 5: Non-Western Military Assessment")
|
| 211 |
+
|
| 212 |
+
lenses = AnalyticalLenses()
|
| 213 |
+
|
| 214 |
+
militaries = ["Chinese PLA", "Russian Military", "Iranian Systems"]
|
| 215 |
+
|
| 216 |
+
for military in militaries:
|
| 217 |
+
analysis = lenses.non_western.analyze(military)
|
| 218 |
+
|
| 219 |
+
print(f"\n{military.upper()}:")
|
| 220 |
+
print(f"\nOperational Culture: {analysis['operational_culture']}")
|
| 221 |
+
|
| 222 |
+
print("\nStrengths:")
|
| 223 |
+
for strength in analysis['strengths']:
|
| 224 |
+
print(f" - {strength}")
|
| 225 |
+
|
| 226 |
+
print("\nWeaknesses:")
|
| 227 |
+
for weakness in analysis['weaknesses']:
|
| 228 |
+
print(f" - {weakness}")
|
| 229 |
+
|
| 230 |
+
print(f"\nKey Insight: {analysis['key_insight']}")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def example_6_all_lenses_integration():
|
| 234 |
+
"""Demonstrate integration of all four lenses."""
|
| 235 |
+
print_section("Example 6: All Lenses Integration")
|
| 236 |
+
|
| 237 |
+
engine = AnalyticalEngine()
|
| 238 |
+
|
| 239 |
+
print("Analytical Priorities:\n")
|
| 240 |
+
priorities = engine.get_analytical_priorities()
|
| 241 |
+
for i, priority in enumerate(priorities, 1):
|
| 242 |
+
print(f"{i}. {priority}")
|
| 243 |
+
|
| 244 |
+
print("\n" + "-" * 80)
|
| 245 |
+
|
| 246 |
+
# Quick analysis example
|
| 247 |
+
print("\nQuick Analysis Example:")
|
| 248 |
+
print("Query: 'Iran's ability to sustain proxy operations in Syria'\n")
|
| 249 |
+
|
| 250 |
+
analysis = engine.quick_analysis(
|
| 251 |
+
query="Iran's ability to sustain proxy operations in Syria",
|
| 252 |
+
governance_type=GovernanceType.AUTHORITARIAN_CENTRALIZED,
|
| 253 |
+
corruption_type=CorruptionType.INSTITUTIONALIZED_PATRONAGE,
|
| 254 |
+
military_system="Iranian Systems",
|
| 255 |
+
summary="""Iran demonstrates structural advantages in proxy coordination despite
|
| 256 |
+
conventional military limitations. IRGC Quds Force maintains effective command and control
|
| 257 |
+
through patronage networks that double as operational infrastructure.""",
|
| 258 |
+
logistics_assessment="""Supply lines to Syria rely on air bridge through Iraq and
|
| 259 |
+
maritime routes. Vulnerable to interdiction but demonstrated resilience through redundancy.
|
| 260 |
+
Sanctions impact advanced systems but not basic sustainment.""",
|
| 261 |
+
scenarios=[
|
| 262 |
+
{
|
| 263 |
+
'name': 'Sustained proxy presence',
|
| 264 |
+
'probability': 0.60,
|
| 265 |
+
'description': 'Current operational tempo maintained'
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
'name': 'Degraded operations',
|
| 269 |
+
'probability': 0.30,
|
| 270 |
+
'description': 'Israeli interdiction reduces capability'
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
'name': 'Expansion',
|
| 274 |
+
'probability': 0.10,
|
| 275 |
+
'description': 'Regional instability creates opportunities'
|
| 276 |
+
}
|
| 277 |
+
]
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
print(analysis)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def main():
|
| 284 |
+
"""Run all examples."""
|
| 285 |
+
print("\n" + "=" * 80)
|
| 286 |
+
print(" GeoBot 2.0: Cold Systems Analysis with Geopolitical Nuance")
|
| 287 |
+
print("=" * 80)
|
| 288 |
+
|
| 289 |
+
examples = [
|
| 290 |
+
("Framework Overview", example_1_framework_overview),
|
| 291 |
+
("China Rocket Force Analysis", example_2_china_rocket_force_analysis),
|
| 292 |
+
("Governance Comparison", example_3_governance_comparison),
|
| 293 |
+
("Corruption Assessment", example_4_corruption_assessment),
|
| 294 |
+
("Non-Western Military Assessment", example_5_non_western_military_assessment),
|
| 295 |
+
("All Lenses Integration", example_6_all_lenses_integration),
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
print("\nAvailable Examples:")
|
| 299 |
+
for i, (name, _) in enumerate(examples, 1):
|
| 300 |
+
print(f" {i}. {name}")
|
| 301 |
+
|
| 302 |
+
print("\nRunning all examples...\n")
|
| 303 |
+
|
| 304 |
+
for name, example_func in examples:
|
| 305 |
+
try:
|
| 306 |
+
example_func()
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print(f"Error in {name}: {e}")
|
| 309 |
+
import traceback
|
| 310 |
+
traceback.print_exc()
|
| 311 |
+
|
| 312 |
+
print("\n" + "=" * 80)
|
| 313 |
+
print(" All examples completed")
|
| 314 |
+
print("=" * 80 + "\n")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
examples/EXAMPLES_STATUS.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[PASTE MARKDOWN ABOVE]
|
examples/README.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GeoBotv1 Examples
|
| 2 |
+
|
| 3 |
+
This directory contains example scripts demonstrating the capabilities of GeoBotv1.
|
| 4 |
+
|
| 5 |
+
## Examples Overview
|
| 6 |
+
|
| 7 |
+
### 01_basic_usage.py
|
| 8 |
+
Basic introduction to GeoBotv1 core components:
|
| 9 |
+
- Creating geopolitical scenarios
|
| 10 |
+
- Building causal graphs
|
| 11 |
+
- Running Monte Carlo simulations
|
| 12 |
+
- Bayesian belief updating
|
| 13 |
+
- Uncertainty quantification
|
| 14 |
+
|
| 15 |
+
**Run it:**
|
| 16 |
+
```bash
|
| 17 |
+
python 01_basic_usage.py
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### 02_data_ingestion.py
|
| 21 |
+
Demonstrates data ingestion capabilities:
|
| 22 |
+
- PDF document processing
|
| 23 |
+
- Web scraping and article extraction
|
| 24 |
+
- News aggregation from multiple sources
|
| 25 |
+
- Intelligence extraction from documents
|
| 26 |
+
- Entity and keyword extraction
|
| 27 |
+
|
| 28 |
+
**Run it:**
|
| 29 |
+
```bash
|
| 30 |
+
python 02_data_ingestion.py
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
**Note:** For full functionality, install optional dependencies:
|
| 34 |
+
```bash
|
| 35 |
+
pip install pypdf pdfplumber beautifulsoup4 newspaper3k trafilatura feedparser
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
### 03_intervention_simulation.py
|
| 39 |
+
Advanced intervention and counterfactual analysis:
|
| 40 |
+
- Policy intervention simulation
|
| 41 |
+
- Comparing multiple policy options
|
| 42 |
+
- Finding optimal interventions
|
| 43 |
+
- Counterfactual reasoning ("what if" scenarios)
|
| 44 |
+
- Causal effect estimation
|
| 45 |
+
|
| 46 |
+
**Run it:**
|
| 47 |
+
```bash
|
| 48 |
+
python 03_intervention_simulation.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### 04_advanced_features.py
|
| 52 |
+
Research-grade advanced mathematical features:
|
| 53 |
+
- Sequential Monte Carlo (particle filtering) for nonlinear state estimation
|
| 54 |
+
- Stochastic Differential Equations (Euler-Maruyama, Milstein, Jump-Diffusion)
|
| 55 |
+
- Gradient-based Optimal Transport with Kantorovich duality
|
| 56 |
+
- Entropic OT with Sinkhorn algorithm
|
| 57 |
+
- Structured event extraction from intelligence text
|
| 58 |
+
- Event database with temporal normalization
|
| 59 |
+
|
| 60 |
+
**Run it:**
|
| 61 |
+
```bash
|
| 62 |
+
python 04_advanced_features.py
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
**Note:** Some features require additional dependencies:
|
| 66 |
+
```bash
|
| 67 |
+
pip install torch # For advanced features
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Additional Resources
|
| 71 |
+
|
| 72 |
+
### Creating Custom Scenarios
|
| 73 |
+
```python
|
| 74 |
+
from geobot.core.scenario import Scenario
|
| 75 |
+
import numpy as np
|
| 76 |
+
|
| 77 |
+
scenario = Scenario(
|
| 78 |
+
name="custom_scenario",
|
| 79 |
+
features={
|
| 80 |
+
'tension': np.array([0.7]),
|
| 81 |
+
'stability': np.array([0.4]),
|
| 82 |
+
},
|
| 83 |
+
probability=1.0
|
| 84 |
+
)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Building Causal Models
|
| 88 |
+
```python
|
| 89 |
+
from geobot.models.causal_graph import CausalGraph
|
| 90 |
+
|
| 91 |
+
graph = CausalGraph(name="my_model")
|
| 92 |
+
graph.add_node('cause')
|
| 93 |
+
graph.add_node('effect')
|
| 94 |
+
graph.add_edge('cause', 'effect', strength=0.8)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### Monte Carlo Simulation
|
| 98 |
+
```python
|
| 99 |
+
from geobot.simulation.monte_carlo import MonteCarloEngine, SimulationConfig
|
| 100 |
+
|
| 101 |
+
config = SimulationConfig(n_simulations=1000, time_horizon=100)
|
| 102 |
+
engine = MonteCarloEngine(config)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### Web Scraping
|
| 106 |
+
```python
|
| 107 |
+
from geobot.data_ingestion.web_scraper import ArticleExtractor
|
| 108 |
+
|
| 109 |
+
extractor = ArticleExtractor()
|
| 110 |
+
article = extractor.extract_article('https://example.com/article')
|
| 111 |
+
print(article['title'])
|
| 112 |
+
print(article['text'])
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### PDF Processing
|
| 116 |
+
```python
|
| 117 |
+
from geobot.data_ingestion.pdf_reader import PDFProcessor
|
| 118 |
+
|
| 119 |
+
processor = PDFProcessor()
|
| 120 |
+
result = processor.extract_intelligence('report.pdf')
|
| 121 |
+
print(f"Risk Level: {result['intelligence']['risk_level']}")
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
## Need Help?
|
| 125 |
+
|
| 126 |
+
- Check the main README.md in the project root
|
| 127 |
+
- Review the module documentation in each package
|
| 128 |
+
- Examine the source code for detailed implementation
|
| 129 |
+
|
| 130 |
+
## Contributing
|
| 131 |
+
|
| 132 |
+
Have an interesting use case? Create a new example script and submit a PR!
|
examples/taiwan_situation_room.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Taiwan Situation Room - GeoBot 2.0 Analytical Framework Demo
|
| 3 |
+
|
| 4 |
+
Comprehensive demonstration of GeoBot 2.0 analytical capabilities applied
|
| 5 |
+
to Taiwan Strait scenario analysis. Integrates:
|
| 6 |
+
|
| 7 |
+
- GeoBot 2.0 analytical lenses (Governance, Logistics, Corruption, Non-Western)
|
| 8 |
+
- Bayesian forecasting and belief updating
|
| 9 |
+
- Structural causal models for intervention analysis
|
| 10 |
+
- Hawkes processes for escalation dynamics
|
| 11 |
+
|
| 12 |
+
Scenario: Rising tensions in Taiwan Strait with potential for
|
| 13 |
+
military escalation. Analysis evaluates PRC capabilities, deterrence
|
| 14 |
+
credibility, and intervention outcomes.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
sys.path.insert(0, '..')
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
|
| 23 |
+
# GeoBot 2.0 Analytical Framework
|
| 24 |
+
from geobot.analysis import (
|
| 25 |
+
AnalyticalEngine,
|
| 26 |
+
GovernanceType,
|
| 27 |
+
CorruptionType,
|
| 28 |
+
AnalyticalLenses
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Bayesian Forecasting (if numpy available)
|
| 32 |
+
try:
|
| 33 |
+
from geobot.bayes import (
|
| 34 |
+
BayesianForecaster,
|
| 35 |
+
GeopoliticalPrior,
|
| 36 |
+
PriorType,
|
| 37 |
+
EvidenceUpdate,
|
| 38 |
+
EvidenceType
|
| 39 |
+
)
|
| 40 |
+
BAYES_AVAILABLE = True
|
| 41 |
+
except ImportError:
|
| 42 |
+
BAYES_AVAILABLE = False
|
| 43 |
+
print("Note: Bayesian forecasting requires numpy - skipping Bayesian analysis")
|
| 44 |
+
|
| 45 |
+
# Causal Models (if numpy/networkx available)
|
| 46 |
+
try:
|
| 47 |
+
from geobot.causal import (
|
| 48 |
+
StructuralCausalModel,
|
| 49 |
+
StructuralEquation,
|
| 50 |
+
Intervention,
|
| 51 |
+
Counterfactual
|
| 52 |
+
)
|
| 53 |
+
CAUSAL_AVAILABLE = True
|
| 54 |
+
except ImportError:
|
| 55 |
+
CAUSAL_AVAILABLE = False
|
| 56 |
+
print("Note: Causal models require numpy/networkx - skipping causal analysis")
|
| 57 |
+
|
| 58 |
+
# Hawkes processes (if scipy available)
|
| 59 |
+
try:
|
| 60 |
+
from geobot.simulation.hawkes import (
|
| 61 |
+
HawkesSimulator,
|
| 62 |
+
quick_conflict_contagion_analysis
|
| 63 |
+
)
|
| 64 |
+
HAWKES_AVAILABLE = True
|
| 65 |
+
except ImportError:
|
| 66 |
+
HAWKES_AVAILABLE = False
|
| 67 |
+
print("Note: Hawkes processes require scipy - skipping escalation dynamics")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def print_header(title):
|
| 71 |
+
"""Print formatted section header."""
|
| 72 |
+
print("\n" + "=" * 80)
|
| 73 |
+
print(f" {title}")
|
| 74 |
+
print("=" * 80 + "\n")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def print_subheader(title):
|
| 78 |
+
"""Print formatted subsection header."""
|
| 79 |
+
print("\n" + "-" * 80)
|
| 80 |
+
print(f" {title}")
|
| 81 |
+
print("-" * 80 + "\n")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ============================================================================
|
| 85 |
+
# Part 1: GeoBot 2.0 Core Analysis
|
| 86 |
+
# ============================================================================
|
| 87 |
+
|
| 88 |
+
def part1_geobot_core_analysis():
|
| 89 |
+
"""
|
| 90 |
+
Apply GeoBot 2.0 analytical framework to Taiwan scenario.
|
| 91 |
+
"""
|
| 92 |
+
print_header("Part 1: GeoBot 2.0 Analytical Framework - Taiwan Scenario")
|
| 93 |
+
|
| 94 |
+
engine = AnalyticalEngine()
|
| 95 |
+
|
| 96 |
+
# Scenario context
|
| 97 |
+
query = """PRC conducts large-scale military exercises around Taiwan,
|
| 98 |
+
including live-fire drills and simulated blockade operations. US conducts
|
| 99 |
+
freedom of navigation operations in Taiwan Strait. What are the escalation
|
| 100 |
+
risks and intervention outcomes?"""
|
| 101 |
+
|
| 102 |
+
print(f"QUERY: {query}\n")
|
| 103 |
+
|
| 104 |
+
# Build comprehensive context
|
| 105 |
+
context = {
|
| 106 |
+
'governance_type': GovernanceType.AUTHORITARIAN_CENTRALIZED,
|
| 107 |
+
'corruption_type': CorruptionType.MANAGED_BOUNDED,
|
| 108 |
+
'military_system': 'Chinese PLA',
|
| 109 |
+
'scenario_description': 'PRC military exercises and potential Taiwan contingency',
|
| 110 |
+
'operational_context': 'High-intensity joint operations in near-seas environment',
|
| 111 |
+
|
| 112 |
+
'summary': """PRC demonstrates improving capability for joint operations in Taiwan
|
| 113 |
+
Strait, but faces significant logistical and operational challenges for sustained
|
| 114 |
+
high-intensity operations. Authoritarian governance enables rapid mobilization but
|
| 115 |
+
information flow problems could create coordination failures under stress.""",
|
| 116 |
+
|
| 117 |
+
'logistics_assessment': """PRC Eastern Theater Command has concentrated logistics
|
| 118 |
+
infrastructure supporting Taiwan contingency. Civil-military fusion enables rapid resource
|
| 119 |
+
mobilization. However, sustained amphibious/air assault operations would stress logistics
|
| 120 |
+
systems untested in combat. Key constraints: sealift capacity, contested logistics under
|
| 121 |
+
US/allied interdiction, ammunition sustainment for high-intensity operations.""",
|
| 122 |
+
|
| 123 |
+
# Scenarios with probabilities
|
| 124 |
+
'scenarios': [
|
| 125 |
+
{
|
| 126 |
+
'name': 'Coercive demonstration without escalation',
|
| 127 |
+
'probability': 0.55,
|
| 128 |
+
'description': 'Exercises conclude after demonstrating capability and resolve'
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
'name': 'Graduated escalation (quarantine/blockade)',
|
| 132 |
+
'probability': 0.30,
|
| 133 |
+
'description': 'PRC implements quarantine, testing US/allied response'
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
'name': 'Limited kinetic action',
|
| 137 |
+
'probability': 0.10,
|
| 138 |
+
'description': 'Strikes on Taiwan military targets, no invasion'
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
'name': 'Full-scale invasion attempt',
|
| 142 |
+
'probability': 0.05,
|
| 143 |
+
'description': 'Amphibious/airborne assault on Taiwan'
|
| 144 |
+
}
|
| 145 |
+
],
|
| 146 |
+
|
| 147 |
+
# Uncertainty factors
|
| 148 |
+
'uncertainty_factors': [
|
| 149 |
+
'PRC leadership risk tolerance and decision calculus',
|
| 150 |
+
'Taiwan domestic political response and resolve',
|
| 151 |
+
'US extended deterrence credibility perception',
|
| 152 |
+
'PLA actual readiness vs. reported readiness (information distortion risk)',
|
| 153 |
+
'Third-party actions (Japan, Australia, regional states)',
|
| 154 |
+
'Economic interdependence constraints on escalation'
|
| 155 |
+
],
|
| 156 |
+
|
| 157 |
+
# Signals to watch
|
| 158 |
+
'signals_to_watch': [
|
| 159 |
+
'PLA logistics mobilization (satellite-observable sealift, air transport concentration)',
|
| 160 |
+
'Rocket Force alert status and deployment patterns',
|
| 161 |
+
'PLAN submarine deployments',
|
| 162 |
+
'Civilian shipping disruptions (clearance of civilian vessels from exercise areas)',
|
| 163 |
+
'PRC domestic propaganda shifts (priming for kinetic action vs. victorious conclusion)',
|
| 164 |
+
'US carrier strike group deployments and readiness status',
|
| 165 |
+
'Taiwan reserve mobilization signals',
|
| 166 |
+
'Japanese Self-Defense Force posture changes'
|
| 167 |
+
],
|
| 168 |
+
|
| 169 |
+
# Comparative notes
|
| 170 |
+
'comparative_notes': """Unlike Russia-Ukraine, PRC faces amphibious/air assault across
|
| 171 |
+
defended strait with peer/near-peer opposition (US, Japan, Australia). PLA has not conducted
|
| 172 |
+
combat operations since 1979, vs. Russia's experience in Syria, Georgia, Ukraine. However,
|
| 173 |
+
PRC has advantage of proximity, massive firepower overmatch against Taiwan alone, and
|
| 174 |
+
authoritarian ability to sustain economic costs."""
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
# Add governance-specific analysis
|
| 178 |
+
context['governance_context'] = {
|
| 179 |
+
'trade_off': """Authoritarian governance enables PRC to:
|
| 180 |
+
- Rapidly mobilize resources without legislative approval
|
| 181 |
+
- Sustain operations despite economic costs and casualties
|
| 182 |
+
- Conduct strategic surprise without public debate
|
| 183 |
+
|
| 184 |
+
But creates risks:
|
| 185 |
+
- Information distortion about PLA readiness/capabilities
|
| 186 |
+
- Over-confidence in leadership due to filtered reporting
|
| 187 |
+
- Inflexible response to unexpected battlefield developments""",
|
| 188 |
+
|
| 189 |
+
'context_specific_advantage': """In crisis initiation and short, sharp operations,
|
| 190 |
+
authoritarian system has decision-speed advantage. In sustained operations requiring
|
| 191 |
+
adaptation, democratic information flow advantages become more important. Key question:
|
| 192 |
+
Can PRC achieve fait accompli before US/allied decision-making concludes?"""
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# Add corruption context
|
| 196 |
+
context['corruption_details'] = {
|
| 197 |
+
'evidence': """Post-2012 anti-corruption campaigns have reduced parasitic corruption
|
| 198 |
+
in PLA, especially after 2017 Rocket Force purges. However, managed corruption model means:
|
| 199 |
+
- Procurement still involves kickbacks, but constrained to avoid readiness impact
|
| 200 |
+
- Promotion decisions still involve patronage, affecting command quality
|
| 201 |
+
- Readiness reporting still subject to careerism incentives""",
|
| 202 |
+
|
| 203 |
+
'risk_assessment': """Corruption less likely to cause catastrophic equipment failures
|
| 204 |
+
(cf. Russian logistics in Ukraine), but could create:
|
| 205 |
+
- Over-estimation of PLA capabilities by leadership
|
| 206 |
+
- Coordination problems from patronage-based command appointments
|
| 207 |
+
- Supply chain inefficiencies under stress"""
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# Add non-Western analysis
|
| 211 |
+
context['non_western_context'] = {
|
| 212 |
+
'analysis_framework': """PLA operational culture emphasizes:
|
| 213 |
+
- Centralized planning with detailed pre-scripted operations
|
| 214 |
+
- Heavy firepower preparation before maneuver
|
| 215 |
+
- Political control through party committee system
|
| 216 |
+
- Joint operations still developing (improving but not NATO-level)
|
| 217 |
+
|
| 218 |
+
This creates both capabilities and constraints different from Western assumptions.""",
|
| 219 |
+
|
| 220 |
+
'key_distinction': """Western analysis often assumes PLA would operate like NATO forces.
|
| 221 |
+
In reality, PLA would likely emphasize:
|
| 222 |
+
- Overwhelming initial firepower (missiles, air strikes) to create shock
|
| 223 |
+
- Rapid fait accompli before US can intervene
|
| 224 |
+
- Accepting higher casualties than Western forces
|
| 225 |
+
- Using information operations and political warfare alongside kinetic
|
| 226 |
+
|
| 227 |
+
These reflect Chinese strategic culture and organizational strengths, not deficiencies."""
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
# Perform analysis
|
| 231 |
+
print_subheader("GeoBot 2.0 Analytical Output")
|
| 232 |
+
analysis = engine.analyze(query, context)
|
| 233 |
+
print(analysis)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ============================================================================
|
| 237 |
+
# Part 2: Bayesian Belief Updating
|
| 238 |
+
# ============================================================================
|
| 239 |
+
|
| 240 |
+
def part2_bayesian_analysis():
|
| 241 |
+
"""
|
| 242 |
+
Bayesian belief updating as new intelligence arrives.
|
| 243 |
+
"""
|
| 244 |
+
if not BAYES_AVAILABLE:
|
| 245 |
+
print_header("Part 2: Bayesian Analysis - SKIPPED (numpy not available)")
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
print_header("Part 2: Bayesian Belief Updating - Intelligence Integration")
|
| 249 |
+
|
| 250 |
+
forecaster = BayesianForecaster()
|
| 251 |
+
|
| 252 |
+
# Set prior on PRC invasion probability within 12 months
|
| 253 |
+
invasion_prior = GeopoliticalPrior(
|
| 254 |
+
parameter_name="invasion_probability_12mo",
|
| 255 |
+
prior_type=PriorType.BETA,
|
| 256 |
+
parameters={'alpha': 2.0, 'beta': 18.0}, # Prior mean ~0.10
|
| 257 |
+
description="Probability of PRC invasion attempt within 12 months"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
forecaster.set_prior(invasion_prior)
|
| 261 |
+
|
| 262 |
+
print("PRIOR BELIEF:")
|
| 263 |
+
print(f" Distribution: Beta(α=2.0, β=18.0)")
|
| 264 |
+
print(f" Prior mean: ~0.10 (10% chance)")
|
| 265 |
+
print(f" This reflects baseline assessment before current crisis\n")
|
| 266 |
+
|
| 267 |
+
# Evidence 1: Satellite imagery shows sealift mobilization
|
| 268 |
+
print_subheader("Evidence Update 1: Satellite Imagery")
|
| 269 |
+
print("Satellite imagery shows increased sealift concentration in Fujian ports")
|
| 270 |
+
print("Assessing impact on invasion probability...\n")
|
| 271 |
+
|
| 272 |
+
def sealift_likelihood(p):
|
| 273 |
+
# If invasion likely, sealift mobilization is very likely
|
| 274 |
+
# If invasion unlikely, some mobilization still possible (exercises)
|
| 275 |
+
return p * 0.9 + (1 - p) * 0.2
|
| 276 |
+
|
| 277 |
+
evidence1 = EvidenceUpdate(
|
| 278 |
+
evidence_type=EvidenceType.SATELLITE_IMAGERY,
|
| 279 |
+
observation="sealift_mobilization",
|
| 280 |
+
likelihood_function=sealift_likelihood,
|
| 281 |
+
reliability=0.95,
|
| 282 |
+
source="Commercial satellite analysis"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
belief1 = forecaster.update_belief(
|
| 286 |
+
"invasion_probability_12mo",
|
| 287 |
+
evidence1,
|
| 288 |
+
n_samples=10000
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
print(f"Updated belief after satellite evidence:")
|
| 292 |
+
print(f" Mean: {belief1.mean():.3f}")
|
| 293 |
+
print(f" Median: {belief1.median():.3f}")
|
| 294 |
+
print(f" 95% CI: {belief1.credible_interval(0.05)}")
|
| 295 |
+
print(f" P(invasion > 0.20): {belief1.probability_greater_than(0.20):.2f}\n")
|
| 296 |
+
|
| 297 |
+
# Evidence 2: HUMINT reports purges in Taiwan Affairs Office
|
| 298 |
+
print_subheader("Evidence Update 2: HUMINT Report")
|
| 299 |
+
print("HUMINT reports internal purges in Taiwan Affairs Office leadership")
|
| 300 |
+
print("Interpretation: Could indicate pre-operation security tightening OR internal dysfunction\n")
|
| 301 |
+
|
| 302 |
+
def purge_likelihood(p):
|
| 303 |
+
# Purges could indicate either preparation or problems
|
| 304 |
+
# Moderate signal
|
| 305 |
+
return p * 0.6 + (1 - p) * 0.4
|
| 306 |
+
|
| 307 |
+
evidence2 = EvidenceUpdate(
|
| 308 |
+
evidence_type=EvidenceType.INTELLIGENCE_REPORT,
|
| 309 |
+
observation="tao_purges",
|
| 310 |
+
likelihood_function=purge_likelihood,
|
| 311 |
+
reliability=0.70, # HUMINT less reliable than satellite
|
| 312 |
+
source="HUMINT Taiwan Affairs Office"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
belief2 = forecaster.update_belief(
|
| 316 |
+
"invasion_probability_12mo",
|
| 317 |
+
evidence2,
|
| 318 |
+
n_samples=10000
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
print(f"Updated belief after HUMINT evidence:")
|
| 322 |
+
print(f" Mean: {belief2.mean():.3f}")
|
| 323 |
+
print(f" Median: {belief2.median():.3f}")
|
| 324 |
+
print(f" 95% CI: {belief2.credible_interval(0.05)}")
|
| 325 |
+
print(f" P(invasion > 0.20): {belief2.probability_greater_than(0.20):.2f}\n")
|
| 326 |
+
|
| 327 |
+
# Evidence 3: Economic data shows continued deep integration
|
| 328 |
+
print_subheader("Evidence Update 3: Economic Data")
|
| 329 |
+
print("Economic data shows continued deep PRC-Taiwan trade integration, no decoupling")
|
| 330 |
+
print("Interpretation: Reduces likelihood of near-term kinetic action\n")
|
| 331 |
+
|
| 332 |
+
def economic_likelihood(p):
|
| 333 |
+
# Continued integration suggests not preparing for war
|
| 334 |
+
return p * 0.3 + (1 - p) * 0.8
|
| 335 |
+
|
| 336 |
+
evidence3 = EvidenceUpdate(
|
| 337 |
+
evidence_type=EvidenceType.ECONOMIC_DATA,
|
| 338 |
+
observation="continued_integration",
|
| 339 |
+
likelihood_function=economic_likelihood,
|
| 340 |
+
reliability=1.0, # Economic data highly reliable
|
| 341 |
+
source="Trade statistics"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
belief3 = forecaster.update_belief(
|
| 345 |
+
"invasion_probability_12mo",
|
| 346 |
+
evidence3,
|
| 347 |
+
n_samples=10000
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
print(f"Final belief after all evidence:")
|
| 351 |
+
print(f" Mean: {belief3.mean():.3f}")
|
| 352 |
+
print(f" Median: {belief3.median():.3f}")
|
| 353 |
+
print(f" 95% CI: {belief3.credible_interval(0.05)}")
|
| 354 |
+
print(f" P(invasion > 0.20): {belief3.probability_greater_than(0.20):.2f}")
|
| 355 |
+
print(f" P(invasion > 0.30): {belief3.probability_greater_than(0.30):.2f}")
|
| 356 |
+
|
| 357 |
+
# Summary
|
| 358 |
+
summary = forecaster.get_belief_summary("invasion_probability_12mo")
|
| 359 |
+
print(f"\nBelief Summary:")
|
| 360 |
+
print(f" Evidence updates: {summary['n_evidence_updates']}")
|
| 361 |
+
print(f" Evidence types: {summary['evidence_types']}")
|
| 362 |
+
print(f" Final assessment: {summary['mean']:.1%} probability within 12 months")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ============================================================================
|
| 366 |
+
# Part 3: Causal Intervention Analysis
|
| 367 |
+
# ============================================================================
|
| 368 |
+
|
| 369 |
+
def part3_causal_intervention_analysis():
|
| 370 |
+
"""
|
| 371 |
+
Use structural causal models to evaluate intervention outcomes.
|
| 372 |
+
"""
|
| 373 |
+
if not CAUSAL_AVAILABLE:
|
| 374 |
+
print_header("Part 3: Causal Analysis - SKIPPED (dependencies not available)")
|
| 375 |
+
return
|
| 376 |
+
|
| 377 |
+
print_header("Part 3: Causal Intervention Analysis - Policy Counterfactuals")
|
| 378 |
+
|
| 379 |
+
# Build SCM for Taiwan deterrence
|
| 380 |
+
print("Building Structural Causal Model for Taiwan deterrence dynamics...\n")
|
| 381 |
+
|
| 382 |
+
scm = StructuralCausalModel(name="TaiwanDeterrenceSCM")
|
| 383 |
+
|
| 384 |
+
noise_dist = lambda n: np.random.randn(n) * 0.05
|
| 385 |
+
|
| 386 |
+
# US military presence -> PRC perception of US resolve
|
| 387 |
+
scm.add_equation(StructuralEquation(
|
| 388 |
+
variable="prc_perceived_us_resolve",
|
| 389 |
+
parents=["us_military_presence"],
|
| 390 |
+
function=lambda p: 0.3 + 0.6 * p["us_military_presence"],
|
| 391 |
+
noise_dist=noise_dist,
|
| 392 |
+
description="US military presence increases PRC perception of US resolve"
|
| 393 |
+
))
|
| 394 |
+
|
| 395 |
+
# Taiwan defense spending -> Taiwan military capability
|
| 396 |
+
scm.add_equation(StructuralEquation(
|
| 397 |
+
variable="taiwan_military_capability",
|
| 398 |
+
parents=["taiwan_defense_spending"],
|
| 399 |
+
function=lambda p: 0.4 + 0.5 * p["taiwan_defense_spending"],
|
| 400 |
+
noise_dist=noise_dist,
|
| 401 |
+
description="Taiwan defense spending improves military capability"
|
| 402 |
+
))
|
| 403 |
+
|
| 404 |
+
# PRC perceived costs = f(US resolve, Taiwan capability)
|
| 405 |
+
scm.add_equation(StructuralEquation(
|
| 406 |
+
variable="prc_perceived_costs",
|
| 407 |
+
parents=["prc_perceived_us_resolve", "taiwan_military_capability"],
|
| 408 |
+
function=lambda p: (0.4 * p["prc_perceived_us_resolve"] +
|
| 409 |
+
0.3 * p["taiwan_military_capability"]),
|
| 410 |
+
noise_dist=noise_dist,
|
| 411 |
+
description="Perceived costs depend on US resolve and Taiwan capability"
|
| 412 |
+
))
|
| 413 |
+
|
| 414 |
+
# Conflict risk = f(perceived costs, prc_domestic_pressure)
|
| 415 |
+
scm.add_equation(StructuralEquation(
|
| 416 |
+
variable="conflict_risk",
|
| 417 |
+
parents=["prc_perceived_costs", "prc_domestic_pressure"],
|
| 418 |
+
function=lambda p: (0.5 + 0.3 * p["prc_domestic_pressure"] -
|
| 419 |
+
0.4 * p["prc_perceived_costs"]),
|
| 420 |
+
noise_dist=noise_dist,
|
| 421 |
+
description="Conflict risk increases with domestic pressure, decreases with perceived costs"
|
| 422 |
+
))
|
| 423 |
+
|
| 424 |
+
# Baseline scenario
|
| 425 |
+
print_subheader("Baseline Scenario")
|
| 426 |
+
baseline_data = scm.simulate(n_samples=10000, random_state=42)
|
| 427 |
+
print(f"Baseline conflict risk: Mean = {np.mean(baseline_data['conflict_risk']):.3f}, "
|
| 428 |
+
f"Std = {np.std(baseline_data['conflict_risk']):.3f}\n")
|
| 429 |
+
|
| 430 |
+
# Intervention 1: Increase US military presence
|
| 431 |
+
print_subheader("Intervention 1: Increase US Military Presence")
|
| 432 |
+
print("do(us_military_presence = 0.8) # High presence\n")
|
| 433 |
+
|
| 434 |
+
intervention1 = Intervention(
|
| 435 |
+
variable="us_military_presence",
|
| 436 |
+
value=0.8,
|
| 437 |
+
description="Sustained US carrier presence + forward-deployed assets"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
post_intervention1 = scm.intervene([intervention1], n_samples=10000, random_state=42)
|
| 441 |
+
print(f"Post-intervention conflict risk: Mean = {np.mean(post_intervention1['conflict_risk']):.3f}")
|
| 442 |
+
print(f"Effect of intervention: {np.mean(baseline_data['conflict_risk']) - np.mean(post_intervention1['conflict_risk']):.3f} reduction")
|
| 443 |
+
print(f"Interpretation: Increasing US presence reduces conflict risk by ~{100*(np.mean(baseline_data['conflict_risk']) - np.mean(post_intervention1['conflict_risk'])):.1f} percentage points\n")
|
| 444 |
+
|
| 445 |
+
# Intervention 2: Increase Taiwan defense spending
|
| 446 |
+
print_subheader("Intervention 2: Increase Taiwan Defense Spending")
|
| 447 |
+
print("do(taiwan_defense_spending = 0.9) # Major defense investment\n")
|
| 448 |
+
|
| 449 |
+
intervention2 = Intervention(
|
| 450 |
+
variable="taiwan_defense_spending",
|
| 451 |
+
value=0.9,
|
| 452 |
+
description="Major asymmetric defense investments"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
post_intervention2 = scm.intervene([intervention2], n_samples=10000, random_state=42)
|
| 456 |
+
print(f"Post-intervention conflict risk: Mean = {np.mean(post_intervention2['conflict_risk']):.3f}")
|
| 457 |
+
print(f"Effect of intervention: {np.mean(baseline_data['conflict_risk']) - np.mean(post_intervention2['conflict_risk']):.3f} reduction\n")
|
| 458 |
+
|
| 459 |
+
# Combined intervention
|
| 460 |
+
print_subheader("Intervention 3: Combined Strategy")
|
| 461 |
+
print("do(us_military_presence = 0.8, taiwan_defense_spending = 0.9)\n")
|
| 462 |
+
|
| 463 |
+
combined_data = scm.intervene([intervention1, intervention2], n_samples=10000, random_state=42)
|
| 464 |
+
print(f"Post-intervention conflict risk: Mean = {np.mean(combined_data['conflict_risk']):.3f}")
|
| 465 |
+
print(f"Effect of combined intervention: {np.mean(baseline_data['conflict_risk']) - np.mean(combined_data['conflict_risk']):.3f} reduction")
|
| 466 |
+
print(f"Interpretation: Combined strategy most effective for reducing conflict risk\n")
|
| 467 |
+
|
| 468 |
+
# Counterfactual query
|
| 469 |
+
print_subheader("Counterfactual Query")
|
| 470 |
+
print("Question: What would conflict risk be if we had maintained high US presence,")
|
| 471 |
+
print("given that we currently observe moderate US presence?\n")
|
| 472 |
+
|
| 473 |
+
counterfactual = Counterfactual(
|
| 474 |
+
query_variable="conflict_risk",
|
| 475 |
+
intervention=Intervention("us_military_presence", 0.8),
|
| 476 |
+
observations={"us_military_presence": 0.5}
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
cf_result = scm.counterfactual_query(counterfactual, n_samples=10000)
|
| 480 |
+
print(f"Counterfactual conflict risk: {cf_result['expected_value']:.3f}")
|
| 481 |
+
print(f"95% CI: ({cf_result['quantiles']['5%']:.3f}, {cf_result['quantiles']['95%']:.3f})")
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
# ============================================================================
|
| 485 |
+
# Part 4: Escalation Dynamics (Hawkes Processes)
|
| 486 |
+
# ============================================================================
|
| 487 |
+
|
| 488 |
+
def part4_escalation_dynamics():
|
| 489 |
+
"""
|
| 490 |
+
Model escalation dynamics using Hawkes processes.
|
| 491 |
+
"""
|
| 492 |
+
if not HAWKES_AVAILABLE:
|
| 493 |
+
print_header("Part 4: Escalation Dynamics - SKIPPED (scipy not available)")
|
| 494 |
+
return
|
| 495 |
+
|
| 496 |
+
print_header("Part 4: Escalation Dynamics - Self-Exciting Processes")
|
| 497 |
+
|
| 498 |
+
from geobot.simulation.hawkes import HawkesParameters
|
| 499 |
+
|
| 500 |
+
print("Modeling crisis escalation as self-exciting point process...")
|
| 501 |
+
print("Events cluster in time and trigger subsequent events (escalatory spiral)\n")
|
| 502 |
+
|
| 503 |
+
# Simulate escalation scenario
|
| 504 |
+
print_subheader("Scenario: Incremental Escalation with Contagion")
|
| 505 |
+
|
| 506 |
+
# Three actors: PRC, Taiwan, US
|
| 507 |
+
baseline_rates = [0.5, 0.3, 0.2] # PRC initiates more, US responds
|
| 508 |
+
countries = ['PRC', 'Taiwan', 'US']
|
| 509 |
+
|
| 510 |
+
# Contagion: PRC action triggers Taiwan/US response
|
| 511 |
+
alpha_matrix = np.array([
|
| 512 |
+
[0.3, 0.2, 0.15], # PRC actions trigger more PRC, Taiwan, US actions
|
| 513 |
+
[0.4, 0.2, 0.3], # Taiwan actions strongly trigger PRC and US
|
| 514 |
+
[0.5, 0.2, 0.1], # US actions strongly trigger PRC responses
|
| 515 |
+
])
|
| 516 |
+
|
| 517 |
+
beta_matrix = np.ones((3, 3)) * 1.5 # Decay rate
|
| 518 |
+
|
| 519 |
+
params = HawkesParameters(
|
| 520 |
+
mu=np.array(baseline_rates),
|
| 521 |
+
alpha=alpha_matrix,
|
| 522 |
+
beta=beta_matrix
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Simulate 30-day crisis
|
| 526 |
+
simulator = HawkesSimulator(n_dimensions=3)
|
| 527 |
+
events = simulator.simulate(T=30.0, params=params, random_state=42)
|
| 528 |
+
|
| 529 |
+
print(f"Simulated 30-day crisis escalation:")
|
| 530 |
+
for i, country in enumerate(countries):
|
| 531 |
+
print(f" {country}: {len(events[i])} escalatory events")
|
| 532 |
+
|
| 533 |
+
# Assess stability
|
| 534 |
+
stability = simulator.assess_stability(params)
|
| 535 |
+
print(f"\nEscalation dynamics stability assessment:")
|
| 536 |
+
print(f" Branching ratio: {stability['branching_ratio']:.3f}")
|
| 537 |
+
print(f" Regime: {stability['regime']}")
|
| 538 |
+
print(f" Interpretation: {stability['interpretation']}\n")
|
| 539 |
+
|
| 540 |
+
if stability['is_explosive']:
|
| 541 |
+
print("⚠️ WARNING: Process is supercritical - escalation could spiral out of control")
|
| 542 |
+
else:
|
| 543 |
+
print("✓ Process is subcritical - escalation will stabilize")
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# ============================================================================
|
| 547 |
+
# Main Execution
|
| 548 |
+
# ============================================================================
|
| 549 |
+
|
| 550 |
+
def main():
|
| 551 |
+
"""Run complete Taiwan situation room analysis."""
|
| 552 |
+
|
| 553 |
+
print("\n" + "=" * 80)
|
| 554 |
+
print(" TAIWAN SITUATION ROOM")
|
| 555 |
+
print(" GeoBot 2.0 Integrated Geopolitical Analysis")
|
| 556 |
+
print(" " + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
| 557 |
+
print("=" * 80)
|
| 558 |
+
|
| 559 |
+
# Run all parts
|
| 560 |
+
part1_geobot_core_analysis()
|
| 561 |
+
part2_bayesian_analysis()
|
| 562 |
+
part3_causal_intervention_analysis()
|
| 563 |
+
part4_escalation_dynamics()
|
| 564 |
+
|
| 565 |
+
# Summary
|
| 566 |
+
print_header("Summary and Recommendations")
|
| 567 |
+
|
| 568 |
+
print("""
|
| 569 |
+
INTEGRATED ASSESSMENT:
|
| 570 |
+
|
| 571 |
+
1. GEOBOT 2.0 ANALYTICAL FRAMEWORK
|
| 572 |
+
- PRC has improving joint operations capability but faces significant logistical
|
| 573 |
+
constraints for sustained high-intensity operations
|
| 574 |
+
- Authoritarian governance enables rapid mobilization but creates information
|
| 575 |
+
flow risks
|
| 576 |
+
- Managed corruption likely constrained enough to maintain basic functionality
|
| 577 |
+
- Non-Western analysis reveals PRC emphasis on firepower and fait accompli
|
| 578 |
+
|
| 579 |
+
2. BAYESIAN BELIEF UPDATING
|
| 580 |
+
- Posterior probability of invasion within 12 months: ~15-20% (up from 10% prior)
|
| 581 |
+
- Satellite evidence of sealift mobilization raises concern
|
| 582 |
+
- Economic integration evidence reduces near-term kinetic risk
|
| 583 |
+
- Continued monitoring required as new intelligence arrives
|
| 584 |
+
|
| 585 |
+
3. CAUSAL INTERVENTION ANALYSIS
|
| 586 |
+
- Combined strategy (US presence + Taiwan defense) most effective
|
| 587 |
+
- US military presence has direct deterrent effect
|
| 588 |
+
- Taiwan capabilities create operational costs for PRC
|
| 589 |
+
- Counterfactual analysis supports sustained presence policy
|
| 590 |
+
|
| 591 |
+
4. ESCALATION DYNAMICS
|
| 592 |
+
- Current contagion parameters suggest subcritical regime (stable)
|
| 593 |
+
- However, parameter changes could shift to explosive regime
|
| 594 |
+
- Escalation management critical to prevent spiral
|
| 595 |
+
|
| 596 |
+
POLICY RECOMMENDATIONS:
|
| 597 |
+
- Maintain credible US extended deterrence
|
| 598 |
+
- Support Taiwan asymmetric defense capabilities
|
| 599 |
+
- Engage in crisis management mechanisms to prevent escalation spirals
|
| 600 |
+
- Continue intelligence collection on PLA readiness and mobilization
|
| 601 |
+
- Monitor for signals of PRC leadership decision to use force
|
| 602 |
+
""")
|
| 603 |
+
|
| 604 |
+
print("=" * 80)
|
| 605 |
+
print(" Analysis Complete")
|
| 606 |
+
print("=" * 80 + "\n")
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
if __name__ == "__main__":
|
| 610 |
+
main()
|
geobot/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeoBotv1: Geopolitical Forecasting Framework
|
| 3 |
+
|
| 4 |
+
A comprehensive framework for geopolitical risk analysis, conflict prediction,
|
| 5 |
+
and intervention simulation using advanced mathematical and statistical methods.
|
| 6 |
+
|
| 7 |
+
Version 2.0 includes the GeoBot analytical framework for clinical systems analysis
|
| 8 |
+
with geopolitical nuance.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__version__ = "2.0.0"
|
| 12 |
+
__author__ = "GeoBotv1 Team"
|
| 13 |
+
|
| 14 |
+
# Core modules
|
| 15 |
+
from . import core
|
| 16 |
+
from . import models
|
| 17 |
+
from . import inference
|
| 18 |
+
from . import simulation
|
| 19 |
+
from . import timeseries
|
| 20 |
+
from . import ml
|
| 21 |
+
from . import data_ingestion
|
| 22 |
+
from . import utils
|
| 23 |
+
from . import config
|
| 24 |
+
from . import analysis
|
| 25 |
+
|
| 26 |
+
# New modules in v2.0
|
| 27 |
+
from . import bayes
|
| 28 |
+
from . import causal
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"core",
|
| 32 |
+
"models",
|
| 33 |
+
"inference",
|
| 34 |
+
"simulation",
|
| 35 |
+
"timeseries",
|
| 36 |
+
"ml",
|
| 37 |
+
"data_ingestion",
|
| 38 |
+
"utils",
|
| 39 |
+
"config",
|
| 40 |
+
"analysis",
|
| 41 |
+
"bayes",
|
| 42 |
+
"causal",
|
| 43 |
+
]
|
geobot/analysis/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeoBot 2.0 Analytical Framework
|
| 3 |
+
|
| 4 |
+
A clinical, logistics-focused analytical framework for geopolitical analysis
|
| 5 |
+
with institutional agility assessment and cultural-operational context.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .framework import GeoBotFramework, AnalyticalPrinciples, CoreIdentity
|
| 9 |
+
from .lenses import (
|
| 10 |
+
LogisticsLens,
|
| 11 |
+
GovernanceLens,
|
| 12 |
+
CorruptionLens,
|
| 13 |
+
NonWesternLens,
|
| 14 |
+
AnalyticalLenses
|
| 15 |
+
)
|
| 16 |
+
from .engine import AnalyticalEngine
|
| 17 |
+
from .formatter import AnalysisFormatter
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"GeoBotFramework",
|
| 21 |
+
"AnalyticalPrinciples",
|
| 22 |
+
"CoreIdentity",
|
| 23 |
+
"LogisticsLens",
|
| 24 |
+
"GovernanceLens",
|
| 25 |
+
"CorruptionLens",
|
| 26 |
+
"NonWesternLens",
|
| 27 |
+
"AnalyticalLenses",
|
| 28 |
+
"AnalyticalEngine",
|
| 29 |
+
"AnalysisFormatter",
|
| 30 |
+
]
|
geobot/analysis/engine.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeoBot 2.0 Analytical Engine
|
| 3 |
+
|
| 4 |
+
Main interface for conducting GeoBot 2.0 analysis. Integrates all lenses,
|
| 5 |
+
applies analytical priorities, and generates formatted output.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Any, Optional
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
|
| 11 |
+
from .framework import GeoBotFramework
|
| 12 |
+
from .lenses import (
|
| 13 |
+
AnalyticalLenses,
|
| 14 |
+
GovernanceType,
|
| 15 |
+
CorruptionType,
|
| 16 |
+
MilitaryProfile
|
| 17 |
+
)
|
| 18 |
+
from .formatter import AnalysisFormatter, AnalysisOutput, Scenario
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AnalyticalPriorities:
|
| 23 |
+
"""
|
| 24 |
+
Analytical priorities for GeoBot 2.0.
|
| 25 |
+
|
| 26 |
+
Every analysis checks:
|
| 27 |
+
1. Governance structure impact
|
| 28 |
+
2. Logistics coherence
|
| 29 |
+
3. Corruption type and impact
|
| 30 |
+
4. Institutional context
|
| 31 |
+
5. Communication networks
|
| 32 |
+
6. Organizational cohesion
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
priorities: List[str] = field(default_factory=lambda: [
|
| 36 |
+
"Governance structure impact - Does this scenario favor centralized decision-making or distributed adaptation?",
|
| 37 |
+
"Logistics coherence - Supply chains, maintenance, communications",
|
| 38 |
+
"Corruption type and impact - What kind of corruption exists? Does it critically impair this specific operation?",
|
| 39 |
+
"Institutional context - Are we analyzing this military using appropriate cultural/organizational frameworks?",
|
| 40 |
+
"Communication networks - Information flow and coordination",
|
| 41 |
+
"Organizational cohesion - Unit cohesion and morale"
|
| 42 |
+
])
|
| 43 |
+
|
| 44 |
+
def check_all(self, context: Dict[str, Any]) -> Dict[str, bool]:
|
| 45 |
+
"""
|
| 46 |
+
Check which priorities have been addressed in analysis.
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
context : Dict[str, Any]
|
| 51 |
+
Analysis context
|
| 52 |
+
|
| 53 |
+
Returns
|
| 54 |
+
-------
|
| 55 |
+
Dict[str, bool]
|
| 56 |
+
Priority checklist
|
| 57 |
+
"""
|
| 58 |
+
return {
|
| 59 |
+
'governance_structure': 'governance_type' in context,
|
| 60 |
+
'logistics': 'logistics_assessment' in context,
|
| 61 |
+
'corruption': 'corruption_type' in context,
|
| 62 |
+
'institutional_context': 'military_system' in context,
|
| 63 |
+
'communications': 'communications_status' in context,
|
| 64 |
+
'cohesion': 'cohesion_assessment' in context
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AnalyticalEngine:
|
| 69 |
+
"""
|
| 70 |
+
Main analytical engine for GeoBot 2.0.
|
| 71 |
+
|
| 72 |
+
Integrates framework, lenses, priorities, and formatter to provide
|
| 73 |
+
comprehensive geopolitical analysis.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self):
|
| 77 |
+
"""Initialize the analytical engine."""
|
| 78 |
+
self.framework = GeoBotFramework()
|
| 79 |
+
self.lenses = AnalyticalLenses()
|
| 80 |
+
self.priorities = AnalyticalPriorities()
|
| 81 |
+
self.formatter = AnalysisFormatter()
|
| 82 |
+
|
| 83 |
+
def analyze(
|
| 84 |
+
self,
|
| 85 |
+
query: str,
|
| 86 |
+
context: Dict[str, Any]
|
| 87 |
+
) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Conduct complete GeoBot 2.0 analysis.
|
| 90 |
+
|
| 91 |
+
Parameters
|
| 92 |
+
----------
|
| 93 |
+
query : str
|
| 94 |
+
Analytical query or question
|
| 95 |
+
context : Dict[str, Any]
|
| 96 |
+
Context for analysis including:
|
| 97 |
+
- governance_type: GovernanceType
|
| 98 |
+
- corruption_type: CorruptionType
|
| 99 |
+
- military_system: str
|
| 100 |
+
- logistics_data: Dict
|
| 101 |
+
- scenario_description: str
|
| 102 |
+
|
| 103 |
+
Returns
|
| 104 |
+
-------
|
| 105 |
+
str
|
| 106 |
+
Formatted analysis
|
| 107 |
+
"""
|
| 108 |
+
# Apply all lenses
|
| 109 |
+
analysis_components = self._apply_lenses(context)
|
| 110 |
+
|
| 111 |
+
# Create structured output
|
| 112 |
+
output = self._create_output(query, context, analysis_components)
|
| 113 |
+
|
| 114 |
+
# Validate against framework
|
| 115 |
+
if not self.framework.validate_analysis(output.__dict__):
|
| 116 |
+
raise ValueError("Analysis does not adhere to GeoBot 2.0 framework")
|
| 117 |
+
|
| 118 |
+
# Format and return
|
| 119 |
+
return self.formatter.format_analysis(output)
|
| 120 |
+
|
| 121 |
+
def _apply_lenses(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 122 |
+
"""
|
| 123 |
+
Apply all analytical lenses to context.
|
| 124 |
+
|
| 125 |
+
Parameters
|
| 126 |
+
----------
|
| 127 |
+
context : Dict[str, Any]
|
| 128 |
+
Analysis context
|
| 129 |
+
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
Dict[str, Any]
|
| 133 |
+
Lens analyses
|
| 134 |
+
"""
|
| 135 |
+
components = {}
|
| 136 |
+
|
| 137 |
+
# Lens A: Logistics
|
| 138 |
+
components['logistics'] = self.lenses.logistics.analyze(context)
|
| 139 |
+
|
| 140 |
+
# Lens B: Governance
|
| 141 |
+
if 'governance_type' in context:
|
| 142 |
+
gov_type = context['governance_type']
|
| 143 |
+
scenario = context.get('scenario_description', 'General scenario')
|
| 144 |
+
components['governance'] = self.lenses.governance.analyze(
|
| 145 |
+
gov_type, scenario
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
components['governance'] = {
|
| 149 |
+
'note': 'Governance type not specified'
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# Lens C: Corruption
|
| 153 |
+
if 'corruption_type' in context:
|
| 154 |
+
corr_type = context['corruption_type']
|
| 155 |
+
operational_context = context.get('operational_context', 'General operations')
|
| 156 |
+
components['corruption'] = self.lenses.corruption.analyze(
|
| 157 |
+
corr_type, operational_context
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
components['corruption'] = {
|
| 161 |
+
'note': 'Corruption type not specified'
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# Lens D: Non-Western
|
| 165 |
+
if 'military_system' in context:
|
| 166 |
+
military = context['military_system']
|
| 167 |
+
components['non_western'] = self.lenses.non_western.analyze(military)
|
| 168 |
+
else:
|
| 169 |
+
components['non_western'] = {
|
| 170 |
+
'note': 'Military system not specified'
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
return components
|
| 174 |
+
|
| 175 |
+
def _create_output(
|
| 176 |
+
self,
|
| 177 |
+
query: str,
|
| 178 |
+
context: Dict[str, Any],
|
| 179 |
+
components: Dict[str, Any]
|
| 180 |
+
) -> AnalysisOutput:
|
| 181 |
+
"""
|
| 182 |
+
Create structured analysis output.
|
| 183 |
+
|
| 184 |
+
Parameters
|
| 185 |
+
----------
|
| 186 |
+
query : str
|
| 187 |
+
Original query
|
| 188 |
+
context : Dict[str, Any]
|
| 189 |
+
Analysis context
|
| 190 |
+
components : Dict[str, Any]
|
| 191 |
+
Lens analysis components
|
| 192 |
+
|
| 193 |
+
Returns
|
| 194 |
+
-------
|
| 195 |
+
AnalysisOutput
|
| 196 |
+
Structured output
|
| 197 |
+
"""
|
| 198 |
+
# Extract or create summary
|
| 199 |
+
summary = context.get('summary', f"Analysis of: {query}")
|
| 200 |
+
|
| 201 |
+
# Governance analysis
|
| 202 |
+
governance_analysis = {
|
| 203 |
+
'system_type': context.get('governance_type', 'Not specified'),
|
| 204 |
+
**components.get('governance', {})
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
# Logistics analysis
|
| 208 |
+
logistics_analysis = {
|
| 209 |
+
'assessment': context.get('logistics_assessment', 'Requires detailed logistics data'),
|
| 210 |
+
**components.get('logistics', {})
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
# Corruption assessment
|
| 214 |
+
corruption_assessment = {
|
| 215 |
+
'corruption_type': context.get('corruption_type', 'Not specified'),
|
| 216 |
+
**components.get('corruption', {})
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
# Non-Western perspective
|
| 220 |
+
non_western_perspective = {
|
| 221 |
+
'military_system': context.get('military_system', 'Not specified'),
|
| 222 |
+
**components.get('non_western', {})
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
# Scenarios
|
| 226 |
+
scenarios = context.get('scenarios', [])
|
| 227 |
+
|
| 228 |
+
# Uncertainty
|
| 229 |
+
uncertainty = context.get('uncertainty_factors', [
|
| 230 |
+
"Limited visibility into internal decision-making",
|
| 231 |
+
"Intelligence gaps in operational readiness",
|
| 232 |
+
"Uncertainty in actor intentions"
|
| 233 |
+
])
|
| 234 |
+
|
| 235 |
+
# Signals
|
| 236 |
+
signals = context.get('signals_to_watch', [
|
| 237 |
+
"Changes in leadership or command structure",
|
| 238 |
+
"Shifts in training tempo or deployment patterns",
|
| 239 |
+
"Procurement and supply chain activity"
|
| 240 |
+
])
|
| 241 |
+
|
| 242 |
+
# Comparative notes
|
| 243 |
+
comparative = context.get('comparative_notes')
|
| 244 |
+
|
| 245 |
+
return self.formatter.create_structured_output(
|
| 246 |
+
summary=summary,
|
| 247 |
+
governance=governance_analysis,
|
| 248 |
+
logistics=logistics_analysis,
|
| 249 |
+
corruption=corruption_assessment,
|
| 250 |
+
non_western=non_western_perspective,
|
| 251 |
+
scenarios=scenarios,
|
| 252 |
+
uncertainty=uncertainty,
|
| 253 |
+
signals=signals,
|
| 254 |
+
comparative=comparative
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def quick_analysis(
|
| 258 |
+
self,
|
| 259 |
+
query: str,
|
| 260 |
+
governance_type: GovernanceType,
|
| 261 |
+
corruption_type: CorruptionType,
|
| 262 |
+
military_system: str,
|
| 263 |
+
summary: str,
|
| 264 |
+
**kwargs
|
| 265 |
+
) -> str:
|
| 266 |
+
"""
|
| 267 |
+
Quick analysis with minimal context.
|
| 268 |
+
|
| 269 |
+
Parameters
|
| 270 |
+
----------
|
| 271 |
+
query : str
|
| 272 |
+
Analysis query
|
| 273 |
+
governance_type : GovernanceType
|
| 274 |
+
Type of governance system
|
| 275 |
+
corruption_type : CorruptionType
|
| 276 |
+
Type of corruption present
|
| 277 |
+
military_system : str
|
| 278 |
+
Military system being analyzed
|
| 279 |
+
summary : str
|
| 280 |
+
Assessment summary
|
| 281 |
+
**kwargs
|
| 282 |
+
Additional context
|
| 283 |
+
|
| 284 |
+
Returns
|
| 285 |
+
-------
|
| 286 |
+
str
|
| 287 |
+
Formatted analysis
|
| 288 |
+
"""
|
| 289 |
+
context = {
|
| 290 |
+
'governance_type': governance_type,
|
| 291 |
+
'corruption_type': corruption_type,
|
| 292 |
+
'military_system': military_system,
|
| 293 |
+
'summary': summary,
|
| 294 |
+
**kwargs
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
return self.analyze(query, context)
|
| 298 |
+
|
| 299 |
+
def compare_governance_systems(
|
| 300 |
+
self,
|
| 301 |
+
scenario: str,
|
| 302 |
+
authoritarian_context: Dict[str, Any],
|
| 303 |
+
democratic_context: Dict[str, Any]
|
| 304 |
+
) -> str:
|
| 305 |
+
"""
|
| 306 |
+
Compare authoritarian vs democratic systems for a scenario.
|
| 307 |
+
|
| 308 |
+
Parameters
|
| 309 |
+
----------
|
| 310 |
+
scenario : str
|
| 311 |
+
Scenario description
|
| 312 |
+
authoritarian_context : Dict[str, Any]
|
| 313 |
+
Context for authoritarian system
|
| 314 |
+
democratic_context : Dict[str, Any]
|
| 315 |
+
Context for democratic system
|
| 316 |
+
|
| 317 |
+
Returns
|
| 318 |
+
-------
|
| 319 |
+
str
|
| 320 |
+
Comparative analysis
|
| 321 |
+
"""
|
| 322 |
+
comparison = self.lenses.governance.compare_systems(scenario)
|
| 323 |
+
|
| 324 |
+
output = [
|
| 325 |
+
f"COMPARATIVE ANALYSIS: {scenario}",
|
| 326 |
+
"",
|
| 327 |
+
"AUTHORITARIAN/CENTRALIZED SYSTEMS:",
|
| 328 |
+
"Advantages:"
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
for adv in comparison['authoritarian']['advantages']:
|
| 332 |
+
output.append(f" - {adv}")
|
| 333 |
+
|
| 334 |
+
output.append("\nDisadvantages:")
|
| 335 |
+
for dis in comparison['authoritarian']['disadvantages']:
|
| 336 |
+
output.append(f" - {dis}")
|
| 337 |
+
|
| 338 |
+
output.append("\nDEMOCRATIC/CONSENSUS SYSTEMS:")
|
| 339 |
+
output.append("Advantages:")
|
| 340 |
+
for adv in comparison['democratic']['advantages']:
|
| 341 |
+
output.append(f" - {adv}")
|
| 342 |
+
|
| 343 |
+
output.append("\nDisadvantages:")
|
| 344 |
+
for dis in comparison['democratic']['disadvantages']:
|
| 345 |
+
output.append(f" - {dis}")
|
| 346 |
+
|
| 347 |
+
output.append(f"\nKEY INSIGHT: {comparison['key_insight']}")
|
| 348 |
+
|
| 349 |
+
return "\n".join(output)
|
| 350 |
+
|
| 351 |
+
def assess_corruption_impact(
|
| 352 |
+
self,
|
| 353 |
+
corruption_type: CorruptionType,
|
| 354 |
+
operation_type: str
|
| 355 |
+
) -> str:
|
| 356 |
+
"""
|
| 357 |
+
Assess corruption impact on specific operation type.
|
| 358 |
+
|
| 359 |
+
Parameters
|
| 360 |
+
----------
|
| 361 |
+
corruption_type : CorruptionType
|
| 362 |
+
Type of corruption
|
| 363 |
+
operation_type : str
|
| 364 |
+
Type of operation
|
| 365 |
+
|
| 366 |
+
Returns
|
| 367 |
+
-------
|
| 368 |
+
str
|
| 369 |
+
Impact assessment
|
| 370 |
+
"""
|
| 371 |
+
return self.lenses.corruption.assess_impact(corruption_type, operation_type)
|
| 372 |
+
|
| 373 |
+
def get_framework_summary(self) -> Dict[str, Any]:
|
| 374 |
+
"""
|
| 375 |
+
Get summary of GeoBot 2.0 framework.
|
| 376 |
+
|
| 377 |
+
Returns
|
| 378 |
+
-------
|
| 379 |
+
Dict[str, Any]
|
| 380 |
+
Framework summary
|
| 381 |
+
"""
|
| 382 |
+
return self.framework.get_framework_summary()
|
| 383 |
+
|
| 384 |
+
def get_analytical_priorities(self) -> List[str]:
|
| 385 |
+
"""
|
| 386 |
+
Get list of analytical priorities.
|
| 387 |
+
|
| 388 |
+
Returns
|
| 389 |
+
-------
|
| 390 |
+
List[str]
|
| 391 |
+
Analytical priorities
|
| 392 |
+
"""
|
| 393 |
+
return self.priorities.priorities
|
geobot/analysis/formatter.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeoBot 2.0 Analysis Formatter
|
| 3 |
+
|
| 4 |
+
Formats analytical outputs according to GeoBot 2.0 specifications:
|
| 5 |
+
1. Summary conclusion
|
| 6 |
+
2. Governance structure analysis
|
| 7 |
+
3. Logistical interpretation
|
| 8 |
+
4. Corruption impact
|
| 9 |
+
5. Non-Western perspective integration
|
| 10 |
+
6. Scenarios (with probabilities)
|
| 11 |
+
7. Uncertainty factors
|
| 12 |
+
8. Signals to watch
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from typing import Dict, List, Any, Optional
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class Scenario:
|
| 21 |
+
"""A scenario with probability estimate."""
|
| 22 |
+
name: str
|
| 23 |
+
probability: float
|
| 24 |
+
description: str
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class AnalysisOutput:
|
| 29 |
+
"""Structured analysis output."""
|
| 30 |
+
summary: str
|
| 31 |
+
governance_analysis: Dict[str, Any]
|
| 32 |
+
logistics_analysis: Dict[str, Any]
|
| 33 |
+
corruption_assessment: Dict[str, Any]
|
| 34 |
+
non_western_perspective: Dict[str, Any]
|
| 35 |
+
scenarios: List[Scenario] = field(default_factory=list)
|
| 36 |
+
uncertainty_factors: List[str] = field(default_factory=list)
|
| 37 |
+
signals_to_watch: List[str] = field(default_factory=list)
|
| 38 |
+
comparative_notes: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AnalysisFormatter:
|
| 42 |
+
"""
|
| 43 |
+
Formatter for GeoBot 2.0 analytical outputs.
|
| 44 |
+
|
| 45 |
+
Formats analysis according to the default output formula:
|
| 46 |
+
1. Summary conclusion
|
| 47 |
+
2. Governance structure analysis
|
| 48 |
+
3. Logistical interpretation
|
| 49 |
+
4. Corruption impact
|
| 50 |
+
5. Non-Western perspective integration
|
| 51 |
+
6. Scenarios (with probabilities)
|
| 52 |
+
7. Uncertainty factors
|
| 53 |
+
8. Signals to watch
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self):
|
| 57 |
+
"""Initialize the formatter."""
|
| 58 |
+
self.section_separator = "\n" + "=" * 60 + "\n"
|
| 59 |
+
self.subsection_separator = "\n" + "-" * 60 + "\n"
|
| 60 |
+
|
| 61 |
+
def format_analysis(self, analysis: AnalysisOutput) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Format complete analysis output.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
analysis : AnalysisOutput
|
| 68 |
+
Analysis to format
|
| 69 |
+
|
| 70 |
+
Returns
|
| 71 |
+
-------
|
| 72 |
+
str
|
| 73 |
+
Formatted analysis
|
| 74 |
+
"""
|
| 75 |
+
sections = []
|
| 76 |
+
|
| 77 |
+
# 1. Summary
|
| 78 |
+
sections.append(self._format_summary(analysis.summary))
|
| 79 |
+
|
| 80 |
+
# 2. Governance Structure Analysis
|
| 81 |
+
sections.append(self._format_governance(analysis.governance_analysis))
|
| 82 |
+
|
| 83 |
+
# 3. Logistics
|
| 84 |
+
sections.append(self._format_logistics(analysis.logistics_analysis))
|
| 85 |
+
|
| 86 |
+
# 4. Corruption Dynamics
|
| 87 |
+
sections.append(self._format_corruption(analysis.corruption_assessment))
|
| 88 |
+
|
| 89 |
+
# 5. Non-Western Context
|
| 90 |
+
sections.append(self._format_non_western(analysis.non_western_perspective))
|
| 91 |
+
|
| 92 |
+
# 6. Scenarios
|
| 93 |
+
sections.append(self._format_scenarios(analysis.scenarios))
|
| 94 |
+
|
| 95 |
+
# 7. Uncertainty
|
| 96 |
+
sections.append(self._format_uncertainty(analysis.uncertainty_factors))
|
| 97 |
+
|
| 98 |
+
# 8. Signals to Watch
|
| 99 |
+
sections.append(self._format_signals(analysis.signals_to_watch))
|
| 100 |
+
|
| 101 |
+
# 9. Comparative Note (if present)
|
| 102 |
+
if analysis.comparative_notes:
|
| 103 |
+
sections.append(self._format_comparative(analysis.comparative_notes))
|
| 104 |
+
|
| 105 |
+
return self.section_separator.join(sections)
|
| 106 |
+
|
| 107 |
+
def _format_summary(self, summary: str) -> str:
|
| 108 |
+
"""Format assessment summary."""
|
| 109 |
+
return f"""ASSESSMENT:
|
| 110 |
+
{summary}"""
|
| 111 |
+
|
| 112 |
+
def _format_governance(self, governance: Dict[str, Any]) -> str:
|
| 113 |
+
"""Format governance structure analysis."""
|
| 114 |
+
output = ["GOVERNANCE STRUCTURE CONTEXT:"]
|
| 115 |
+
|
| 116 |
+
if 'system_type' in governance:
|
| 117 |
+
output.append(f"\nSystem Type: {governance['system_type']}")
|
| 118 |
+
|
| 119 |
+
if 'advantages' in governance:
|
| 120 |
+
output.append("\nAdvantages:")
|
| 121 |
+
for adv in governance['advantages']:
|
| 122 |
+
output.append(f" - {adv}")
|
| 123 |
+
|
| 124 |
+
if 'disadvantages' in governance:
|
| 125 |
+
output.append("\nDisadvantages:")
|
| 126 |
+
for dis in governance['disadvantages']:
|
| 127 |
+
output.append(f" - {dis}")
|
| 128 |
+
|
| 129 |
+
if 'trade_off' in governance:
|
| 130 |
+
output.append(f"\nTrade-off: {governance['trade_off']}")
|
| 131 |
+
|
| 132 |
+
if 'context_specific_advantage' in governance:
|
| 133 |
+
output.append(f"\nContextual Advantage: {governance['context_specific_advantage']}")
|
| 134 |
+
|
| 135 |
+
return "\n".join(output)
|
| 136 |
+
|
| 137 |
+
def _format_logistics(self, logistics: Dict[str, Any]) -> str:
|
| 138 |
+
"""Format logistics analysis."""
|
| 139 |
+
output = ["LOGISTICS:"]
|
| 140 |
+
|
| 141 |
+
if 'assessment' in logistics:
|
| 142 |
+
output.append(f"\n{logistics['assessment']}")
|
| 143 |
+
|
| 144 |
+
if 'supply_chain_status' in logistics:
|
| 145 |
+
output.append(f"\nSupply Chain Status: {logistics['supply_chain_status']}")
|
| 146 |
+
|
| 147 |
+
if 'maintenance_capacity' in logistics:
|
| 148 |
+
output.append(f"Maintenance Capacity: {logistics['maintenance_capacity']}")
|
| 149 |
+
|
| 150 |
+
if 'key_constraints' in logistics:
|
| 151 |
+
output.append("\nKey Constraints:")
|
| 152 |
+
for constraint in logistics['key_constraints']:
|
| 153 |
+
output.append(f" - {constraint}")
|
| 154 |
+
|
| 155 |
+
if 'mitigating_factors' in logistics:
|
| 156 |
+
output.append("\nMitigating Factors:")
|
| 157 |
+
for factor in logistics['mitigating_factors']:
|
| 158 |
+
output.append(f" - {factor}")
|
| 159 |
+
|
| 160 |
+
return "\n".join(output)
|
| 161 |
+
|
| 162 |
+
def _format_corruption(self, corruption: Dict[str, Any]) -> str:
|
| 163 |
+
"""Format corruption assessment."""
|
| 164 |
+
output = ["CORRUPTION DYNAMICS:"]
|
| 165 |
+
|
| 166 |
+
if 'corruption_type' in corruption:
|
| 167 |
+
output.append(f"\nCorruption Type: {corruption['corruption_type']}")
|
| 168 |
+
|
| 169 |
+
if 'evidence' in corruption:
|
| 170 |
+
output.append(f"\nEvidence: {corruption['evidence']}")
|
| 171 |
+
|
| 172 |
+
if 'operational_impact' in corruption:
|
| 173 |
+
output.append(f"\nOperational Impact: {corruption['operational_impact']}")
|
| 174 |
+
|
| 175 |
+
if 'comparison' in corruption:
|
| 176 |
+
output.append(f"\nComparative Context: {corruption['comparison']}")
|
| 177 |
+
|
| 178 |
+
if 'risk_assessment' in corruption:
|
| 179 |
+
output.append(f"\nRisk: {corruption['risk_assessment']}")
|
| 180 |
+
|
| 181 |
+
return "\n".join(output)
|
| 182 |
+
|
| 183 |
+
def _format_non_western(self, non_western: Dict[str, Any]) -> str:
|
| 184 |
+
"""Format non-Western perspective."""
|
| 185 |
+
output = ["NON-WESTERN CONTEXT:"]
|
| 186 |
+
|
| 187 |
+
if 'analysis_framework' in non_western:
|
| 188 |
+
output.append(f"\n{non_western['analysis_framework']}")
|
| 189 |
+
|
| 190 |
+
if 'indigenous_strengths' in non_western:
|
| 191 |
+
output.append("\nIndigenous Strengths:")
|
| 192 |
+
for strength in non_western['indigenous_strengths']:
|
| 193 |
+
output.append(f" - {strength}")
|
| 194 |
+
|
| 195 |
+
if 'structural_constraints' in non_western:
|
| 196 |
+
output.append("\nStructural Constraints:")
|
| 197 |
+
for constraint in non_western['structural_constraints']:
|
| 198 |
+
output.append(f" - {constraint}")
|
| 199 |
+
|
| 200 |
+
if 'institutional_context' in non_western:
|
| 201 |
+
output.append(f"\nInstitutional Context: {non_western['institutional_context']}")
|
| 202 |
+
|
| 203 |
+
if 'key_distinction' in non_western:
|
| 204 |
+
output.append(f"\nKey Distinction: {non_western['key_distinction']}")
|
| 205 |
+
|
| 206 |
+
return "\n".join(output)
|
| 207 |
+
|
| 208 |
+
def _format_scenarios(self, scenarios: List[Scenario]) -> str:
|
| 209 |
+
"""Format scenario probabilities."""
|
| 210 |
+
output = ["SCENARIOS:"]
|
| 211 |
+
|
| 212 |
+
if not scenarios:
|
| 213 |
+
output.append("\n(Scenarios require additional context)")
|
| 214 |
+
return "\n".join(output)
|
| 215 |
+
|
| 216 |
+
# Sort by probability descending
|
| 217 |
+
sorted_scenarios = sorted(scenarios, key=lambda s: s.probability, reverse=True)
|
| 218 |
+
|
| 219 |
+
for scenario in sorted_scenarios:
|
| 220 |
+
output.append(f"\n• {scenario.name} ({scenario.probability:.2f})")
|
| 221 |
+
output.append(f" {scenario.description}")
|
| 222 |
+
|
| 223 |
+
return "\n".join(output)
|
| 224 |
+
|
| 225 |
+
def _format_uncertainty(self, uncertainty_factors: List[str]) -> str:
|
| 226 |
+
"""Format uncertainty factors."""
|
| 227 |
+
output = ["UNCERTAINTY:"]
|
| 228 |
+
|
| 229 |
+
if not uncertainty_factors:
|
| 230 |
+
output.append("\n(Standard intelligence limitations apply)")
|
| 231 |
+
return "\n".join(output)
|
| 232 |
+
|
| 233 |
+
for factor in uncertainty_factors:
|
| 234 |
+
output.append(f" - {factor}")
|
| 235 |
+
|
| 236 |
+
return "\n".join(output)
|
| 237 |
+
|
| 238 |
+
def _format_signals(self, signals: List[str]) -> str:
|
| 239 |
+
"""Format signals to watch."""
|
| 240 |
+
output = ["SIGNALS TO WATCH:"]
|
| 241 |
+
|
| 242 |
+
if not signals:
|
| 243 |
+
output.append("\n(Ongoing monitoring required)")
|
| 244 |
+
return "\n".join(output)
|
| 245 |
+
|
| 246 |
+
for signal in signals:
|
| 247 |
+
output.append(f" - {signal}")
|
| 248 |
+
|
| 249 |
+
return "\n".join(output)
|
| 250 |
+
|
| 251 |
+
def _format_comparative(self, comparative_note: str) -> str:
|
| 252 |
+
"""Format comparative note."""
|
| 253 |
+
return f"""COMPARATIVE NOTE:
|
| 254 |
+
{comparative_note}"""
|
| 255 |
+
|
| 256 |
+
def create_structured_output(
|
| 257 |
+
self,
|
| 258 |
+
summary: str,
|
| 259 |
+
governance: Dict[str, Any],
|
| 260 |
+
logistics: Dict[str, Any],
|
| 261 |
+
corruption: Dict[str, Any],
|
| 262 |
+
non_western: Dict[str, Any],
|
| 263 |
+
scenarios: Optional[List[Dict[str, Any]]] = None,
|
| 264 |
+
uncertainty: Optional[List[str]] = None,
|
| 265 |
+
signals: Optional[List[str]] = None,
|
| 266 |
+
comparative: Optional[str] = None
|
| 267 |
+
) -> AnalysisOutput:
|
| 268 |
+
"""
|
| 269 |
+
Create structured analysis output.
|
| 270 |
+
|
| 271 |
+
Parameters
|
| 272 |
+
----------
|
| 273 |
+
summary : str
|
| 274 |
+
Assessment summary
|
| 275 |
+
governance : Dict[str, Any]
|
| 276 |
+
Governance structure analysis
|
| 277 |
+
logistics : Dict[str, Any]
|
| 278 |
+
Logistics analysis
|
| 279 |
+
corruption : Dict[str, Any]
|
| 280 |
+
Corruption assessment
|
| 281 |
+
non_western : Dict[str, Any]
|
| 282 |
+
Non-Western perspective
|
| 283 |
+
scenarios : Optional[List[Dict[str, Any]]]
|
| 284 |
+
List of scenarios with probabilities
|
| 285 |
+
uncertainty : Optional[List[str]]
|
| 286 |
+
Uncertainty factors
|
| 287 |
+
signals : Optional[List[str]]
|
| 288 |
+
Signals to watch
|
| 289 |
+
comparative : Optional[str]
|
| 290 |
+
Comparative note
|
| 291 |
+
|
| 292 |
+
Returns
|
| 293 |
+
-------
|
| 294 |
+
AnalysisOutput
|
| 295 |
+
Structured output
|
| 296 |
+
"""
|
| 297 |
+
scenario_objects = []
|
| 298 |
+
if scenarios:
|
| 299 |
+
scenario_objects = [
|
| 300 |
+
Scenario(
|
| 301 |
+
name=s['name'],
|
| 302 |
+
probability=s['probability'],
|
| 303 |
+
description=s['description']
|
| 304 |
+
)
|
| 305 |
+
for s in scenarios
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
return AnalysisOutput(
|
| 309 |
+
summary=summary,
|
| 310 |
+
governance_analysis=governance,
|
| 311 |
+
logistics_analysis=logistics,
|
| 312 |
+
corruption_assessment=corruption,
|
| 313 |
+
non_western_perspective=non_western,
|
| 314 |
+
scenarios=scenario_objects,
|
| 315 |
+
uncertainty_factors=uncertainty or [],
|
| 316 |
+
signals_to_watch=signals or [],
|
| 317 |
+
comparative_notes=comparative
|
| 318 |
+
)
|
geobot/analysis/framework.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeoBot 2.0 Core Framework
|
| 3 |
+
|
| 4 |
+
Defines the core identity, tone, and analytical principles for GeoBot 2.0,
|
| 5 |
+
a clinical systems analyst with geopolitical nuance.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ToneElement(Enum):
|
| 14 |
+
"""Tone elements for GeoBot 2.0 analysis."""
|
| 15 |
+
NEUTRAL = "neutral and clinical"
|
| 16 |
+
ANALYTIC = "analytic, not poetic"
|
| 17 |
+
SKEPTICAL = "cautiously skeptical of all systems, including Western ones"
|
| 18 |
+
SYSTEMS_ORIENTED = "systems-oriented - analyzes structural trade-offs"
|
| 19 |
+
CAVEATED = "heavily caveated"
|
| 20 |
+
RISK_REPORT = "risk-report style"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class CoreIdentity:
|
| 25 |
+
"""
|
| 26 |
+
Core identity of GeoBot 2.0.
|
| 27 |
+
|
| 28 |
+
GeoBot remains a clinical, logistics-focused analyst, but now integrates:
|
| 29 |
+
- Institutional agility assessment
|
| 30 |
+
- Cultural-operational context
|
| 31 |
+
- Adaptive capacity modeling
|
| 32 |
+
- Non-Western institutional logic
|
| 33 |
+
|
| 34 |
+
Key shift: Analyzes structural trade-offs rather than assuming
|
| 35 |
+
Western organizational models are superior.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
focus: str = "clinical, logistics-focused analyst"
|
| 39 |
+
|
| 40 |
+
integration_elements: List[str] = field(default_factory=lambda: [
|
| 41 |
+
"Institutional agility assessment (authoritarian vs. consensus-based decision structures)",
|
| 42 |
+
"Cultural-operational context (how different militaries actually function, not just Western assumptions)",
|
| 43 |
+
"Adaptive capacity modeling (who can pivot quickly under stress, and why)",
|
| 44 |
+
"Non-Western institutional logic (understanding Chinese, Russian, Iranian, etc. systems on their own terms)"
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
key_shift: str = "Analyzes structural trade-offs rather than assuming Western organizational models are superior"
|
| 48 |
+
|
| 49 |
+
tone_elements: List[ToneElement] = field(default_factory=lambda: [
|
| 50 |
+
ToneElement.NEUTRAL,
|
| 51 |
+
ToneElement.ANALYTIC,
|
| 52 |
+
ToneElement.SKEPTICAL,
|
| 53 |
+
ToneElement.SYSTEMS_ORIENTED,
|
| 54 |
+
ToneElement.CAVEATED,
|
| 55 |
+
ToneElement.RISK_REPORT
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
def get_tone_description(self) -> str:
|
| 59 |
+
"""Get description of analytical tone."""
|
| 60 |
+
return "\n".join([f"- {tone.value}" for tone in self.tone_elements])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class AnalyticalPrinciples:
|
| 65 |
+
"""
|
| 66 |
+
Embedded analytical principles that GeoBot 2.0 believes and operates by.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
principles: List[str] = field(default_factory=lambda: [
|
| 70 |
+
"Governance structure creates operational trade-offs, not just advantages/disadvantages",
|
| 71 |
+
"Authoritarian systems have real agility advantages in strategic pivots and crisis mobilization",
|
| 72 |
+
"Corruption impact depends on type and context, not just existence",
|
| 73 |
+
"Non-Western militaries must be analyzed using their own organizational logic",
|
| 74 |
+
"Logistics remain the ultimate constraint, but cultural factors shape how logistics are managed",
|
| 75 |
+
"Western military assumptions often miss indigenous capabilities",
|
| 76 |
+
"Purges can signal both weakness AND functional institutional enforcement"
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
def get_principles_list(self) -> List[str]:
|
| 80 |
+
"""Get list of analytical principles."""
|
| 81 |
+
return self.principles
|
| 82 |
+
|
| 83 |
+
def validate_analysis(self, analysis: Dict[str, Any]) -> bool:
|
| 84 |
+
"""
|
| 85 |
+
Validate that an analysis adheres to analytical principles.
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
analysis : Dict[str, Any]
|
| 90 |
+
Analysis to validate
|
| 91 |
+
|
| 92 |
+
Returns
|
| 93 |
+
-------
|
| 94 |
+
bool
|
| 95 |
+
True if analysis adheres to principles
|
| 96 |
+
"""
|
| 97 |
+
# Check for required elements
|
| 98 |
+
required_elements = [
|
| 99 |
+
'governance_context',
|
| 100 |
+
'logistics_analysis',
|
| 101 |
+
'corruption_assessment',
|
| 102 |
+
'non_western_perspective'
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
return all(element in analysis for element in required_elements)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@dataclass
|
| 109 |
+
class GeoBotFramework:
|
| 110 |
+
"""
|
| 111 |
+
Complete GeoBot 2.0 framework combining identity, tone, and principles.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
core_identity: CoreIdentity = field(default_factory=CoreIdentity)
|
| 115 |
+
analytical_principles: AnalyticalPrinciples = field(default_factory=AnalyticalPrinciples)
|
| 116 |
+
|
| 117 |
+
version: str = "2.0"
|
| 118 |
+
description: str = "Cold Systems Analysis with Geopolitical Nuance"
|
| 119 |
+
|
| 120 |
+
def get_framework_summary(self) -> Dict[str, Any]:
|
| 121 |
+
"""
|
| 122 |
+
Get summary of the complete framework.
|
| 123 |
+
|
| 124 |
+
Returns
|
| 125 |
+
-------
|
| 126 |
+
Dict[str, Any]
|
| 127 |
+
Framework summary
|
| 128 |
+
"""
|
| 129 |
+
return {
|
| 130 |
+
'version': self.version,
|
| 131 |
+
'description': self.description,
|
| 132 |
+
'identity': {
|
| 133 |
+
'focus': self.core_identity.focus,
|
| 134 |
+
'key_shift': self.core_identity.key_shift,
|
| 135 |
+
'integration_elements': self.core_identity.integration_elements
|
| 136 |
+
},
|
| 137 |
+
'tone': self.core_identity.get_tone_description(),
|
| 138 |
+
'principles': self.analytical_principles.get_principles_list()
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
def validate_analysis(self, analysis: Dict[str, Any]) -> bool:
|
| 142 |
+
"""
|
| 143 |
+
Validate that an analysis adheres to GeoBot 2.0 framework.
|
| 144 |
+
|
| 145 |
+
Parameters
|
| 146 |
+
----------
|
| 147 |
+
analysis : Dict[str, Any]
|
| 148 |
+
Analysis to validate
|
| 149 |
+
|
| 150 |
+
Returns
|
| 151 |
+
-------
|
| 152 |
+
bool
|
| 153 |
+
True if analysis adheres to framework
|
| 154 |
+
"""
|
| 155 |
+
return self.analytical_principles.validate_analysis(analysis)
|
geobot/analysis/lenses.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeoBot 2.0 Analytical Lenses
|
| 3 |
+
|
| 4 |
+
Four complementary analytical lenses for geopolitical analysis:
|
| 5 |
+
- Lens A: Logistics as Power
|
| 6 |
+
- Lens B: Governance Structure & Decision Speed
|
| 7 |
+
- Lens C: Corruption as Context-Dependent Variable
|
| 8 |
+
- Lens D: Non-Western Military Logic
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Dict, List, Any, Optional
|
| 13 |
+
from enum import Enum
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ============================================================================
|
| 17 |
+
# Lens A: Logistics as Power
|
| 18 |
+
# ============================================================================
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class LogisticsLens:
|
| 22 |
+
"""
|
| 23 |
+
Lens A: Logistics as Power (Unchanged from GeoBot v1)
|
| 24 |
+
|
| 25 |
+
Prioritizes supply chains, throughput, maintenance, communications infrastructure.
|
| 26 |
+
Logistics remain the ultimate constraint.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
name: str = "Logistics as Power"
|
| 30 |
+
priority_areas: List[str] = field(default_factory=lambda: [
|
| 31 |
+
"Supply chains and throughput",
|
| 32 |
+
"Maintenance capacity",
|
| 33 |
+
"Communications infrastructure",
|
| 34 |
+
"Resource mobilization speed",
|
| 35 |
+
"Sustainment capacity"
|
| 36 |
+
])
|
| 37 |
+
|
| 38 |
+
def analyze(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 39 |
+
"""
|
| 40 |
+
Analyze situation through logistics lens.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
context : Dict[str, Any]
|
| 45 |
+
Situational context
|
| 46 |
+
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
Dict[str, Any]
|
| 50 |
+
Logistics analysis
|
| 51 |
+
"""
|
| 52 |
+
return {
|
| 53 |
+
'lens': self.name,
|
| 54 |
+
'priority_areas': self.priority_areas,
|
| 55 |
+
'assessment': "Logistics coherence analysis required",
|
| 56 |
+
'context': context
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ============================================================================
|
| 61 |
+
# Lens B: Governance Structure & Decision Speed
|
| 62 |
+
# ============================================================================
|
| 63 |
+
|
| 64 |
+
class GovernanceType(Enum):
|
| 65 |
+
"""Types of governance structures."""
|
| 66 |
+
AUTHORITARIAN_CENTRALIZED = "authoritarian/centralized"
|
| 67 |
+
DEMOCRATIC_CONSENSUS = "democratic/consensus"
|
| 68 |
+
HYBRID = "hybrid"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class GovernanceAdvantages:
|
| 73 |
+
"""Advantages of a governance type."""
|
| 74 |
+
advantages: List[str] = field(default_factory=list)
|
| 75 |
+
disadvantages: List[str] = field(default_factory=list)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class GovernanceLens:
|
| 80 |
+
"""
|
| 81 |
+
Lens B: Governance Structure & Decision Speed
|
| 82 |
+
|
| 83 |
+
Evaluates institutional agility and decision-making structures.
|
| 84 |
+
Recognizes that different governance structures create different
|
| 85 |
+
operational capabilities, not just deficits.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
name: str = "Governance Structure & Decision Speed"
|
| 89 |
+
|
| 90 |
+
authoritarian_profile: GovernanceAdvantages = field(default_factory=lambda: GovernanceAdvantages(
|
| 91 |
+
advantages=[
|
| 92 |
+
"Faster strategic pivots (no legislative/consensus delays)",
|
| 93 |
+
"Rapid resource mobilization during crises",
|
| 94 |
+
"Unified command structures (fewer veto points)",
|
| 95 |
+
"Ability to absorb short-term costs for long-term positioning",
|
| 96 |
+
"Less vulnerable to public opinion shifts"
|
| 97 |
+
],
|
| 98 |
+
disadvantages=[
|
| 99 |
+
"Higher corruption risk (less accountability)",
|
| 100 |
+
"Information distortion (fear of reporting bad news upward)",
|
| 101 |
+
"Brittleness under sustained stress (rigid hierarchies)",
|
| 102 |
+
"Lower tactical initiative at junior levels",
|
| 103 |
+
"Purge-induced institutional memory loss"
|
| 104 |
+
]
|
| 105 |
+
))
|
| 106 |
+
|
| 107 |
+
democratic_profile: GovernanceAdvantages = field(default_factory=lambda: GovernanceAdvantages(
|
| 108 |
+
advantages=[
|
| 109 |
+
"Better information flow (less fear-based reporting)",
|
| 110 |
+
"Higher tactical flexibility (NCO empowerment)",
|
| 111 |
+
"More resilient under prolonged strain",
|
| 112 |
+
"Transparent procurement (lower corruption)",
|
| 113 |
+
"Adaptive learning cultures"
|
| 114 |
+
],
|
| 115 |
+
disadvantages=[
|
| 116 |
+
"Slower strategic decision-making (multiple approval layers)",
|
| 117 |
+
"Political constraints on deployment/escalation",
|
| 118 |
+
"Public opinion as operational constraint",
|
| 119 |
+
"Bureaucratic friction in mobilization",
|
| 120 |
+
"Difficulty sustaining unpopular policies"
|
| 121 |
+
]
|
| 122 |
+
))
|
| 123 |
+
|
| 124 |
+
def analyze(self, governance_type: GovernanceType, scenario_context: str) -> Dict[str, Any]:
|
| 125 |
+
"""
|
| 126 |
+
Analyze which governance type has structural advantage for specific scenario.
|
| 127 |
+
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
governance_type : GovernanceType
|
| 131 |
+
Type of governance structure
|
| 132 |
+
scenario_context : str
|
| 133 |
+
Description of scenario requiring analysis
|
| 134 |
+
|
| 135 |
+
Returns
|
| 136 |
+
-------
|
| 137 |
+
Dict[str, Any]
|
| 138 |
+
Governance structure analysis
|
| 139 |
+
"""
|
| 140 |
+
profile = None
|
| 141 |
+
if governance_type == GovernanceType.AUTHORITARIAN_CENTRALIZED:
|
| 142 |
+
profile = self.authoritarian_profile
|
| 143 |
+
elif governance_type == GovernanceType.DEMOCRATIC_CONSENSUS:
|
| 144 |
+
profile = self.democratic_profile
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
'lens': self.name,
|
| 148 |
+
'governance_type': governance_type.value,
|
| 149 |
+
'advantages': profile.advantages if profile else [],
|
| 150 |
+
'disadvantages': profile.disadvantages if profile else [],
|
| 151 |
+
'scenario_context': scenario_context,
|
| 152 |
+
'key_question': "Which governance type advantages matter for this specific scenario?"
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def compare_systems(self, scenario_context: str) -> Dict[str, Any]:
|
| 156 |
+
"""
|
| 157 |
+
Compare authoritarian vs democratic systems for a scenario.
|
| 158 |
+
|
| 159 |
+
Parameters
|
| 160 |
+
----------
|
| 161 |
+
scenario_context : str
|
| 162 |
+
Description of scenario
|
| 163 |
+
|
| 164 |
+
Returns
|
| 165 |
+
-------
|
| 166 |
+
Dict[str, Any]
|
| 167 |
+
Comparative analysis
|
| 168 |
+
"""
|
| 169 |
+
return {
|
| 170 |
+
'lens': self.name,
|
| 171 |
+
'scenario': scenario_context,
|
| 172 |
+
'authoritarian': {
|
| 173 |
+
'advantages': self.authoritarian_profile.advantages,
|
| 174 |
+
'disadvantages': self.authoritarian_profile.disadvantages
|
| 175 |
+
},
|
| 176 |
+
'democratic': {
|
| 177 |
+
'advantages': self.democratic_profile.advantages,
|
| 178 |
+
'disadvantages': self.democratic_profile.disadvantages
|
| 179 |
+
},
|
| 180 |
+
'key_insight': "Governance structure creates operational trade-offs, not universal superiority"
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ============================================================================
|
| 185 |
+
# Lens C: Corruption as Context-Dependent Variable
|
| 186 |
+
# ============================================================================
|
| 187 |
+
|
| 188 |
+
class CorruptionType(Enum):
|
| 189 |
+
"""Types of corruption and their operational impacts."""
|
| 190 |
+
PARASITIC = "parasitic" # Hollows readiness, predictably degrades performance
|
| 191 |
+
MANAGED_BOUNDED = "managed/bounded" # Limited by periodic purges
|
| 192 |
+
INSTITUTIONALIZED_PATRONAGE = "institutionalized patronage" # Loyalty networks
|
| 193 |
+
LOW_CORRUPTION = "low corruption" # Western militaries, Singapore
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@dataclass
|
| 197 |
+
class CorruptionProfile:
|
| 198 |
+
"""Profile of corruption type and its impacts."""
|
| 199 |
+
corruption_type: CorruptionType
|
| 200 |
+
characteristics: List[str] = field(default_factory=list)
|
| 201 |
+
operational_impact: str = ""
|
| 202 |
+
examples: List[str] = field(default_factory=list)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@dataclass
|
| 206 |
+
class CorruptionLens:
|
| 207 |
+
"""
|
| 208 |
+
Lens C: Corruption as Context-Dependent Variable
|
| 209 |
+
|
| 210 |
+
Corruption is no longer assumed to be universally crippling.
|
| 211 |
+
Instead, analyzes corruption type and its context-specific impacts.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
name: str = "Corruption as Context-Dependent Variable"
|
| 215 |
+
|
| 216 |
+
corruption_profiles: Dict[CorruptionType, CorruptionProfile] = field(default_factory=lambda: {
|
| 217 |
+
CorruptionType.PARASITIC: CorruptionProfile(
|
| 218 |
+
corruption_type=CorruptionType.PARASITIC,
|
| 219 |
+
characteristics=[
|
| 220 |
+
"Hollows readiness",
|
| 221 |
+
"Steals from supply chains",
|
| 222 |
+
"Predictably degrades performance"
|
| 223 |
+
],
|
| 224 |
+
operational_impact="Severe degradation of operational capability",
|
| 225 |
+
examples=["Russia (extensive)", "Many Global South militaries"]
|
| 226 |
+
),
|
| 227 |
+
CorruptionType.MANAGED_BOUNDED: CorruptionProfile(
|
| 228 |
+
corruption_type=CorruptionType.MANAGED_BOUNDED,
|
| 229 |
+
characteristics=[
|
| 230 |
+
"Limited by periodic purges and surveillance",
|
| 231 |
+
"Still present, but constrained enough to maintain basic functionality",
|
| 232 |
+
"Risk: purges themselves create instability"
|
| 233 |
+
],
|
| 234 |
+
operational_impact="Moderate impact, mitigated by enforcement",
|
| 235 |
+
examples=["China post-Xi purges"]
|
| 236 |
+
),
|
| 237 |
+
CorruptionType.INSTITUTIONALIZED_PATRONAGE: CorruptionProfile(
|
| 238 |
+
corruption_type=CorruptionType.INSTITUTIONALIZED_PATRONAGE,
|
| 239 |
+
characteristics=[
|
| 240 |
+
"Loyalty networks provide cohesion",
|
| 241 |
+
"Can coexist with effectiveness if tied to performance"
|
| 242 |
+
],
|
| 243 |
+
operational_impact="Variable - depends on performance accountability",
|
| 244 |
+
examples=["Iran IRGC", "Some Gulf states"]
|
| 245 |
+
),
|
| 246 |
+
CorruptionType.LOW_CORRUPTION: CorruptionProfile(
|
| 247 |
+
corruption_type=CorruptionType.LOW_CORRUPTION,
|
| 248 |
+
characteristics=[
|
| 249 |
+
"Advantage in sustained operations",
|
| 250 |
+
"Can be slower to mobilize"
|
| 251 |
+
],
|
| 252 |
+
operational_impact="Minimal negative impact, enables sustained operations",
|
| 253 |
+
examples=["Western militaries", "Singapore"]
|
| 254 |
+
)
|
| 255 |
+
})
|
| 256 |
+
|
| 257 |
+
def analyze(self, corruption_type: CorruptionType, operational_context: str) -> Dict[str, Any]:
|
| 258 |
+
"""
|
| 259 |
+
Analyze corruption impact in specific operational context.
|
| 260 |
+
|
| 261 |
+
Parameters
|
| 262 |
+
----------
|
| 263 |
+
corruption_type : CorruptionType
|
| 264 |
+
Type of corruption present
|
| 265 |
+
operational_context : str
|
| 266 |
+
Operational context for assessment
|
| 267 |
+
|
| 268 |
+
Returns
|
| 269 |
+
-------
|
| 270 |
+
Dict[str, Any]
|
| 271 |
+
Corruption impact analysis
|
| 272 |
+
"""
|
| 273 |
+
profile = self.corruption_profiles[corruption_type]
|
| 274 |
+
|
| 275 |
+
return {
|
| 276 |
+
'lens': self.name,
|
| 277 |
+
'corruption_type': corruption_type.value,
|
| 278 |
+
'characteristics': profile.characteristics,
|
| 279 |
+
'operational_impact': profile.operational_impact,
|
| 280 |
+
'examples': profile.examples,
|
| 281 |
+
'operational_context': operational_context,
|
| 282 |
+
'key_question': "What type of corruption, and how does it interact with operational demands?"
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
def assess_impact(self, corruption_type: CorruptionType, operation_type: str) -> str:
|
| 286 |
+
"""
|
| 287 |
+
Assess whether corruption critically impairs specific operation.
|
| 288 |
+
|
| 289 |
+
Parameters
|
| 290 |
+
----------
|
| 291 |
+
corruption_type : CorruptionType
|
| 292 |
+
Type of corruption
|
| 293 |
+
operation_type : str
|
| 294 |
+
Type of military operation
|
| 295 |
+
|
| 296 |
+
Returns
|
| 297 |
+
-------
|
| 298 |
+
str
|
| 299 |
+
Impact assessment
|
| 300 |
+
"""
|
| 301 |
+
profile = self.corruption_profiles[corruption_type]
|
| 302 |
+
return f"For {operation_type}: {profile.operational_impact}"
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ============================================================================
|
| 306 |
+
# Lens D: Non-Western Military Logic
|
| 307 |
+
# ============================================================================
|
| 308 |
+
|
| 309 |
+
@dataclass
|
| 310 |
+
class MilitaryProfile:
|
| 311 |
+
"""Profile of a military system's strengths and weaknesses."""
|
| 312 |
+
strengths: List[str] = field(default_factory=list)
|
| 313 |
+
weaknesses: List[str] = field(default_factory=list)
|
| 314 |
+
operational_culture: str = ""
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@dataclass
|
| 318 |
+
class NonWesternLens:
|
| 319 |
+
"""
|
| 320 |
+
Lens D: Non-Western Military Logic
|
| 321 |
+
|
| 322 |
+
Incorporates indigenous operational cultures rather than measuring
|
| 323 |
+
everything against NATO standards. Analyzes non-Western militaries
|
| 324 |
+
using appropriate cultural/organizational frameworks.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
name: str = "Non-Western Military Logic"
|
| 328 |
+
|
| 329 |
+
military_profiles: Dict[str, MilitaryProfile] = field(default_factory=lambda: {
|
| 330 |
+
"Chinese PLA": MilitaryProfile(
|
| 331 |
+
strengths=[
|
| 332 |
+
"Rapid infrastructure mobilization (civil-military fusion)",
|
| 333 |
+
"Industrial base integration",
|
| 334 |
+
"Coastal defense asymmetric advantages",
|
| 335 |
+
"Improving joint operations capability",
|
| 336 |
+
"Long-term strategic patience"
|
| 337 |
+
],
|
| 338 |
+
weaknesses=[
|
| 339 |
+
"Limited expeditionary experience",
|
| 340 |
+
"Unproven complex joint operations",
|
| 341 |
+
"NCO corps still developing",
|
| 342 |
+
"Logistics for sustained high-intensity operations"
|
| 343 |
+
],
|
| 344 |
+
operational_culture="Centralized strategic planning with improving tactical adaptation"
|
| 345 |
+
),
|
| 346 |
+
"Russian Military": MilitaryProfile(
|
| 347 |
+
strengths=[
|
| 348 |
+
"Artillery coordination",
|
| 349 |
+
"Tactical adaptation under fire (demonstrated in Ukraine)",
|
| 350 |
+
"Willingness to accept casualties",
|
| 351 |
+
"Deep fires integration"
|
| 352 |
+
],
|
| 353 |
+
weaknesses=[
|
| 354 |
+
"Logistics corruption (confirmed)",
|
| 355 |
+
"Poor junior leadership initiative",
|
| 356 |
+
"Industrial base constraints under sanctions"
|
| 357 |
+
],
|
| 358 |
+
operational_culture="Heavy firepower doctrine with rigid tactical execution"
|
| 359 |
+
),
|
| 360 |
+
"Iranian Systems": MilitaryProfile(
|
| 361 |
+
strengths=[
|
| 362 |
+
"Proxy warfare coordination",
|
| 363 |
+
"Missile/drone saturation tactics",
|
| 364 |
+
"Strategic patience",
|
| 365 |
+
"Asymmetric warfare effectiveness"
|
| 366 |
+
],
|
| 367 |
+
weaknesses=[
|
| 368 |
+
"Air force decay",
|
| 369 |
+
"Sanctions-induced technology gaps",
|
| 370 |
+
"Conventional forces limitations"
|
| 371 |
+
],
|
| 372 |
+
operational_culture="Asymmetric focus with strategic depth through proxies"
|
| 373 |
+
)
|
| 374 |
+
})
|
| 375 |
+
|
| 376 |
+
def analyze(self, military: str) -> Dict[str, Any]:
|
| 377 |
+
"""
|
| 378 |
+
Analyze military using appropriate cultural/organizational framework.
|
| 379 |
+
|
| 380 |
+
Parameters
|
| 381 |
+
----------
|
| 382 |
+
military : str
|
| 383 |
+
Military system to analyze
|
| 384 |
+
|
| 385 |
+
Returns
|
| 386 |
+
-------
|
| 387 |
+
Dict[str, Any]
|
| 388 |
+
Non-Western perspective analysis
|
| 389 |
+
"""
|
| 390 |
+
if military not in self.military_profiles:
|
| 391 |
+
return {
|
| 392 |
+
'lens': self.name,
|
| 393 |
+
'military': military,
|
| 394 |
+
'warning': "Military profile not defined - analysis requires custom framework",
|
| 395 |
+
'key_question': "What are we missing if we only use Western assumptions?"
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
profile = self.military_profiles[military]
|
| 399 |
+
|
| 400 |
+
return {
|
| 401 |
+
'lens': self.name,
|
| 402 |
+
'military': military,
|
| 403 |
+
'strengths': profile.strengths,
|
| 404 |
+
'weaknesses': profile.weaknesses,
|
| 405 |
+
'operational_culture': profile.operational_culture,
|
| 406 |
+
'key_insight': "Non-obvious strengths often missed by Western-centric analysis",
|
| 407 |
+
'key_question': "Are we analyzing this military using appropriate cultural/organizational frameworks?"
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
def add_military_profile(self, military: str, profile: MilitaryProfile) -> None:
|
| 411 |
+
"""
|
| 412 |
+
Add new military profile to the lens.
|
| 413 |
+
|
| 414 |
+
Parameters
|
| 415 |
+
----------
|
| 416 |
+
military : str
|
| 417 |
+
Name of military system
|
| 418 |
+
profile : MilitaryProfile
|
| 419 |
+
Profile of the military system
|
| 420 |
+
"""
|
| 421 |
+
self.military_profiles[military] = profile
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# ============================================================================
|
| 425 |
+
# Combined Analytical Lenses
|
| 426 |
+
# ============================================================================
|
| 427 |
+
|
| 428 |
+
@dataclass
|
| 429 |
+
class AnalyticalLenses:
|
| 430 |
+
"""
|
| 431 |
+
Combined analytical lenses for comprehensive GeoBot 2.0 analysis.
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
logistics: LogisticsLens = field(default_factory=LogisticsLens)
|
| 435 |
+
governance: GovernanceLens = field(default_factory=GovernanceLens)
|
| 436 |
+
corruption: CorruptionLens = field(default_factory=CorruptionLens)
|
| 437 |
+
non_western: NonWesternLens = field(default_factory=NonWesternLens)
|
| 438 |
+
|
| 439 |
+
def get_all_lenses(self) -> Dict[str, Any]:
|
| 440 |
+
"""
|
| 441 |
+
Get all analytical lenses.
|
| 442 |
+
|
| 443 |
+
Returns
|
| 444 |
+
-------
|
| 445 |
+
Dict[str, Any]
|
| 446 |
+
All lenses
|
| 447 |
+
"""
|
| 448 |
+
return {
|
| 449 |
+
'Lens A': self.logistics,
|
| 450 |
+
'Lens B': self.governance,
|
| 451 |
+
'Lens C': self.corruption,
|
| 452 |
+
'Lens D': self.non_western
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
def apply_all_lenses(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 456 |
+
"""
|
| 457 |
+
Apply all lenses to a given context.
|
| 458 |
+
|
| 459 |
+
Parameters
|
| 460 |
+
----------
|
| 461 |
+
context : Dict[str, Any]
|
| 462 |
+
Context to analyze
|
| 463 |
+
|
| 464 |
+
Returns
|
| 465 |
+
-------
|
| 466 |
+
Dict[str, Any]
|
| 467 |
+
Multi-lens analysis
|
| 468 |
+
"""
|
| 469 |
+
return {
|
| 470 |
+
'logistics_analysis': self.logistics.analyze(context),
|
| 471 |
+
'governance_analysis': self.governance.compare_systems(
|
| 472 |
+
context.get('scenario', 'General scenario')
|
| 473 |
+
),
|
| 474 |
+
'corruption_analysis': "Requires corruption type specification",
|
| 475 |
+
'non_western_analysis': "Requires military system specification",
|
| 476 |
+
'integrated_assessment': "Apply all lenses for comprehensive analysis"
|
| 477 |
+
}
|
geobot/bayes/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bayesian forecasting and belief updating for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .forecasting import (
|
| 6 |
+
BayesianForecaster,
|
| 7 |
+
BeliefState,
|
| 8 |
+
GeopoliticalPrior,
|
| 9 |
+
EvidenceUpdate,
|
| 10 |
+
ForecastDistribution,
|
| 11 |
+
CredibleInterval
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"BayesianForecaster",
|
| 16 |
+
"BeliefState",
|
| 17 |
+
"GeopoliticalPrior",
|
| 18 |
+
"EvidenceUpdate",
|
| 19 |
+
"ForecastDistribution",
|
| 20 |
+
"CredibleInterval",
|
| 21 |
+
]
|
geobot/bayes/forecasting.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bayesian Forecasting Module for GeoBotv1
|
| 3 |
+
|
| 4 |
+
Provides Bayesian belief updating, prior construction, and probabilistic forecasting
|
| 5 |
+
for geopolitical scenarios. Integrates with GeoBot 2.0 analytical framework.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Dict, List, Any, Optional, Callable, Tuple
|
| 10 |
+
from enum import Enum
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy import stats
|
| 13 |
+
from scipy.optimize import minimize
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PriorType(Enum):
|
| 17 |
+
"""Types of prior distributions."""
|
| 18 |
+
UNIFORM = "uniform"
|
| 19 |
+
NORMAL = "normal"
|
| 20 |
+
BETA = "beta"
|
| 21 |
+
GAMMA = "gamma"
|
| 22 |
+
EXPERT_INFORMED = "expert_informed"
|
| 23 |
+
HISTORICAL = "historical"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class EvidenceType(Enum):
|
| 27 |
+
"""Types of evidence for belief updating."""
|
| 28 |
+
INTELLIGENCE_REPORT = "intelligence_report"
|
| 29 |
+
SATELLITE_IMAGERY = "satellite_imagery"
|
| 30 |
+
ECONOMIC_DATA = "economic_data"
|
| 31 |
+
MILITARY_MOVEMENT = "military_movement"
|
| 32 |
+
DIPLOMATIC_SIGNAL = "diplomatic_signal"
|
| 33 |
+
OPEN_SOURCE = "open_source"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class GeopoliticalPrior:
|
| 38 |
+
"""
|
| 39 |
+
Prior distribution for geopolitical parameter.
|
| 40 |
+
|
| 41 |
+
Attributes
|
| 42 |
+
----------
|
| 43 |
+
parameter_name : str
|
| 44 |
+
Name of the parameter
|
| 45 |
+
prior_type : PriorType
|
| 46 |
+
Type of prior distribution
|
| 47 |
+
parameters : Dict[str, float]
|
| 48 |
+
Distribution parameters
|
| 49 |
+
description : str
|
| 50 |
+
Description of what this parameter represents
|
| 51 |
+
"""
|
| 52 |
+
parameter_name: str
|
| 53 |
+
prior_type: PriorType
|
| 54 |
+
parameters: Dict[str, float]
|
| 55 |
+
description: str = ""
|
| 56 |
+
|
| 57 |
+
def sample(self, n_samples: int = 1000, random_state: Optional[int] = None) -> np.ndarray:
|
| 58 |
+
"""
|
| 59 |
+
Sample from prior distribution.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
n_samples : int
|
| 64 |
+
Number of samples
|
| 65 |
+
random_state : Optional[int]
|
| 66 |
+
Random seed
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
np.ndarray
|
| 71 |
+
Samples from prior
|
| 72 |
+
"""
|
| 73 |
+
if random_state is not None:
|
| 74 |
+
np.random.seed(random_state)
|
| 75 |
+
|
| 76 |
+
if self.prior_type == PriorType.UNIFORM:
|
| 77 |
+
low = self.parameters['low']
|
| 78 |
+
high = self.parameters['high']
|
| 79 |
+
return np.random.uniform(low, high, n_samples)
|
| 80 |
+
|
| 81 |
+
elif self.prior_type == PriorType.NORMAL:
|
| 82 |
+
mean = self.parameters['mean']
|
| 83 |
+
std = self.parameters['std']
|
| 84 |
+
return np.random.normal(mean, std, n_samples)
|
| 85 |
+
|
| 86 |
+
elif self.prior_type == PriorType.BETA:
|
| 87 |
+
alpha = self.parameters['alpha']
|
| 88 |
+
beta = self.parameters['beta']
|
| 89 |
+
return np.random.beta(alpha, beta, n_samples)
|
| 90 |
+
|
| 91 |
+
elif self.prior_type == PriorType.GAMMA:
|
| 92 |
+
shape = self.parameters['shape']
|
| 93 |
+
scale = self.parameters['scale']
|
| 94 |
+
return np.random.gamma(shape, scale, n_samples)
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(f"Sampling not implemented for {self.prior_type}")
|
| 98 |
+
|
| 99 |
+
def pdf(self, x: np.ndarray) -> np.ndarray:
|
| 100 |
+
"""
|
| 101 |
+
Compute probability density function.
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
x : np.ndarray
|
| 106 |
+
Points at which to evaluate PDF
|
| 107 |
+
|
| 108 |
+
Returns
|
| 109 |
+
-------
|
| 110 |
+
np.ndarray
|
| 111 |
+
PDF values
|
| 112 |
+
"""
|
| 113 |
+
if self.prior_type == PriorType.UNIFORM:
|
| 114 |
+
low = self.parameters['low']
|
| 115 |
+
high = self.parameters['high']
|
| 116 |
+
return stats.uniform.pdf(x, loc=low, scale=high-low)
|
| 117 |
+
|
| 118 |
+
elif self.prior_type == PriorType.NORMAL:
|
| 119 |
+
mean = self.parameters['mean']
|
| 120 |
+
std = self.parameters['std']
|
| 121 |
+
return stats.norm.pdf(x, loc=mean, scale=std)
|
| 122 |
+
|
| 123 |
+
elif self.prior_type == PriorType.BETA:
|
| 124 |
+
alpha = self.parameters['alpha']
|
| 125 |
+
beta = self.parameters['beta']
|
| 126 |
+
return stats.beta.pdf(x, alpha, beta)
|
| 127 |
+
|
| 128 |
+
elif self.prior_type == PriorType.GAMMA:
|
| 129 |
+
shape = self.parameters['shape']
|
| 130 |
+
scale = self.parameters['scale']
|
| 131 |
+
return stats.gamma.pdf(x, shape, scale=scale)
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"PDF not implemented for {self.prior_type}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class EvidenceUpdate:
|
| 139 |
+
"""
|
| 140 |
+
Evidence for Bayesian belief updating.
|
| 141 |
+
|
| 142 |
+
Attributes
|
| 143 |
+
----------
|
| 144 |
+
evidence_type : EvidenceType
|
| 145 |
+
Type of evidence
|
| 146 |
+
observation : Any
|
| 147 |
+
Observed value or data
|
| 148 |
+
likelihood_function : Callable
|
| 149 |
+
Function mapping parameters to likelihood of observation
|
| 150 |
+
reliability : float
|
| 151 |
+
Reliability score [0, 1]
|
| 152 |
+
source : str
|
| 153 |
+
Source of evidence
|
| 154 |
+
timestamp : Optional[str]
|
| 155 |
+
When evidence was collected
|
| 156 |
+
"""
|
| 157 |
+
evidence_type: EvidenceType
|
| 158 |
+
observation: Any
|
| 159 |
+
likelihood_function: Callable[[np.ndarray], np.ndarray]
|
| 160 |
+
reliability: float = 1.0
|
| 161 |
+
source: str = ""
|
| 162 |
+
timestamp: Optional[str] = None
|
| 163 |
+
|
| 164 |
+
def compute_likelihood(self, parameter_values: np.ndarray) -> np.ndarray:
|
| 165 |
+
"""
|
| 166 |
+
Compute likelihood of observation given parameter values.
|
| 167 |
+
|
| 168 |
+
Parameters
|
| 169 |
+
----------
|
| 170 |
+
parameter_values : np.ndarray
|
| 171 |
+
Parameter values
|
| 172 |
+
|
| 173 |
+
Returns
|
| 174 |
+
-------
|
| 175 |
+
np.ndarray
|
| 176 |
+
Likelihood values
|
| 177 |
+
"""
|
| 178 |
+
base_likelihood = self.likelihood_function(parameter_values)
|
| 179 |
+
# Adjust for reliability
|
| 180 |
+
return base_likelihood ** self.reliability
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@dataclass
|
| 184 |
+
class BeliefState:
|
| 185 |
+
"""
|
| 186 |
+
Current belief state (posterior distribution).
|
| 187 |
+
|
| 188 |
+
Attributes
|
| 189 |
+
----------
|
| 190 |
+
parameter_name : str
|
| 191 |
+
Name of parameter
|
| 192 |
+
posterior_samples : np.ndarray
|
| 193 |
+
Samples from posterior distribution
|
| 194 |
+
prior : GeopoliticalPrior
|
| 195 |
+
Original prior
|
| 196 |
+
evidence_history : List[EvidenceUpdate]
|
| 197 |
+
Evidence used to update beliefs
|
| 198 |
+
"""
|
| 199 |
+
parameter_name: str
|
| 200 |
+
posterior_samples: np.ndarray
|
| 201 |
+
prior: GeopoliticalPrior
|
| 202 |
+
evidence_history: List[EvidenceUpdate] = field(default_factory=list)
|
| 203 |
+
|
| 204 |
+
def mean(self) -> float:
|
| 205 |
+
"""Posterior mean."""
|
| 206 |
+
return float(np.mean(self.posterior_samples))
|
| 207 |
+
|
| 208 |
+
def median(self) -> float:
|
| 209 |
+
"""Posterior median."""
|
| 210 |
+
return float(np.median(self.posterior_samples))
|
| 211 |
+
|
| 212 |
+
def std(self) -> float:
|
| 213 |
+
"""Posterior standard deviation."""
|
| 214 |
+
return float(np.std(self.posterior_samples))
|
| 215 |
+
|
| 216 |
+
def quantile(self, q: float) -> float:
|
| 217 |
+
"""
|
| 218 |
+
Posterior quantile.
|
| 219 |
+
|
| 220 |
+
Parameters
|
| 221 |
+
----------
|
| 222 |
+
q : float
|
| 223 |
+
Quantile in [0, 1]
|
| 224 |
+
|
| 225 |
+
Returns
|
| 226 |
+
-------
|
| 227 |
+
float
|
| 228 |
+
Quantile value
|
| 229 |
+
"""
|
| 230 |
+
return float(np.quantile(self.posterior_samples, q))
|
| 231 |
+
|
| 232 |
+
def credible_interval(self, alpha: float = 0.05) -> Tuple[float, float]:
|
| 233 |
+
"""
|
| 234 |
+
Compute credible interval.
|
| 235 |
+
|
| 236 |
+
Parameters
|
| 237 |
+
----------
|
| 238 |
+
alpha : float
|
| 239 |
+
Significance level (default 0.05 for 95% CI)
|
| 240 |
+
|
| 241 |
+
Returns
|
| 242 |
+
-------
|
| 243 |
+
Tuple[float, float]
|
| 244 |
+
Lower and upper bounds
|
| 245 |
+
"""
|
| 246 |
+
lower = self.quantile(alpha / 2)
|
| 247 |
+
upper = self.quantile(1 - alpha / 2)
|
| 248 |
+
return (lower, upper)
|
| 249 |
+
|
| 250 |
+
def probability_greater_than(self, threshold: float) -> float:
|
| 251 |
+
"""
|
| 252 |
+
Compute P(parameter > threshold | evidence).
|
| 253 |
+
|
| 254 |
+
Parameters
|
| 255 |
+
----------
|
| 256 |
+
threshold : float
|
| 257 |
+
Threshold value
|
| 258 |
+
|
| 259 |
+
Returns
|
| 260 |
+
-------
|
| 261 |
+
float
|
| 262 |
+
Probability
|
| 263 |
+
"""
|
| 264 |
+
return float(np.mean(self.posterior_samples > threshold))
|
| 265 |
+
|
| 266 |
+
def probability_in_range(self, low: float, high: float) -> float:
|
| 267 |
+
"""
|
| 268 |
+
Compute P(low < parameter < high | evidence).
|
| 269 |
+
|
| 270 |
+
Parameters
|
| 271 |
+
----------
|
| 272 |
+
low : float
|
| 273 |
+
Lower bound
|
| 274 |
+
high : float
|
| 275 |
+
Upper bound
|
| 276 |
+
|
| 277 |
+
Returns
|
| 278 |
+
-------
|
| 279 |
+
float
|
| 280 |
+
Probability
|
| 281 |
+
"""
|
| 282 |
+
return float(np.mean((self.posterior_samples > low) &
|
| 283 |
+
(self.posterior_samples < high)))
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
@dataclass
|
| 287 |
+
class CredibleInterval:
|
| 288 |
+
"""Credible interval for forecast."""
|
| 289 |
+
lower: float
|
| 290 |
+
upper: float
|
| 291 |
+
alpha: float # Significance level
|
| 292 |
+
|
| 293 |
+
@property
|
| 294 |
+
def width(self) -> float:
|
| 295 |
+
"""Interval width."""
|
| 296 |
+
return self.upper - self.lower
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def credibility(self) -> float:
|
| 300 |
+
"""Credibility level (e.g., 0.95 for 95% CI)."""
|
| 301 |
+
return 1 - self.alpha
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@dataclass
|
| 305 |
+
class ForecastDistribution:
|
| 306 |
+
"""
|
| 307 |
+
Predictive distribution for geopolitical forecast.
|
| 308 |
+
|
| 309 |
+
Attributes
|
| 310 |
+
----------
|
| 311 |
+
variable_name : str
|
| 312 |
+
Name of forecasted variable
|
| 313 |
+
samples : np.ndarray
|
| 314 |
+
Samples from predictive distribution
|
| 315 |
+
time_horizon : int
|
| 316 |
+
Forecast horizon (days, months, etc.)
|
| 317 |
+
conditioning_info : Dict[str, Any]
|
| 318 |
+
Information conditioned on
|
| 319 |
+
"""
|
| 320 |
+
variable_name: str
|
| 321 |
+
samples: np.ndarray
|
| 322 |
+
time_horizon: int
|
| 323 |
+
conditioning_info: Dict[str, Any] = field(default_factory=dict)
|
| 324 |
+
|
| 325 |
+
def point_forecast(self, method: str = 'mean') -> float:
|
| 326 |
+
"""
|
| 327 |
+
Point forecast.
|
| 328 |
+
|
| 329 |
+
Parameters
|
| 330 |
+
----------
|
| 331 |
+
method : str
|
| 332 |
+
'mean', 'median', or 'mode'
|
| 333 |
+
|
| 334 |
+
Returns
|
| 335 |
+
-------
|
| 336 |
+
float
|
| 337 |
+
Point forecast
|
| 338 |
+
"""
|
| 339 |
+
if method == 'mean':
|
| 340 |
+
return float(np.mean(self.samples))
|
| 341 |
+
elif method == 'median':
|
| 342 |
+
return float(np.median(self.samples))
|
| 343 |
+
elif method == 'mode':
|
| 344 |
+
# Use kernel density estimation for mode
|
| 345 |
+
from scipy.stats import gaussian_kde
|
| 346 |
+
kde = gaussian_kde(self.samples)
|
| 347 |
+
x = np.linspace(self.samples.min(), self.samples.max(), 1000)
|
| 348 |
+
return float(x[np.argmax(kde(x))])
|
| 349 |
+
else:
|
| 350 |
+
raise ValueError(f"Unknown method: {method}")
|
| 351 |
+
|
| 352 |
+
def credible_interval(self, alpha: float = 0.05) -> CredibleInterval:
|
| 353 |
+
"""
|
| 354 |
+
Compute credible interval.
|
| 355 |
+
|
| 356 |
+
Parameters
|
| 357 |
+
----------
|
| 358 |
+
alpha : float
|
| 359 |
+
Significance level
|
| 360 |
+
|
| 361 |
+
Returns
|
| 362 |
+
-------
|
| 363 |
+
CredibleInterval
|
| 364 |
+
Credible interval
|
| 365 |
+
"""
|
| 366 |
+
lower = float(np.quantile(self.samples, alpha / 2))
|
| 367 |
+
upper = float(np.quantile(self.samples, 1 - alpha / 2))
|
| 368 |
+
return CredibleInterval(lower=lower, upper=upper, alpha=alpha)
|
| 369 |
+
|
| 370 |
+
def probability_of_event(self, condition: Callable[[np.ndarray], np.ndarray]) -> float:
|
| 371 |
+
"""
|
| 372 |
+
Probability of event defined by condition.
|
| 373 |
+
|
| 374 |
+
Parameters
|
| 375 |
+
----------
|
| 376 |
+
condition : Callable
|
| 377 |
+
Function that returns True/False for each sample
|
| 378 |
+
|
| 379 |
+
Returns
|
| 380 |
+
-------
|
| 381 |
+
float
|
| 382 |
+
Probability
|
| 383 |
+
"""
|
| 384 |
+
return float(np.mean(condition(self.samples)))
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class BayesianForecaster:
|
| 388 |
+
"""
|
| 389 |
+
Bayesian forecasting engine for geopolitical analysis.
|
| 390 |
+
|
| 391 |
+
Integrates with GeoBot 2.0 analytical framework to provide
|
| 392 |
+
probabilistic forecasts with explicit uncertainty quantification.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
def __init__(self):
|
| 396 |
+
"""Initialize Bayesian forecaster."""
|
| 397 |
+
self.priors: Dict[str, GeopoliticalPrior] = {}
|
| 398 |
+
self.beliefs: Dict[str, BeliefState] = {}
|
| 399 |
+
|
| 400 |
+
def set_prior(self, prior: GeopoliticalPrior) -> None:
|
| 401 |
+
"""
|
| 402 |
+
Set prior distribution for parameter.
|
| 403 |
+
|
| 404 |
+
Parameters
|
| 405 |
+
----------
|
| 406 |
+
prior : GeopoliticalPrior
|
| 407 |
+
Prior distribution
|
| 408 |
+
"""
|
| 409 |
+
self.priors[prior.parameter_name] = prior
|
| 410 |
+
|
| 411 |
+
def update_belief(
|
| 412 |
+
self,
|
| 413 |
+
parameter_name: str,
|
| 414 |
+
evidence: EvidenceUpdate,
|
| 415 |
+
n_samples: int = 10000,
|
| 416 |
+
method: str = 'importance_sampling'
|
| 417 |
+
) -> BeliefState:
|
| 418 |
+
"""
|
| 419 |
+
Update beliefs using Bayes' rule.
|
| 420 |
+
|
| 421 |
+
Parameters
|
| 422 |
+
----------
|
| 423 |
+
parameter_name : str
|
| 424 |
+
Parameter to update
|
| 425 |
+
evidence : EvidenceUpdate
|
| 426 |
+
New evidence
|
| 427 |
+
n_samples : int
|
| 428 |
+
Number of samples for approximation
|
| 429 |
+
method : str
|
| 430 |
+
'importance_sampling' or 'rejection_sampling'
|
| 431 |
+
|
| 432 |
+
Returns
|
| 433 |
+
-------
|
| 434 |
+
BeliefState
|
| 435 |
+
Updated belief state
|
| 436 |
+
"""
|
| 437 |
+
if parameter_name not in self.priors:
|
| 438 |
+
raise ValueError(f"No prior set for {parameter_name}")
|
| 439 |
+
|
| 440 |
+
# Get current prior or posterior
|
| 441 |
+
if parameter_name in self.beliefs:
|
| 442 |
+
# Use previous posterior as new prior
|
| 443 |
+
prior_samples = self.beliefs[parameter_name].posterior_samples
|
| 444 |
+
else:
|
| 445 |
+
# Use original prior
|
| 446 |
+
prior_samples = self.priors[parameter_name].sample(n_samples)
|
| 447 |
+
|
| 448 |
+
# Compute likelihoods
|
| 449 |
+
likelihoods = evidence.compute_likelihood(prior_samples)
|
| 450 |
+
|
| 451 |
+
if method == 'importance_sampling':
|
| 452 |
+
# Importance sampling with resampling
|
| 453 |
+
weights = likelihoods / np.sum(likelihoods)
|
| 454 |
+
|
| 455 |
+
# Resample according to weights
|
| 456 |
+
indices = np.random.choice(
|
| 457 |
+
len(prior_samples),
|
| 458 |
+
size=n_samples,
|
| 459 |
+
replace=True,
|
| 460 |
+
p=weights
|
| 461 |
+
)
|
| 462 |
+
posterior_samples = prior_samples[indices]
|
| 463 |
+
|
| 464 |
+
elif method == 'rejection_sampling':
|
| 465 |
+
# Rejection sampling
|
| 466 |
+
max_likelihood = np.max(likelihoods)
|
| 467 |
+
accepted = []
|
| 468 |
+
|
| 469 |
+
for sample, likelihood in zip(prior_samples, likelihoods):
|
| 470 |
+
if np.random.uniform(0, max_likelihood) < likelihood:
|
| 471 |
+
accepted.append(sample)
|
| 472 |
+
|
| 473 |
+
if len(accepted) < 100:
|
| 474 |
+
raise ValueError("Rejection sampling failed - too few accepted samples")
|
| 475 |
+
|
| 476 |
+
posterior_samples = np.array(accepted)
|
| 477 |
+
|
| 478 |
+
else:
|
| 479 |
+
raise ValueError(f"Unknown method: {method}")
|
| 480 |
+
|
| 481 |
+
# Create or update belief state
|
| 482 |
+
if parameter_name in self.beliefs:
|
| 483 |
+
belief = self.beliefs[parameter_name]
|
| 484 |
+
belief.posterior_samples = posterior_samples
|
| 485 |
+
belief.evidence_history.append(evidence)
|
| 486 |
+
else:
|
| 487 |
+
belief = BeliefState(
|
| 488 |
+
parameter_name=parameter_name,
|
| 489 |
+
posterior_samples=posterior_samples,
|
| 490 |
+
prior=self.priors[parameter_name],
|
| 491 |
+
evidence_history=[evidence]
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
self.beliefs[parameter_name] = belief
|
| 495 |
+
return belief
|
| 496 |
+
|
| 497 |
+
def sequential_update(
|
| 498 |
+
self,
|
| 499 |
+
parameter_name: str,
|
| 500 |
+
evidence_sequence: List[EvidenceUpdate],
|
| 501 |
+
n_samples: int = 10000
|
| 502 |
+
) -> BeliefState:
|
| 503 |
+
"""
|
| 504 |
+
Sequential belief updating with multiple pieces of evidence.
|
| 505 |
+
|
| 506 |
+
Parameters
|
| 507 |
+
----------
|
| 508 |
+
parameter_name : str
|
| 509 |
+
Parameter to update
|
| 510 |
+
evidence_sequence : List[EvidenceUpdate]
|
| 511 |
+
Sequence of evidence
|
| 512 |
+
n_samples : int
|
| 513 |
+
Number of samples
|
| 514 |
+
|
| 515 |
+
Returns
|
| 516 |
+
-------
|
| 517 |
+
BeliefState
|
| 518 |
+
Final belief state
|
| 519 |
+
"""
|
| 520 |
+
for evidence in evidence_sequence:
|
| 521 |
+
self.update_belief(parameter_name, evidence, n_samples)
|
| 522 |
+
|
| 523 |
+
return self.beliefs[parameter_name]
|
| 524 |
+
|
| 525 |
+
def forecast(
|
| 526 |
+
self,
|
| 527 |
+
variable_name: str,
|
| 528 |
+
predictive_function: Callable[[Dict[str, float]], float],
|
| 529 |
+
time_horizon: int,
|
| 530 |
+
n_samples: int = 10000,
|
| 531 |
+
conditioning_info: Optional[Dict[str, Any]] = None
|
| 532 |
+
) -> ForecastDistribution:
|
| 533 |
+
"""
|
| 534 |
+
Generate probabilistic forecast.
|
| 535 |
+
|
| 536 |
+
Parameters
|
| 537 |
+
----------
|
| 538 |
+
variable_name : str
|
| 539 |
+
Variable to forecast
|
| 540 |
+
predictive_function : Callable
|
| 541 |
+
Function mapping parameter values to prediction
|
| 542 |
+
time_horizon : int
|
| 543 |
+
Forecast horizon
|
| 544 |
+
n_samples : int
|
| 545 |
+
Number of forecast samples
|
| 546 |
+
conditioning_info : Optional[Dict[str, Any]]
|
| 547 |
+
Additional conditioning information
|
| 548 |
+
|
| 549 |
+
Returns
|
| 550 |
+
-------
|
| 551 |
+
ForecastDistribution
|
| 552 |
+
Forecast distribution
|
| 553 |
+
"""
|
| 554 |
+
# Sample parameters from beliefs
|
| 555 |
+
parameter_samples = {}
|
| 556 |
+
for param_name, belief in self.beliefs.items():
|
| 557 |
+
indices = np.random.choice(len(belief.posterior_samples), size=n_samples)
|
| 558 |
+
parameter_samples[param_name] = belief.posterior_samples[indices]
|
| 559 |
+
|
| 560 |
+
# Generate forecasts
|
| 561 |
+
forecast_samples = np.zeros(n_samples)
|
| 562 |
+
for i in range(n_samples):
|
| 563 |
+
params = {name: samples[i] for name, samples in parameter_samples.items()}
|
| 564 |
+
forecast_samples[i] = predictive_function(params)
|
| 565 |
+
|
| 566 |
+
return ForecastDistribution(
|
| 567 |
+
variable_name=variable_name,
|
| 568 |
+
samples=forecast_samples,
|
| 569 |
+
time_horizon=time_horizon,
|
| 570 |
+
conditioning_info=conditioning_info or {}
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
def model_comparison(
|
| 574 |
+
self,
|
| 575 |
+
models: Dict[str, Callable],
|
| 576 |
+
evidence: List[EvidenceUpdate],
|
| 577 |
+
prior_model_probs: Optional[Dict[str, float]] = None
|
| 578 |
+
) -> Dict[str, float]:
|
| 579 |
+
"""
|
| 580 |
+
Bayesian model comparison using evidence.
|
| 581 |
+
|
| 582 |
+
Parameters
|
| 583 |
+
----------
|
| 584 |
+
models : Dict[str, Callable]
|
| 585 |
+
Dictionary of models (name -> likelihood function)
|
| 586 |
+
evidence : List[EvidenceUpdate]
|
| 587 |
+
Evidence for comparison
|
| 588 |
+
prior_model_probs : Optional[Dict[str, float]]
|
| 589 |
+
Prior model probabilities
|
| 590 |
+
|
| 591 |
+
Returns
|
| 592 |
+
-------
|
| 593 |
+
Dict[str, float]
|
| 594 |
+
Posterior model probabilities
|
| 595 |
+
"""
|
| 596 |
+
if prior_model_probs is None:
|
| 597 |
+
# Uniform prior over models
|
| 598 |
+
prior_model_probs = {name: 1.0 / len(models) for name in models}
|
| 599 |
+
|
| 600 |
+
# Compute marginal likelihoods (evidence)
|
| 601 |
+
marginal_likelihoods = {}
|
| 602 |
+
|
| 603 |
+
for model_name, model_fn in models.items():
|
| 604 |
+
# This is a simplified version - full implementation would
|
| 605 |
+
# integrate over parameter space
|
| 606 |
+
likelihood = 1.0
|
| 607 |
+
for ev in evidence:
|
| 608 |
+
# Assuming model_fn can compute likelihood
|
| 609 |
+
likelihood *= np.mean(ev.compute_likelihood(model_fn))
|
| 610 |
+
|
| 611 |
+
marginal_likelihoods[model_name] = likelihood
|
| 612 |
+
|
| 613 |
+
# Compute posterior model probabilities
|
| 614 |
+
posterior_probs = {}
|
| 615 |
+
total = 0.0
|
| 616 |
+
|
| 617 |
+
for model_name in models:
|
| 618 |
+
unnormalized = (prior_model_probs[model_name] *
|
| 619 |
+
marginal_likelihoods[model_name])
|
| 620 |
+
posterior_probs[model_name] = unnormalized
|
| 621 |
+
total += unnormalized
|
| 622 |
+
|
| 623 |
+
# Normalize
|
| 624 |
+
for model_name in posterior_probs:
|
| 625 |
+
posterior_probs[model_name] /= total
|
| 626 |
+
|
| 627 |
+
return posterior_probs
|
| 628 |
+
|
| 629 |
+
def get_belief_summary(self, parameter_name: str) -> Dict[str, Any]:
|
| 630 |
+
"""
|
| 631 |
+
Get summary statistics for belief state.
|
| 632 |
+
|
| 633 |
+
Parameters
|
| 634 |
+
----------
|
| 635 |
+
parameter_name : str
|
| 636 |
+
Parameter name
|
| 637 |
+
|
| 638 |
+
Returns
|
| 639 |
+
-------
|
| 640 |
+
Dict[str, Any]
|
| 641 |
+
Summary statistics
|
| 642 |
+
"""
|
| 643 |
+
if parameter_name not in self.beliefs:
|
| 644 |
+
raise ValueError(f"No beliefs for {parameter_name}")
|
| 645 |
+
|
| 646 |
+
belief = self.beliefs[parameter_name]
|
| 647 |
+
ci_95 = belief.credible_interval(alpha=0.05)
|
| 648 |
+
ci_90 = belief.credible_interval(alpha=0.10)
|
| 649 |
+
|
| 650 |
+
return {
|
| 651 |
+
'parameter': parameter_name,
|
| 652 |
+
'mean': belief.mean(),
|
| 653 |
+
'median': belief.median(),
|
| 654 |
+
'std': belief.std(),
|
| 655 |
+
'95%_CI': ci_95,
|
| 656 |
+
'90%_CI': ci_90,
|
| 657 |
+
'n_evidence_updates': len(belief.evidence_history),
|
| 658 |
+
'evidence_types': [ev.evidence_type.value for ev in belief.evidence_history]
|
| 659 |
+
}
|
geobot/causal/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Causal inference and structural causal models for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .structural_model import (
|
| 6 |
+
StructuralCausalModel,
|
| 7 |
+
StructuralEquation,
|
| 8 |
+
Intervention,
|
| 9 |
+
Counterfactual,
|
| 10 |
+
CausalEffect,
|
| 11 |
+
IdentificationStrategy,
|
| 12 |
+
estimate_causal_effect
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"StructuralCausalModel",
|
| 17 |
+
"StructuralEquation",
|
| 18 |
+
"Intervention",
|
| 19 |
+
"Counterfactual",
|
| 20 |
+
"CausalEffect",
|
| 21 |
+
"IdentificationStrategy",
|
| 22 |
+
"estimate_causal_effect",
|
| 23 |
+
]
|
geobot/causal/structural_model.py
ADDED
|
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Structural Causal Models for GeoBotv1
|
| 3 |
+
|
| 4 |
+
Implements Structural Causal Models (SCMs) for geopolitical analysis,
|
| 5 |
+
intervention simulation, and counterfactual reasoning. Integrates with
|
| 6 |
+
GeoBot 2.0 analytical framework.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Dict, List, Any, Optional, Callable, Set, Tuple
|
| 11 |
+
from enum import Enum
|
| 12 |
+
import numpy as np
|
| 13 |
+
import networkx as nx
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class IdentificationStrategy(Enum):
|
| 17 |
+
"""Strategies for identifying causal effects."""
|
| 18 |
+
BACKDOOR_ADJUSTMENT = "backdoor"
|
| 19 |
+
FRONTDOOR_ADJUSTMENT = "frontdoor"
|
| 20 |
+
INSTRUMENTAL_VARIABLES = "iv"
|
| 21 |
+
DO_CALCULUS = "do_calculus"
|
| 22 |
+
STRUCTURAL_EQUATIONS = "structural"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class StructuralEquation:
|
| 27 |
+
"""
|
| 28 |
+
Structural equation for a variable in SCM.
|
| 29 |
+
|
| 30 |
+
X := f(Pa_X, U_X)
|
| 31 |
+
|
| 32 |
+
Attributes
|
| 33 |
+
----------
|
| 34 |
+
variable : str
|
| 35 |
+
Variable name
|
| 36 |
+
parents : List[str]
|
| 37 |
+
Parent variables in causal graph
|
| 38 |
+
function : Callable
|
| 39 |
+
Structural function f
|
| 40 |
+
noise_dist : Callable
|
| 41 |
+
Distribution of exogenous noise U_X
|
| 42 |
+
description : str
|
| 43 |
+
Description of equation
|
| 44 |
+
"""
|
| 45 |
+
variable: str
|
| 46 |
+
parents: List[str]
|
| 47 |
+
function: Callable[[Dict[str, float]], float]
|
| 48 |
+
noise_dist: Callable[[int], np.ndarray]
|
| 49 |
+
description: str = ""
|
| 50 |
+
|
| 51 |
+
def evaluate(self, parent_values: Dict[str, float], noise: Optional[float] = None) -> float:
|
| 52 |
+
"""
|
| 53 |
+
Evaluate structural equation.
|
| 54 |
+
|
| 55 |
+
Parameters
|
| 56 |
+
----------
|
| 57 |
+
parent_values : Dict[str, float]
|
| 58 |
+
Values of parent variables
|
| 59 |
+
noise : Optional[float]
|
| 60 |
+
Noise value (if None, sample from distribution)
|
| 61 |
+
|
| 62 |
+
Returns
|
| 63 |
+
-------
|
| 64 |
+
float
|
| 65 |
+
Value of variable
|
| 66 |
+
"""
|
| 67 |
+
if noise is None:
|
| 68 |
+
noise = self.noise_dist(1)[0]
|
| 69 |
+
|
| 70 |
+
return self.function(parent_values) + noise
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class Intervention:
|
| 75 |
+
"""
|
| 76 |
+
Causal intervention do(X = x).
|
| 77 |
+
|
| 78 |
+
Attributes
|
| 79 |
+
----------
|
| 80 |
+
variable : str
|
| 81 |
+
Variable being intervened on
|
| 82 |
+
value : float
|
| 83 |
+
Value set by intervention
|
| 84 |
+
description : str
|
| 85 |
+
Description of intervention
|
| 86 |
+
"""
|
| 87 |
+
variable: str
|
| 88 |
+
value: float
|
| 89 |
+
description: str = ""
|
| 90 |
+
|
| 91 |
+
def __repr__(self) -> str:
|
| 92 |
+
return f"do({self.variable} = {self.value})"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class Counterfactual:
|
| 97 |
+
"""
|
| 98 |
+
Counterfactual query.
|
| 99 |
+
|
| 100 |
+
"What would Y be if we had done X = x, given that we observed Z = z?"
|
| 101 |
+
|
| 102 |
+
Attributes
|
| 103 |
+
----------
|
| 104 |
+
query_variable : str
|
| 105 |
+
Variable being queried
|
| 106 |
+
intervention : Intervention
|
| 107 |
+
Counterfactual intervention
|
| 108 |
+
observations : Dict[str, float]
|
| 109 |
+
Observed values
|
| 110 |
+
"""
|
| 111 |
+
query_variable: str
|
| 112 |
+
intervention: Intervention
|
| 113 |
+
observations: Dict[str, float] = field(default_factory=dict)
|
| 114 |
+
|
| 115 |
+
def __repr__(self) -> str:
|
| 116 |
+
obs_str = ", ".join([f"{k}={v}" for k, v in self.observations.items()])
|
| 117 |
+
return f"{self.query_variable}_{{{self.intervention}}} | {obs_str}"
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass
|
| 121 |
+
class CausalEffect:
|
| 122 |
+
"""
|
| 123 |
+
Estimated causal effect.
|
| 124 |
+
|
| 125 |
+
Attributes
|
| 126 |
+
----------
|
| 127 |
+
treatment : str
|
| 128 |
+
Treatment variable
|
| 129 |
+
outcome : str
|
| 130 |
+
Outcome variable
|
| 131 |
+
effect : float
|
| 132 |
+
Estimated average causal effect
|
| 133 |
+
std_error : Optional[float]
|
| 134 |
+
Standard error of estimate
|
| 135 |
+
confidence_interval : Optional[Tuple[float, float]]
|
| 136 |
+
Confidence interval
|
| 137 |
+
identification_strategy : IdentificationStrategy
|
| 138 |
+
How effect was identified
|
| 139 |
+
"""
|
| 140 |
+
treatment: str
|
| 141 |
+
outcome: str
|
| 142 |
+
effect: float
|
| 143 |
+
std_error: Optional[float] = None
|
| 144 |
+
confidence_interval: Optional[Tuple[float, float]] = None
|
| 145 |
+
identification_strategy: Optional[IdentificationStrategy] = None
|
| 146 |
+
|
| 147 |
+
def __repr__(self) -> str:
|
| 148 |
+
ci_str = ""
|
| 149 |
+
if self.confidence_interval:
|
| 150 |
+
ci_str = f", 95% CI: {self.confidence_interval}"
|
| 151 |
+
return f"ACE({self.treatment} → {self.outcome}) = {self.effect:.3f}{ci_str}"
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class StructuralCausalModel:
|
| 155 |
+
"""
|
| 156 |
+
Structural Causal Model for geopolitical analysis.
|
| 157 |
+
|
| 158 |
+
An SCM consists of:
|
| 159 |
+
1. Causal graph G (DAG)
|
| 160 |
+
2. Structural equations for each variable
|
| 161 |
+
3. Exogenous noise distributions
|
| 162 |
+
|
| 163 |
+
Enables:
|
| 164 |
+
- Intervention simulation (do-operator)
|
| 165 |
+
- Counterfactual reasoning
|
| 166 |
+
- Causal effect identification
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(self, name: str = "GeopoliticalSCM"):
|
| 170 |
+
"""
|
| 171 |
+
Initialize SCM.
|
| 172 |
+
|
| 173 |
+
Parameters
|
| 174 |
+
----------
|
| 175 |
+
name : str
|
| 176 |
+
Name of SCM
|
| 177 |
+
"""
|
| 178 |
+
self.name = name
|
| 179 |
+
self.graph = nx.DiGraph()
|
| 180 |
+
self.equations: Dict[str, StructuralEquation] = {}
|
| 181 |
+
self.exogenous_variables: Set[str] = set()
|
| 182 |
+
|
| 183 |
+
def add_equation(self, equation: StructuralEquation) -> None:
|
| 184 |
+
"""
|
| 185 |
+
Add structural equation to model.
|
| 186 |
+
|
| 187 |
+
Parameters
|
| 188 |
+
----------
|
| 189 |
+
equation : StructuralEquation
|
| 190 |
+
Structural equation
|
| 191 |
+
"""
|
| 192 |
+
self.equations[equation.variable] = equation
|
| 193 |
+
|
| 194 |
+
# Add to graph
|
| 195 |
+
self.graph.add_node(equation.variable)
|
| 196 |
+
for parent in equation.parents:
|
| 197 |
+
self.graph.add_edge(parent, equation.variable)
|
| 198 |
+
|
| 199 |
+
def add_exogenous(self, variable: str, distribution: Callable[[int], np.ndarray]) -> None:
|
| 200 |
+
"""
|
| 201 |
+
Add exogenous variable.
|
| 202 |
+
|
| 203 |
+
Parameters
|
| 204 |
+
----------
|
| 205 |
+
variable : str
|
| 206 |
+
Variable name
|
| 207 |
+
distribution : Callable
|
| 208 |
+
Distribution for sampling
|
| 209 |
+
"""
|
| 210 |
+
self.exogenous_variables.add(variable)
|
| 211 |
+
self.graph.add_node(variable)
|
| 212 |
+
|
| 213 |
+
def topological_order(self) -> List[str]:
|
| 214 |
+
"""
|
| 215 |
+
Get topological ordering of variables.
|
| 216 |
+
|
| 217 |
+
Returns
|
| 218 |
+
-------
|
| 219 |
+
List[str]
|
| 220 |
+
Topologically sorted variables
|
| 221 |
+
"""
|
| 222 |
+
try:
|
| 223 |
+
return list(nx.topological_sort(self.graph))
|
| 224 |
+
except nx.NetworkXError:
|
| 225 |
+
raise ValueError("Graph contains cycles - not a valid DAG")
|
| 226 |
+
|
| 227 |
+
def simulate(
|
| 228 |
+
self,
|
| 229 |
+
n_samples: int = 1000,
|
| 230 |
+
interventions: Optional[List[Intervention]] = None,
|
| 231 |
+
random_state: Optional[int] = None
|
| 232 |
+
) -> Dict[str, np.ndarray]:
|
| 233 |
+
"""
|
| 234 |
+
Simulate from SCM.
|
| 235 |
+
|
| 236 |
+
Parameters
|
| 237 |
+
----------
|
| 238 |
+
n_samples : int
|
| 239 |
+
Number of samples
|
| 240 |
+
interventions : Optional[List[Intervention]]
|
| 241 |
+
Interventions to apply
|
| 242 |
+
random_state : Optional[int]
|
| 243 |
+
Random seed
|
| 244 |
+
|
| 245 |
+
Returns
|
| 246 |
+
-------
|
| 247 |
+
Dict[str, np.ndarray]
|
| 248 |
+
Simulated data for each variable
|
| 249 |
+
"""
|
| 250 |
+
if random_state is not None:
|
| 251 |
+
np.random.seed(random_state)
|
| 252 |
+
|
| 253 |
+
# Intervention variables
|
| 254 |
+
intervention_dict = {}
|
| 255 |
+
if interventions:
|
| 256 |
+
intervention_dict = {iv.variable: iv.value for iv in interventions}
|
| 257 |
+
|
| 258 |
+
# Initialize data
|
| 259 |
+
data = {}
|
| 260 |
+
|
| 261 |
+
# Topological order
|
| 262 |
+
order = self.topological_order()
|
| 263 |
+
|
| 264 |
+
# Simulate each variable in order
|
| 265 |
+
for var in order:
|
| 266 |
+
if var in intervention_dict:
|
| 267 |
+
# Variable is intervened on - set to intervention value
|
| 268 |
+
data[var] = np.full(n_samples, intervention_dict[var])
|
| 269 |
+
|
| 270 |
+
elif var in self.exogenous_variables:
|
| 271 |
+
# Exogenous variable - sample from distribution
|
| 272 |
+
# For now, assume standard normal if not specified
|
| 273 |
+
data[var] = np.random.randn(n_samples)
|
| 274 |
+
|
| 275 |
+
elif var in self.equations:
|
| 276 |
+
# Endogenous variable - evaluate structural equation
|
| 277 |
+
eq = self.equations[var]
|
| 278 |
+
values = np.zeros(n_samples)
|
| 279 |
+
|
| 280 |
+
for i in range(n_samples):
|
| 281 |
+
parent_vals = {p: data[p][i] for p in eq.parents}
|
| 282 |
+
values[i] = eq.evaluate(parent_vals)
|
| 283 |
+
|
| 284 |
+
data[var] = values
|
| 285 |
+
|
| 286 |
+
else:
|
| 287 |
+
raise ValueError(f"No equation for variable {var}")
|
| 288 |
+
|
| 289 |
+
return data
|
| 290 |
+
|
| 291 |
+
def intervene(
|
| 292 |
+
self,
|
| 293 |
+
interventions: List[Intervention],
|
| 294 |
+
n_samples: int = 1000,
|
| 295 |
+
random_state: Optional[int] = None
|
| 296 |
+
) -> Dict[str, np.ndarray]:
|
| 297 |
+
"""
|
| 298 |
+
Simulate interventions using do-operator.
|
| 299 |
+
|
| 300 |
+
Parameters
|
| 301 |
+
----------
|
| 302 |
+
interventions : List[Intervention]
|
| 303 |
+
Interventions to apply
|
| 304 |
+
n_samples : int
|
| 305 |
+
Number of samples
|
| 306 |
+
random_state : Optional[int]
|
| 307 |
+
Random seed
|
| 308 |
+
|
| 309 |
+
Returns
|
| 310 |
+
-------
|
| 311 |
+
Dict[str, np.ndarray]
|
| 312 |
+
Post-intervention data
|
| 313 |
+
"""
|
| 314 |
+
return self.simulate(n_samples, interventions, random_state)
|
| 315 |
+
|
| 316 |
+
def estimate_causal_effect(
|
| 317 |
+
self,
|
| 318 |
+
treatment: str,
|
| 319 |
+
outcome: str,
|
| 320 |
+
n_samples: int = 10000,
|
| 321 |
+
treatment_values: Optional[List[float]] = None
|
| 322 |
+
) -> CausalEffect:
|
| 323 |
+
"""
|
| 324 |
+
Estimate average causal effect of treatment on outcome.
|
| 325 |
+
|
| 326 |
+
Parameters
|
| 327 |
+
----------
|
| 328 |
+
treatment : str
|
| 329 |
+
Treatment variable
|
| 330 |
+
outcome : str
|
| 331 |
+
Outcome variable
|
| 332 |
+
n_samples : int
|
| 333 |
+
Number of simulation samples
|
| 334 |
+
treatment_values : Optional[List[float]]
|
| 335 |
+
Treatment values to compare (default [0, 1])
|
| 336 |
+
|
| 337 |
+
Returns
|
| 338 |
+
-------
|
| 339 |
+
CausalEffect
|
| 340 |
+
Estimated causal effect
|
| 341 |
+
"""
|
| 342 |
+
if treatment_values is None:
|
| 343 |
+
treatment_values = [0.0, 1.0]
|
| 344 |
+
|
| 345 |
+
# Simulate under different treatment values
|
| 346 |
+
outcomes = []
|
| 347 |
+
for t_val in treatment_values:
|
| 348 |
+
intervention = Intervention(variable=treatment, value=t_val)
|
| 349 |
+
data = self.intervene([intervention], n_samples)
|
| 350 |
+
outcomes.append(np.mean(data[outcome]))
|
| 351 |
+
|
| 352 |
+
# Average causal effect
|
| 353 |
+
ace = outcomes[1] - outcomes[0]
|
| 354 |
+
|
| 355 |
+
# Bootstrap for standard error
|
| 356 |
+
bootstrap_effects = []
|
| 357 |
+
for _ in range(100):
|
| 358 |
+
boot_outcomes = []
|
| 359 |
+
for t_val in treatment_values:
|
| 360 |
+
intervention = Intervention(variable=treatment, value=t_val)
|
| 361 |
+
data = self.intervene([intervention], n_samples=1000)
|
| 362 |
+
boot_outcomes.append(np.mean(data[outcome]))
|
| 363 |
+
bootstrap_effects.append(boot_outcomes[1] - boot_outcomes[0])
|
| 364 |
+
|
| 365 |
+
std_error = np.std(bootstrap_effects)
|
| 366 |
+
ci = (
|
| 367 |
+
ace - 1.96 * std_error,
|
| 368 |
+
ace + 1.96 * std_error
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
return CausalEffect(
|
| 372 |
+
treatment=treatment,
|
| 373 |
+
outcome=outcome,
|
| 374 |
+
effect=ace,
|
| 375 |
+
std_error=std_error,
|
| 376 |
+
confidence_interval=ci,
|
| 377 |
+
identification_strategy=IdentificationStrategy.STRUCTURAL_EQUATIONS
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def counterfactual_query(
|
| 381 |
+
self,
|
| 382 |
+
query: Counterfactual,
|
| 383 |
+
n_samples: int = 10000
|
| 384 |
+
) -> Dict[str, Any]:
|
| 385 |
+
"""
|
| 386 |
+
Answer counterfactual query.
|
| 387 |
+
|
| 388 |
+
Three-step process:
|
| 389 |
+
1. Abduction: Infer exogenous variables from observations
|
| 390 |
+
2. Action: Apply intervention
|
| 391 |
+
3. Prediction: Compute outcome
|
| 392 |
+
|
| 393 |
+
Parameters
|
| 394 |
+
----------
|
| 395 |
+
query : Counterfactual
|
| 396 |
+
Counterfactual query
|
| 397 |
+
n_samples : int
|
| 398 |
+
Number of samples for approximation
|
| 399 |
+
|
| 400 |
+
Returns
|
| 401 |
+
-------
|
| 402 |
+
Dict[str, Any]
|
| 403 |
+
Counterfactual results
|
| 404 |
+
"""
|
| 405 |
+
# Simplified counterfactual reasoning
|
| 406 |
+
# Full implementation would do proper abduction step
|
| 407 |
+
|
| 408 |
+
# For now, simulate with intervention
|
| 409 |
+
data = self.intervene([query.intervention], n_samples)
|
| 410 |
+
|
| 411 |
+
return {
|
| 412 |
+
'query': str(query),
|
| 413 |
+
'expected_value': float(np.mean(data[query.query_variable])),
|
| 414 |
+
'std': float(np.std(data[query.query_variable])),
|
| 415 |
+
'median': float(np.median(data[query.query_variable])),
|
| 416 |
+
'quantiles': {
|
| 417 |
+
'5%': float(np.quantile(data[query.query_variable], 0.05)),
|
| 418 |
+
'25%': float(np.quantile(data[query.query_variable], 0.25)),
|
| 419 |
+
'75%': float(np.quantile(data[query.query_variable], 0.75)),
|
| 420 |
+
'95%': float(np.quantile(data[query.query_variable], 0.95)),
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
def find_backdoor_paths(self, treatment: str, outcome: str) -> List[List[str]]:
|
| 425 |
+
"""
|
| 426 |
+
Find backdoor paths from treatment to outcome.
|
| 427 |
+
|
| 428 |
+
Parameters
|
| 429 |
+
----------
|
| 430 |
+
treatment : str
|
| 431 |
+
Treatment variable
|
| 432 |
+
outcome : str
|
| 433 |
+
Outcome variable
|
| 434 |
+
|
| 435 |
+
Returns
|
| 436 |
+
-------
|
| 437 |
+
List[List[str]]
|
| 438 |
+
List of backdoor paths
|
| 439 |
+
"""
|
| 440 |
+
# Create undirected version of graph
|
| 441 |
+
undirected = self.graph.to_undirected()
|
| 442 |
+
|
| 443 |
+
# Find all paths
|
| 444 |
+
try:
|
| 445 |
+
all_paths = list(nx.all_simple_paths(undirected, treatment, outcome))
|
| 446 |
+
except nx.NodeNotFound:
|
| 447 |
+
return []
|
| 448 |
+
|
| 449 |
+
# Filter for backdoor paths (paths that go through parent of treatment)
|
| 450 |
+
backdoor_paths = []
|
| 451 |
+
treatment_parents = set(self.graph.predecessors(treatment))
|
| 452 |
+
|
| 453 |
+
for path in all_paths:
|
| 454 |
+
# Check if path starts with an edge into treatment
|
| 455 |
+
if len(path) > 1 and path[1] in treatment_parents:
|
| 456 |
+
backdoor_paths.append(path)
|
| 457 |
+
|
| 458 |
+
return backdoor_paths
|
| 459 |
+
|
| 460 |
+
def find_backdoor_adjustment_set(
|
| 461 |
+
self,
|
| 462 |
+
treatment: str,
|
| 463 |
+
outcome: str
|
| 464 |
+
) -> Optional[Set[str]]:
|
| 465 |
+
"""
|
| 466 |
+
Find minimal backdoor adjustment set.
|
| 467 |
+
|
| 468 |
+
Parameters
|
| 469 |
+
----------
|
| 470 |
+
treatment : str
|
| 471 |
+
Treatment variable
|
| 472 |
+
outcome : str
|
| 473 |
+
Outcome variable
|
| 474 |
+
|
| 475 |
+
Returns
|
| 476 |
+
-------
|
| 477 |
+
Optional[Set[str]]
|
| 478 |
+
Backdoor adjustment set, or None if no valid set exists
|
| 479 |
+
"""
|
| 480 |
+
backdoor_paths = self.find_backdoor_paths(treatment, outcome)
|
| 481 |
+
|
| 482 |
+
if not backdoor_paths:
|
| 483 |
+
return set() # No backdoor paths, empty set suffices
|
| 484 |
+
|
| 485 |
+
# Find minimal set that blocks all backdoor paths
|
| 486 |
+
# This is simplified - full implementation would use
|
| 487 |
+
# proper d-separation testing
|
| 488 |
+
|
| 489 |
+
# Collect all variables in backdoor paths
|
| 490 |
+
candidates = set()
|
| 491 |
+
for path in backdoor_paths:
|
| 492 |
+
candidates.update(path[1:-1]) # Exclude treatment and outcome
|
| 493 |
+
|
| 494 |
+
# Remove descendants of treatment (would create bias)
|
| 495 |
+
treatment_descendants = nx.descendants(self.graph, treatment)
|
| 496 |
+
candidates -= treatment_descendants
|
| 497 |
+
|
| 498 |
+
return candidates
|
| 499 |
+
|
| 500 |
+
def plot_graph(self, filename: Optional[str] = None) -> None:
|
| 501 |
+
"""
|
| 502 |
+
Plot causal graph.
|
| 503 |
+
|
| 504 |
+
Parameters
|
| 505 |
+
----------
|
| 506 |
+
filename : Optional[str]
|
| 507 |
+
File to save plot to
|
| 508 |
+
"""
|
| 509 |
+
try:
|
| 510 |
+
import matplotlib.pyplot as plt
|
| 511 |
+
|
| 512 |
+
pos = nx.spring_layout(self.graph)
|
| 513 |
+
nx.draw(
|
| 514 |
+
self.graph,
|
| 515 |
+
pos,
|
| 516 |
+
with_labels=True,
|
| 517 |
+
node_color='lightblue',
|
| 518 |
+
node_size=1500,
|
| 519 |
+
font_size=10,
|
| 520 |
+
font_weight='bold',
|
| 521 |
+
arrows=True,
|
| 522 |
+
arrowsize=20
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if filename:
|
| 526 |
+
plt.savefig(filename)
|
| 527 |
+
else:
|
| 528 |
+
plt.show()
|
| 529 |
+
|
| 530 |
+
except ImportError:
|
| 531 |
+
print("matplotlib not available for plotting")
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def estimate_causal_effect(
|
| 535 |
+
scm: StructuralCausalModel,
|
| 536 |
+
treatment: str,
|
| 537 |
+
outcome: str,
|
| 538 |
+
adjustment_set: Optional[Set[str]] = None,
|
| 539 |
+
n_samples: int = 10000
|
| 540 |
+
) -> CausalEffect:
|
| 541 |
+
"""
|
| 542 |
+
Estimate causal effect using appropriate identification strategy.
|
| 543 |
+
|
| 544 |
+
Parameters
|
| 545 |
+
----------
|
| 546 |
+
scm : StructuralCausalModel
|
| 547 |
+
Structural causal model
|
| 548 |
+
treatment : str
|
| 549 |
+
Treatment variable
|
| 550 |
+
outcome : str
|
| 551 |
+
Outcome variable
|
| 552 |
+
adjustment_set : Optional[Set[str]]
|
| 553 |
+
Variables to adjust for (if None, use backdoor criterion)
|
| 554 |
+
n_samples : int
|
| 555 |
+
Number of samples
|
| 556 |
+
|
| 557 |
+
Returns
|
| 558 |
+
-------
|
| 559 |
+
CausalEffect
|
| 560 |
+
Estimated causal effect
|
| 561 |
+
"""
|
| 562 |
+
# Find adjustment set if not provided
|
| 563 |
+
if adjustment_set is None:
|
| 564 |
+
adjustment_set = scm.find_backdoor_adjustment_set(treatment, outcome)
|
| 565 |
+
|
| 566 |
+
if adjustment_set is None:
|
| 567 |
+
# Try frontdoor or IV
|
| 568 |
+
raise ValueError("Cannot identify causal effect - no valid adjustment set")
|
| 569 |
+
|
| 570 |
+
# Use structural equations for direct estimation
|
| 571 |
+
return scm.estimate_causal_effect(treatment, outcome, n_samples)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
# ============================================================================
|
| 575 |
+
# Predefined SCMs for Geopolitical Analysis
|
| 576 |
+
# ============================================================================
|
| 577 |
+
|
| 578 |
+
def create_sanctions_scm() -> StructuralCausalModel:
|
| 579 |
+
"""
|
| 580 |
+
Create SCM for sanctions analysis.
|
| 581 |
+
|
| 582 |
+
Variables:
|
| 583 |
+
- sanctions: Binary sanctions imposed
|
| 584 |
+
- trade_disruption: Trade flow disruption
|
| 585 |
+
- economic_growth: Economic growth rate
|
| 586 |
+
- regime_stability: Regime stability score
|
| 587 |
+
|
| 588 |
+
Returns
|
| 589 |
+
-------
|
| 590 |
+
StructuralCausalModel
|
| 591 |
+
Sanctions SCM
|
| 592 |
+
"""
|
| 593 |
+
scm = StructuralCausalModel(name="SanctionsSCM")
|
| 594 |
+
|
| 595 |
+
# Exogenous noise (simplified as standard normal)
|
| 596 |
+
noise_dist = lambda n: np.random.randn(n) * 0.1
|
| 597 |
+
|
| 598 |
+
# Trade disruption = f(sanctions) + noise
|
| 599 |
+
scm.add_equation(StructuralEquation(
|
| 600 |
+
variable="trade_disruption",
|
| 601 |
+
parents=["sanctions"],
|
| 602 |
+
function=lambda p: 0.7 * p["sanctions"],
|
| 603 |
+
noise_dist=noise_dist,
|
| 604 |
+
description="Sanctions directly reduce trade"
|
| 605 |
+
))
|
| 606 |
+
|
| 607 |
+
# Economic growth = f(trade_disruption) + noise
|
| 608 |
+
scm.add_equation(StructuralEquation(
|
| 609 |
+
variable="economic_growth",
|
| 610 |
+
parents=["trade_disruption"],
|
| 611 |
+
function=lambda p: 0.05 - 0.4 * p["trade_disruption"],
|
| 612 |
+
noise_dist=noise_dist,
|
| 613 |
+
description="Trade disruption reduces growth"
|
| 614 |
+
))
|
| 615 |
+
|
| 616 |
+
# Regime stability = f(economic_growth) + noise
|
| 617 |
+
scm.add_equation(StructuralEquation(
|
| 618 |
+
variable="regime_stability",
|
| 619 |
+
parents=["economic_growth"],
|
| 620 |
+
function=lambda p: 0.7 + 0.5 * p["economic_growth"],
|
| 621 |
+
noise_dist=noise_dist,
|
| 622 |
+
description="Economic growth affects regime stability"
|
| 623 |
+
))
|
| 624 |
+
|
| 625 |
+
return scm
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def create_conflict_escalation_scm() -> StructuralCausalModel:
|
| 629 |
+
"""
|
| 630 |
+
Create SCM for conflict escalation.
|
| 631 |
+
|
| 632 |
+
Variables:
|
| 633 |
+
- military_buildup: Military force buildup
|
| 634 |
+
- diplomatic_tension: Diplomatic relations tension
|
| 635 |
+
- conflict_risk: Risk of armed conflict
|
| 636 |
+
|
| 637 |
+
Returns
|
| 638 |
+
-------
|
| 639 |
+
StructuralCausalModel
|
| 640 |
+
Conflict escalation SCM
|
| 641 |
+
"""
|
| 642 |
+
scm = StructuralCausalModel(name="ConflictEscalationSCM")
|
| 643 |
+
|
| 644 |
+
noise_dist = lambda n: np.random.randn(n) * 0.05
|
| 645 |
+
|
| 646 |
+
# Diplomatic tension = f(military_buildup) + noise
|
| 647 |
+
scm.add_equation(StructuralEquation(
|
| 648 |
+
variable="diplomatic_tension",
|
| 649 |
+
parents=["military_buildup"],
|
| 650 |
+
function=lambda p: 0.3 + 0.6 * p["military_buildup"],
|
| 651 |
+
noise_dist=noise_dist,
|
| 652 |
+
description="Military buildup increases diplomatic tension"
|
| 653 |
+
))
|
| 654 |
+
|
| 655 |
+
# Conflict risk = f(military_buildup, diplomatic_tension) + noise
|
| 656 |
+
scm.add_equation(StructuralEquation(
|
| 657 |
+
variable="conflict_risk",
|
| 658 |
+
parents=["military_buildup", "diplomatic_tension"],
|
| 659 |
+
function=lambda p: 0.1 + 0.4 * p["military_buildup"] + 0.3 * p["diplomatic_tension"],
|
| 660 |
+
noise_dist=noise_dist,
|
| 661 |
+
description="Both military buildup and tension increase conflict risk"
|
| 662 |
+
))
|
| 663 |
+
|
| 664 |
+
return scm
|
geobot/cli.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
"""
|
| 3 |
+
Temporary CLI stub for GeoBotv1.
|
| 4 |
+
This just proves the package and entry point are wired up correctly.
|
| 5 |
+
"""
|
| 6 |
+
print("GeoBotv1 CLI is installed and reachable. (Stub CLI for now.)")
|
geobot/config/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration management for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .settings import Settings, get_settings, update_settings
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"Settings",
|
| 9 |
+
"get_settings",
|
| 10 |
+
"update_settings",
|
| 11 |
+
]
|
geobot/config/settings.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Settings and configuration for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
import yaml
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Settings:
|
| 13 |
+
"""
|
| 14 |
+
Global settings for GeoBotv1.
|
| 15 |
+
"""
|
| 16 |
+
# Simulation settings
|
| 17 |
+
default_n_simulations: int = 1000
|
| 18 |
+
default_time_horizon: int = 100
|
| 19 |
+
random_seed: Optional[int] = None
|
| 20 |
+
|
| 21 |
+
# Data ingestion settings
|
| 22 |
+
pdf_extraction_method: str = 'auto'
|
| 23 |
+
web_scraping_timeout: int = 30
|
| 24 |
+
article_extraction_method: str = 'auto'
|
| 25 |
+
|
| 26 |
+
# ML settings
|
| 27 |
+
risk_scoring_method: str = 'gradient_boosting'
|
| 28 |
+
embedding_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 29 |
+
|
| 30 |
+
# Bayesian inference settings
|
| 31 |
+
bayesian_method: str = 'grid'
|
| 32 |
+
n_mcmc_samples: int = 10000
|
| 33 |
+
|
| 34 |
+
# Causal inference settings
|
| 35 |
+
causal_discovery_method: str = 'pc'
|
| 36 |
+
causal_discovery_alpha: float = 0.05
|
| 37 |
+
|
| 38 |
+
# Data directories
|
| 39 |
+
data_dir: str = 'data'
|
| 40 |
+
cache_dir: str = '.cache'
|
| 41 |
+
output_dir: str = 'output'
|
| 42 |
+
|
| 43 |
+
# Logging
|
| 44 |
+
log_level: str = 'INFO'
|
| 45 |
+
log_file: Optional[str] = None
|
| 46 |
+
|
| 47 |
+
# Custom settings
|
| 48 |
+
custom: Dict[str, Any] = field(default_factory=dict)
|
| 49 |
+
|
| 50 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 51 |
+
"""Convert settings to dictionary."""
|
| 52 |
+
return {
|
| 53 |
+
'simulation': {
|
| 54 |
+
'default_n_simulations': self.default_n_simulations,
|
| 55 |
+
'default_time_horizon': self.default_time_horizon,
|
| 56 |
+
'random_seed': self.random_seed
|
| 57 |
+
},
|
| 58 |
+
'data_ingestion': {
|
| 59 |
+
'pdf_extraction_method': self.pdf_extraction_method,
|
| 60 |
+
'web_scraping_timeout': self.web_scraping_timeout,
|
| 61 |
+
'article_extraction_method': self.article_extraction_method
|
| 62 |
+
},
|
| 63 |
+
'ml': {
|
| 64 |
+
'risk_scoring_method': self.risk_scoring_method,
|
| 65 |
+
'embedding_model': self.embedding_model
|
| 66 |
+
},
|
| 67 |
+
'bayesian': {
|
| 68 |
+
'method': self.bayesian_method,
|
| 69 |
+
'n_mcmc_samples': self.n_mcmc_samples
|
| 70 |
+
},
|
| 71 |
+
'causal': {
|
| 72 |
+
'discovery_method': self.causal_discovery_method,
|
| 73 |
+
'discovery_alpha': self.causal_discovery_alpha
|
| 74 |
+
},
|
| 75 |
+
'directories': {
|
| 76 |
+
'data_dir': self.data_dir,
|
| 77 |
+
'cache_dir': self.cache_dir,
|
| 78 |
+
'output_dir': self.output_dir
|
| 79 |
+
},
|
| 80 |
+
'logging': {
|
| 81 |
+
'log_level': self.log_level,
|
| 82 |
+
'log_file': self.log_file
|
| 83 |
+
},
|
| 84 |
+
'custom': self.custom
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'Settings':
|
| 89 |
+
"""Load settings from dictionary."""
|
| 90 |
+
settings = cls()
|
| 91 |
+
|
| 92 |
+
if 'simulation' in data:
|
| 93 |
+
settings.default_n_simulations = data['simulation'].get('default_n_simulations', 1000)
|
| 94 |
+
settings.default_time_horizon = data['simulation'].get('default_time_horizon', 100)
|
| 95 |
+
settings.random_seed = data['simulation'].get('random_seed')
|
| 96 |
+
|
| 97 |
+
if 'data_ingestion' in data:
|
| 98 |
+
settings.pdf_extraction_method = data['data_ingestion'].get('pdf_extraction_method', 'auto')
|
| 99 |
+
settings.web_scraping_timeout = data['data_ingestion'].get('web_scraping_timeout', 30)
|
| 100 |
+
settings.article_extraction_method = data['data_ingestion'].get('article_extraction_method', 'auto')
|
| 101 |
+
|
| 102 |
+
if 'ml' in data:
|
| 103 |
+
settings.risk_scoring_method = data['ml'].get('risk_scoring_method', 'gradient_boosting')
|
| 104 |
+
settings.embedding_model = data['ml'].get('embedding_model', 'sentence-transformers/all-MiniLM-L6-v2')
|
| 105 |
+
|
| 106 |
+
if 'bayesian' in data:
|
| 107 |
+
settings.bayesian_method = data['bayesian'].get('method', 'grid')
|
| 108 |
+
settings.n_mcmc_samples = data['bayesian'].get('n_mcmc_samples', 10000)
|
| 109 |
+
|
| 110 |
+
if 'causal' in data:
|
| 111 |
+
settings.causal_discovery_method = data['causal'].get('discovery_method', 'pc')
|
| 112 |
+
settings.causal_discovery_alpha = data['causal'].get('discovery_alpha', 0.05)
|
| 113 |
+
|
| 114 |
+
if 'directories' in data:
|
| 115 |
+
settings.data_dir = data['directories'].get('data_dir', 'data')
|
| 116 |
+
settings.cache_dir = data['directories'].get('cache_dir', '.cache')
|
| 117 |
+
settings.output_dir = data['directories'].get('output_dir', 'output')
|
| 118 |
+
|
| 119 |
+
if 'logging' in data:
|
| 120 |
+
settings.log_level = data['logging'].get('log_level', 'INFO')
|
| 121 |
+
settings.log_file = data['logging'].get('log_file')
|
| 122 |
+
|
| 123 |
+
if 'custom' in data:
|
| 124 |
+
settings.custom = data['custom']
|
| 125 |
+
|
| 126 |
+
return settings
|
| 127 |
+
|
| 128 |
+
def save(self, path: str) -> None:
|
| 129 |
+
"""Save settings to YAML file."""
|
| 130 |
+
with open(path, 'w') as f:
|
| 131 |
+
yaml.dump(self.to_dict(), f, default_flow_style=False)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def load(cls, path: str) -> 'Settings':
|
| 135 |
+
"""Load settings from YAML file."""
|
| 136 |
+
with open(path, 'r') as f:
|
| 137 |
+
data = yaml.safe_load(f)
|
| 138 |
+
return cls.from_dict(data)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Global settings instance
|
| 142 |
+
_global_settings: Optional[Settings] = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_settings() -> Settings:
|
| 146 |
+
"""
|
| 147 |
+
Get global settings instance.
|
| 148 |
+
|
| 149 |
+
Returns
|
| 150 |
+
-------
|
| 151 |
+
Settings
|
| 152 |
+
Global settings
|
| 153 |
+
"""
|
| 154 |
+
global _global_settings
|
| 155 |
+
if _global_settings is None:
|
| 156 |
+
_global_settings = Settings()
|
| 157 |
+
return _global_settings
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def update_settings(settings: Settings) -> None:
|
| 161 |
+
"""
|
| 162 |
+
Update global settings.
|
| 163 |
+
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
settings : Settings
|
| 167 |
+
New settings
|
| 168 |
+
"""
|
| 169 |
+
global _global_settings
|
| 170 |
+
_global_settings = settings
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def load_settings_from_file(path: str) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Load settings from file and update global settings.
|
| 176 |
+
|
| 177 |
+
Parameters
|
| 178 |
+
----------
|
| 179 |
+
path : str
|
| 180 |
+
Path to settings file
|
| 181 |
+
"""
|
| 182 |
+
settings = Settings.load(path)
|
| 183 |
+
update_settings(settings)
|
geobot/core/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core mathematical frameworks for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .optimal_transport import WassersteinDistance, ScenarioComparator
|
| 6 |
+
from .scenario import Scenario, ScenarioDistribution
|
| 7 |
+
from .advanced_optimal_transport import (
|
| 8 |
+
GradientBasedOT,
|
| 9 |
+
KantorovichDuality,
|
| 10 |
+
EntropicOT,
|
| 11 |
+
UnbalancedOT,
|
| 12 |
+
GromovWassersteinDistance
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"WassersteinDistance",
|
| 17 |
+
"ScenarioComparator",
|
| 18 |
+
"Scenario",
|
| 19 |
+
"ScenarioDistribution",
|
| 20 |
+
"GradientBasedOT",
|
| 21 |
+
"KantorovichDuality",
|
| 22 |
+
"EntropicOT",
|
| 23 |
+
"UnbalancedOT",
|
| 24 |
+
"GromovWassersteinDistance",
|
| 25 |
+
]
|
geobot/core/advanced_optimal_transport.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Optimal Transport with Gradient-Based Methods
|
| 3 |
+
|
| 4 |
+
Implements sophisticated optimal transport algorithms:
|
| 5 |
+
- Gradient-based OT with automatic differentiation
|
| 6 |
+
- Kantorovich duality formulation
|
| 7 |
+
- Sinkhorn with entropic regularization (tunable ε)
|
| 8 |
+
- Wasserstein barycenters
|
| 9 |
+
- Gromov-Wasserstein distance
|
| 10 |
+
- Unbalanced optimal transport
|
| 11 |
+
|
| 12 |
+
Provides geometric insights on the space of probability measures
|
| 13 |
+
and leverages gradient flows for OT computations.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
from typing import Tuple, Optional, Callable, List
|
| 18 |
+
from scipy.optimize import minimize
|
| 19 |
+
from scipy.spatial.distance import cdist
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import ot as pot
|
| 23 |
+
HAS_POT = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
HAS_POT = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class GradientBasedOT:
|
| 29 |
+
"""
|
| 30 |
+
Gradient-based optimal transport using automatic differentiation.
|
| 31 |
+
|
| 32 |
+
Solves the Monge-Kantorovich problem:
|
| 33 |
+
min_γ ∈ Π(μ,ν) ⟨γ, C⟩
|
| 34 |
+
|
| 35 |
+
using gradient-based optimization on the dual problem or
|
| 36 |
+
on parametric transport maps.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, cost_fn: Optional[Callable] = None):
|
| 40 |
+
"""
|
| 41 |
+
Initialize gradient-based OT.
|
| 42 |
+
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
cost_fn : callable, optional
|
| 46 |
+
Cost function c(x, y). Defaults to squared Euclidean distance.
|
| 47 |
+
"""
|
| 48 |
+
self.cost_fn = cost_fn or (lambda x, y: np.sum((x - y)**2))
|
| 49 |
+
|
| 50 |
+
def compute_transport_map(
|
| 51 |
+
self,
|
| 52 |
+
X_source: np.ndarray,
|
| 53 |
+
X_target: np.ndarray,
|
| 54 |
+
method: str = 'gradient',
|
| 55 |
+
reg: float = 0.1,
|
| 56 |
+
n_iter: int = 100
|
| 57 |
+
) -> Tuple[np.ndarray, float]:
|
| 58 |
+
"""
|
| 59 |
+
Compute optimal transport map via gradient descent.
|
| 60 |
+
|
| 61 |
+
For Monge formulation: min_T E[c(x, T(x))]
|
| 62 |
+
|
| 63 |
+
Parameters
|
| 64 |
+
----------
|
| 65 |
+
X_source : np.ndarray, shape (n, d)
|
| 66 |
+
Source samples
|
| 67 |
+
X_target : np.ndarray, shape (m, d)
|
| 68 |
+
Target samples
|
| 69 |
+
method : str
|
| 70 |
+
Method ('gradient', 'dual')
|
| 71 |
+
reg : float
|
| 72 |
+
Regularization parameter
|
| 73 |
+
n_iter : int
|
| 74 |
+
Number of iterations
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
tuple
|
| 79 |
+
(transport_map, cost)
|
| 80 |
+
"""
|
| 81 |
+
n_source, d = X_source.shape
|
| 82 |
+
n_target = X_target.shape[0]
|
| 83 |
+
|
| 84 |
+
if method == 'gradient':
|
| 85 |
+
# Parametrize transport map as T(x) = x + displacement
|
| 86 |
+
displacement = np.zeros_like(X_source)
|
| 87 |
+
|
| 88 |
+
learning_rate = 0.01
|
| 89 |
+
|
| 90 |
+
for iteration in range(n_iter):
|
| 91 |
+
# Compute transported points
|
| 92 |
+
X_transported = X_source + displacement
|
| 93 |
+
|
| 94 |
+
# Find nearest neighbors in target (assignment)
|
| 95 |
+
distances = cdist(X_transported, X_target)
|
| 96 |
+
assignments = np.argmin(distances, axis=1)
|
| 97 |
+
|
| 98 |
+
# Compute cost
|
| 99 |
+
cost = np.mean([
|
| 100 |
+
self.cost_fn(X_transported[i], X_target[assignments[i]])
|
| 101 |
+
for i in range(n_source)
|
| 102 |
+
])
|
| 103 |
+
|
| 104 |
+
# Gradient (simplified - would use autograd)
|
| 105 |
+
gradient = np.zeros_like(displacement)
|
| 106 |
+
for i in range(n_source):
|
| 107 |
+
gradient[i] = 2 * (X_transported[i] - X_target[assignments[i]])
|
| 108 |
+
|
| 109 |
+
# Regularization
|
| 110 |
+
gradient += reg * displacement
|
| 111 |
+
|
| 112 |
+
# Update
|
| 113 |
+
displacement -= learning_rate * gradient
|
| 114 |
+
|
| 115 |
+
if iteration % 20 == 0:
|
| 116 |
+
print(f"Iteration {iteration}, Cost: {cost:.6f}")
|
| 117 |
+
|
| 118 |
+
return displacement, cost
|
| 119 |
+
|
| 120 |
+
elif method == 'dual':
|
| 121 |
+
return self._solve_dual_problem(X_source, X_target, reg)
|
| 122 |
+
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Unknown method: {method}")
|
| 125 |
+
|
| 126 |
+
def _solve_dual_problem(
|
| 127 |
+
self,
|
| 128 |
+
X_source: np.ndarray,
|
| 129 |
+
X_target: np.ndarray,
|
| 130 |
+
reg: float
|
| 131 |
+
) -> Tuple[np.ndarray, float]:
|
| 132 |
+
"""
|
| 133 |
+
Solve dual OT problem.
|
| 134 |
+
|
| 135 |
+
Dual formulation (Kantorovich):
|
| 136 |
+
max_{f,g} E_μ[f(x)] + E_ν[g(y)]
|
| 137 |
+
s.t. f(x) + g(y) ≤ c(x,y)
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
X_source : np.ndarray
|
| 142 |
+
Source points
|
| 143 |
+
X_target : np.ndarray
|
| 144 |
+
Target points
|
| 145 |
+
reg : float
|
| 146 |
+
Regularization
|
| 147 |
+
|
| 148 |
+
Returns
|
| 149 |
+
-------
|
| 150 |
+
tuple
|
| 151 |
+
(potentials, cost)
|
| 152 |
+
"""
|
| 153 |
+
n_source = len(X_source)
|
| 154 |
+
n_target = len(X_target)
|
| 155 |
+
|
| 156 |
+
# Cost matrix
|
| 157 |
+
C = cdist(X_source, X_target, metric='sqeuclidean')
|
| 158 |
+
|
| 159 |
+
# Solve via constrained optimization
|
| 160 |
+
# Variables: [f (n_source), g (n_target)]
|
| 161 |
+
n_vars = n_source + n_target
|
| 162 |
+
|
| 163 |
+
def objective(x):
|
| 164 |
+
f = x[:n_source]
|
| 165 |
+
g = x[n_source:]
|
| 166 |
+
return -(np.mean(f) + np.mean(g)) # Negative for minimization
|
| 167 |
+
|
| 168 |
+
def constraint(x):
|
| 169 |
+
f = x[:n_source]
|
| 170 |
+
g = x[n_source:]
|
| 171 |
+
# f(x_i) + g(y_j) - c(x_i, y_j) ≤ 0
|
| 172 |
+
return C - (f[:, np.newaxis] + g[np.newaxis, :])
|
| 173 |
+
|
| 174 |
+
from scipy.optimize import NonlinearConstraint
|
| 175 |
+
nlc = NonlinearConstraint(
|
| 176 |
+
lambda x: constraint(x).flatten(),
|
| 177 |
+
-np.inf,
|
| 178 |
+
0
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
x0 = np.zeros(n_vars)
|
| 182 |
+
result = minimize(objective, x0, constraints=nlc, method='SLSQP')
|
| 183 |
+
|
| 184 |
+
f_opt = result.x[:n_source]
|
| 185 |
+
g_opt = result.x[n_source:]
|
| 186 |
+
|
| 187 |
+
cost = -result.fun
|
| 188 |
+
|
| 189 |
+
return np.column_stack([f_opt, g_opt]), cost
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class KantorovichDuality:
|
| 193 |
+
"""
|
| 194 |
+
Kantorovich duality formulation for optimal transport.
|
| 195 |
+
|
| 196 |
+
Primal: min_γ ∈ Π(μ,ν) ∫∫ c(x,y) dγ(x,y)
|
| 197 |
+
|
| 198 |
+
Dual: max_{f,g} ∫ f dμ + ∫ g dν
|
| 199 |
+
s.t. f(x) + g(y) ≤ c(x,y)
|
| 200 |
+
|
| 201 |
+
Strong duality holds under mild conditions.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
def __init__(self):
|
| 205 |
+
"""Initialize Kantorovich duality solver."""
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
def solve_primal(
|
| 209 |
+
self,
|
| 210 |
+
mu: np.ndarray,
|
| 211 |
+
nu: np.ndarray,
|
| 212 |
+
C: np.ndarray,
|
| 213 |
+
method: str = 'emd'
|
| 214 |
+
) -> Tuple[np.ndarray, float]:
|
| 215 |
+
"""
|
| 216 |
+
Solve primal OT problem.
|
| 217 |
+
|
| 218 |
+
Parameters
|
| 219 |
+
----------
|
| 220 |
+
mu : np.ndarray
|
| 221 |
+
Source distribution weights
|
| 222 |
+
nu : np.ndarray
|
| 223 |
+
Target distribution weights
|
| 224 |
+
C : np.ndarray
|
| 225 |
+
Cost matrix
|
| 226 |
+
method : str
|
| 227 |
+
Method ('emd', 'sinkhorn')
|
| 228 |
+
|
| 229 |
+
Returns
|
| 230 |
+
-------
|
| 231 |
+
tuple
|
| 232 |
+
(coupling, cost)
|
| 233 |
+
"""
|
| 234 |
+
if not HAS_POT:
|
| 235 |
+
raise ImportError("POT library required")
|
| 236 |
+
|
| 237 |
+
if method == 'emd':
|
| 238 |
+
coupling = pot.emd(mu, nu, C)
|
| 239 |
+
elif method == 'sinkhorn':
|
| 240 |
+
coupling = pot.sinkhorn(mu, nu, C, reg=0.1)
|
| 241 |
+
else:
|
| 242 |
+
raise ValueError(f"Unknown method: {method}")
|
| 243 |
+
|
| 244 |
+
cost = np.sum(coupling * C)
|
| 245 |
+
return coupling, cost
|
| 246 |
+
|
| 247 |
+
def solve_dual(
|
| 248 |
+
self,
|
| 249 |
+
mu: np.ndarray,
|
| 250 |
+
nu: np.ndarray,
|
| 251 |
+
C: np.ndarray,
|
| 252 |
+
max_iter: int = 1000,
|
| 253 |
+
tol: float = 1e-6
|
| 254 |
+
) -> Tuple[np.ndarray, np.ndarray, float]:
|
| 255 |
+
"""
|
| 256 |
+
Solve dual OT problem via iterative Bregman projections.
|
| 257 |
+
|
| 258 |
+
Parameters
|
| 259 |
+
----------
|
| 260 |
+
mu : np.ndarray
|
| 261 |
+
Source weights
|
| 262 |
+
nu : np.ndarray
|
| 263 |
+
Target weights
|
| 264 |
+
C : np.ndarray
|
| 265 |
+
Cost matrix
|
| 266 |
+
max_iter : int
|
| 267 |
+
Maximum iterations
|
| 268 |
+
tol : float
|
| 269 |
+
Convergence tolerance
|
| 270 |
+
|
| 271 |
+
Returns
|
| 272 |
+
-------
|
| 273 |
+
tuple
|
| 274 |
+
(f, g, dual_value) where f, g are dual potentials
|
| 275 |
+
"""
|
| 276 |
+
n = len(mu)
|
| 277 |
+
m = len(nu)
|
| 278 |
+
|
| 279 |
+
# Initialize dual variables (Kantorovich potentials)
|
| 280 |
+
f = np.zeros(n)
|
| 281 |
+
g = np.zeros(m)
|
| 282 |
+
|
| 283 |
+
for iteration in range(max_iter):
|
| 284 |
+
# Update g (c-transform of f)
|
| 285 |
+
# g(y) = min_x [c(x,y) - f(x)]
|
| 286 |
+
g_new = np.min(C - f[:, np.newaxis], axis=0)
|
| 287 |
+
|
| 288 |
+
# Update f (c-transform of g)
|
| 289 |
+
# f(x) = min_y [c(x,y) - g(y)]
|
| 290 |
+
f_new = np.min(C - g_new[np.newaxis, :], axis=1)
|
| 291 |
+
|
| 292 |
+
# Check convergence
|
| 293 |
+
if np.max(np.abs(f_new - f)) < tol and np.max(np.abs(g_new - g)) < tol:
|
| 294 |
+
break
|
| 295 |
+
|
| 296 |
+
f, g = f_new, g_new
|
| 297 |
+
|
| 298 |
+
# Dual value
|
| 299 |
+
dual_value = np.dot(f, mu) + np.dot(g, nu)
|
| 300 |
+
|
| 301 |
+
return f, g, dual_value
|
| 302 |
+
|
| 303 |
+
def verify_duality_gap(
|
| 304 |
+
self,
|
| 305 |
+
mu: np.ndarray,
|
| 306 |
+
nu: np.ndarray,
|
| 307 |
+
C: np.ndarray
|
| 308 |
+
) -> float:
|
| 309 |
+
"""
|
| 310 |
+
Verify strong duality: primal_cost - dual_cost ≈ 0.
|
| 311 |
+
|
| 312 |
+
Parameters
|
| 313 |
+
----------
|
| 314 |
+
mu : np.ndarray
|
| 315 |
+
Source weights
|
| 316 |
+
nu : np.ndarray
|
| 317 |
+
Target weights
|
| 318 |
+
C : np.ndarray
|
| 319 |
+
Cost matrix
|
| 320 |
+
|
| 321 |
+
Returns
|
| 322 |
+
-------
|
| 323 |
+
float
|
| 324 |
+
Duality gap
|
| 325 |
+
"""
|
| 326 |
+
# Solve primal
|
| 327 |
+
coupling, primal_cost = self.solve_primal(mu, nu, C)
|
| 328 |
+
|
| 329 |
+
# Solve dual
|
| 330 |
+
f, g, dual_value = self.solve_dual(mu, nu, C)
|
| 331 |
+
|
| 332 |
+
gap = primal_cost - dual_value
|
| 333 |
+
return gap
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class EntropicOT:
|
| 337 |
+
"""
|
| 338 |
+
Entropic optimal transport with Sinkhorn algorithm.
|
| 339 |
+
|
| 340 |
+
Regularized OT:
|
| 341 |
+
min_γ ∈ Π(μ,ν) ⟨γ, C⟩ + ε H(γ)
|
| 342 |
+
|
| 343 |
+
where H(γ) = - ∑_ij γ_ij log γ_ij is entropy.
|
| 344 |
+
|
| 345 |
+
Sinkhorn iterations converge geometrically fast.
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(self, epsilon: float = 0.1):
|
| 349 |
+
"""
|
| 350 |
+
Initialize entropic OT.
|
| 351 |
+
|
| 352 |
+
Parameters
|
| 353 |
+
----------
|
| 354 |
+
epsilon : float
|
| 355 |
+
Entropic regularization parameter
|
| 356 |
+
"""
|
| 357 |
+
self.epsilon = epsilon
|
| 358 |
+
|
| 359 |
+
def sinkhorn(
|
| 360 |
+
self,
|
| 361 |
+
mu: np.ndarray,
|
| 362 |
+
nu: np.ndarray,
|
| 363 |
+
C: np.ndarray,
|
| 364 |
+
max_iter: int = 1000,
|
| 365 |
+
tol: float = 1e-6
|
| 366 |
+
) -> Tuple[np.ndarray, float]:
|
| 367 |
+
"""
|
| 368 |
+
Sinkhorn algorithm for entropic OT.
|
| 369 |
+
|
| 370 |
+
Parameters
|
| 371 |
+
----------
|
| 372 |
+
mu : np.ndarray
|
| 373 |
+
Source distribution
|
| 374 |
+
nu : np.ndarray
|
| 375 |
+
Target distribution
|
| 376 |
+
C : np.ndarray
|
| 377 |
+
Cost matrix
|
| 378 |
+
max_iter : int
|
| 379 |
+
Maximum iterations
|
| 380 |
+
tol : float
|
| 381 |
+
Convergence tolerance
|
| 382 |
+
|
| 383 |
+
Returns
|
| 384 |
+
-------
|
| 385 |
+
tuple
|
| 386 |
+
(coupling, cost)
|
| 387 |
+
"""
|
| 388 |
+
n, m = C.shape
|
| 389 |
+
|
| 390 |
+
# Kernel matrix
|
| 391 |
+
K = np.exp(-C / self.epsilon)
|
| 392 |
+
|
| 393 |
+
# Initialize scaling vectors
|
| 394 |
+
u = np.ones(n) / n
|
| 395 |
+
v = np.ones(m) / m
|
| 396 |
+
|
| 397 |
+
for iteration in range(max_iter):
|
| 398 |
+
u_prev = u.copy()
|
| 399 |
+
|
| 400 |
+
# Update u
|
| 401 |
+
u = mu / (K @ v)
|
| 402 |
+
|
| 403 |
+
# Update v
|
| 404 |
+
v = nu / (K.T @ u)
|
| 405 |
+
|
| 406 |
+
# Check convergence
|
| 407 |
+
if np.max(np.abs(u - u_prev)) < tol:
|
| 408 |
+
break
|
| 409 |
+
|
| 410 |
+
# Compute coupling
|
| 411 |
+
coupling = u[:, np.newaxis] * K * v[np.newaxis, :]
|
| 412 |
+
|
| 413 |
+
# Compute cost
|
| 414 |
+
cost = np.sum(coupling * C)
|
| 415 |
+
|
| 416 |
+
return coupling, cost
|
| 417 |
+
|
| 418 |
+
def sinkhorn_log_stabilized(
|
| 419 |
+
self,
|
| 420 |
+
mu: np.ndarray,
|
| 421 |
+
nu: np.ndarray,
|
| 422 |
+
C: np.ndarray,
|
| 423 |
+
max_iter: int = 1000,
|
| 424 |
+
tol: float = 1e-6
|
| 425 |
+
) -> Tuple[np.ndarray, float]:
|
| 426 |
+
"""
|
| 427 |
+
Log-stabilized Sinkhorn algorithm (more numerically stable).
|
| 428 |
+
|
| 429 |
+
Works in log-domain to avoid overflow/underflow.
|
| 430 |
+
|
| 431 |
+
Parameters
|
| 432 |
+
----------
|
| 433 |
+
mu : np.ndarray
|
| 434 |
+
Source distribution
|
| 435 |
+
nu : np.ndarray
|
| 436 |
+
Target distribution
|
| 437 |
+
C : np.ndarray
|
| 438 |
+
Cost matrix
|
| 439 |
+
max_iter : int
|
| 440 |
+
Maximum iterations
|
| 441 |
+
tol : float
|
| 442 |
+
Tolerance
|
| 443 |
+
|
| 444 |
+
Returns
|
| 445 |
+
-------
|
| 446 |
+
tuple
|
| 447 |
+
(coupling, cost)
|
| 448 |
+
"""
|
| 449 |
+
n, m = C.shape
|
| 450 |
+
|
| 451 |
+
# Log-domain variables
|
| 452 |
+
log_mu = np.log(mu)
|
| 453 |
+
log_nu = np.log(nu)
|
| 454 |
+
|
| 455 |
+
# Initialize
|
| 456 |
+
f = np.zeros(n)
|
| 457 |
+
g = np.zeros(m)
|
| 458 |
+
|
| 459 |
+
for iteration in range(max_iter):
|
| 460 |
+
f_prev = f.copy()
|
| 461 |
+
|
| 462 |
+
# Update f
|
| 463 |
+
f = -self.epsilon * self._log_sum_exp(
|
| 464 |
+
(g[np.newaxis, :] - C) / self.epsilon,
|
| 465 |
+
axis=1
|
| 466 |
+
) + log_mu + self.epsilon * np.log(n)
|
| 467 |
+
|
| 468 |
+
# Update g
|
| 469 |
+
g = -self.epsilon * self._log_sum_exp(
|
| 470 |
+
(f[:, np.newaxis] - C) / self.epsilon,
|
| 471 |
+
axis=0
|
| 472 |
+
) + log_nu + self.epsilon * np.log(m)
|
| 473 |
+
|
| 474 |
+
# Check convergence
|
| 475 |
+
if np.max(np.abs(f - f_prev)) < tol:
|
| 476 |
+
break
|
| 477 |
+
|
| 478 |
+
# Compute coupling (in log domain, then exp)
|
| 479 |
+
log_coupling = (f[:, np.newaxis] + g[np.newaxis, :] - C) / self.epsilon
|
| 480 |
+
coupling = np.exp(log_coupling)
|
| 481 |
+
|
| 482 |
+
# Normalize
|
| 483 |
+
coupling = coupling / np.sum(coupling)
|
| 484 |
+
|
| 485 |
+
cost = np.sum(coupling * C)
|
| 486 |
+
|
| 487 |
+
return coupling, cost
|
| 488 |
+
|
| 489 |
+
def _log_sum_exp(self, X: np.ndarray, axis: int) -> np.ndarray:
|
| 490 |
+
"""
|
| 491 |
+
Numerically stable log-sum-exp.
|
| 492 |
+
|
| 493 |
+
log(∑ exp(X_i)) computed stably.
|
| 494 |
+
|
| 495 |
+
Parameters
|
| 496 |
+
----------
|
| 497 |
+
X : np.ndarray
|
| 498 |
+
Input array
|
| 499 |
+
axis : int
|
| 500 |
+
Axis to sum over
|
| 501 |
+
|
| 502 |
+
Returns
|
| 503 |
+
-------
|
| 504 |
+
np.ndarray
|
| 505 |
+
log-sum-exp result
|
| 506 |
+
"""
|
| 507 |
+
max_X = np.max(X, axis=axis, keepdims=True)
|
| 508 |
+
return np.log(np.sum(np.exp(X - max_X), axis=axis)) + np.squeeze(max_X, axis=axis)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class UnbalancedOT:
|
| 512 |
+
"""
|
| 513 |
+
Unbalanced optimal transport.
|
| 514 |
+
|
| 515 |
+
Relaxes the marginal constraints, allowing mass creation/destruction.
|
| 516 |
+
Useful when distributions don't have same total mass.
|
| 517 |
+
|
| 518 |
+
Formulation:
|
| 519 |
+
min_γ ⟨γ, C⟩ + ε H(γ) + τ KL(γ1_m | μ) + τ KL(γ^T1_n | ν)
|
| 520 |
+
|
| 521 |
+
where KL is Kullback-Leibler divergence.
|
| 522 |
+
"""
|
| 523 |
+
|
| 524 |
+
def __init__(self, epsilon: float = 0.1, tau: float = 0.1):
|
| 525 |
+
"""
|
| 526 |
+
Initialize unbalanced OT.
|
| 527 |
+
|
| 528 |
+
Parameters
|
| 529 |
+
----------
|
| 530 |
+
epsilon : float
|
| 531 |
+
Entropic regularization
|
| 532 |
+
tau : float
|
| 533 |
+
Marginal relaxation parameter
|
| 534 |
+
"""
|
| 535 |
+
self.epsilon = epsilon
|
| 536 |
+
self.tau = tau
|
| 537 |
+
|
| 538 |
+
def solve(
|
| 539 |
+
self,
|
| 540 |
+
mu: np.ndarray,
|
| 541 |
+
nu: np.ndarray,
|
| 542 |
+
C: np.ndarray,
|
| 543 |
+
max_iter: int = 1000
|
| 544 |
+
) -> Tuple[np.ndarray, float]:
|
| 545 |
+
"""
|
| 546 |
+
Solve unbalanced OT via Sinkhorn-Knopp algorithm.
|
| 547 |
+
|
| 548 |
+
Parameters
|
| 549 |
+
----------
|
| 550 |
+
mu : np.ndarray
|
| 551 |
+
Source (unbalanced) distribution
|
| 552 |
+
nu : np.ndarray
|
| 553 |
+
Target (unbalanced) distribution
|
| 554 |
+
C : np.ndarray
|
| 555 |
+
Cost matrix
|
| 556 |
+
max_iter : int
|
| 557 |
+
Maximum iterations
|
| 558 |
+
|
| 559 |
+
Returns
|
| 560 |
+
-------
|
| 561 |
+
tuple
|
| 562 |
+
(coupling, cost)
|
| 563 |
+
"""
|
| 564 |
+
if not HAS_POT:
|
| 565 |
+
raise ImportError("POT library required for unbalanced OT")
|
| 566 |
+
|
| 567 |
+
# Use POT's unbalanced solver
|
| 568 |
+
coupling = pot.unbalanced.sinkhorn_unbalanced(
|
| 569 |
+
mu, nu, C,
|
| 570 |
+
reg=self.epsilon,
|
| 571 |
+
reg_m=self.tau,
|
| 572 |
+
numItermax=max_iter
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
cost = np.sum(coupling * C)
|
| 576 |
+
|
| 577 |
+
return coupling, cost
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class GromovWassersteinDistance:
|
| 581 |
+
"""
|
| 582 |
+
Gromov-Wasserstein distance for comparing metric spaces.
|
| 583 |
+
|
| 584 |
+
Useful when we want to compare structures (graphs, networks)
|
| 585 |
+
rather than points in a common space.
|
| 586 |
+
"""
|
| 587 |
+
|
| 588 |
+
def __init__(self, epsilon: float = 0.1):
|
| 589 |
+
"""
|
| 590 |
+
Initialize Gromov-Wasserstein.
|
| 591 |
+
|
| 592 |
+
Parameters
|
| 593 |
+
----------
|
| 594 |
+
epsilon : float
|
| 595 |
+
Entropic regularization
|
| 596 |
+
"""
|
| 597 |
+
self.epsilon = epsilon
|
| 598 |
+
|
| 599 |
+
def compute(
|
| 600 |
+
self,
|
| 601 |
+
C1: np.ndarray,
|
| 602 |
+
C2: np.ndarray,
|
| 603 |
+
p: Optional[np.ndarray] = None,
|
| 604 |
+
q: Optional[np.ndarray] = None,
|
| 605 |
+
loss_fun: str = 'square_loss',
|
| 606 |
+
max_iter: int = 100
|
| 607 |
+
) -> Tuple[np.ndarray, float]:
|
| 608 |
+
"""
|
| 609 |
+
Compute Gromov-Wasserstein distance.
|
| 610 |
+
|
| 611 |
+
Parameters
|
| 612 |
+
----------
|
| 613 |
+
C1 : np.ndarray
|
| 614 |
+
Intra-space cost matrix for space 1
|
| 615 |
+
C2 : np.ndarray
|
| 616 |
+
Intra-space cost matrix for space 2
|
| 617 |
+
p : np.ndarray, optional
|
| 618 |
+
Distribution on space 1
|
| 619 |
+
q : np.ndarray, optional
|
| 620 |
+
Distribution on space 2
|
| 621 |
+
loss_fun : str
|
| 622 |
+
Loss function
|
| 623 |
+
max_iter : int
|
| 624 |
+
Maximum iterations
|
| 625 |
+
|
| 626 |
+
Returns
|
| 627 |
+
-------
|
| 628 |
+
tuple
|
| 629 |
+
(coupling, GW_distance)
|
| 630 |
+
"""
|
| 631 |
+
if not HAS_POT:
|
| 632 |
+
raise ImportError("POT library required for Gromov-Wasserstein")
|
| 633 |
+
|
| 634 |
+
n1, n2 = C1.shape[0], C2.shape[0]
|
| 635 |
+
|
| 636 |
+
# Default uniform distributions
|
| 637 |
+
if p is None:
|
| 638 |
+
p = np.ones(n1) / n1
|
| 639 |
+
if q is None:
|
| 640 |
+
q = np.ones(n2) / n2
|
| 641 |
+
|
| 642 |
+
# Compute Gromov-Wasserstein
|
| 643 |
+
gw_dist, log = pot.gromov.entropic_gromov_wasserstein(
|
| 644 |
+
C1, C2, p, q,
|
| 645 |
+
loss_fun=loss_fun,
|
| 646 |
+
epsilon=self.epsilon,
|
| 647 |
+
max_iter=max_iter,
|
| 648 |
+
log=True
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
coupling = log['T']
|
| 652 |
+
|
| 653 |
+
return coupling, gw_dist
|
geobot/core/optimal_transport.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimal Transport Module - Wasserstein Distances
|
| 3 |
+
|
| 4 |
+
Provides geometric measures of how much "effort" is needed to move from
|
| 5 |
+
one geopolitical scenario to another using optimal transport theory.
|
| 6 |
+
|
| 7 |
+
Applications:
|
| 8 |
+
- Measure regime shifts
|
| 9 |
+
- Compare distributions of Monte Carlo futures
|
| 10 |
+
- Quantify shock impact
|
| 11 |
+
- Measure closeness of geopolitical scenarios
|
| 12 |
+
- Detect structural change
|
| 13 |
+
- Logistics modeling
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
from typing import Union, Tuple, Optional, List
|
| 18 |
+
from scipy.spatial.distance import cdist
|
| 19 |
+
from scipy.stats import wasserstein_distance
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import ot # Python Optimal Transport library
|
| 23 |
+
HAS_POT = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
HAS_POT = False
|
| 26 |
+
print("Warning: POT library not available. Some features will be limited.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class WassersteinDistance:
|
| 30 |
+
"""
|
| 31 |
+
Compute Wasserstein distances between probability distributions.
|
| 32 |
+
|
| 33 |
+
The Wasserstein distance (also known as Earth Mover's Distance) provides
|
| 34 |
+
a principled way to measure the distance between probability distributions,
|
| 35 |
+
accounting for the geometry of the underlying space.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, metric: str = 'euclidean', p: int = 2):
|
| 39 |
+
"""
|
| 40 |
+
Initialize Wasserstein distance calculator.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
metric : str
|
| 45 |
+
Distance metric to use for ground distance ('euclidean', 'cityblock', etc.)
|
| 46 |
+
p : int
|
| 47 |
+
Order of Wasserstein distance (1 or 2)
|
| 48 |
+
"""
|
| 49 |
+
self.metric = metric
|
| 50 |
+
self.p = p
|
| 51 |
+
|
| 52 |
+
def compute_1d(
|
| 53 |
+
self,
|
| 54 |
+
u_values: np.ndarray,
|
| 55 |
+
v_values: np.ndarray,
|
| 56 |
+
u_weights: Optional[np.ndarray] = None,
|
| 57 |
+
v_weights: Optional[np.ndarray] = None
|
| 58 |
+
) -> float:
|
| 59 |
+
"""
|
| 60 |
+
Compute 1D Wasserstein distance between two distributions.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
u_values : np.ndarray
|
| 65 |
+
Values for first distribution
|
| 66 |
+
v_values : np.ndarray
|
| 67 |
+
Values for second distribution
|
| 68 |
+
u_weights : np.ndarray, optional
|
| 69 |
+
Weights for first distribution (defaults to uniform)
|
| 70 |
+
v_weights : np.ndarray, optional
|
| 71 |
+
Weights for second distribution (defaults to uniform)
|
| 72 |
+
|
| 73 |
+
Returns
|
| 74 |
+
-------
|
| 75 |
+
float
|
| 76 |
+
Wasserstein distance
|
| 77 |
+
"""
|
| 78 |
+
return wasserstein_distance(u_values, v_values, u_weights, v_weights)
|
| 79 |
+
|
| 80 |
+
def compute_nd(
|
| 81 |
+
self,
|
| 82 |
+
X_source: np.ndarray,
|
| 83 |
+
X_target: np.ndarray,
|
| 84 |
+
a: Optional[np.ndarray] = None,
|
| 85 |
+
b: Optional[np.ndarray] = None,
|
| 86 |
+
method: str = 'sinkhorn'
|
| 87 |
+
) -> float:
|
| 88 |
+
"""
|
| 89 |
+
Compute n-dimensional Wasserstein distance.
|
| 90 |
+
|
| 91 |
+
Parameters
|
| 92 |
+
----------
|
| 93 |
+
X_source : np.ndarray, shape (n_samples_source, n_features)
|
| 94 |
+
Source distribution samples
|
| 95 |
+
X_target : np.ndarray, shape (n_samples_target, n_features)
|
| 96 |
+
Target distribution samples
|
| 97 |
+
a : np.ndarray, optional
|
| 98 |
+
Weights for source distribution (defaults to uniform)
|
| 99 |
+
b : np.ndarray, optional
|
| 100 |
+
Weights for target distribution (defaults to uniform)
|
| 101 |
+
method : str
|
| 102 |
+
Method to use ('sinkhorn', 'emd', 'emd2')
|
| 103 |
+
|
| 104 |
+
Returns
|
| 105 |
+
-------
|
| 106 |
+
float
|
| 107 |
+
Wasserstein distance
|
| 108 |
+
"""
|
| 109 |
+
if not HAS_POT:
|
| 110 |
+
raise ImportError("POT library required for n-dimensional distances")
|
| 111 |
+
|
| 112 |
+
n_source = X_source.shape[0]
|
| 113 |
+
n_target = X_target.shape[0]
|
| 114 |
+
|
| 115 |
+
# Default to uniform distributions
|
| 116 |
+
if a is None:
|
| 117 |
+
a = np.ones(n_source) / n_source
|
| 118 |
+
if b is None:
|
| 119 |
+
b = np.ones(n_target) / n_target
|
| 120 |
+
|
| 121 |
+
# Compute cost matrix
|
| 122 |
+
M = cdist(X_source, X_target, metric=self.metric)
|
| 123 |
+
|
| 124 |
+
# Compute optimal transport
|
| 125 |
+
if method == 'sinkhorn':
|
| 126 |
+
# Sinkhorn algorithm (faster, approximate)
|
| 127 |
+
distance = ot.sinkhorn2(a, b, M, reg=0.1)
|
| 128 |
+
elif method == 'emd':
|
| 129 |
+
# Exact EMD
|
| 130 |
+
distance = ot.emd2(a, b, M)
|
| 131 |
+
elif method == 'emd2':
|
| 132 |
+
# Squared EMD
|
| 133 |
+
distance = ot.emd2(a, b, M**2)
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError(f"Unknown method: {method}")
|
| 136 |
+
|
| 137 |
+
return float(distance)
|
| 138 |
+
|
| 139 |
+
def compute_barycenter(
|
| 140 |
+
self,
|
| 141 |
+
distributions: List[np.ndarray],
|
| 142 |
+
weights: Optional[np.ndarray] = None,
|
| 143 |
+
method: str = 'sinkhorn'
|
| 144 |
+
) -> np.ndarray:
|
| 145 |
+
"""
|
| 146 |
+
Compute Wasserstein barycenter of multiple distributions.
|
| 147 |
+
|
| 148 |
+
This finds the "average" distribution in Wasserstein space.
|
| 149 |
+
|
| 150 |
+
Parameters
|
| 151 |
+
----------
|
| 152 |
+
distributions : list of np.ndarray
|
| 153 |
+
List of distributions to average
|
| 154 |
+
weights : np.ndarray, optional
|
| 155 |
+
Weights for each distribution
|
| 156 |
+
method : str
|
| 157 |
+
Method to use ('sinkhorn')
|
| 158 |
+
|
| 159 |
+
Returns
|
| 160 |
+
-------
|
| 161 |
+
np.ndarray
|
| 162 |
+
Wasserstein barycenter
|
| 163 |
+
"""
|
| 164 |
+
if not HAS_POT:
|
| 165 |
+
raise ImportError("POT library required for barycenter computation")
|
| 166 |
+
|
| 167 |
+
n_distributions = len(distributions)
|
| 168 |
+
|
| 169 |
+
if weights is None:
|
| 170 |
+
weights = np.ones(n_distributions) / n_distributions
|
| 171 |
+
|
| 172 |
+
# Stack distributions
|
| 173 |
+
A = np.column_stack(distributions)
|
| 174 |
+
|
| 175 |
+
# Compute barycenter
|
| 176 |
+
if method == 'sinkhorn':
|
| 177 |
+
barycenter = ot.bregman.barycenter(A, M=None, reg=0.1, weights=weights)
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError(f"Unknown method: {method}")
|
| 180 |
+
|
| 181 |
+
return barycenter
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ScenarioComparator:
|
| 185 |
+
"""
|
| 186 |
+
Compare geopolitical scenarios using optimal transport.
|
| 187 |
+
|
| 188 |
+
This class provides high-level methods for comparing scenarios,
|
| 189 |
+
detecting regime shifts, and quantifying shock impacts.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, metric: str = 'euclidean'):
|
| 193 |
+
"""
|
| 194 |
+
Initialize scenario comparator.
|
| 195 |
+
|
| 196 |
+
Parameters
|
| 197 |
+
----------
|
| 198 |
+
metric : str
|
| 199 |
+
Distance metric for ground distance
|
| 200 |
+
"""
|
| 201 |
+
self.wasserstein = WassersteinDistance(metric=metric)
|
| 202 |
+
|
| 203 |
+
def compare_scenarios(
|
| 204 |
+
self,
|
| 205 |
+
scenario1: np.ndarray,
|
| 206 |
+
scenario2: np.ndarray,
|
| 207 |
+
weights1: Optional[np.ndarray] = None,
|
| 208 |
+
weights2: Optional[np.ndarray] = None
|
| 209 |
+
) -> float:
|
| 210 |
+
"""
|
| 211 |
+
Compare two geopolitical scenarios.
|
| 212 |
+
|
| 213 |
+
Parameters
|
| 214 |
+
----------
|
| 215 |
+
scenario1 : np.ndarray
|
| 216 |
+
First scenario (features x samples)
|
| 217 |
+
scenario2 : np.ndarray
|
| 218 |
+
Second scenario (features x samples)
|
| 219 |
+
weights1 : np.ndarray, optional
|
| 220 |
+
Weights for first scenario
|
| 221 |
+
weights2 : np.ndarray, optional
|
| 222 |
+
Weights for second scenario
|
| 223 |
+
|
| 224 |
+
Returns
|
| 225 |
+
-------
|
| 226 |
+
float
|
| 227 |
+
Distance between scenarios
|
| 228 |
+
"""
|
| 229 |
+
return self.wasserstein.compute_nd(scenario1, scenario2, weights1, weights2)
|
| 230 |
+
|
| 231 |
+
def detect_regime_shift(
|
| 232 |
+
self,
|
| 233 |
+
baseline: np.ndarray,
|
| 234 |
+
current: np.ndarray,
|
| 235 |
+
threshold: float = 0.1
|
| 236 |
+
) -> Tuple[bool, float]:
|
| 237 |
+
"""
|
| 238 |
+
Detect if a regime shift has occurred.
|
| 239 |
+
|
| 240 |
+
Parameters
|
| 241 |
+
----------
|
| 242 |
+
baseline : np.ndarray
|
| 243 |
+
Baseline scenario distribution
|
| 244 |
+
current : np.ndarray
|
| 245 |
+
Current scenario distribution
|
| 246 |
+
threshold : float
|
| 247 |
+
Threshold for detecting shift
|
| 248 |
+
|
| 249 |
+
Returns
|
| 250 |
+
-------
|
| 251 |
+
tuple
|
| 252 |
+
(shift_detected, distance)
|
| 253 |
+
"""
|
| 254 |
+
distance = self.compare_scenarios(baseline, current)
|
| 255 |
+
shift_detected = distance > threshold
|
| 256 |
+
|
| 257 |
+
return shift_detected, distance
|
| 258 |
+
|
| 259 |
+
def quantify_shock_impact(
|
| 260 |
+
self,
|
| 261 |
+
pre_shock: np.ndarray,
|
| 262 |
+
post_shock: np.ndarray
|
| 263 |
+
) -> dict:
|
| 264 |
+
"""
|
| 265 |
+
Quantify the impact of a shock event.
|
| 266 |
+
|
| 267 |
+
Parameters
|
| 268 |
+
----------
|
| 269 |
+
pre_shock : np.ndarray
|
| 270 |
+
Pre-shock scenario distribution
|
| 271 |
+
post_shock : np.ndarray
|
| 272 |
+
Post-shock scenario distribution
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
dict
|
| 277 |
+
Dictionary with impact metrics
|
| 278 |
+
"""
|
| 279 |
+
distance = self.compare_scenarios(pre_shock, post_shock)
|
| 280 |
+
|
| 281 |
+
# Compute additional metrics
|
| 282 |
+
mean_shift = np.linalg.norm(np.mean(post_shock, axis=0) - np.mean(pre_shock, axis=0))
|
| 283 |
+
variance_change = np.abs(np.var(post_shock) - np.var(pre_shock))
|
| 284 |
+
|
| 285 |
+
return {
|
| 286 |
+
'wasserstein_distance': distance,
|
| 287 |
+
'mean_shift': mean_shift,
|
| 288 |
+
'variance_change': variance_change,
|
| 289 |
+
'impact_magnitude': distance * mean_shift
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
def compute_scenario_trajectory(
|
| 293 |
+
self,
|
| 294 |
+
scenarios: List[np.ndarray]
|
| 295 |
+
) -> np.ndarray:
|
| 296 |
+
"""
|
| 297 |
+
Compute trajectory of scenarios over time.
|
| 298 |
+
|
| 299 |
+
Parameters
|
| 300 |
+
----------
|
| 301 |
+
scenarios : list of np.ndarray
|
| 302 |
+
Time series of scenarios
|
| 303 |
+
|
| 304 |
+
Returns
|
| 305 |
+
-------
|
| 306 |
+
np.ndarray
|
| 307 |
+
Array of distances between consecutive scenarios
|
| 308 |
+
"""
|
| 309 |
+
n_scenarios = len(scenarios)
|
| 310 |
+
distances = np.zeros(n_scenarios - 1)
|
| 311 |
+
|
| 312 |
+
for i in range(n_scenarios - 1):
|
| 313 |
+
distances[i] = self.compare_scenarios(scenarios[i], scenarios[i + 1])
|
| 314 |
+
|
| 315 |
+
return distances
|
| 316 |
+
|
| 317 |
+
def logistics_optimal_transport(
|
| 318 |
+
self,
|
| 319 |
+
supply: np.ndarray,
|
| 320 |
+
demand: np.ndarray,
|
| 321 |
+
supply_locations: np.ndarray,
|
| 322 |
+
demand_locations: np.ndarray
|
| 323 |
+
) -> Tuple[np.ndarray, float]:
|
| 324 |
+
"""
|
| 325 |
+
Solve logistics problem using optimal transport.
|
| 326 |
+
|
| 327 |
+
Parameters
|
| 328 |
+
----------
|
| 329 |
+
supply : np.ndarray
|
| 330 |
+
Supply amounts at each location
|
| 331 |
+
demand : np.ndarray
|
| 332 |
+
Demand amounts at each location
|
| 333 |
+
supply_locations : np.ndarray
|
| 334 |
+
Coordinates of supply locations
|
| 335 |
+
demand_locations : np.ndarray
|
| 336 |
+
Coordinates of demand locations
|
| 337 |
+
|
| 338 |
+
Returns
|
| 339 |
+
-------
|
| 340 |
+
tuple
|
| 341 |
+
(transport_plan, total_cost)
|
| 342 |
+
"""
|
| 343 |
+
if not HAS_POT:
|
| 344 |
+
raise ImportError("POT library required for logistics optimization")
|
| 345 |
+
|
| 346 |
+
# Normalize supply and demand
|
| 347 |
+
supply_norm = supply / supply.sum()
|
| 348 |
+
demand_norm = demand / demand.sum()
|
| 349 |
+
|
| 350 |
+
# Compute cost matrix (distances)
|
| 351 |
+
M = cdist(supply_locations, demand_locations, metric=self.wasserstein.metric)
|
| 352 |
+
|
| 353 |
+
# Compute optimal transport plan
|
| 354 |
+
transport_plan = ot.emd(supply_norm, demand_norm, M)
|
| 355 |
+
total_cost = np.sum(transport_plan * M)
|
| 356 |
+
|
| 357 |
+
# Scale back to original quantities
|
| 358 |
+
transport_plan *= supply.sum()
|
| 359 |
+
|
| 360 |
+
return transport_plan, total_cost
|
geobot/core/scenario.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Scenario representation and management for geopolitical modeling.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import Dict, List, Optional, Any
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Scenario:
|
| 13 |
+
"""
|
| 14 |
+
Represents a geopolitical scenario with multiple features and metadata.
|
| 15 |
+
|
| 16 |
+
Attributes
|
| 17 |
+
----------
|
| 18 |
+
name : str
|
| 19 |
+
Name or identifier for the scenario
|
| 20 |
+
features : Dict[str, np.ndarray]
|
| 21 |
+
Dictionary of feature names to values
|
| 22 |
+
timestamp : datetime
|
| 23 |
+
Timestamp of the scenario
|
| 24 |
+
metadata : Dict[str, Any]
|
| 25 |
+
Additional metadata
|
| 26 |
+
probability : float
|
| 27 |
+
Probability or weight of this scenario (for ensembles)
|
| 28 |
+
"""
|
| 29 |
+
name: str
|
| 30 |
+
features: Dict[str, np.ndarray]
|
| 31 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 32 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 33 |
+
probability: float = 1.0
|
| 34 |
+
|
| 35 |
+
def get_feature_vector(self, feature_names: Optional[List[str]] = None) -> np.ndarray:
|
| 36 |
+
"""
|
| 37 |
+
Get features as a vector.
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
feature_names : List[str], optional
|
| 42 |
+
List of feature names to include (if None, use all)
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
np.ndarray
|
| 47 |
+
Feature vector
|
| 48 |
+
"""
|
| 49 |
+
if feature_names is None:
|
| 50 |
+
feature_names = list(self.features.keys())
|
| 51 |
+
|
| 52 |
+
vectors = [self.features[name].flatten() for name in feature_names if name in self.features]
|
| 53 |
+
return np.concatenate(vectors)
|
| 54 |
+
|
| 55 |
+
def get_feature_matrix(self) -> np.ndarray:
|
| 56 |
+
"""
|
| 57 |
+
Get all features as a matrix.
|
| 58 |
+
|
| 59 |
+
Returns
|
| 60 |
+
-------
|
| 61 |
+
np.ndarray
|
| 62 |
+
Feature matrix (n_features, ...)
|
| 63 |
+
"""
|
| 64 |
+
return np.array([v for v in self.features.values()])
|
| 65 |
+
|
| 66 |
+
def add_feature(self, name: str, values: np.ndarray) -> None:
|
| 67 |
+
"""
|
| 68 |
+
Add a new feature to the scenario.
|
| 69 |
+
|
| 70 |
+
Parameters
|
| 71 |
+
----------
|
| 72 |
+
name : str
|
| 73 |
+
Feature name
|
| 74 |
+
values : np.ndarray
|
| 75 |
+
Feature values
|
| 76 |
+
"""
|
| 77 |
+
self.features[name] = values
|
| 78 |
+
|
| 79 |
+
def remove_feature(self, name: str) -> None:
|
| 80 |
+
"""
|
| 81 |
+
Remove a feature from the scenario.
|
| 82 |
+
|
| 83 |
+
Parameters
|
| 84 |
+
----------
|
| 85 |
+
name : str
|
| 86 |
+
Feature name to remove
|
| 87 |
+
"""
|
| 88 |
+
if name in self.features:
|
| 89 |
+
del self.features[name]
|
| 90 |
+
|
| 91 |
+
def clone(self) -> 'Scenario':
|
| 92 |
+
"""
|
| 93 |
+
Create a deep copy of the scenario.
|
| 94 |
+
|
| 95 |
+
Returns
|
| 96 |
+
-------
|
| 97 |
+
Scenario
|
| 98 |
+
Cloned scenario
|
| 99 |
+
"""
|
| 100 |
+
return Scenario(
|
| 101 |
+
name=self.name,
|
| 102 |
+
features={k: v.copy() for k, v in self.features.items()},
|
| 103 |
+
timestamp=self.timestamp,
|
| 104 |
+
metadata=self.metadata.copy(),
|
| 105 |
+
probability=self.probability
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ScenarioDistribution:
|
| 110 |
+
"""
|
| 111 |
+
Represents a distribution over multiple scenarios.
|
| 112 |
+
|
| 113 |
+
This is useful for Monte Carlo simulations, ensemble forecasting,
|
| 114 |
+
and probabilistic reasoning.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, scenarios: Optional[List[Scenario]] = None):
|
| 118 |
+
"""
|
| 119 |
+
Initialize scenario distribution.
|
| 120 |
+
|
| 121 |
+
Parameters
|
| 122 |
+
----------
|
| 123 |
+
scenarios : List[Scenario], optional
|
| 124 |
+
Initial list of scenarios
|
| 125 |
+
"""
|
| 126 |
+
self.scenarios: List[Scenario] = scenarios if scenarios is not None else []
|
| 127 |
+
|
| 128 |
+
def add_scenario(self, scenario: Scenario) -> None:
|
| 129 |
+
"""
|
| 130 |
+
Add a scenario to the distribution.
|
| 131 |
+
|
| 132 |
+
Parameters
|
| 133 |
+
----------
|
| 134 |
+
scenario : Scenario
|
| 135 |
+
Scenario to add
|
| 136 |
+
"""
|
| 137 |
+
self.scenarios.append(scenario)
|
| 138 |
+
|
| 139 |
+
def get_probabilities(self) -> np.ndarray:
|
| 140 |
+
"""
|
| 141 |
+
Get probabilities of all scenarios.
|
| 142 |
+
|
| 143 |
+
Returns
|
| 144 |
+
-------
|
| 145 |
+
np.ndarray
|
| 146 |
+
Array of probabilities
|
| 147 |
+
"""
|
| 148 |
+
probs = np.array([s.probability for s in self.scenarios])
|
| 149 |
+
# Normalize
|
| 150 |
+
return probs / probs.sum()
|
| 151 |
+
|
| 152 |
+
def normalize_probabilities(self) -> None:
|
| 153 |
+
"""
|
| 154 |
+
Normalize scenario probabilities to sum to 1.
|
| 155 |
+
"""
|
| 156 |
+
total_prob = sum(s.probability for s in self.scenarios)
|
| 157 |
+
for scenario in self.scenarios:
|
| 158 |
+
scenario.probability /= total_prob
|
| 159 |
+
|
| 160 |
+
def get_feature_samples(self, feature_names: Optional[List[str]] = None) -> np.ndarray:
|
| 161 |
+
"""
|
| 162 |
+
Get feature samples from all scenarios.
|
| 163 |
+
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
feature_names : List[str], optional
|
| 167 |
+
List of feature names to include
|
| 168 |
+
|
| 169 |
+
Returns
|
| 170 |
+
-------
|
| 171 |
+
np.ndarray
|
| 172 |
+
Feature samples (n_scenarios, n_features)
|
| 173 |
+
"""
|
| 174 |
+
samples = [s.get_feature_vector(feature_names) for s in self.scenarios]
|
| 175 |
+
return np.array(samples)
|
| 176 |
+
|
| 177 |
+
def get_weighted_mean(self, feature_names: Optional[List[str]] = None) -> np.ndarray:
|
| 178 |
+
"""
|
| 179 |
+
Compute weighted mean of features.
|
| 180 |
+
|
| 181 |
+
Parameters
|
| 182 |
+
----------
|
| 183 |
+
feature_names : List[str], optional
|
| 184 |
+
List of feature names to include
|
| 185 |
+
|
| 186 |
+
Returns
|
| 187 |
+
-------
|
| 188 |
+
np.ndarray
|
| 189 |
+
Weighted mean feature vector
|
| 190 |
+
"""
|
| 191 |
+
samples = self.get_feature_samples(feature_names)
|
| 192 |
+
probs = self.get_probabilities()
|
| 193 |
+
return np.average(samples, axis=0, weights=probs)
|
| 194 |
+
|
| 195 |
+
def get_variance(self, feature_names: Optional[List[str]] = None) -> np.ndarray:
|
| 196 |
+
"""
|
| 197 |
+
Compute variance of features.
|
| 198 |
+
|
| 199 |
+
Parameters
|
| 200 |
+
----------
|
| 201 |
+
feature_names : List[str], optional
|
| 202 |
+
List of feature names to include
|
| 203 |
+
|
| 204 |
+
Returns
|
| 205 |
+
-------
|
| 206 |
+
np.ndarray
|
| 207 |
+
Variance of features
|
| 208 |
+
"""
|
| 209 |
+
samples = self.get_feature_samples(feature_names)
|
| 210 |
+
probs = self.get_probabilities()
|
| 211 |
+
mean = self.get_weighted_mean(feature_names)
|
| 212 |
+
|
| 213 |
+
variance = np.average((samples - mean) ** 2, axis=0, weights=probs)
|
| 214 |
+
return variance
|
| 215 |
+
|
| 216 |
+
def sample(self, n_samples: int = 1, replace: bool = True) -> List[Scenario]:
|
| 217 |
+
"""
|
| 218 |
+
Sample scenarios from the distribution.
|
| 219 |
+
|
| 220 |
+
Parameters
|
| 221 |
+
----------
|
| 222 |
+
n_samples : int
|
| 223 |
+
Number of samples to draw
|
| 224 |
+
replace : bool
|
| 225 |
+
Whether to sample with replacement
|
| 226 |
+
|
| 227 |
+
Returns
|
| 228 |
+
-------
|
| 229 |
+
List[Scenario]
|
| 230 |
+
Sampled scenarios
|
| 231 |
+
"""
|
| 232 |
+
probs = self.get_probabilities()
|
| 233 |
+
indices = np.random.choice(
|
| 234 |
+
len(self.scenarios),
|
| 235 |
+
size=n_samples,
|
| 236 |
+
replace=replace,
|
| 237 |
+
p=probs
|
| 238 |
+
)
|
| 239 |
+
return [self.scenarios[i] for i in indices]
|
| 240 |
+
|
| 241 |
+
def filter_by_probability(self, threshold: float) -> 'ScenarioDistribution':
|
| 242 |
+
"""
|
| 243 |
+
Filter scenarios by probability threshold.
|
| 244 |
+
|
| 245 |
+
Parameters
|
| 246 |
+
----------
|
| 247 |
+
threshold : float
|
| 248 |
+
Minimum probability threshold
|
| 249 |
+
|
| 250 |
+
Returns
|
| 251 |
+
-------
|
| 252 |
+
ScenarioDistribution
|
| 253 |
+
New distribution with filtered scenarios
|
| 254 |
+
"""
|
| 255 |
+
filtered_scenarios = [s for s in self.scenarios if s.probability >= threshold]
|
| 256 |
+
return ScenarioDistribution(filtered_scenarios)
|
| 257 |
+
|
| 258 |
+
def get_top_k(self, k: int) -> 'ScenarioDistribution':
|
| 259 |
+
"""
|
| 260 |
+
Get top k scenarios by probability.
|
| 261 |
+
|
| 262 |
+
Parameters
|
| 263 |
+
----------
|
| 264 |
+
k : int
|
| 265 |
+
Number of scenarios to return
|
| 266 |
+
|
| 267 |
+
Returns
|
| 268 |
+
-------
|
| 269 |
+
ScenarioDistribution
|
| 270 |
+
Distribution with top k scenarios
|
| 271 |
+
"""
|
| 272 |
+
sorted_scenarios = sorted(self.scenarios, key=lambda s: s.probability, reverse=True)
|
| 273 |
+
return ScenarioDistribution(sorted_scenarios[:k])
|
| 274 |
+
|
| 275 |
+
def __len__(self) -> int:
|
| 276 |
+
"""Return number of scenarios."""
|
| 277 |
+
return len(self.scenarios)
|
| 278 |
+
|
| 279 |
+
def __getitem__(self, idx: int) -> Scenario:
|
| 280 |
+
"""Get scenario by index."""
|
| 281 |
+
return self.scenarios[idx]
|
| 282 |
+
|
| 283 |
+
def __iter__(self):
|
| 284 |
+
"""Iterate over scenarios."""
|
| 285 |
+
return iter(self.scenarios)
|
geobot/data_ingestion/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data ingestion modules for GeoBotv1
|
| 3 |
+
|
| 4 |
+
Support for:
|
| 5 |
+
- PDF document reading and processing
|
| 6 |
+
- Web scraping and article extraction
|
| 7 |
+
- News feed ingestion
|
| 8 |
+
- Structured event extraction
|
| 9 |
+
- Event database and temporal normalization
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .pdf_reader import PDFReader, PDFProcessor
|
| 13 |
+
from .web_scraper import WebScraper, ArticleExtractor, NewsAggregator
|
| 14 |
+
from .event_extraction import (
|
| 15 |
+
EventExtractor,
|
| 16 |
+
GeopoliticalEvent,
|
| 17 |
+
EventType,
|
| 18 |
+
TemporalNormalizer,
|
| 19 |
+
CausalFeatureExtractor
|
| 20 |
+
)
|
| 21 |
+
from .event_database import EventDatabase, EventStream
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"PDFReader",
|
| 25 |
+
"PDFProcessor",
|
| 26 |
+
"WebScraper",
|
| 27 |
+
"ArticleExtractor",
|
| 28 |
+
"NewsAggregator",
|
| 29 |
+
"EventExtractor",
|
| 30 |
+
"GeopoliticalEvent",
|
| 31 |
+
"EventType",
|
| 32 |
+
"TemporalNormalizer",
|
| 33 |
+
"CausalFeatureExtractor",
|
| 34 |
+
"EventDatabase",
|
| 35 |
+
"EventStream",
|
| 36 |
+
]
|
geobot/data_ingestion/event_database.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Event Database for Geopolitical Intelligence
|
| 3 |
+
|
| 4 |
+
Persistent storage and querying for structured events.
|
| 5 |
+
|
| 6 |
+
Features:
|
| 7 |
+
- Efficient time-range queries
|
| 8 |
+
- Actor-based filtering
|
| 9 |
+
- Event type filtering
|
| 10 |
+
- Temporal aggregation
|
| 11 |
+
- Causal graph construction from events
|
| 12 |
+
- Export to panel data formats
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import sqlite3
|
| 17 |
+
from datetime import datetime, timedelta
|
| 18 |
+
from typing import List, Dict, Optional, Tuple, Any
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
import pandas as pd
|
| 21 |
+
|
| 22 |
+
from .event_extraction import GeopoliticalEvent, EventType, TemporalNormalizer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class EventDatabase:
|
| 26 |
+
"""
|
| 27 |
+
SQLite-based event database with efficient querying.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, db_path: str = "events.db"):
|
| 31 |
+
"""
|
| 32 |
+
Initialize event database.
|
| 33 |
+
|
| 34 |
+
Parameters
|
| 35 |
+
----------
|
| 36 |
+
db_path : str
|
| 37 |
+
Path to SQLite database file
|
| 38 |
+
"""
|
| 39 |
+
self.db_path = db_path
|
| 40 |
+
self.conn = None
|
| 41 |
+
self._connect()
|
| 42 |
+
self._create_tables()
|
| 43 |
+
|
| 44 |
+
def _connect(self):
|
| 45 |
+
"""Connect to database."""
|
| 46 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 47 |
+
self.conn.row_factory = sqlite3.Row
|
| 48 |
+
|
| 49 |
+
def _create_tables(self):
|
| 50 |
+
"""Create database schema."""
|
| 51 |
+
cursor = self.conn.cursor()
|
| 52 |
+
|
| 53 |
+
# Events table
|
| 54 |
+
cursor.execute('''
|
| 55 |
+
CREATE TABLE IF NOT EXISTS events (
|
| 56 |
+
event_id TEXT PRIMARY KEY,
|
| 57 |
+
timestamp TEXT NOT NULL,
|
| 58 |
+
event_type TEXT NOT NULL,
|
| 59 |
+
location TEXT,
|
| 60 |
+
magnitude REAL,
|
| 61 |
+
confidence REAL,
|
| 62 |
+
source TEXT,
|
| 63 |
+
text TEXT,
|
| 64 |
+
metadata TEXT
|
| 65 |
+
)
|
| 66 |
+
''')
|
| 67 |
+
|
| 68 |
+
# Actors table (many-to-many with events)
|
| 69 |
+
cursor.execute('''
|
| 70 |
+
CREATE TABLE IF NOT EXISTS event_actors (
|
| 71 |
+
event_id TEXT,
|
| 72 |
+
actor TEXT,
|
| 73 |
+
role TEXT,
|
| 74 |
+
FOREIGN KEY (event_id) REFERENCES events(event_id),
|
| 75 |
+
PRIMARY KEY (event_id, actor)
|
| 76 |
+
)
|
| 77 |
+
''')
|
| 78 |
+
|
| 79 |
+
# Causal relationships
|
| 80 |
+
cursor.execute('''
|
| 81 |
+
CREATE TABLE IF NOT EXISTS causal_links (
|
| 82 |
+
cause_event_id TEXT,
|
| 83 |
+
effect_event_id TEXT,
|
| 84 |
+
strength REAL,
|
| 85 |
+
confidence REAL,
|
| 86 |
+
FOREIGN KEY (cause_event_id) REFERENCES events(event_id),
|
| 87 |
+
FOREIGN KEY (effect_event_id) REFERENCES events(event_id),
|
| 88 |
+
PRIMARY KEY (cause_event_id, effect_event_id)
|
| 89 |
+
)
|
| 90 |
+
''')
|
| 91 |
+
|
| 92 |
+
# Create indices for fast queries
|
| 93 |
+
cursor.execute('CREATE INDEX IF NOT EXISTS idx_timestamp ON events(timestamp)')
|
| 94 |
+
cursor.execute('CREATE INDEX IF NOT EXISTS idx_event_type ON events(event_type)')
|
| 95 |
+
cursor.execute('CREATE INDEX IF NOT EXISTS idx_actor ON event_actors(actor)')
|
| 96 |
+
|
| 97 |
+
self.conn.commit()
|
| 98 |
+
|
| 99 |
+
def insert_event(self, event: GeopoliticalEvent) -> None:
|
| 100 |
+
"""
|
| 101 |
+
Insert event into database.
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
event : GeopoliticalEvent
|
| 106 |
+
Event to insert
|
| 107 |
+
"""
|
| 108 |
+
cursor = self.conn.cursor()
|
| 109 |
+
|
| 110 |
+
# Normalize timestamp
|
| 111 |
+
timestamp_str = TemporalNormalizer.normalize_to_utc(event.timestamp).isoformat()
|
| 112 |
+
|
| 113 |
+
# Insert main event
|
| 114 |
+
cursor.execute('''
|
| 115 |
+
INSERT OR REPLACE INTO events
|
| 116 |
+
(event_id, timestamp, event_type, location, magnitude, confidence, source, text, metadata)
|
| 117 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 118 |
+
''', (
|
| 119 |
+
event.event_id,
|
| 120 |
+
timestamp_str,
|
| 121 |
+
event.event_type.value,
|
| 122 |
+
event.location,
|
| 123 |
+
event.magnitude,
|
| 124 |
+
event.confidence,
|
| 125 |
+
event.source,
|
| 126 |
+
event.text,
|
| 127 |
+
json.dumps(event.metadata)
|
| 128 |
+
))
|
| 129 |
+
|
| 130 |
+
# Insert actors
|
| 131 |
+
for actor in event.actors:
|
| 132 |
+
cursor.execute('''
|
| 133 |
+
INSERT OR REPLACE INTO event_actors (event_id, actor, role)
|
| 134 |
+
VALUES (?, ?, ?)
|
| 135 |
+
''', (event.event_id, actor, 'participant'))
|
| 136 |
+
|
| 137 |
+
# Insert target as actor with different role
|
| 138 |
+
if event.target:
|
| 139 |
+
cursor.execute('''
|
| 140 |
+
INSERT OR REPLACE INTO event_actors (event_id, actor, role)
|
| 141 |
+
VALUES (?, ?, ?)
|
| 142 |
+
''', (event.event_id, event.target, 'target'))
|
| 143 |
+
|
| 144 |
+
self.conn.commit()
|
| 145 |
+
|
| 146 |
+
def insert_events(self, events: List[GeopoliticalEvent]) -> None:
|
| 147 |
+
"""
|
| 148 |
+
Bulk insert events.
|
| 149 |
+
|
| 150 |
+
Parameters
|
| 151 |
+
----------
|
| 152 |
+
events : list
|
| 153 |
+
List of events to insert
|
| 154 |
+
"""
|
| 155 |
+
for event in events:
|
| 156 |
+
self.insert_event(event)
|
| 157 |
+
|
| 158 |
+
def query_events(
|
| 159 |
+
self,
|
| 160 |
+
start_time: Optional[datetime] = None,
|
| 161 |
+
end_time: Optional[datetime] = None,
|
| 162 |
+
event_types: Optional[List[EventType]] = None,
|
| 163 |
+
actors: Optional[List[str]] = None,
|
| 164 |
+
min_magnitude: Optional[float] = None,
|
| 165 |
+
limit: Optional[int] = None
|
| 166 |
+
) -> List[GeopoliticalEvent]:
|
| 167 |
+
"""
|
| 168 |
+
Query events with filters.
|
| 169 |
+
|
| 170 |
+
Parameters
|
| 171 |
+
----------
|
| 172 |
+
start_time : datetime, optional
|
| 173 |
+
Start of time range
|
| 174 |
+
end_time : datetime, optional
|
| 175 |
+
End of time range
|
| 176 |
+
event_types : list, optional
|
| 177 |
+
Filter by event types
|
| 178 |
+
actors : list, optional
|
| 179 |
+
Filter by actors
|
| 180 |
+
min_magnitude : float, optional
|
| 181 |
+
Minimum magnitude
|
| 182 |
+
limit : int, optional
|
| 183 |
+
Maximum number of results
|
| 184 |
+
|
| 185 |
+
Returns
|
| 186 |
+
-------
|
| 187 |
+
list
|
| 188 |
+
List of matching events
|
| 189 |
+
"""
|
| 190 |
+
cursor = self.conn.cursor()
|
| 191 |
+
|
| 192 |
+
query = "SELECT DISTINCT e.* FROM events e"
|
| 193 |
+
conditions = []
|
| 194 |
+
params = []
|
| 195 |
+
|
| 196 |
+
# Join with actors if needed
|
| 197 |
+
if actors:
|
| 198 |
+
query += " JOIN event_actors ea ON e.event_id = ea.event_id"
|
| 199 |
+
|
| 200 |
+
# Time range
|
| 201 |
+
if start_time:
|
| 202 |
+
conditions.append("e.timestamp >= ?")
|
| 203 |
+
params.append(start_time.isoformat())
|
| 204 |
+
if end_time:
|
| 205 |
+
conditions.append("e.timestamp <= ?")
|
| 206 |
+
params.append(end_time.isoformat())
|
| 207 |
+
|
| 208 |
+
# Event types
|
| 209 |
+
if event_types:
|
| 210 |
+
placeholders = ','.join('?' * len(event_types))
|
| 211 |
+
conditions.append(f"e.event_type IN ({placeholders})")
|
| 212 |
+
params.extend([et.value for et in event_types])
|
| 213 |
+
|
| 214 |
+
# Actors
|
| 215 |
+
if actors:
|
| 216 |
+
placeholders = ','.join('?' * len(actors))
|
| 217 |
+
conditions.append(f"ea.actor IN ({placeholders})")
|
| 218 |
+
params.extend(actors)
|
| 219 |
+
|
| 220 |
+
# Magnitude
|
| 221 |
+
if min_magnitude is not None:
|
| 222 |
+
conditions.append("e.magnitude >= ?")
|
| 223 |
+
params.append(min_magnitude)
|
| 224 |
+
|
| 225 |
+
# Build query
|
| 226 |
+
if conditions:
|
| 227 |
+
query += " WHERE " + " AND ".join(conditions)
|
| 228 |
+
|
| 229 |
+
query += " ORDER BY e.timestamp DESC"
|
| 230 |
+
|
| 231 |
+
if limit:
|
| 232 |
+
query += f" LIMIT {limit}"
|
| 233 |
+
|
| 234 |
+
# Execute
|
| 235 |
+
cursor.execute(query, params)
|
| 236 |
+
rows = cursor.fetchall()
|
| 237 |
+
|
| 238 |
+
# Convert to GeopoliticalEvent objects
|
| 239 |
+
events = []
|
| 240 |
+
for row in rows:
|
| 241 |
+
# Get actors
|
| 242 |
+
cursor.execute(
|
| 243 |
+
"SELECT actor FROM event_actors WHERE event_id = ?",
|
| 244 |
+
(row['event_id'],)
|
| 245 |
+
)
|
| 246 |
+
actors_rows = cursor.fetchall()
|
| 247 |
+
event_actors = [r['actor'] for r in actors_rows]
|
| 248 |
+
|
| 249 |
+
# Reconstruct event
|
| 250 |
+
event = GeopoliticalEvent(
|
| 251 |
+
event_id=row['event_id'],
|
| 252 |
+
timestamp=datetime.fromisoformat(row['timestamp']),
|
| 253 |
+
event_type=EventType(row['event_type']),
|
| 254 |
+
actors=event_actors,
|
| 255 |
+
location=row['location'],
|
| 256 |
+
magnitude=row['magnitude'],
|
| 257 |
+
confidence=row['confidence'],
|
| 258 |
+
source=row['source'],
|
| 259 |
+
text=row['text'],
|
| 260 |
+
metadata=json.loads(row['metadata']) if row['metadata'] else {}
|
| 261 |
+
)
|
| 262 |
+
events.append(event)
|
| 263 |
+
|
| 264 |
+
return events
|
| 265 |
+
|
| 266 |
+
def get_actor_timeline(
|
| 267 |
+
self,
|
| 268 |
+
actor: str,
|
| 269 |
+
start_time: Optional[datetime] = None,
|
| 270 |
+
end_time: Optional[datetime] = None
|
| 271 |
+
) -> List[GeopoliticalEvent]:
|
| 272 |
+
"""
|
| 273 |
+
Get timeline of events for a specific actor.
|
| 274 |
+
|
| 275 |
+
Parameters
|
| 276 |
+
----------
|
| 277 |
+
actor : str
|
| 278 |
+
Actor name
|
| 279 |
+
start_time : datetime, optional
|
| 280 |
+
Start time
|
| 281 |
+
end_time : datetime, optional
|
| 282 |
+
End time
|
| 283 |
+
|
| 284 |
+
Returns
|
| 285 |
+
-------
|
| 286 |
+
list
|
| 287 |
+
Events involving actor
|
| 288 |
+
"""
|
| 289 |
+
return self.query_events(
|
| 290 |
+
start_time=start_time,
|
| 291 |
+
end_time=end_time,
|
| 292 |
+
actors=[actor]
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
def get_event_counts_by_type(
|
| 296 |
+
self,
|
| 297 |
+
start_time: Optional[datetime] = None,
|
| 298 |
+
end_time: Optional[datetime] = None
|
| 299 |
+
) -> Dict[str, int]:
|
| 300 |
+
"""
|
| 301 |
+
Get event counts by type.
|
| 302 |
+
|
| 303 |
+
Parameters
|
| 304 |
+
----------
|
| 305 |
+
start_time : datetime, optional
|
| 306 |
+
Start time
|
| 307 |
+
end_time : datetime, optional
|
| 308 |
+
End time
|
| 309 |
+
|
| 310 |
+
Returns
|
| 311 |
+
-------
|
| 312 |
+
dict
|
| 313 |
+
Counts by event type
|
| 314 |
+
"""
|
| 315 |
+
cursor = self.conn.cursor()
|
| 316 |
+
|
| 317 |
+
query = "SELECT event_type, COUNT(*) as count FROM events"
|
| 318 |
+
conditions = []
|
| 319 |
+
params = []
|
| 320 |
+
|
| 321 |
+
if start_time:
|
| 322 |
+
conditions.append("timestamp >= ?")
|
| 323 |
+
params.append(start_time.isoformat())
|
| 324 |
+
if end_time:
|
| 325 |
+
conditions.append("timestamp <= ?")
|
| 326 |
+
params.append(end_time.isoformat())
|
| 327 |
+
|
| 328 |
+
if conditions:
|
| 329 |
+
query += " WHERE " + " AND ".join(conditions)
|
| 330 |
+
|
| 331 |
+
query += " GROUP BY event_type"
|
| 332 |
+
|
| 333 |
+
cursor.execute(query, params)
|
| 334 |
+
rows = cursor.fetchall()
|
| 335 |
+
|
| 336 |
+
return {row['event_type']: row['count'] for row in rows}
|
| 337 |
+
|
| 338 |
+
def aggregate_by_time(
|
| 339 |
+
self,
|
| 340 |
+
granularity: str = 'day',
|
| 341 |
+
start_time: Optional[datetime] = None,
|
| 342 |
+
end_time: Optional[datetime] = None,
|
| 343 |
+
event_types: Optional[List[EventType]] = None
|
| 344 |
+
) -> pd.DataFrame:
|
| 345 |
+
"""
|
| 346 |
+
Aggregate events by time period.
|
| 347 |
+
|
| 348 |
+
Parameters
|
| 349 |
+
----------
|
| 350 |
+
granularity : str
|
| 351 |
+
Time granularity ('day', 'week', 'month')
|
| 352 |
+
start_time : datetime, optional
|
| 353 |
+
Start time
|
| 354 |
+
end_time : datetime, optional
|
| 355 |
+
End time
|
| 356 |
+
event_types : list, optional
|
| 357 |
+
Filter by event types
|
| 358 |
+
|
| 359 |
+
Returns
|
| 360 |
+
-------
|
| 361 |
+
pd.DataFrame
|
| 362 |
+
Time series of event counts
|
| 363 |
+
"""
|
| 364 |
+
events = self.query_events(
|
| 365 |
+
start_time=start_time,
|
| 366 |
+
end_time=end_time,
|
| 367 |
+
event_types=event_types
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if not events:
|
| 371 |
+
return pd.DataFrame()
|
| 372 |
+
|
| 373 |
+
# Convert to DataFrame
|
| 374 |
+
df = pd.DataFrame([
|
| 375 |
+
{
|
| 376 |
+
'timestamp': e.timestamp,
|
| 377 |
+
'event_type': e.event_type.value,
|
| 378 |
+
'magnitude': e.magnitude
|
| 379 |
+
}
|
| 380 |
+
for e in events
|
| 381 |
+
])
|
| 382 |
+
|
| 383 |
+
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
| 384 |
+
df = df.set_index('timestamp')
|
| 385 |
+
|
| 386 |
+
# Resample
|
| 387 |
+
if granularity == 'day':
|
| 388 |
+
freq = 'D'
|
| 389 |
+
elif granularity == 'week':
|
| 390 |
+
freq = 'W'
|
| 391 |
+
elif granularity == 'month':
|
| 392 |
+
freq = 'M'
|
| 393 |
+
else:
|
| 394 |
+
raise ValueError(f"Unknown granularity: {granularity}")
|
| 395 |
+
|
| 396 |
+
# Aggregate
|
| 397 |
+
aggregated = df.resample(freq).agg({
|
| 398 |
+
'magnitude': ['count', 'mean', 'sum']
|
| 399 |
+
})
|
| 400 |
+
|
| 401 |
+
return aggregated
|
| 402 |
+
|
| 403 |
+
def export_to_panel_data(
|
| 404 |
+
self,
|
| 405 |
+
actors: List[str],
|
| 406 |
+
start_time: datetime,
|
| 407 |
+
end_time: datetime,
|
| 408 |
+
granularity: str = 'day'
|
| 409 |
+
) -> Dict[str, pd.DataFrame]:
|
| 410 |
+
"""
|
| 411 |
+
Export to panel data format.
|
| 412 |
+
|
| 413 |
+
Parameters
|
| 414 |
+
----------
|
| 415 |
+
actors : list
|
| 416 |
+
List of actors
|
| 417 |
+
start_time : datetime
|
| 418 |
+
Start time
|
| 419 |
+
end_time : datetime
|
| 420 |
+
End time
|
| 421 |
+
granularity : str
|
| 422 |
+
Time granularity
|
| 423 |
+
|
| 424 |
+
Returns
|
| 425 |
+
-------
|
| 426 |
+
dict
|
| 427 |
+
Panel data {actor: DataFrame}
|
| 428 |
+
"""
|
| 429 |
+
from .event_extraction import CausalFeatureExtractor
|
| 430 |
+
|
| 431 |
+
# Get events for each actor
|
| 432 |
+
panel = {}
|
| 433 |
+
for actor in actors:
|
| 434 |
+
events = self.get_actor_timeline(actor, start_time, end_time)
|
| 435 |
+
|
| 436 |
+
# Extract features
|
| 437 |
+
extractor = CausalFeatureExtractor()
|
| 438 |
+
panel_data = extractor.construct_panel_data([events], [actor], granularity)
|
| 439 |
+
|
| 440 |
+
if actor in panel_data:
|
| 441 |
+
panel[actor] = panel_data[actor]
|
| 442 |
+
|
| 443 |
+
return panel
|
| 444 |
+
|
| 445 |
+
def add_causal_link(
|
| 446 |
+
self,
|
| 447 |
+
cause_event_id: str,
|
| 448 |
+
effect_event_id: str,
|
| 449 |
+
strength: float = 1.0,
|
| 450 |
+
confidence: float = 0.5
|
| 451 |
+
) -> None:
|
| 452 |
+
"""
|
| 453 |
+
Add causal link between events.
|
| 454 |
+
|
| 455 |
+
Parameters
|
| 456 |
+
----------
|
| 457 |
+
cause_event_id : str
|
| 458 |
+
ID of cause event
|
| 459 |
+
effect_event_id : str
|
| 460 |
+
ID of effect event
|
| 461 |
+
strength : float
|
| 462 |
+
Causal strength
|
| 463 |
+
confidence : float
|
| 464 |
+
Confidence in link
|
| 465 |
+
"""
|
| 466 |
+
cursor = self.conn.cursor()
|
| 467 |
+
|
| 468 |
+
cursor.execute('''
|
| 469 |
+
INSERT OR REPLACE INTO causal_links
|
| 470 |
+
(cause_event_id, effect_event_id, strength, confidence)
|
| 471 |
+
VALUES (?, ?, ?, ?)
|
| 472 |
+
''', (cause_event_id, effect_event_id, strength, confidence))
|
| 473 |
+
|
| 474 |
+
self.conn.commit()
|
| 475 |
+
|
| 476 |
+
def get_causal_graph(self) -> Dict[str, List[str]]:
|
| 477 |
+
"""
|
| 478 |
+
Get causal graph from event links.
|
| 479 |
+
|
| 480 |
+
Returns
|
| 481 |
+
-------
|
| 482 |
+
dict
|
| 483 |
+
Adjacency list representation
|
| 484 |
+
"""
|
| 485 |
+
cursor = self.conn.cursor()
|
| 486 |
+
|
| 487 |
+
cursor.execute("SELECT cause_event_id, effect_event_id FROM causal_links")
|
| 488 |
+
rows = cursor.fetchall()
|
| 489 |
+
|
| 490 |
+
graph = {}
|
| 491 |
+
for row in rows:
|
| 492 |
+
cause = row['cause_event_id']
|
| 493 |
+
effect = row['effect_event_id']
|
| 494 |
+
|
| 495 |
+
if cause not in graph:
|
| 496 |
+
graph[cause] = []
|
| 497 |
+
graph[cause].append(effect)
|
| 498 |
+
|
| 499 |
+
return graph
|
| 500 |
+
|
| 501 |
+
def close(self):
|
| 502 |
+
"""Close database connection."""
|
| 503 |
+
if self.conn:
|
| 504 |
+
self.conn.close()
|
| 505 |
+
|
| 506 |
+
def __enter__(self):
|
| 507 |
+
"""Context manager entry."""
|
| 508 |
+
return self
|
| 509 |
+
|
| 510 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 511 |
+
"""Context manager exit."""
|
| 512 |
+
self.close()
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class EventStream:
|
| 516 |
+
"""
|
| 517 |
+
Real-time event stream processor.
|
| 518 |
+
|
| 519 |
+
Monitors and processes incoming events in real-time.
|
| 520 |
+
"""
|
| 521 |
+
|
| 522 |
+
def __init__(self, db: EventDatabase):
|
| 523 |
+
"""
|
| 524 |
+
Initialize event stream.
|
| 525 |
+
|
| 526 |
+
Parameters
|
| 527 |
+
----------
|
| 528 |
+
db : EventDatabase
|
| 529 |
+
Event database
|
| 530 |
+
"""
|
| 531 |
+
self.db = db
|
| 532 |
+
self.subscribers = []
|
| 533 |
+
|
| 534 |
+
def subscribe(self, callback: callable) -> None:
|
| 535 |
+
"""
|
| 536 |
+
Subscribe to event stream.
|
| 537 |
+
|
| 538 |
+
Parameters
|
| 539 |
+
----------
|
| 540 |
+
callback : callable
|
| 541 |
+
Function to call on new events
|
| 542 |
+
"""
|
| 543 |
+
self.subscribers.append(callback)
|
| 544 |
+
|
| 545 |
+
def process_event(self, event: GeopoliticalEvent) -> None:
|
| 546 |
+
"""
|
| 547 |
+
Process and store new event.
|
| 548 |
+
|
| 549 |
+
Parameters
|
| 550 |
+
----------
|
| 551 |
+
event : GeopoliticalEvent
|
| 552 |
+
New event
|
| 553 |
+
"""
|
| 554 |
+
# Store in database
|
| 555 |
+
self.db.insert_event(event)
|
| 556 |
+
|
| 557 |
+
# Notify subscribers
|
| 558 |
+
for callback in self.subscribers:
|
| 559 |
+
callback(event)
|
| 560 |
+
|
| 561 |
+
def process_batch(self, events: List[GeopoliticalEvent]) -> None:
|
| 562 |
+
"""
|
| 563 |
+
Process batch of events.
|
| 564 |
+
|
| 565 |
+
Parameters
|
| 566 |
+
----------
|
| 567 |
+
events : list
|
| 568 |
+
List of events
|
| 569 |
+
"""
|
| 570 |
+
self.db.insert_events(events)
|
| 571 |
+
|
| 572 |
+
for event in events:
|
| 573 |
+
for callback in self.subscribers:
|
| 574 |
+
callback(event)
|
geobot/data_ingestion/event_extraction.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Structured Event Extraction Pipeline
|
| 3 |
+
|
| 4 |
+
Converts unstructured intelligence (PDFs, articles, reports) into structured,
|
| 5 |
+
timestamped events suitable for:
|
| 6 |
+
- Causal graph construction and updates
|
| 7 |
+
- Time-series analysis
|
| 8 |
+
- Panel data modeling
|
| 9 |
+
- Temporal feature engineering
|
| 10 |
+
|
| 11 |
+
Event schema:
|
| 12 |
+
- Timestamp (normalized to UTC)
|
| 13 |
+
- Event type (conflict, diplomacy, economic, etc.)
|
| 14 |
+
- Actors (countries, organizations)
|
| 15 |
+
- Location (geospatial)
|
| 16 |
+
- Magnitude/severity
|
| 17 |
+
- Source and confidence
|
| 18 |
+
- Causal attributes (preconditions, effects)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import re
|
| 22 |
+
from datetime import datetime, timezone
|
| 23 |
+
from typing import List, Dict, Optional, Tuple, Any
|
| 24 |
+
import numpy as np
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from enum import Enum
|
| 27 |
+
import json
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class EventType(Enum):
|
| 31 |
+
"""Event taxonomy."""
|
| 32 |
+
CONFLICT = "conflict"
|
| 33 |
+
DIPLOMACY = "diplomacy"
|
| 34 |
+
ECONOMIC = "economic"
|
| 35 |
+
MILITARY_MOBILIZATION = "military_mobilization"
|
| 36 |
+
SANCTIONS = "sanctions"
|
| 37 |
+
ALLIANCE_FORMATION = "alliance_formation"
|
| 38 |
+
TREATY_SIGNING = "treaty_signing"
|
| 39 |
+
PROTEST = "protest"
|
| 40 |
+
ELECTION = "election"
|
| 41 |
+
COUP = "coup"
|
| 42 |
+
TERROR_ATTACK = "terror_attack"
|
| 43 |
+
CYBER_ATTACK = "cyber_attack"
|
| 44 |
+
TRADE_AGREEMENT = "trade_agreement"
|
| 45 |
+
ARMS_DEAL = "arms_deal"
|
| 46 |
+
HUMANITARIAN_CRISIS = "humanitarian_crisis"
|
| 47 |
+
OTHER = "other"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class GeopoliticalEvent:
|
| 52 |
+
"""
|
| 53 |
+
Structured geopolitical event.
|
| 54 |
+
|
| 55 |
+
Attributes
|
| 56 |
+
----------
|
| 57 |
+
event_id : str
|
| 58 |
+
Unique event identifier
|
| 59 |
+
timestamp : datetime
|
| 60 |
+
Event timestamp (normalized to UTC)
|
| 61 |
+
event_type : EventType
|
| 62 |
+
Type of event
|
| 63 |
+
actors : List[str]
|
| 64 |
+
Involved actors (countries, organizations)
|
| 65 |
+
target : Optional[str]
|
| 66 |
+
Target of action (if applicable)
|
| 67 |
+
location : Optional[str]
|
| 68 |
+
Geographic location
|
| 69 |
+
magnitude : float
|
| 70 |
+
Event magnitude/severity (0-1)
|
| 71 |
+
confidence : float
|
| 72 |
+
Extraction confidence (0-1)
|
| 73 |
+
source : str
|
| 74 |
+
Source document/article
|
| 75 |
+
text : str
|
| 76 |
+
Original text describing event
|
| 77 |
+
causal_preconditions : List[str]
|
| 78 |
+
Identified preconditions
|
| 79 |
+
causal_effects : List[str]
|
| 80 |
+
Identified effects
|
| 81 |
+
metadata : Dict[str, Any]
|
| 82 |
+
Additional metadata
|
| 83 |
+
"""
|
| 84 |
+
event_id: str
|
| 85 |
+
timestamp: datetime
|
| 86 |
+
event_type: EventType
|
| 87 |
+
actors: List[str]
|
| 88 |
+
target: Optional[str] = None
|
| 89 |
+
location: Optional[str] = None
|
| 90 |
+
magnitude: float = 0.5
|
| 91 |
+
confidence: float = 0.5
|
| 92 |
+
source: str = ""
|
| 93 |
+
text: str = ""
|
| 94 |
+
causal_preconditions: List[str] = field(default_factory=list)
|
| 95 |
+
causal_effects: List[str] = field(default_factory=list)
|
| 96 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 97 |
+
|
| 98 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 99 |
+
"""Convert to dictionary."""
|
| 100 |
+
return {
|
| 101 |
+
'event_id': self.event_id,
|
| 102 |
+
'timestamp': self.timestamp.isoformat(),
|
| 103 |
+
'event_type': self.event_type.value,
|
| 104 |
+
'actors': self.actors,
|
| 105 |
+
'target': self.target,
|
| 106 |
+
'location': self.location,
|
| 107 |
+
'magnitude': self.magnitude,
|
| 108 |
+
'confidence': self.confidence,
|
| 109 |
+
'source': self.source,
|
| 110 |
+
'text': self.text,
|
| 111 |
+
'causal_preconditions': self.causal_preconditions,
|
| 112 |
+
'causal_effects': self.causal_effects,
|
| 113 |
+
'metadata': self.metadata
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'GeopoliticalEvent':
|
| 118 |
+
"""Load from dictionary."""
|
| 119 |
+
data['timestamp'] = datetime.fromisoformat(data['timestamp'])
|
| 120 |
+
data['event_type'] = EventType(data['event_type'])
|
| 121 |
+
return cls(**data)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class EventExtractor:
|
| 125 |
+
"""
|
| 126 |
+
Extract structured events from unstructured text.
|
| 127 |
+
|
| 128 |
+
Uses rule-based patterns and NLP to identify:
|
| 129 |
+
- Event mentions
|
| 130 |
+
- Temporal expressions
|
| 131 |
+
- Actor identification
|
| 132 |
+
- Event type classification
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(self):
|
| 136 |
+
"""Initialize event extractor."""
|
| 137 |
+
self.country_names = self._load_country_names()
|
| 138 |
+
self.organization_names = self._load_organization_names()
|
| 139 |
+
self.event_patterns = self._compile_event_patterns()
|
| 140 |
+
|
| 141 |
+
def _load_country_names(self) -> List[str]:
|
| 142 |
+
"""Load list of country names."""
|
| 143 |
+
# Extended list of countries
|
| 144 |
+
return [
|
| 145 |
+
'United States', 'USA', 'China', 'Russia', 'India', 'Pakistan',
|
| 146 |
+
'Iran', 'North Korea', 'South Korea', 'Japan', 'Germany', 'France',
|
| 147 |
+
'United Kingdom', 'UK', 'Israel', 'Saudi Arabia', 'Turkey', 'Egypt',
|
| 148 |
+
'Syria', 'Iraq', 'Afghanistan', 'Ukraine', 'Poland', 'Italy', 'Spain',
|
| 149 |
+
'Canada', 'Australia', 'Brazil', 'Mexico', 'South Africa', 'Nigeria'
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
def _load_organization_names(self) -> List[str]:
|
| 153 |
+
"""Load list of international organizations."""
|
| 154 |
+
return [
|
| 155 |
+
'NATO', 'UN', 'United Nations', 'EU', 'European Union',
|
| 156 |
+
'OPEC', 'ASEAN', 'African Union', 'Arab League', 'G7', 'G20',
|
| 157 |
+
'IMF', 'World Bank', 'WTO', 'WHO', 'ICC'
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
def _compile_event_patterns(self) -> Dict[EventType, List[re.Pattern]]:
|
| 161 |
+
"""Compile regex patterns for event types."""
|
| 162 |
+
patterns = {
|
| 163 |
+
EventType.CONFLICT: [
|
| 164 |
+
re.compile(r'\b(attack|strike|bomb|missile|war|combat|clash|battle)\b', re.I),
|
| 165 |
+
re.compile(r'\b(invasion|offensive|assault|raid)\b', re.I)
|
| 166 |
+
],
|
| 167 |
+
EventType.DIPLOMACY: [
|
| 168 |
+
re.compile(r'\b(negotiation|talk|summit|meeting|dialogue)\b', re.I),
|
| 169 |
+
re.compile(r'\b(diplomatic|embassy|ambassador)\b', re.I)
|
| 170 |
+
],
|
| 171 |
+
EventType.SANCTIONS: [
|
| 172 |
+
re.compile(r'\b(sanction|embargo|restriction|ban)\b', re.I)
|
| 173 |
+
],
|
| 174 |
+
EventType.MILITARY_MOBILIZATION: [
|
| 175 |
+
re.compile(r'\b(mobiliz|deploy|troop|force|military)\b', re.I)
|
| 176 |
+
],
|
| 177 |
+
EventType.ALLIANCE_FORMATION: [
|
| 178 |
+
re.compile(r'\b(alliance|partnership|coalition|pact)\b', re.I)
|
| 179 |
+
],
|
| 180 |
+
EventType.TREATY_SIGNING: [
|
| 181 |
+
re.compile(r'\b(treaty|agreement|accord|convention)\b', re.I)
|
| 182 |
+
],
|
| 183 |
+
EventType.ELECTION: [
|
| 184 |
+
re.compile(r'\b(election|vote|ballot|referendum)\b', re.I)
|
| 185 |
+
],
|
| 186 |
+
EventType.COUP: [
|
| 187 |
+
re.compile(r'\b(coup|overthrow|takeover|regime change)\b', re.I)
|
| 188 |
+
],
|
| 189 |
+
EventType.TERROR_ATTACK: [
|
| 190 |
+
re.compile(r'\b(terror|terrorist|extremist|bombing)\b', re.I)
|
| 191 |
+
],
|
| 192 |
+
EventType.CYBER_ATTACK: [
|
| 193 |
+
re.compile(r'\b(cyber|hack|breach|malware|ransomware)\b', re.I)
|
| 194 |
+
]
|
| 195 |
+
}
|
| 196 |
+
return patterns
|
| 197 |
+
|
| 198 |
+
def extract_events(
|
| 199 |
+
self,
|
| 200 |
+
text: str,
|
| 201 |
+
source: str = "",
|
| 202 |
+
default_timestamp: Optional[datetime] = None
|
| 203 |
+
) -> List[GeopoliticalEvent]:
|
| 204 |
+
"""
|
| 205 |
+
Extract events from text.
|
| 206 |
+
|
| 207 |
+
Parameters
|
| 208 |
+
----------
|
| 209 |
+
text : str
|
| 210 |
+
Input text
|
| 211 |
+
source : str
|
| 212 |
+
Source identifier
|
| 213 |
+
default_timestamp : datetime, optional
|
| 214 |
+
Default timestamp if none found
|
| 215 |
+
|
| 216 |
+
Returns
|
| 217 |
+
-------
|
| 218 |
+
list
|
| 219 |
+
List of extracted events
|
| 220 |
+
"""
|
| 221 |
+
events = []
|
| 222 |
+
|
| 223 |
+
# Split into sentences
|
| 224 |
+
sentences = self._split_sentences(text)
|
| 225 |
+
|
| 226 |
+
for i, sentence in enumerate(sentences):
|
| 227 |
+
# Detect event type
|
| 228 |
+
event_type = self._classify_event_type(sentence)
|
| 229 |
+
|
| 230 |
+
if event_type != EventType.OTHER:
|
| 231 |
+
# Extract actors
|
| 232 |
+
actors = self._extract_actors(sentence)
|
| 233 |
+
|
| 234 |
+
# Extract timestamp
|
| 235 |
+
timestamp = self._extract_timestamp(sentence, default_timestamp)
|
| 236 |
+
|
| 237 |
+
# Extract location
|
| 238 |
+
location = self._extract_location(sentence)
|
| 239 |
+
|
| 240 |
+
# Compute magnitude
|
| 241 |
+
magnitude = self._estimate_magnitude(sentence, event_type)
|
| 242 |
+
|
| 243 |
+
# Create event
|
| 244 |
+
event = GeopoliticalEvent(
|
| 245 |
+
event_id=f"{source}_{i}",
|
| 246 |
+
timestamp=timestamp,
|
| 247 |
+
event_type=event_type,
|
| 248 |
+
actors=actors,
|
| 249 |
+
location=location,
|
| 250 |
+
magnitude=magnitude,
|
| 251 |
+
confidence=0.7, # Rule-based confidence
|
| 252 |
+
source=source,
|
| 253 |
+
text=sentence
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
events.append(event)
|
| 257 |
+
|
| 258 |
+
return events
|
| 259 |
+
|
| 260 |
+
def _split_sentences(self, text: str) -> List[str]:
|
| 261 |
+
"""Split text into sentences."""
|
| 262 |
+
# Simple sentence splitting
|
| 263 |
+
sentences = re.split(r'[.!?]+', text)
|
| 264 |
+
return [s.strip() for s in sentences if len(s.strip()) > 20]
|
| 265 |
+
|
| 266 |
+
def _classify_event_type(self, text: str) -> EventType:
|
| 267 |
+
"""Classify event type using patterns."""
|
| 268 |
+
for event_type, patterns in self.event_patterns.items():
|
| 269 |
+
for pattern in patterns:
|
| 270 |
+
if pattern.search(text):
|
| 271 |
+
return event_type
|
| 272 |
+
return EventType.OTHER
|
| 273 |
+
|
| 274 |
+
def _extract_actors(self, text: str) -> List[str]:
|
| 275 |
+
"""Extract actor entities."""
|
| 276 |
+
actors = []
|
| 277 |
+
|
| 278 |
+
# Check for countries
|
| 279 |
+
for country in self.country_names:
|
| 280 |
+
if country.lower() in text.lower():
|
| 281 |
+
actors.append(country)
|
| 282 |
+
|
| 283 |
+
# Check for organizations
|
| 284 |
+
for org in self.organization_names:
|
| 285 |
+
if org.lower() in text.lower():
|
| 286 |
+
actors.append(org)
|
| 287 |
+
|
| 288 |
+
return list(set(actors)) # Remove duplicates
|
| 289 |
+
|
| 290 |
+
def _extract_timestamp(
|
| 291 |
+
self,
|
| 292 |
+
text: str,
|
| 293 |
+
default: Optional[datetime] = None
|
| 294 |
+
) -> datetime:
|
| 295 |
+
"""Extract timestamp from text."""
|
| 296 |
+
# Try to find date patterns
|
| 297 |
+
date_patterns = [
|
| 298 |
+
r'(\d{4})-(\d{2})-(\d{2})', # YYYY-MM-DD
|
| 299 |
+
r'(\d{1,2})/(\d{1,2})/(\d{4})', # MM/DD/YYYY
|
| 300 |
+
r'(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{1,2}),?\s+(\d{4})'
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
for pattern in date_patterns:
|
| 304 |
+
match = re.search(pattern, text)
|
| 305 |
+
if match:
|
| 306 |
+
# Parse date (simplified)
|
| 307 |
+
try:
|
| 308 |
+
date_str = match.group(0)
|
| 309 |
+
# Try multiple formats
|
| 310 |
+
for fmt in ['%Y-%m-%d', '%m/%d/%Y']:
|
| 311 |
+
try:
|
| 312 |
+
return datetime.strptime(date_str, fmt).replace(tzinfo=timezone.utc)
|
| 313 |
+
except:
|
| 314 |
+
continue
|
| 315 |
+
except:
|
| 316 |
+
pass
|
| 317 |
+
|
| 318 |
+
# Default to current time or provided default
|
| 319 |
+
return default or datetime.now(timezone.utc)
|
| 320 |
+
|
| 321 |
+
def _extract_location(self, text: str) -> Optional[str]:
|
| 322 |
+
"""Extract location from text."""
|
| 323 |
+
# Check for country names as locations
|
| 324 |
+
for country in self.country_names:
|
| 325 |
+
if country.lower() in text.lower():
|
| 326 |
+
return country
|
| 327 |
+
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
def _estimate_magnitude(self, text: str, event_type: EventType) -> float:
|
| 331 |
+
"""Estimate event magnitude/severity."""
|
| 332 |
+
# Keywords indicating severity
|
| 333 |
+
high_severity_words = ['major', 'massive', 'large-scale', 'significant', 'devastating']
|
| 334 |
+
low_severity_words = ['minor', 'small', 'limited', 'isolated']
|
| 335 |
+
|
| 336 |
+
text_lower = text.lower()
|
| 337 |
+
|
| 338 |
+
if any(word in text_lower for word in high_severity_words):
|
| 339 |
+
return 0.8
|
| 340 |
+
elif any(word in text_lower for word in low_severity_words):
|
| 341 |
+
return 0.3
|
| 342 |
+
else:
|
| 343 |
+
return 0.5 # Default
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TemporalNormalizer:
|
| 347 |
+
"""
|
| 348 |
+
Normalize timestamps to consistent format (UTC).
|
| 349 |
+
|
| 350 |
+
Handles:
|
| 351 |
+
- Time zone conversion
|
| 352 |
+
- Temporal granularity (day, week, month)
|
| 353 |
+
- Missing timestamps (imputation)
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
@staticmethod
|
| 357 |
+
def normalize_to_utc(dt: datetime) -> datetime:
|
| 358 |
+
"""
|
| 359 |
+
Normalize datetime to UTC.
|
| 360 |
+
|
| 361 |
+
Parameters
|
| 362 |
+
----------
|
| 363 |
+
dt : datetime
|
| 364 |
+
Input datetime
|
| 365 |
+
|
| 366 |
+
Returns
|
| 367 |
+
-------
|
| 368 |
+
datetime
|
| 369 |
+
UTC datetime
|
| 370 |
+
"""
|
| 371 |
+
if dt.tzinfo is None:
|
| 372 |
+
# Assume local time
|
| 373 |
+
return dt.replace(tzinfo=timezone.utc)
|
| 374 |
+
else:
|
| 375 |
+
return dt.astimezone(timezone.utc)
|
| 376 |
+
|
| 377 |
+
@staticmethod
|
| 378 |
+
def round_to_day(dt: datetime) -> datetime:
|
| 379 |
+
"""Round to start of day."""
|
| 380 |
+
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
|
| 381 |
+
|
| 382 |
+
@staticmethod
|
| 383 |
+
def round_to_week(dt: datetime) -> datetime:
|
| 384 |
+
"""Round to start of week (Monday)."""
|
| 385 |
+
day_of_week = dt.weekday()
|
| 386 |
+
days_to_subtract = day_of_week
|
| 387 |
+
week_start = dt - datetime.timedelta(days=days_to_subtract)
|
| 388 |
+
return TemporalNormalizer.round_to_day(week_start)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class CausalFeatureExtractor:
|
| 392 |
+
"""
|
| 393 |
+
Extract causal features from events for modeling.
|
| 394 |
+
|
| 395 |
+
Constructs features suitable for:
|
| 396 |
+
- Causal graph learning
|
| 397 |
+
- Structural equation modeling
|
| 398 |
+
- Time-series forecasting
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
def __init__(self):
|
| 402 |
+
"""Initialize causal feature extractor."""
|
| 403 |
+
pass
|
| 404 |
+
|
| 405 |
+
def extract_features(
|
| 406 |
+
self,
|
| 407 |
+
events: List[GeopoliticalEvent],
|
| 408 |
+
time_window: int = 30
|
| 409 |
+
) -> Dict[str, np.ndarray]:
|
| 410 |
+
"""
|
| 411 |
+
Extract causal features from event sequence.
|
| 412 |
+
|
| 413 |
+
Parameters
|
| 414 |
+
----------
|
| 415 |
+
events : list
|
| 416 |
+
List of events
|
| 417 |
+
time_window : int
|
| 418 |
+
Time window in days
|
| 419 |
+
|
| 420 |
+
Returns
|
| 421 |
+
-------
|
| 422 |
+
dict
|
| 423 |
+
Feature dictionary
|
| 424 |
+
"""
|
| 425 |
+
import numpy as np
|
| 426 |
+
|
| 427 |
+
# Sort events by timestamp
|
| 428 |
+
sorted_events = sorted(events, key=lambda e: e.timestamp)
|
| 429 |
+
|
| 430 |
+
# Count events by type
|
| 431 |
+
event_counts = {}
|
| 432 |
+
for event_type in EventType:
|
| 433 |
+
event_counts[event_type.value] = sum(
|
| 434 |
+
1 for e in sorted_events if e.event_type == event_type
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Actor involvement matrix
|
| 438 |
+
all_actors = list(set(actor for e in sorted_events for actor in e.actors))
|
| 439 |
+
actor_indices = {actor: i for i, actor in enumerate(all_actors)}
|
| 440 |
+
|
| 441 |
+
# Event-actor matrix
|
| 442 |
+
n_events = len(sorted_events)
|
| 443 |
+
n_actors = len(all_actors)
|
| 444 |
+
actor_matrix = np.zeros((n_events, n_actors))
|
| 445 |
+
|
| 446 |
+
for i, event in enumerate(sorted_events):
|
| 447 |
+
for actor in event.actors:
|
| 448 |
+
if actor in actor_indices:
|
| 449 |
+
actor_matrix[i, actor_indices[actor]] = 1
|
| 450 |
+
|
| 451 |
+
# Temporal features
|
| 452 |
+
if sorted_events:
|
| 453 |
+
time_deltas = []
|
| 454 |
+
for i in range(1, len(sorted_events)):
|
| 455 |
+
delta = (sorted_events[i].timestamp - sorted_events[i-1].timestamp).total_seconds() / 86400 # days
|
| 456 |
+
time_deltas.append(delta)
|
| 457 |
+
mean_time_delta = np.mean(time_deltas) if time_deltas else 0
|
| 458 |
+
else:
|
| 459 |
+
mean_time_delta = 0
|
| 460 |
+
|
| 461 |
+
features = {
|
| 462 |
+
'event_counts': np.array([event_counts[et.value] for et in EventType]),
|
| 463 |
+
'actor_matrix': actor_matrix,
|
| 464 |
+
'mean_time_delta': mean_time_delta,
|
| 465 |
+
'total_events': n_events,
|
| 466 |
+
'unique_actors': n_actors
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
return features
|
| 470 |
+
|
| 471 |
+
def construct_panel_data(
|
| 472 |
+
self,
|
| 473 |
+
events: List[GeopoliticalEvent],
|
| 474 |
+
actors: List[str],
|
| 475 |
+
time_granularity: str = 'day'
|
| 476 |
+
) -> Dict[str, Any]:
|
| 477 |
+
"""
|
| 478 |
+
Construct panel data structure from events.
|
| 479 |
+
|
| 480 |
+
Panel data format: (actor, time) -> features
|
| 481 |
+
|
| 482 |
+
Parameters
|
| 483 |
+
----------
|
| 484 |
+
events : list
|
| 485 |
+
List of events
|
| 486 |
+
actors : list
|
| 487 |
+
List of actors
|
| 488 |
+
time_granularity : str
|
| 489 |
+
Time granularity ('day', 'week', 'month')
|
| 490 |
+
|
| 491 |
+
Returns
|
| 492 |
+
-------
|
| 493 |
+
dict
|
| 494 |
+
Panel data structure
|
| 495 |
+
"""
|
| 496 |
+
import pandas as pd
|
| 497 |
+
import numpy as np
|
| 498 |
+
|
| 499 |
+
# Create time index
|
| 500 |
+
if not events:
|
| 501 |
+
return {}
|
| 502 |
+
|
| 503 |
+
sorted_events = sorted(events, key=lambda e: e.timestamp)
|
| 504 |
+
start_time = sorted_events[0].timestamp
|
| 505 |
+
end_time = sorted_events[-1].timestamp
|
| 506 |
+
|
| 507 |
+
# Generate time grid
|
| 508 |
+
if time_granularity == 'day':
|
| 509 |
+
time_index = pd.date_range(start_time, end_time, freq='D')
|
| 510 |
+
elif time_granularity == 'week':
|
| 511 |
+
time_index = pd.date_range(start_time, end_time, freq='W')
|
| 512 |
+
elif time_granularity == 'month':
|
| 513 |
+
time_index = pd.date_range(start_time, end_time, freq='M')
|
| 514 |
+
else:
|
| 515 |
+
raise ValueError(f"Unknown granularity: {time_granularity}")
|
| 516 |
+
|
| 517 |
+
# Initialize panel
|
| 518 |
+
panel = {}
|
| 519 |
+
for actor in actors:
|
| 520 |
+
panel[actor] = pd.DataFrame(index=time_index, columns=['event_count', 'avg_magnitude'])
|
| 521 |
+
panel[actor] = panel[actor].fillna(0)
|
| 522 |
+
|
| 523 |
+
# Fill panel with events
|
| 524 |
+
for event in sorted_events:
|
| 525 |
+
for actor in event.actors:
|
| 526 |
+
if actor in panel:
|
| 527 |
+
# Find closest time point
|
| 528 |
+
event_date = pd.Timestamp(event.timestamp)
|
| 529 |
+
closest_idx = time_index[np.argmin(np.abs(time_index - event_date))]
|
| 530 |
+
|
| 531 |
+
panel[actor].loc[closest_idx, 'event_count'] += 1
|
| 532 |
+
panel[actor].loc[closest_idx, 'avg_magnitude'] += event.magnitude
|
| 533 |
+
|
| 534 |
+
# Normalize magnitudes
|
| 535 |
+
for actor in actors:
|
| 536 |
+
mask = panel[actor]['event_count'] > 0
|
| 537 |
+
panel[actor].loc[mask, 'avg_magnitude'] /= panel[actor].loc[mask, 'event_count']
|
| 538 |
+
|
| 539 |
+
return panel
|
geobot/data_ingestion/pdf_reader.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PDF Reading and Processing Module
|
| 3 |
+
|
| 4 |
+
Comprehensive PDF ingestion capabilities for geopolitical intelligence documents,
|
| 5 |
+
reports, briefings, and analysis.
|
| 6 |
+
|
| 7 |
+
Supports:
|
| 8 |
+
- Text extraction from PDFs
|
| 9 |
+
- Table extraction
|
| 10 |
+
- Metadata extraction
|
| 11 |
+
- Multi-format PDF handling
|
| 12 |
+
- Batch processing
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PDFReader:
|
| 22 |
+
"""
|
| 23 |
+
Read and extract text from PDF documents.
|
| 24 |
+
|
| 25 |
+
Supports multiple PDF libraries for robust extraction.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, method: str = 'auto'):
|
| 29 |
+
"""
|
| 30 |
+
Initialize PDF reader.
|
| 31 |
+
|
| 32 |
+
Parameters
|
| 33 |
+
----------
|
| 34 |
+
method : str
|
| 35 |
+
Extraction method ('pypdf', 'pdfplumber', 'pdfminer', 'auto')
|
| 36 |
+
"""
|
| 37 |
+
self.method = method
|
| 38 |
+
self._check_dependencies()
|
| 39 |
+
|
| 40 |
+
def _check_dependencies(self) -> None:
|
| 41 |
+
"""Check which PDF libraries are available."""
|
| 42 |
+
self.has_pypdf = False
|
| 43 |
+
self.has_pdfplumber = False
|
| 44 |
+
self.has_pdfminer = False
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
import pypdf
|
| 48 |
+
self.has_pypdf = True
|
| 49 |
+
except ImportError:
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
import pdfplumber
|
| 54 |
+
self.has_pdfplumber = True
|
| 55 |
+
except ImportError:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
from pdfminer.high_level import extract_text as pdfminer_extract
|
| 60 |
+
self.has_pdfminer = True
|
| 61 |
+
except ImportError:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
if not any([self.has_pypdf, self.has_pdfplumber, self.has_pdfminer]):
|
| 65 |
+
print("Warning: No PDF libraries available. Please install pypdf, pdfplumber, or pdfminer.six")
|
| 66 |
+
|
| 67 |
+
def read_pdf(self, pdf_path: str) -> Dict[str, Any]:
|
| 68 |
+
"""
|
| 69 |
+
Read PDF and extract all information.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
pdf_path : str
|
| 74 |
+
Path to PDF file
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
dict
|
| 79 |
+
Extracted information including text, metadata, pages
|
| 80 |
+
"""
|
| 81 |
+
if not os.path.exists(pdf_path):
|
| 82 |
+
raise FileNotFoundError(f"PDF not found: {pdf_path}")
|
| 83 |
+
|
| 84 |
+
method = self.method
|
| 85 |
+
if method == 'auto':
|
| 86 |
+
# Choose best available method
|
| 87 |
+
if self.has_pdfplumber:
|
| 88 |
+
method = 'pdfplumber'
|
| 89 |
+
elif self.has_pypdf:
|
| 90 |
+
method = 'pypdf'
|
| 91 |
+
elif self.has_pdfminer:
|
| 92 |
+
method = 'pdfminer'
|
| 93 |
+
else:
|
| 94 |
+
raise ImportError("No PDF library available")
|
| 95 |
+
|
| 96 |
+
if method == 'pypdf':
|
| 97 |
+
return self._read_with_pypdf(pdf_path)
|
| 98 |
+
elif method == 'pdfplumber':
|
| 99 |
+
return self._read_with_pdfplumber(pdf_path)
|
| 100 |
+
elif method == 'pdfminer':
|
| 101 |
+
return self._read_with_pdfminer(pdf_path)
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Unknown method: {method}")
|
| 104 |
+
|
| 105 |
+
def _read_with_pypdf(self, pdf_path: str) -> Dict[str, Any]:
|
| 106 |
+
"""Read PDF using pypdf."""
|
| 107 |
+
import pypdf
|
| 108 |
+
|
| 109 |
+
result = {
|
| 110 |
+
'text': '',
|
| 111 |
+
'pages': [],
|
| 112 |
+
'metadata': {},
|
| 113 |
+
'num_pages': 0
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
with open(pdf_path, 'rb') as file:
|
| 117 |
+
reader = pypdf.PdfReader(file)
|
| 118 |
+
result['num_pages'] = len(reader.pages)
|
| 119 |
+
|
| 120 |
+
# Extract metadata
|
| 121 |
+
if reader.metadata:
|
| 122 |
+
result['metadata'] = {
|
| 123 |
+
'title': reader.metadata.get('/Title', ''),
|
| 124 |
+
'author': reader.metadata.get('/Author', ''),
|
| 125 |
+
'subject': reader.metadata.get('/Subject', ''),
|
| 126 |
+
'creator': reader.metadata.get('/Creator', ''),
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Extract text from each page
|
| 130 |
+
for page_num, page in enumerate(reader.pages):
|
| 131 |
+
page_text = page.extract_text()
|
| 132 |
+
result['pages'].append({
|
| 133 |
+
'page_number': page_num + 1,
|
| 134 |
+
'text': page_text
|
| 135 |
+
})
|
| 136 |
+
result['text'] += page_text + '\n'
|
| 137 |
+
|
| 138 |
+
return result
|
| 139 |
+
|
| 140 |
+
def _read_with_pdfplumber(self, pdf_path: str) -> Dict[str, Any]:
|
| 141 |
+
"""Read PDF using pdfplumber (best for tables)."""
|
| 142 |
+
import pdfplumber
|
| 143 |
+
|
| 144 |
+
result = {
|
| 145 |
+
'text': '',
|
| 146 |
+
'pages': [],
|
| 147 |
+
'tables': [],
|
| 148 |
+
'metadata': {},
|
| 149 |
+
'num_pages': 0
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 153 |
+
result['num_pages'] = len(pdf.pages)
|
| 154 |
+
result['metadata'] = pdf.metadata
|
| 155 |
+
|
| 156 |
+
for page_num, page in enumerate(pdf.pages):
|
| 157 |
+
page_text = page.extract_text()
|
| 158 |
+
page_tables = page.extract_tables()
|
| 159 |
+
|
| 160 |
+
result['pages'].append({
|
| 161 |
+
'page_number': page_num + 1,
|
| 162 |
+
'text': page_text,
|
| 163 |
+
'tables': page_tables
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
result['text'] += page_text + '\n' if page_text else ''
|
| 167 |
+
|
| 168 |
+
if page_tables:
|
| 169 |
+
result['tables'].extend([{
|
| 170 |
+
'page': page_num + 1,
|
| 171 |
+
'data': table
|
| 172 |
+
} for table in page_tables])
|
| 173 |
+
|
| 174 |
+
return result
|
| 175 |
+
|
| 176 |
+
def _read_with_pdfminer(self, pdf_path: str) -> Dict[str, Any]:
|
| 177 |
+
"""Read PDF using pdfminer."""
|
| 178 |
+
from pdfminer.high_level import extract_text, extract_pages
|
| 179 |
+
from pdfminer.layout import LTTextContainer
|
| 180 |
+
|
| 181 |
+
result = {
|
| 182 |
+
'text': '',
|
| 183 |
+
'pages': [],
|
| 184 |
+
'metadata': {},
|
| 185 |
+
'num_pages': 0
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
# Extract all text
|
| 189 |
+
result['text'] = extract_text(pdf_path)
|
| 190 |
+
|
| 191 |
+
# Extract page by page
|
| 192 |
+
pages = list(extract_pages(pdf_path))
|
| 193 |
+
result['num_pages'] = len(pages)
|
| 194 |
+
|
| 195 |
+
for page_num, page_layout in enumerate(pages):
|
| 196 |
+
page_text = ''
|
| 197 |
+
for element in page_layout:
|
| 198 |
+
if isinstance(element, LTTextContainer):
|
| 199 |
+
page_text += element.get_text()
|
| 200 |
+
|
| 201 |
+
result['pages'].append({
|
| 202 |
+
'page_number': page_num + 1,
|
| 203 |
+
'text': page_text
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
def extract_text(self, pdf_path: str) -> str:
|
| 209 |
+
"""
|
| 210 |
+
Extract text from PDF (simple interface).
|
| 211 |
+
|
| 212 |
+
Parameters
|
| 213 |
+
----------
|
| 214 |
+
pdf_path : str
|
| 215 |
+
Path to PDF
|
| 216 |
+
|
| 217 |
+
Returns
|
| 218 |
+
-------
|
| 219 |
+
str
|
| 220 |
+
Extracted text
|
| 221 |
+
"""
|
| 222 |
+
result = self.read_pdf(pdf_path)
|
| 223 |
+
return result['text']
|
| 224 |
+
|
| 225 |
+
def extract_tables(self, pdf_path: str) -> List[List[List[str]]]:
|
| 226 |
+
"""
|
| 227 |
+
Extract tables from PDF.
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
pdf_path : str
|
| 232 |
+
Path to PDF
|
| 233 |
+
|
| 234 |
+
Returns
|
| 235 |
+
-------
|
| 236 |
+
list
|
| 237 |
+
List of tables
|
| 238 |
+
"""
|
| 239 |
+
if not self.has_pdfplumber:
|
| 240 |
+
print("Warning: pdfplumber required for table extraction")
|
| 241 |
+
return []
|
| 242 |
+
|
| 243 |
+
result = self._read_with_pdfplumber(pdf_path)
|
| 244 |
+
return [table['data'] for table in result.get('tables', [])]
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class PDFProcessor:
|
| 248 |
+
"""
|
| 249 |
+
Process and analyze PDF documents for geopolitical intelligence.
|
| 250 |
+
|
| 251 |
+
Provides high-level processing capabilities including:
|
| 252 |
+
- Entity extraction
|
| 253 |
+
- Topic extraction
|
| 254 |
+
- Sentiment analysis
|
| 255 |
+
- Key phrase extraction
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, pdf_reader: Optional[PDFReader] = None):
|
| 259 |
+
"""
|
| 260 |
+
Initialize PDF processor.
|
| 261 |
+
|
| 262 |
+
Parameters
|
| 263 |
+
----------
|
| 264 |
+
pdf_reader : PDFReader, optional
|
| 265 |
+
PDF reader to use
|
| 266 |
+
"""
|
| 267 |
+
self.reader = pdf_reader or PDFReader()
|
| 268 |
+
|
| 269 |
+
def process_document(self, pdf_path: str) -> Dict[str, Any]:
|
| 270 |
+
"""
|
| 271 |
+
Process PDF document and extract intelligence.
|
| 272 |
+
|
| 273 |
+
Parameters
|
| 274 |
+
----------
|
| 275 |
+
pdf_path : str
|
| 276 |
+
Path to PDF
|
| 277 |
+
|
| 278 |
+
Returns
|
| 279 |
+
-------
|
| 280 |
+
dict
|
| 281 |
+
Processed document with analysis
|
| 282 |
+
"""
|
| 283 |
+
# Extract content
|
| 284 |
+
content = self.reader.read_pdf(pdf_path)
|
| 285 |
+
|
| 286 |
+
# Basic processing
|
| 287 |
+
processed = {
|
| 288 |
+
'file_path': pdf_path,
|
| 289 |
+
'file_name': Path(pdf_path).name,
|
| 290 |
+
'text': content['text'],
|
| 291 |
+
'num_pages': content['num_pages'],
|
| 292 |
+
'metadata': content.get('metadata', {}),
|
| 293 |
+
'word_count': len(content['text'].split()),
|
| 294 |
+
'char_count': len(content['text']),
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
# Extract key information
|
| 298 |
+
processed['entities'] = self._extract_entities(content['text'])
|
| 299 |
+
processed['keywords'] = self._extract_keywords(content['text'])
|
| 300 |
+
processed['summary'] = self._generate_summary(content['text'])
|
| 301 |
+
|
| 302 |
+
return processed
|
| 303 |
+
|
| 304 |
+
def _extract_entities(self, text: str) -> Dict[str, List[str]]:
|
| 305 |
+
"""
|
| 306 |
+
Extract named entities (countries, organizations, people).
|
| 307 |
+
|
| 308 |
+
Parameters
|
| 309 |
+
----------
|
| 310 |
+
text : str
|
| 311 |
+
Text to analyze
|
| 312 |
+
|
| 313 |
+
Returns
|
| 314 |
+
-------
|
| 315 |
+
dict
|
| 316 |
+
Extracted entities by type
|
| 317 |
+
"""
|
| 318 |
+
entities = {
|
| 319 |
+
'countries': [],
|
| 320 |
+
'organizations': [],
|
| 321 |
+
'people': [],
|
| 322 |
+
'locations': []
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
# Simple pattern-based extraction (can be enhanced with NER)
|
| 326 |
+
# Common country names
|
| 327 |
+
countries = ['United States', 'China', 'Russia', 'Iran', 'North Korea',
|
| 328 |
+
'India', 'Pakistan', 'Israel', 'Saudi Arabia', 'Turkey',
|
| 329 |
+
'France', 'Germany', 'United Kingdom', 'Japan', 'South Korea']
|
| 330 |
+
|
| 331 |
+
for country in countries:
|
| 332 |
+
if country in text:
|
| 333 |
+
entities['countries'].append(country)
|
| 334 |
+
|
| 335 |
+
# Organizations (simple patterns)
|
| 336 |
+
org_patterns = [r'\b([A-Z][A-Za-z]+(?:\s+[A-Z][A-Za-z]+)*)\s+(?:Organization|Agency|Ministry|Department|Council)\b']
|
| 337 |
+
for pattern in org_patterns:
|
| 338 |
+
matches = re.findall(pattern, text)
|
| 339 |
+
entities['organizations'].extend(matches)
|
| 340 |
+
|
| 341 |
+
return entities
|
| 342 |
+
|
| 343 |
+
def _extract_keywords(self, text: str, n_keywords: int = 10) -> List[Tuple[str, float]]:
|
| 344 |
+
"""
|
| 345 |
+
Extract keywords from text.
|
| 346 |
+
|
| 347 |
+
Parameters
|
| 348 |
+
----------
|
| 349 |
+
text : str
|
| 350 |
+
Text to analyze
|
| 351 |
+
n_keywords : int
|
| 352 |
+
Number of keywords to extract
|
| 353 |
+
|
| 354 |
+
Returns
|
| 355 |
+
-------
|
| 356 |
+
list
|
| 357 |
+
List of (keyword, score) tuples
|
| 358 |
+
"""
|
| 359 |
+
# Simple frequency-based extraction
|
| 360 |
+
words = text.lower().split()
|
| 361 |
+
|
| 362 |
+
# Remove common words
|
| 363 |
+
stopwords = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at',
|
| 364 |
+
'to', 'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was',
|
| 365 |
+
'are', 'were', 'been', 'be', 'have', 'has', 'had', 'do',
|
| 366 |
+
'does', 'did', 'will', 'would', 'should', 'could', 'may',
|
| 367 |
+
'might', 'can', 'this', 'that', 'these', 'those'}
|
| 368 |
+
|
| 369 |
+
words = [w for w in words if w not in stopwords and len(w) > 3]
|
| 370 |
+
|
| 371 |
+
# Count frequencies
|
| 372 |
+
from collections import Counter
|
| 373 |
+
word_freq = Counter(words)
|
| 374 |
+
|
| 375 |
+
# Return top keywords
|
| 376 |
+
return word_freq.most_common(n_keywords)
|
| 377 |
+
|
| 378 |
+
def _generate_summary(self, text: str, num_sentences: int = 3) -> str:
|
| 379 |
+
"""
|
| 380 |
+
Generate simple extractive summary.
|
| 381 |
+
|
| 382 |
+
Parameters
|
| 383 |
+
----------
|
| 384 |
+
text : str
|
| 385 |
+
Text to summarize
|
| 386 |
+
num_sentences : int
|
| 387 |
+
Number of sentences in summary
|
| 388 |
+
|
| 389 |
+
Returns
|
| 390 |
+
-------
|
| 391 |
+
str
|
| 392 |
+
Summary
|
| 393 |
+
"""
|
| 394 |
+
# Split into sentences
|
| 395 |
+
sentences = re.split(r'[.!?]+', text)
|
| 396 |
+
sentences = [s.strip() for s in sentences if len(s.strip()) > 20]
|
| 397 |
+
|
| 398 |
+
# Take first few sentences as summary (simple approach)
|
| 399 |
+
summary_sentences = sentences[:num_sentences]
|
| 400 |
+
|
| 401 |
+
return '. '.join(summary_sentences) + '.'
|
| 402 |
+
|
| 403 |
+
def batch_process(self, pdf_directory: str, pattern: str = '*.pdf') -> List[Dict[str, Any]]:
|
| 404 |
+
"""
|
| 405 |
+
Process multiple PDFs in a directory.
|
| 406 |
+
|
| 407 |
+
Parameters
|
| 408 |
+
----------
|
| 409 |
+
pdf_directory : str
|
| 410 |
+
Directory containing PDFs
|
| 411 |
+
pattern : str
|
| 412 |
+
File pattern to match
|
| 413 |
+
|
| 414 |
+
Returns
|
| 415 |
+
-------
|
| 416 |
+
list
|
| 417 |
+
List of processed documents
|
| 418 |
+
"""
|
| 419 |
+
pdf_dir = Path(pdf_directory)
|
| 420 |
+
pdf_files = list(pdf_dir.glob(pattern))
|
| 421 |
+
|
| 422 |
+
results = []
|
| 423 |
+
for pdf_file in pdf_files:
|
| 424 |
+
try:
|
| 425 |
+
processed = self.process_document(str(pdf_file))
|
| 426 |
+
results.append(processed)
|
| 427 |
+
except Exception as e:
|
| 428 |
+
print(f"Error processing {pdf_file}: {e}")
|
| 429 |
+
|
| 430 |
+
return results
|
| 431 |
+
|
| 432 |
+
def extract_intelligence(self, pdf_path: str) -> Dict[str, Any]:
|
| 433 |
+
"""
|
| 434 |
+
Extract geopolitical intelligence from PDF.
|
| 435 |
+
|
| 436 |
+
Parameters
|
| 437 |
+
----------
|
| 438 |
+
pdf_path : str
|
| 439 |
+
Path to PDF
|
| 440 |
+
|
| 441 |
+
Returns
|
| 442 |
+
-------
|
| 443 |
+
dict
|
| 444 |
+
Intelligence summary
|
| 445 |
+
"""
|
| 446 |
+
processed = self.process_document(pdf_path)
|
| 447 |
+
|
| 448 |
+
# Analyze for geopolitical indicators
|
| 449 |
+
text = processed['text'].lower()
|
| 450 |
+
|
| 451 |
+
indicators = {
|
| 452 |
+
'conflict_indicators': self._detect_conflict_indicators(text),
|
| 453 |
+
'risk_level': self._assess_risk_level(text),
|
| 454 |
+
'mentioned_countries': processed['entities'].get('countries', []),
|
| 455 |
+
'key_topics': [kw[0] for kw in processed['keywords'][:5]],
|
| 456 |
+
'document_type': self._classify_document_type(text)
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
return {**processed, 'intelligence': indicators}
|
| 460 |
+
|
| 461 |
+
def _detect_conflict_indicators(self, text: str) -> List[str]:
|
| 462 |
+
"""Detect conflict-related keywords."""
|
| 463 |
+
conflict_keywords = ['war', 'conflict', 'military', 'attack', 'invasion',
|
| 464 |
+
'sanctions', 'escalation', 'tension', 'threat', 'crisis']
|
| 465 |
+
|
| 466 |
+
detected = [kw for kw in conflict_keywords if kw in text]
|
| 467 |
+
return detected
|
| 468 |
+
|
| 469 |
+
def _assess_risk_level(self, text: str) -> str:
|
| 470 |
+
"""Simple risk level assessment."""
|
| 471 |
+
high_risk_terms = ['imminent', 'urgent', 'critical', 'severe', 'escalating']
|
| 472 |
+
medium_risk_terms = ['concern', 'monitoring', 'potential', 'emerging']
|
| 473 |
+
|
| 474 |
+
high_count = sum(1 for term in high_risk_terms if term in text)
|
| 475 |
+
medium_count = sum(1 for term in medium_risk_terms if term in text)
|
| 476 |
+
|
| 477 |
+
if high_count > 2:
|
| 478 |
+
return 'HIGH'
|
| 479 |
+
elif medium_count > 2:
|
| 480 |
+
return 'MEDIUM'
|
| 481 |
+
else:
|
| 482 |
+
return 'LOW'
|
| 483 |
+
|
| 484 |
+
def _classify_document_type(self, text: str) -> str:
|
| 485 |
+
"""Classify document type."""
|
| 486 |
+
if 'intelligence report' in text or 'classified' in text:
|
| 487 |
+
return 'Intelligence Report'
|
| 488 |
+
elif 'analysis' in text or 'assessment' in text:
|
| 489 |
+
return 'Analysis'
|
| 490 |
+
elif 'briefing' in text:
|
| 491 |
+
return 'Briefing'
|
| 492 |
+
else:
|
| 493 |
+
return 'General Document'
|
geobot/data_ingestion/web_scraper.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web Scraping and Article Extraction Module
|
| 3 |
+
|
| 4 |
+
Comprehensive web scraping capabilities for:
|
| 5 |
+
- News articles
|
| 6 |
+
- Analysis pieces
|
| 7 |
+
- Intelligence reports
|
| 8 |
+
- Research papers
|
| 9 |
+
- Real-time news feeds
|
| 10 |
+
|
| 11 |
+
Supports multiple extraction methods for robustness.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
import requests
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from urllib.parse import urlparse
|
| 18 |
+
from typing import (
|
| 19 |
+
List,
|
| 20 |
+
Dict,
|
| 21 |
+
Any,
|
| 22 |
+
Tuple,
|
| 23 |
+
Optional,
|
| 24 |
+
Callable,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# -----------------------------------------------------
|
| 28 |
+
# WEB SCRAPER
|
| 29 |
+
# -----------------------------------------------------
|
| 30 |
+
|
| 31 |
+
class WebScraper:
|
| 32 |
+
"""
|
| 33 |
+
General-purpose web scraper for geopolitical content.
|
| 34 |
+
Handles various website structures and content types.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, user_agent: Optional[str] = None):
|
| 38 |
+
self.user_agent = user_agent or "GeoBotv1/1.0 (Geopolitical Analysis)"
|
| 39 |
+
self.session = requests.Session()
|
| 40 |
+
self.session.headers.update({"User-Agent": self.user_agent})
|
| 41 |
+
|
| 42 |
+
def fetch_url(self, url: str, timeout: int = 30) -> Dict[str, Any]:
|
| 43 |
+
"""Fetch raw HTML from a URL."""
|
| 44 |
+
try:
|
| 45 |
+
response = self.session.get(url, timeout=timeout)
|
| 46 |
+
response.raise_for_status()
|
| 47 |
+
return {
|
| 48 |
+
"url": url,
|
| 49 |
+
"status_code": response.status_code,
|
| 50 |
+
"content": response.text,
|
| 51 |
+
"headers": dict(response.headers),
|
| 52 |
+
"encoding": response.encoding,
|
| 53 |
+
"timestamp": datetime.now().isoformat(),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
except requests.RequestException as e:
|
| 57 |
+
return {
|
| 58 |
+
"url": url,
|
| 59 |
+
"error": str(e),
|
| 60 |
+
"status_code": None,
|
| 61 |
+
"content": None,
|
| 62 |
+
"timestamp": datetime.now().isoformat(),
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def parse_html(self, html_content: str) -> Dict[str, Any]:
|
| 66 |
+
"""Parse HTML using BeautifulSoup if available."""
|
| 67 |
+
try:
|
| 68 |
+
from bs4 import BeautifulSoup
|
| 69 |
+
|
| 70 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
| 71 |
+
|
| 72 |
+
parsed = {
|
| 73 |
+
"title": soup.title.string if soup.title else "",
|
| 74 |
+
"text": soup.get_text(),
|
| 75 |
+
"links": [a.get("href") for a in soup.find_all("a", href=True)],
|
| 76 |
+
"images": [img.get("src") for img in soup.find_all("img", src=True)],
|
| 77 |
+
"meta": {},
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
for meta in soup.find_all("meta"):
|
| 81 |
+
name = meta.get("name") or meta.get("property")
|
| 82 |
+
content = meta.get("content")
|
| 83 |
+
if name and content:
|
| 84 |
+
parsed["meta"][name] = content
|
| 85 |
+
|
| 86 |
+
return parsed
|
| 87 |
+
|
| 88 |
+
except ImportError:
|
| 89 |
+
# Fallback if soup is missing
|
| 90 |
+
return {
|
| 91 |
+
"title": "",
|
| 92 |
+
"text": self._simple_html_strip(html_content),
|
| 93 |
+
"links": [],
|
| 94 |
+
"images": [],
|
| 95 |
+
"meta": {},
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def _simple_html_strip(self, html: str) -> str:
|
| 99 |
+
"""Simple fallback for removing HTML tags."""
|
| 100 |
+
return re.sub(r"<[^>]+>", "", html)
|
| 101 |
+
|
| 102 |
+
def scrape_url(self, url: str) -> Dict[str, Any]:
|
| 103 |
+
"""Fetch + parse a URL."""
|
| 104 |
+
response = self.fetch_url(url)
|
| 105 |
+
if response.get("error"):
|
| 106 |
+
return response
|
| 107 |
+
|
| 108 |
+
parsed = self.parse_html(response["content"])
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"url": url,
|
| 112 |
+
"domain": urlparse(url).netloc,
|
| 113 |
+
"title": parsed["title"],
|
| 114 |
+
"text": parsed["text"],
|
| 115 |
+
"meta": parsed["meta"],
|
| 116 |
+
"links": parsed["links"],
|
| 117 |
+
"images": parsed["images"],
|
| 118 |
+
"timestamp": response["timestamp"],
|
| 119 |
+
"status_code": response["status_code"],
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# -----------------------------------------------------
|
| 124 |
+
# ARTICLE EXTRACTION
|
| 125 |
+
# -----------------------------------------------------
|
| 126 |
+
|
| 127 |
+
class ArticleExtractor:
|
| 128 |
+
"""
|
| 129 |
+
Wrapper for newspaper3k / trafilatura / fallback extraction.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, method: str = "auto"):
|
| 133 |
+
self.method = method
|
| 134 |
+
self._check_dependencies()
|
| 135 |
+
|
| 136 |
+
def _check_dependencies(self) -> None:
|
| 137 |
+
self.has_newspaper = False
|
| 138 |
+
self.has_trafilatura = False
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
import newspaper # noqa
|
| 142 |
+
self.has_newspaper = True
|
| 143 |
+
except ImportError:
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
import trafilatura # noqa
|
| 148 |
+
self.has_trafilatura = True
|
| 149 |
+
except ImportError:
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
def extract_article(self, url: str) -> Dict[str, Any]:
|
| 153 |
+
"""Choose extraction method based on dependencies."""
|
| 154 |
+
method = self.method
|
| 155 |
+
|
| 156 |
+
if method == "auto":
|
| 157 |
+
if self.has_newspaper:
|
| 158 |
+
method = "newspaper"
|
| 159 |
+
elif self.has_trafilatura:
|
| 160 |
+
method = "trafilatura"
|
| 161 |
+
else:
|
| 162 |
+
method = "basic"
|
| 163 |
+
|
| 164 |
+
if method == "newspaper":
|
| 165 |
+
return self._extract_with_newspaper(url)
|
| 166 |
+
elif method == "trafilatura":
|
| 167 |
+
return self._extract_with_trafilatura(url)
|
| 168 |
+
else:
|
| 169 |
+
return self._extract_basic(url)
|
| 170 |
+
|
| 171 |
+
def _extract_with_newspaper(self, url: str) -> Dict[str, Any]:
|
| 172 |
+
"""Extract article using newspaper3k."""
|
| 173 |
+
try:
|
| 174 |
+
from newspaper import Article
|
| 175 |
+
|
| 176 |
+
article = Article(url)
|
| 177 |
+
article.download()
|
| 178 |
+
article.parse()
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
article.nlp()
|
| 182 |
+
keywords = article.keywords
|
| 183 |
+
summary = article.summary
|
| 184 |
+
except Exception:
|
| 185 |
+
keywords = []
|
| 186 |
+
summary = ""
|
| 187 |
+
|
| 188 |
+
return {
|
| 189 |
+
"url": url,
|
| 190 |
+
"title": article.title,
|
| 191 |
+
"text": article.text,
|
| 192 |
+
"authors": article.authors,
|
| 193 |
+
"publish_date": article.publish_date.isoformat() if article.publish_date else None,
|
| 194 |
+
"keywords": keywords,
|
| 195 |
+
"summary": summary,
|
| 196 |
+
"top_image": article.top_image,
|
| 197 |
+
"images": list(article.images),
|
| 198 |
+
"method": "newspaper",
|
| 199 |
+
"timestamp": datetime.now().isoformat(),
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
except Exception as e:
|
| 203 |
+
return {"url": url, "error": str(e), "method": "newspaper"}
|
| 204 |
+
|
| 205 |
+
def _extract_with_trafilatura(self, url: str) -> Dict[str, Any]:
|
| 206 |
+
"""Extract article using trafilatura."""
|
| 207 |
+
try:
|
| 208 |
+
import trafilatura
|
| 209 |
+
|
| 210 |
+
downloaded = trafilatura.fetch_url(url)
|
| 211 |
+
text = trafilatura.extract(downloaded)
|
| 212 |
+
metadata = trafilatura.extract_metadata(downloaded)
|
| 213 |
+
|
| 214 |
+
return {
|
| 215 |
+
"url": url,
|
| 216 |
+
"title": metadata.title if metadata else "",
|
| 217 |
+
"text": text or "",
|
| 218 |
+
"authors": [metadata.author] if metadata and metadata.author else [],
|
| 219 |
+
"publish_date": metadata.date if metadata else None,
|
| 220 |
+
"description": metadata.description if metadata else "",
|
| 221 |
+
"method": "trafilatura",
|
| 222 |
+
"timestamp": datetime.now().isoformat(),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
return {"url": url, "error": str(e), "method": "trafilatura"}
|
| 227 |
+
|
| 228 |
+
def _extract_basic(self, url: str) -> Dict[str, Any]:
|
| 229 |
+
scraper = WebScraper()
|
| 230 |
+
content = scraper.scrape_url(url)
|
| 231 |
+
|
| 232 |
+
return {
|
| 233 |
+
"url": url,
|
| 234 |
+
"title": content.get("title", ""),
|
| 235 |
+
"text": content.get("text", ""),
|
| 236 |
+
"meta": content.get("meta", {}),
|
| 237 |
+
"method": "basic",
|
| 238 |
+
"timestamp": datetime.now().isoformat(),
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
def batch_extract(self, urls: List[str]) -> List[Dict[str, Any]]:
|
| 242 |
+
articles = []
|
| 243 |
+
for url in urls:
|
| 244 |
+
try:
|
| 245 |
+
articles.append(self.extract_article(url))
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"Error extracting {url}: {e}")
|
| 248 |
+
return articles
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# -----------------------------------------------------
|
| 252 |
+
# NEWS AGGREGATOR
|
| 253 |
+
# -----------------------------------------------------
|
| 254 |
+
|
| 255 |
+
class NewsAggregator:
|
| 256 |
+
"""Aggregate RSS feeds + websites into normalized article objects."""
|
| 257 |
+
|
| 258 |
+
def __init__(self):
|
| 259 |
+
self.extractor = ArticleExtractor()
|
| 260 |
+
self.sources: List[Dict[str, Any]] = []
|
| 261 |
+
|
| 262 |
+
def add_source(self, name: str, url: str, source_type: str = "rss") -> None:
|
| 263 |
+
self.sources.append({"name": name, "url": url, "type": source_type})
|
| 264 |
+
|
| 265 |
+
def fetch_news(self, keywords: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
| 266 |
+
articles = []
|
| 267 |
+
|
| 268 |
+
for source in self.sources:
|
| 269 |
+
try:
|
| 270 |
+
if source["type"] == "rss":
|
| 271 |
+
pulled = self._fetch_rss(source["url"])
|
| 272 |
+
else:
|
| 273 |
+
pulled = self._fetch_website(source["url"])
|
| 274 |
+
|
| 275 |
+
for a in pulled:
|
| 276 |
+
a["source"] = source["name"]
|
| 277 |
+
if keywords:
|
| 278 |
+
txt = a.get("text", "").lower()
|
| 279 |
+
if any(kw.lower() in txt for kw in keywords):
|
| 280 |
+
articles.append(a)
|
| 281 |
+
else:
|
| 282 |
+
articles.append(a)
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
print(f"Error fetching from {source['name']}: {e}")
|
| 286 |
+
|
| 287 |
+
return articles
|
| 288 |
+
|
| 289 |
+
def _fetch_rss(self, rss_url: str) -> List[Dict[str, Any]]:
|
| 290 |
+
try:
|
| 291 |
+
import feedparser
|
| 292 |
+
|
| 293 |
+
feed = feedparser.parse(rss_url)
|
| 294 |
+
articles = []
|
| 295 |
+
|
| 296 |
+
for entry in feed.entries:
|
| 297 |
+
base = {
|
| 298 |
+
"title": entry.get("title", ""),
|
| 299 |
+
"url": entry.get("link", ""),
|
| 300 |
+
"summary": entry.get("summary", ""),
|
| 301 |
+
"publish_date": entry.get("published", ""),
|
| 302 |
+
"authors": [a.get("name") for a in entry.get("authors", [])],
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
if base["url"]:
|
| 306 |
+
try:
|
| 307 |
+
full = self.extractor.extract_article(base["url"])
|
| 308 |
+
base["text"] = full.get("text", base["summary"])
|
| 309 |
+
except Exception:
|
| 310 |
+
base["text"] = base["summary"]
|
| 311 |
+
|
| 312 |
+
articles.append(base)
|
| 313 |
+
|
| 314 |
+
return articles
|
| 315 |
+
|
| 316 |
+
except ImportError:
|
| 317 |
+
print("feedparser not installed: pip install feedparser")
|
| 318 |
+
return []
|
| 319 |
+
|
| 320 |
+
def _fetch_website(self, url: str) -> List[Dict[str, Any]]:
|
| 321 |
+
article = self.extractor.extract_article(url)
|
| 322 |
+
return [article] if not article.get("error") else []
|
| 323 |
+
|
| 324 |
+
def monitor_sources(
|
| 325 |
+
self,
|
| 326 |
+
keywords: List[str],
|
| 327 |
+
callback: Optional[Callable[[List[Dict[str, Any]]], None]] = None,
|
| 328 |
+
interval: int = 3600,
|
| 329 |
+
) -> None:
|
| 330 |
+
"""Continuously monitor sources for new articles."""
|
| 331 |
+
import time
|
| 332 |
+
|
| 333 |
+
seen: set = set()
|
| 334 |
+
|
| 335 |
+
while True:
|
| 336 |
+
articles = self.fetch_news(keywords)
|
| 337 |
+
new_articles = [a for a in articles if a["url"] not in seen]
|
| 338 |
+
|
| 339 |
+
if new_articles and callback:
|
| 340 |
+
callback(new_articles)
|
| 341 |
+
|
| 342 |
+
seen.update(a["url"] for a in new_articles)
|
| 343 |
+
|
| 344 |
+
time.sleep(interval)
|
| 345 |
+
|
| 346 |
+
def get_trending_topics(
|
| 347 |
+
self,
|
| 348 |
+
articles: List[Dict[str, Any]],
|
| 349 |
+
n_topics: int = 10
|
| 350 |
+
) -> List[Tuple[str, int]]:
|
| 351 |
+
"""Compute most common keywords."""
|
| 352 |
+
from collections import Counter
|
| 353 |
+
|
| 354 |
+
words = []
|
| 355 |
+
stop = {
|
| 356 |
+
"the", "a", "an", "and", "or", "but", "in", "on", "at",
|
| 357 |
+
"to", "for", "of", "with", "by", "from"
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
for art in articles:
|
| 361 |
+
text = (art.get("text", "") + " " + art.get("title", "")).lower()
|
| 362 |
+
ws = [w for w in text.split() if len(w) > 3 and w not in stop]
|
| 363 |
+
words.extend(ws)
|
| 364 |
+
|
| 365 |
+
return Counter(words).most_common(n_topics)
|
geobot/inference/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference engines for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .do_calculus import DoCalculus, InterventionSimulator
|
| 6 |
+
from .bayesian_engine import BayesianEngine, BeliefUpdater
|
| 7 |
+
from .particle_filter import SequentialMonteCarlo, AuxiliaryParticleFilter, RaoBlackwellizedParticleFilter
|
| 8 |
+
from .variational_inference import VariationalInference, MeanFieldVI, ADVI
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"DoCalculus",
|
| 12 |
+
"InterventionSimulator",
|
| 13 |
+
"BayesianEngine",
|
| 14 |
+
"BeliefUpdater",
|
| 15 |
+
"SequentialMonteCarlo",
|
| 16 |
+
"AuxiliaryParticleFilter",
|
| 17 |
+
"RaoBlackwellizedParticleFilter",
|
| 18 |
+
"VariationalInference",
|
| 19 |
+
"MeanFieldVI",
|
| 20 |
+
"ADVI",
|
| 21 |
+
]
|
geobot/inference/bayesian_engine.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bayesian Inference Engine
|
| 3 |
+
|
| 4 |
+
Provides principled way to update beliefs as new intelligence, rumors,
|
| 5 |
+
events, or data arrive.
|
| 6 |
+
|
| 7 |
+
Components:
|
| 8 |
+
- Priors: Baseline beliefs
|
| 9 |
+
- Likelihood: Evidence
|
| 10 |
+
- Posteriors: Updated beliefs
|
| 11 |
+
|
| 12 |
+
Necessary for:
|
| 13 |
+
- Real-time updates
|
| 14 |
+
- Intelligence feeds
|
| 15 |
+
- Event-driven recalibration
|
| 16 |
+
- Uncertainty tracking
|
| 17 |
+
|
| 18 |
+
Monte Carlo + Bayesian updating = elite forecasting
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
from typing import Dict, List, Optional, Callable, Any, Tuple
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from scipy import stats
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class Prior:
|
| 30 |
+
"""
|
| 31 |
+
Represents a prior distribution.
|
| 32 |
+
|
| 33 |
+
Attributes
|
| 34 |
+
----------
|
| 35 |
+
name : str
|
| 36 |
+
Name of the variable
|
| 37 |
+
distribution : Any
|
| 38 |
+
Prior distribution (scipy.stats distribution)
|
| 39 |
+
parameters : dict
|
| 40 |
+
Distribution parameters
|
| 41 |
+
"""
|
| 42 |
+
name: str
|
| 43 |
+
distribution: Any
|
| 44 |
+
parameters: Dict[str, float]
|
| 45 |
+
|
| 46 |
+
def sample(self, n_samples: int = 1) -> np.ndarray:
|
| 47 |
+
"""
|
| 48 |
+
Sample from prior.
|
| 49 |
+
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
n_samples : int
|
| 53 |
+
Number of samples
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
np.ndarray
|
| 58 |
+
Samples from prior
|
| 59 |
+
"""
|
| 60 |
+
return self.distribution.rvs(size=n_samples, **self.parameters)
|
| 61 |
+
|
| 62 |
+
def pdf(self, x: np.ndarray) -> np.ndarray:
|
| 63 |
+
"""
|
| 64 |
+
Evaluate prior probability density.
|
| 65 |
+
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
x : np.ndarray
|
| 69 |
+
Points to evaluate
|
| 70 |
+
|
| 71 |
+
Returns
|
| 72 |
+
-------
|
| 73 |
+
np.ndarray
|
| 74 |
+
Probability densities
|
| 75 |
+
"""
|
| 76 |
+
return self.distribution.pdf(x, **self.parameters)
|
| 77 |
+
|
| 78 |
+
def log_pdf(self, x: np.ndarray) -> np.ndarray:
|
| 79 |
+
"""
|
| 80 |
+
Evaluate log prior probability density.
|
| 81 |
+
|
| 82 |
+
Parameters
|
| 83 |
+
----------
|
| 84 |
+
x : np.ndarray
|
| 85 |
+
Points to evaluate
|
| 86 |
+
|
| 87 |
+
Returns
|
| 88 |
+
-------
|
| 89 |
+
np.ndarray
|
| 90 |
+
Log probability densities
|
| 91 |
+
"""
|
| 92 |
+
return self.distribution.logpdf(x, **self.parameters)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class Evidence:
|
| 97 |
+
"""
|
| 98 |
+
Represents evidence/observation.
|
| 99 |
+
|
| 100 |
+
Attributes
|
| 101 |
+
----------
|
| 102 |
+
observation : Any
|
| 103 |
+
Observed data
|
| 104 |
+
likelihood_fn : Callable
|
| 105 |
+
Likelihood function
|
| 106 |
+
timestamp : float
|
| 107 |
+
Time of observation
|
| 108 |
+
confidence : float
|
| 109 |
+
Confidence in observation (0-1)
|
| 110 |
+
"""
|
| 111 |
+
observation: Any
|
| 112 |
+
likelihood_fn: Callable
|
| 113 |
+
timestamp: float
|
| 114 |
+
confidence: float = 1.0
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class BayesianEngine:
|
| 118 |
+
"""
|
| 119 |
+
Bayesian inference engine for belief updating.
|
| 120 |
+
|
| 121 |
+
This engine maintains and updates probability distributions
|
| 122 |
+
as new evidence arrives, enabling real-time forecasting with
|
| 123 |
+
uncertainty quantification.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self):
|
| 127 |
+
"""Initialize Bayesian engine."""
|
| 128 |
+
self.priors: Dict[str, Prior] = {}
|
| 129 |
+
self.posteriors: Dict[str, np.ndarray] = {}
|
| 130 |
+
self.evidence_history: List[Evidence] = []
|
| 131 |
+
|
| 132 |
+
def set_prior(self, prior: Prior) -> None:
|
| 133 |
+
"""
|
| 134 |
+
Set prior distribution for a variable.
|
| 135 |
+
|
| 136 |
+
Parameters
|
| 137 |
+
----------
|
| 138 |
+
prior : Prior
|
| 139 |
+
Prior distribution
|
| 140 |
+
"""
|
| 141 |
+
self.priors[prior.name] = prior
|
| 142 |
+
|
| 143 |
+
def update(
|
| 144 |
+
self,
|
| 145 |
+
variable: str,
|
| 146 |
+
evidence: Evidence,
|
| 147 |
+
method: str = 'grid'
|
| 148 |
+
) -> np.ndarray:
|
| 149 |
+
"""
|
| 150 |
+
Update beliefs given evidence.
|
| 151 |
+
|
| 152 |
+
Parameters
|
| 153 |
+
----------
|
| 154 |
+
variable : str
|
| 155 |
+
Variable to update
|
| 156 |
+
evidence : Evidence
|
| 157 |
+
New evidence
|
| 158 |
+
method : str
|
| 159 |
+
Update method ('grid', 'mcmc', 'analytical')
|
| 160 |
+
|
| 161 |
+
Returns
|
| 162 |
+
-------
|
| 163 |
+
np.ndarray
|
| 164 |
+
Posterior samples
|
| 165 |
+
"""
|
| 166 |
+
if variable not in self.priors:
|
| 167 |
+
raise ValueError(f"No prior set for {variable}")
|
| 168 |
+
|
| 169 |
+
self.evidence_history.append(evidence)
|
| 170 |
+
|
| 171 |
+
if method == 'grid':
|
| 172 |
+
posterior = self._grid_update(variable, evidence)
|
| 173 |
+
elif method == 'mcmc':
|
| 174 |
+
posterior = self._mcmc_update(variable, evidence)
|
| 175 |
+
elif method == 'analytical':
|
| 176 |
+
posterior = self._analytical_update(variable, evidence)
|
| 177 |
+
else:
|
| 178 |
+
raise ValueError(f"Unknown method: {method}")
|
| 179 |
+
|
| 180 |
+
self.posteriors[variable] = posterior
|
| 181 |
+
return posterior
|
| 182 |
+
|
| 183 |
+
def _grid_update(self, variable: str, evidence: Evidence) -> np.ndarray:
|
| 184 |
+
"""
|
| 185 |
+
Grid approximation for Bayesian update.
|
| 186 |
+
|
| 187 |
+
Parameters
|
| 188 |
+
----------
|
| 189 |
+
variable : str
|
| 190 |
+
Variable name
|
| 191 |
+
evidence : Evidence
|
| 192 |
+
Evidence
|
| 193 |
+
|
| 194 |
+
Returns
|
| 195 |
+
-------
|
| 196 |
+
np.ndarray
|
| 197 |
+
Posterior samples
|
| 198 |
+
"""
|
| 199 |
+
prior = self.priors[variable]
|
| 200 |
+
|
| 201 |
+
# Create grid
|
| 202 |
+
n_grid = 1000
|
| 203 |
+
if hasattr(prior.distribution, 'support'):
|
| 204 |
+
support = prior.distribution.support()
|
| 205 |
+
grid = np.linspace(support[0], support[1], n_grid)
|
| 206 |
+
else:
|
| 207 |
+
# Default grid
|
| 208 |
+
mean = prior.parameters.get('loc', 0)
|
| 209 |
+
std = prior.parameters.get('scale', 1)
|
| 210 |
+
grid = np.linspace(mean - 4*std, mean + 4*std, n_grid)
|
| 211 |
+
|
| 212 |
+
# Compute prior * likelihood
|
| 213 |
+
prior_vals = prior.pdf(grid)
|
| 214 |
+
likelihood_vals = evidence.likelihood_fn(grid, evidence.observation)
|
| 215 |
+
|
| 216 |
+
# Weight by evidence confidence
|
| 217 |
+
likelihood_vals = likelihood_vals ** evidence.confidence
|
| 218 |
+
|
| 219 |
+
# Compute posterior (unnormalized)
|
| 220 |
+
posterior_vals = prior_vals * likelihood_vals
|
| 221 |
+
|
| 222 |
+
# Normalize
|
| 223 |
+
posterior_vals /= posterior_vals.sum()
|
| 224 |
+
|
| 225 |
+
# Sample from posterior
|
| 226 |
+
n_samples = 10000
|
| 227 |
+
posterior_samples = np.random.choice(grid, size=n_samples, p=posterior_vals)
|
| 228 |
+
|
| 229 |
+
return posterior_samples
|
| 230 |
+
|
| 231 |
+
def _mcmc_update(
|
| 232 |
+
self,
|
| 233 |
+
variable: str,
|
| 234 |
+
evidence: Evidence,
|
| 235 |
+
n_samples: int = 10000
|
| 236 |
+
) -> np.ndarray:
|
| 237 |
+
"""
|
| 238 |
+
MCMC-based Bayesian update.
|
| 239 |
+
|
| 240 |
+
Parameters
|
| 241 |
+
----------
|
| 242 |
+
variable : str
|
| 243 |
+
Variable name
|
| 244 |
+
evidence : Evidence
|
| 245 |
+
Evidence
|
| 246 |
+
n_samples : int
|
| 247 |
+
Number of MCMC samples
|
| 248 |
+
|
| 249 |
+
Returns
|
| 250 |
+
-------
|
| 251 |
+
np.ndarray
|
| 252 |
+
Posterior samples
|
| 253 |
+
"""
|
| 254 |
+
prior = self.priors[variable]
|
| 255 |
+
|
| 256 |
+
def log_posterior(x):
|
| 257 |
+
log_prior = prior.log_pdf(np.array([x]))[0]
|
| 258 |
+
log_likelihood = np.log(evidence.likelihood_fn(np.array([x]), evidence.observation)[0] + 1e-10)
|
| 259 |
+
return log_prior + evidence.confidence * log_likelihood
|
| 260 |
+
|
| 261 |
+
# Simple Metropolis-Hastings
|
| 262 |
+
samples = []
|
| 263 |
+
current = prior.sample(1)[0]
|
| 264 |
+
current_log_p = log_posterior(current)
|
| 265 |
+
|
| 266 |
+
for _ in range(n_samples):
|
| 267 |
+
# Propose
|
| 268 |
+
proposal = current + np.random.normal(0, 0.1)
|
| 269 |
+
proposal_log_p = log_posterior(proposal)
|
| 270 |
+
|
| 271 |
+
# Accept/reject
|
| 272 |
+
log_alpha = proposal_log_p - current_log_p
|
| 273 |
+
if np.log(np.random.uniform()) < log_alpha:
|
| 274 |
+
current = proposal
|
| 275 |
+
current_log_p = proposal_log_p
|
| 276 |
+
|
| 277 |
+
samples.append(current)
|
| 278 |
+
|
| 279 |
+
return np.array(samples)
|
| 280 |
+
|
| 281 |
+
def _analytical_update(self, variable: str, evidence: Evidence) -> np.ndarray:
|
| 282 |
+
"""
|
| 283 |
+
Analytical Bayesian update (for conjugate priors).
|
| 284 |
+
|
| 285 |
+
Parameters
|
| 286 |
+
----------
|
| 287 |
+
variable : str
|
| 288 |
+
Variable name
|
| 289 |
+
evidence : Evidence
|
| 290 |
+
Evidence
|
| 291 |
+
|
| 292 |
+
Returns
|
| 293 |
+
-------
|
| 294 |
+
np.ndarray
|
| 295 |
+
Posterior samples
|
| 296 |
+
"""
|
| 297 |
+
# Placeholder - would implement conjugate updates
|
| 298 |
+
# For now, fall back to grid
|
| 299 |
+
return self._grid_update(variable, evidence)
|
| 300 |
+
|
| 301 |
+
def get_posterior_summary(self, variable: str) -> Dict[str, float]:
|
| 302 |
+
"""
|
| 303 |
+
Get summary statistics of posterior.
|
| 304 |
+
|
| 305 |
+
Parameters
|
| 306 |
+
----------
|
| 307 |
+
variable : str
|
| 308 |
+
Variable name
|
| 309 |
+
|
| 310 |
+
Returns
|
| 311 |
+
-------
|
| 312 |
+
dict
|
| 313 |
+
Summary statistics
|
| 314 |
+
"""
|
| 315 |
+
if variable not in self.posteriors:
|
| 316 |
+
raise ValueError(f"No posterior for {variable}")
|
| 317 |
+
|
| 318 |
+
samples = self.posteriors[variable]
|
| 319 |
+
|
| 320 |
+
return {
|
| 321 |
+
'mean': np.mean(samples),
|
| 322 |
+
'median': np.median(samples),
|
| 323 |
+
'std': np.std(samples),
|
| 324 |
+
'q5': np.percentile(samples, 5),
|
| 325 |
+
'q25': np.percentile(samples, 25),
|
| 326 |
+
'q75': np.percentile(samples, 75),
|
| 327 |
+
'q95': np.percentile(samples, 95)
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
def get_credible_interval(
|
| 331 |
+
self,
|
| 332 |
+
variable: str,
|
| 333 |
+
alpha: float = 0.05
|
| 334 |
+
) -> Tuple[float, float]:
|
| 335 |
+
"""
|
| 336 |
+
Get credible interval for posterior.
|
| 337 |
+
|
| 338 |
+
Parameters
|
| 339 |
+
----------
|
| 340 |
+
variable : str
|
| 341 |
+
Variable name
|
| 342 |
+
alpha : float
|
| 343 |
+
Significance level
|
| 344 |
+
|
| 345 |
+
Returns
|
| 346 |
+
-------
|
| 347 |
+
tuple
|
| 348 |
+
(lower, upper) bounds of credible interval
|
| 349 |
+
"""
|
| 350 |
+
if variable not in self.posteriors:
|
| 351 |
+
raise ValueError(f"No posterior for {variable}")
|
| 352 |
+
|
| 353 |
+
samples = self.posteriors[variable]
|
| 354 |
+
lower = np.percentile(samples, 100 * alpha / 2)
|
| 355 |
+
upper = np.percentile(samples, 100 * (1 - alpha / 2))
|
| 356 |
+
|
| 357 |
+
return lower, upper
|
| 358 |
+
|
| 359 |
+
def compute_bayes_factor(
|
| 360 |
+
self,
|
| 361 |
+
variable: str,
|
| 362 |
+
hypothesis1: Callable,
|
| 363 |
+
hypothesis2: Callable
|
| 364 |
+
) -> float:
|
| 365 |
+
"""
|
| 366 |
+
Compute Bayes factor for two hypotheses.
|
| 367 |
+
|
| 368 |
+
Parameters
|
| 369 |
+
----------
|
| 370 |
+
variable : str
|
| 371 |
+
Variable name
|
| 372 |
+
hypothesis1 : callable
|
| 373 |
+
First hypothesis (returns bool)
|
| 374 |
+
hypothesis2 : callable
|
| 375 |
+
Second hypothesis (returns bool)
|
| 376 |
+
|
| 377 |
+
Returns
|
| 378 |
+
-------
|
| 379 |
+
float
|
| 380 |
+
Bayes factor (BF > 1 favors hypothesis1)
|
| 381 |
+
"""
|
| 382 |
+
if variable not in self.posteriors:
|
| 383 |
+
raise ValueError(f"No posterior for {variable}")
|
| 384 |
+
|
| 385 |
+
samples = self.posteriors[variable]
|
| 386 |
+
|
| 387 |
+
p1 = np.mean([hypothesis1(x) for x in samples])
|
| 388 |
+
p2 = np.mean([hypothesis2(x) for x in samples])
|
| 389 |
+
|
| 390 |
+
if p2 == 0:
|
| 391 |
+
return np.inf
|
| 392 |
+
return p1 / p2
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class BeliefUpdater:
|
| 396 |
+
"""
|
| 397 |
+
High-level interface for updating geopolitical beliefs.
|
| 398 |
+
|
| 399 |
+
This class provides domain-specific methods for updating
|
| 400 |
+
beliefs based on intelligence, events, and rumors.
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
def __init__(self):
|
| 404 |
+
"""Initialize belief updater."""
|
| 405 |
+
self.engine = BayesianEngine()
|
| 406 |
+
self.beliefs: Dict[str, Dict[str, Any]] = {}
|
| 407 |
+
|
| 408 |
+
def initialize_belief(
|
| 409 |
+
self,
|
| 410 |
+
name: str,
|
| 411 |
+
prior_mean: float,
|
| 412 |
+
prior_std: float,
|
| 413 |
+
belief_type: str = 'continuous'
|
| 414 |
+
) -> None:
|
| 415 |
+
"""
|
| 416 |
+
Initialize a belief with prior.
|
| 417 |
+
|
| 418 |
+
Parameters
|
| 419 |
+
----------
|
| 420 |
+
name : str
|
| 421 |
+
Belief name
|
| 422 |
+
prior_mean : float
|
| 423 |
+
Prior mean
|
| 424 |
+
prior_std : float
|
| 425 |
+
Prior standard deviation
|
| 426 |
+
belief_type : str
|
| 427 |
+
Type of belief ('continuous', 'probability')
|
| 428 |
+
"""
|
| 429 |
+
if belief_type == 'continuous':
|
| 430 |
+
distribution = stats.norm
|
| 431 |
+
elif belief_type == 'probability':
|
| 432 |
+
# Use beta distribution for probabilities
|
| 433 |
+
# Convert mean/std to alpha/beta parameters
|
| 434 |
+
mean = np.clip(prior_mean, 0.01, 0.99)
|
| 435 |
+
var = prior_std ** 2
|
| 436 |
+
alpha = mean * (mean * (1 - mean) / var - 1)
|
| 437 |
+
beta = (1 - mean) * (mean * (1 - mean) / var - 1)
|
| 438 |
+
distribution = stats.beta
|
| 439 |
+
prior_mean = alpha
|
| 440 |
+
prior_std = beta
|
| 441 |
+
else:
|
| 442 |
+
distribution = stats.norm
|
| 443 |
+
|
| 444 |
+
prior = Prior(
|
| 445 |
+
name=name,
|
| 446 |
+
distribution=distribution,
|
| 447 |
+
parameters={'loc': prior_mean, 'scale': prior_std}
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.engine.set_prior(prior)
|
| 451 |
+
self.beliefs[name] = {
|
| 452 |
+
'type': belief_type,
|
| 453 |
+
'initialized': True
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
def update_from_intelligence(
|
| 457 |
+
self,
|
| 458 |
+
belief: str,
|
| 459 |
+
observation: float,
|
| 460 |
+
reliability: float = 0.8
|
| 461 |
+
) -> Dict[str, float]:
|
| 462 |
+
"""
|
| 463 |
+
Update belief from intelligence report.
|
| 464 |
+
|
| 465 |
+
Parameters
|
| 466 |
+
----------
|
| 467 |
+
belief : str
|
| 468 |
+
Belief name
|
| 469 |
+
observation : float
|
| 470 |
+
Observed value from intelligence
|
| 471 |
+
reliability : float
|
| 472 |
+
Reliability of intelligence source (0-1)
|
| 473 |
+
|
| 474 |
+
Returns
|
| 475 |
+
-------
|
| 476 |
+
dict
|
| 477 |
+
Posterior summary
|
| 478 |
+
"""
|
| 479 |
+
def likelihood_fn(x, obs):
|
| 480 |
+
# Gaussian likelihood centered at observation
|
| 481 |
+
# Width depends on reliability
|
| 482 |
+
std = 1.0 / reliability
|
| 483 |
+
return stats.norm.pdf(x, loc=obs, scale=std)
|
| 484 |
+
|
| 485 |
+
evidence = Evidence(
|
| 486 |
+
observation=observation,
|
| 487 |
+
likelihood_fn=likelihood_fn,
|
| 488 |
+
timestamp=pd.Timestamp.now().timestamp(),
|
| 489 |
+
confidence=reliability
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
self.engine.update(belief, evidence, method='grid')
|
| 493 |
+
return self.engine.get_posterior_summary(belief)
|
| 494 |
+
|
| 495 |
+
def update_from_event(
|
| 496 |
+
self,
|
| 497 |
+
belief: str,
|
| 498 |
+
event_impact: float,
|
| 499 |
+
event_certainty: float = 1.0
|
| 500 |
+
) -> Dict[str, float]:
|
| 501 |
+
"""
|
| 502 |
+
Update belief from observed event.
|
| 503 |
+
|
| 504 |
+
Parameters
|
| 505 |
+
----------
|
| 506 |
+
belief : str
|
| 507 |
+
Belief name
|
| 508 |
+
event_impact : float
|
| 509 |
+
Impact of event (shift in belief)
|
| 510 |
+
event_certainty : float
|
| 511 |
+
Certainty that event occurred (0-1)
|
| 512 |
+
|
| 513 |
+
Returns
|
| 514 |
+
-------
|
| 515 |
+
dict
|
| 516 |
+
Posterior summary
|
| 517 |
+
"""
|
| 518 |
+
# Get current belief
|
| 519 |
+
if belief not in self.engine.posteriors:
|
| 520 |
+
# Use prior
|
| 521 |
+
current_samples = self.engine.priors[belief].sample(10000)
|
| 522 |
+
else:
|
| 523 |
+
current_samples = self.engine.posteriors[belief]
|
| 524 |
+
|
| 525 |
+
current_mean = np.mean(current_samples)
|
| 526 |
+
|
| 527 |
+
# Create shifted observation
|
| 528 |
+
observation = current_mean + event_impact
|
| 529 |
+
|
| 530 |
+
def likelihood_fn(x, obs):
|
| 531 |
+
return stats.norm.pdf(x, loc=obs, scale=abs(event_impact) * 0.1)
|
| 532 |
+
|
| 533 |
+
evidence = Evidence(
|
| 534 |
+
observation=observation,
|
| 535 |
+
likelihood_fn=likelihood_fn,
|
| 536 |
+
timestamp=pd.Timestamp.now().timestamp(),
|
| 537 |
+
confidence=event_certainty
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
self.engine.update(belief, evidence, method='grid')
|
| 541 |
+
return self.engine.get_posterior_summary(belief)
|
| 542 |
+
|
| 543 |
+
def get_belief_probability(
|
| 544 |
+
self,
|
| 545 |
+
belief: str,
|
| 546 |
+
threshold: float,
|
| 547 |
+
direction: str = 'greater'
|
| 548 |
+
) -> float:
|
| 549 |
+
"""
|
| 550 |
+
Get probability that belief exceeds threshold.
|
| 551 |
+
|
| 552 |
+
Parameters
|
| 553 |
+
----------
|
| 554 |
+
belief : str
|
| 555 |
+
Belief name
|
| 556 |
+
threshold : float
|
| 557 |
+
Threshold value
|
| 558 |
+
direction : str
|
| 559 |
+
'greater' or 'less'
|
| 560 |
+
|
| 561 |
+
Returns
|
| 562 |
+
-------
|
| 563 |
+
float
|
| 564 |
+
Probability
|
| 565 |
+
"""
|
| 566 |
+
if belief not in self.engine.posteriors:
|
| 567 |
+
samples = self.engine.priors[belief].sample(10000)
|
| 568 |
+
else:
|
| 569 |
+
samples = self.engine.posteriors[belief]
|
| 570 |
+
|
| 571 |
+
if direction == 'greater':
|
| 572 |
+
return np.mean(samples > threshold)
|
| 573 |
+
else:
|
| 574 |
+
return np.mean(samples < threshold)
|
| 575 |
+
|
| 576 |
+
def compare_beliefs(self, belief1: str, belief2: str) -> Dict[str, float]:
|
| 577 |
+
"""
|
| 578 |
+
Compare two beliefs.
|
| 579 |
+
|
| 580 |
+
Parameters
|
| 581 |
+
----------
|
| 582 |
+
belief1 : str
|
| 583 |
+
First belief
|
| 584 |
+
belief2 : str
|
| 585 |
+
Second belief
|
| 586 |
+
|
| 587 |
+
Returns
|
| 588 |
+
-------
|
| 589 |
+
dict
|
| 590 |
+
Comparison results
|
| 591 |
+
"""
|
| 592 |
+
if belief1 not in self.engine.posteriors:
|
| 593 |
+
samples1 = self.engine.priors[belief1].sample(10000)
|
| 594 |
+
else:
|
| 595 |
+
samples1 = self.engine.posteriors[belief1]
|
| 596 |
+
|
| 597 |
+
if belief2 not in self.engine.posteriors:
|
| 598 |
+
samples2 = self.engine.priors[belief2].sample(10000)
|
| 599 |
+
else:
|
| 600 |
+
samples2 = self.engine.posteriors[belief2]
|
| 601 |
+
|
| 602 |
+
return {
|
| 603 |
+
'p_belief1_greater': np.mean(samples1 > samples2),
|
| 604 |
+
'mean_difference': np.mean(samples1 - samples2),
|
| 605 |
+
'correlation': np.corrcoef(samples1, samples2)[0, 1]
|
| 606 |
+
}
|
geobot/inference/do_calculus.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Do-Calculus Module - Intervention Reasoning
|
| 3 |
+
|
| 4 |
+
Implements Pearl's do-calculus for counterfactual analysis and policy simulation.
|
| 5 |
+
|
| 6 |
+
Instead of just forecasting "what will happen," this module enables:
|
| 7 |
+
- "What if the U.S. sanctions X?"
|
| 8 |
+
- "What if China mobilizes?"
|
| 9 |
+
- "What if NATO deploys troops?"
|
| 10 |
+
- "What if an election is rigged?"
|
| 11 |
+
|
| 12 |
+
This is the foundation for counterfactual geopolitics.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from typing import Dict, List, Set, Optional, Tuple, Any
|
| 18 |
+
from ..models.causal_graph import CausalGraph, StructuralCausalModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DoCalculus:
|
| 22 |
+
"""
|
| 23 |
+
Implement Pearl's do-calculus for causal inference.
|
| 24 |
+
|
| 25 |
+
The do-calculus provides rules for transforming interventional
|
| 26 |
+
distributions into observational ones, enabling causal effect
|
| 27 |
+
estimation from observational data.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, causal_graph: CausalGraph):
|
| 31 |
+
"""
|
| 32 |
+
Initialize do-calculus engine.
|
| 33 |
+
|
| 34 |
+
Parameters
|
| 35 |
+
----------
|
| 36 |
+
causal_graph : CausalGraph
|
| 37 |
+
Causal graph structure
|
| 38 |
+
"""
|
| 39 |
+
self.graph = causal_graph
|
| 40 |
+
|
| 41 |
+
def is_identifiable(
|
| 42 |
+
self,
|
| 43 |
+
treatment: str,
|
| 44 |
+
outcome: str,
|
| 45 |
+
confounders: Optional[Set[str]] = None
|
| 46 |
+
) -> bool:
|
| 47 |
+
"""
|
| 48 |
+
Check if causal effect is identifiable.
|
| 49 |
+
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
treatment : str
|
| 53 |
+
Treatment variable
|
| 54 |
+
outcome : str
|
| 55 |
+
Outcome variable
|
| 56 |
+
confounders : Set[str], optional
|
| 57 |
+
Known confounders
|
| 58 |
+
|
| 59 |
+
Returns
|
| 60 |
+
-------
|
| 61 |
+
bool
|
| 62 |
+
True if effect is identifiable
|
| 63 |
+
"""
|
| 64 |
+
# Basic check: are treatment and outcome d-separated after intervention?
|
| 65 |
+
# This is a simplified version
|
| 66 |
+
|
| 67 |
+
# Get all backdoor paths
|
| 68 |
+
backdoor_paths = self._get_backdoor_paths(treatment, outcome)
|
| 69 |
+
|
| 70 |
+
if len(backdoor_paths) == 0:
|
| 71 |
+
# No backdoor paths, effect is identifiable
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
if confounders is not None:
|
| 75 |
+
# Check if confounders block all backdoor paths
|
| 76 |
+
return self._blocks_backdoor_paths(backdoor_paths, confounders)
|
| 77 |
+
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
def _get_backdoor_paths(self, treatment: str, outcome: str) -> List[List[str]]:
|
| 81 |
+
"""
|
| 82 |
+
Get all backdoor paths from treatment to outcome.
|
| 83 |
+
|
| 84 |
+
A backdoor path is a path from treatment to outcome that
|
| 85 |
+
starts with an arrow into the treatment.
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
treatment : str
|
| 90 |
+
Treatment variable
|
| 91 |
+
outcome : str
|
| 92 |
+
Outcome variable
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
List[List[str]]
|
| 97 |
+
List of backdoor paths
|
| 98 |
+
"""
|
| 99 |
+
import networkx as nx
|
| 100 |
+
|
| 101 |
+
backdoor_paths = []
|
| 102 |
+
|
| 103 |
+
# Get all simple paths from treatment to outcome
|
| 104 |
+
try:
|
| 105 |
+
all_paths = list(nx.all_simple_paths(
|
| 106 |
+
self.graph.graph.to_undirected(),
|
| 107 |
+
treatment,
|
| 108 |
+
outcome
|
| 109 |
+
))
|
| 110 |
+
except nx.NetworkXNoPath:
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
# Filter for backdoor paths
|
| 114 |
+
for path in all_paths:
|
| 115 |
+
if len(path) > 2: # Must have intermediate nodes
|
| 116 |
+
# Check if first edge goes into treatment
|
| 117 |
+
second_node = path[1]
|
| 118 |
+
if self.graph.graph.has_edge(second_node, treatment):
|
| 119 |
+
backdoor_paths.append(path)
|
| 120 |
+
|
| 121 |
+
return backdoor_paths
|
| 122 |
+
|
| 123 |
+
def _blocks_backdoor_paths(
|
| 124 |
+
self,
|
| 125 |
+
paths: List[List[str]],
|
| 126 |
+
conditioning_set: Set[str]
|
| 127 |
+
) -> bool:
|
| 128 |
+
"""
|
| 129 |
+
Check if conditioning set blocks all backdoor paths.
|
| 130 |
+
|
| 131 |
+
Parameters
|
| 132 |
+
----------
|
| 133 |
+
paths : List[List[str]]
|
| 134 |
+
Backdoor paths
|
| 135 |
+
conditioning_set : Set[str]
|
| 136 |
+
Variables to condition on
|
| 137 |
+
|
| 138 |
+
Returns
|
| 139 |
+
-------
|
| 140 |
+
bool
|
| 141 |
+
True if all paths are blocked
|
| 142 |
+
"""
|
| 143 |
+
for path in paths:
|
| 144 |
+
if not self._is_path_blocked(path, conditioning_set):
|
| 145 |
+
return False
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
def _is_path_blocked(self, path: List[str], conditioning_set: Set[str]) -> bool:
|
| 149 |
+
"""
|
| 150 |
+
Check if a path is blocked by conditioning set.
|
| 151 |
+
|
| 152 |
+
Parameters
|
| 153 |
+
----------
|
| 154 |
+
path : List[str]
|
| 155 |
+
Path to check
|
| 156 |
+
conditioning_set : Set[str]
|
| 157 |
+
Conditioning set
|
| 158 |
+
|
| 159 |
+
Returns
|
| 160 |
+
-------
|
| 161 |
+
bool
|
| 162 |
+
True if path is blocked
|
| 163 |
+
"""
|
| 164 |
+
# Simplified version: check if any non-collider in path is in conditioning set
|
| 165 |
+
for node in path[1:-1]: # Exclude endpoints
|
| 166 |
+
if node in conditioning_set:
|
| 167 |
+
# Check if it's a collider
|
| 168 |
+
idx = path.index(node)
|
| 169 |
+
prev_node = path[idx - 1]
|
| 170 |
+
next_node = path[idx + 1]
|
| 171 |
+
|
| 172 |
+
# It's a collider if both edges point to it
|
| 173 |
+
is_collider = (
|
| 174 |
+
self.graph.graph.has_edge(prev_node, node) and
|
| 175 |
+
self.graph.graph.has_edge(next_node, node)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if not is_collider:
|
| 179 |
+
return True
|
| 180 |
+
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
def find_adjustment_set(
|
| 184 |
+
self,
|
| 185 |
+
treatment: str,
|
| 186 |
+
outcome: str,
|
| 187 |
+
method: str = 'backdoor'
|
| 188 |
+
) -> Set[str]:
|
| 189 |
+
"""
|
| 190 |
+
Find valid adjustment set for identifying causal effect.
|
| 191 |
+
|
| 192 |
+
Parameters
|
| 193 |
+
----------
|
| 194 |
+
treatment : str
|
| 195 |
+
Treatment variable
|
| 196 |
+
outcome : str
|
| 197 |
+
Outcome variable
|
| 198 |
+
method : str
|
| 199 |
+
Method to use ('backdoor', 'minimal')
|
| 200 |
+
|
| 201 |
+
Returns
|
| 202 |
+
-------
|
| 203 |
+
Set[str]
|
| 204 |
+
Valid adjustment set
|
| 205 |
+
"""
|
| 206 |
+
if method == 'backdoor':
|
| 207 |
+
return self._backdoor_adjustment_set(treatment, outcome)
|
| 208 |
+
elif method == 'minimal':
|
| 209 |
+
return self._minimal_adjustment_set(treatment, outcome)
|
| 210 |
+
else:
|
| 211 |
+
raise ValueError(f"Unknown method: {method}")
|
| 212 |
+
|
| 213 |
+
def _backdoor_adjustment_set(self, treatment: str, outcome: str) -> Set[str]:
|
| 214 |
+
"""
|
| 215 |
+
Find backdoor adjustment set.
|
| 216 |
+
|
| 217 |
+
Parameters
|
| 218 |
+
----------
|
| 219 |
+
treatment : str
|
| 220 |
+
Treatment variable
|
| 221 |
+
outcome : str
|
| 222 |
+
Outcome variable
|
| 223 |
+
|
| 224 |
+
Returns
|
| 225 |
+
-------
|
| 226 |
+
Set[str]
|
| 227 |
+
Backdoor adjustment set
|
| 228 |
+
"""
|
| 229 |
+
# Get all parents of treatment (excluding outcome's descendants)
|
| 230 |
+
parents = set(self.graph.get_parents(treatment))
|
| 231 |
+
|
| 232 |
+
# Remove outcome and its descendants
|
| 233 |
+
outcome_descendants = self.graph.get_descendants(outcome)
|
| 234 |
+
adjustment_set = parents - outcome_descendants - {outcome}
|
| 235 |
+
|
| 236 |
+
return adjustment_set
|
| 237 |
+
|
| 238 |
+
def _minimal_adjustment_set(self, treatment: str, outcome: str) -> Set[str]:
|
| 239 |
+
"""
|
| 240 |
+
Find minimal adjustment set.
|
| 241 |
+
|
| 242 |
+
Parameters
|
| 243 |
+
----------
|
| 244 |
+
treatment : str
|
| 245 |
+
Treatment variable
|
| 246 |
+
outcome : str
|
| 247 |
+
Outcome variable
|
| 248 |
+
|
| 249 |
+
Returns
|
| 250 |
+
-------
|
| 251 |
+
Set[str]
|
| 252 |
+
Minimal adjustment set
|
| 253 |
+
"""
|
| 254 |
+
# Start with backdoor set
|
| 255 |
+
backdoor_set = self._backdoor_adjustment_set(treatment, outcome)
|
| 256 |
+
|
| 257 |
+
# Try removing variables one by one
|
| 258 |
+
minimal_set = backdoor_set.copy()
|
| 259 |
+
|
| 260 |
+
for var in backdoor_set:
|
| 261 |
+
candidate_set = minimal_set - {var}
|
| 262 |
+
backdoor_paths = self._get_backdoor_paths(treatment, outcome)
|
| 263 |
+
|
| 264 |
+
if self._blocks_backdoor_paths(backdoor_paths, candidate_set):
|
| 265 |
+
minimal_set = candidate_set
|
| 266 |
+
|
| 267 |
+
return minimal_set
|
| 268 |
+
|
| 269 |
+
def compute_ate(
|
| 270 |
+
self,
|
| 271 |
+
data: pd.DataFrame,
|
| 272 |
+
treatment: str,
|
| 273 |
+
outcome: str,
|
| 274 |
+
adjustment_set: Optional[Set[str]] = None
|
| 275 |
+
) -> float:
|
| 276 |
+
"""
|
| 277 |
+
Compute Average Treatment Effect (ATE).
|
| 278 |
+
|
| 279 |
+
ATE = E[Y | do(X=1)] - E[Y | do(X=0)]
|
| 280 |
+
|
| 281 |
+
Parameters
|
| 282 |
+
----------
|
| 283 |
+
data : pd.DataFrame
|
| 284 |
+
Observational data
|
| 285 |
+
treatment : str
|
| 286 |
+
Treatment variable
|
| 287 |
+
outcome : str
|
| 288 |
+
Outcome variable
|
| 289 |
+
adjustment_set : Set[str], optional
|
| 290 |
+
Variables to adjust for
|
| 291 |
+
|
| 292 |
+
Returns
|
| 293 |
+
-------
|
| 294 |
+
float
|
| 295 |
+
Average Treatment Effect
|
| 296 |
+
"""
|
| 297 |
+
if adjustment_set is None:
|
| 298 |
+
adjustment_set = self.find_adjustment_set(treatment, outcome)
|
| 299 |
+
|
| 300 |
+
# Stratification estimator
|
| 301 |
+
if len(adjustment_set) == 0:
|
| 302 |
+
# No confounding
|
| 303 |
+
treated = data[data[treatment] == 1][outcome].mean()
|
| 304 |
+
control = data[data[treatment] == 0][outcome].mean()
|
| 305 |
+
return treated - control
|
| 306 |
+
|
| 307 |
+
# With adjustment
|
| 308 |
+
# Group by adjustment variables
|
| 309 |
+
adjustment_vars = list(adjustment_set)
|
| 310 |
+
|
| 311 |
+
ate = 0.0
|
| 312 |
+
for strata, group in data.groupby(adjustment_vars):
|
| 313 |
+
if len(group) > 0:
|
| 314 |
+
# Compute effect in this stratum
|
| 315 |
+
treated = group[group[treatment] == 1][outcome].mean()
|
| 316 |
+
control = group[group[treatment] == 0][outcome].mean()
|
| 317 |
+
|
| 318 |
+
if not np.isnan(treated) and not np.isnan(control):
|
| 319 |
+
strata_effect = treated - control
|
| 320 |
+
strata_weight = len(group) / len(data)
|
| 321 |
+
ate += strata_effect * strata_weight
|
| 322 |
+
|
| 323 |
+
return ate
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class InterventionSimulator:
|
| 327 |
+
"""
|
| 328 |
+
Simulate policy interventions using structural causal models.
|
| 329 |
+
|
| 330 |
+
This class provides high-level interface for testing
|
| 331 |
+
"what if" scenarios in geopolitical contexts.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
def __init__(self, scm: StructuralCausalModel):
|
| 335 |
+
"""
|
| 336 |
+
Initialize intervention simulator.
|
| 337 |
+
|
| 338 |
+
Parameters
|
| 339 |
+
----------
|
| 340 |
+
scm : StructuralCausalModel
|
| 341 |
+
Structural causal model
|
| 342 |
+
"""
|
| 343 |
+
self.scm = scm
|
| 344 |
+
self.do_calculus = DoCalculus(scm.graph)
|
| 345 |
+
|
| 346 |
+
def simulate_intervention(
|
| 347 |
+
self,
|
| 348 |
+
intervention: Dict[str, float],
|
| 349 |
+
n_samples: int = 1000,
|
| 350 |
+
outcomes: Optional[List[str]] = None
|
| 351 |
+
) -> Dict[str, np.ndarray]:
|
| 352 |
+
"""
|
| 353 |
+
Simulate an intervention.
|
| 354 |
+
|
| 355 |
+
Parameters
|
| 356 |
+
----------
|
| 357 |
+
intervention : dict
|
| 358 |
+
Intervention specification {variable: value}
|
| 359 |
+
n_samples : int
|
| 360 |
+
Number of Monte Carlo samples
|
| 361 |
+
outcomes : List[str], optional
|
| 362 |
+
Outcome variables to track
|
| 363 |
+
|
| 364 |
+
Returns
|
| 365 |
+
-------
|
| 366 |
+
dict
|
| 367 |
+
Simulated outcomes
|
| 368 |
+
"""
|
| 369 |
+
# Sample from intervened distribution
|
| 370 |
+
samples = self.scm.sample(n_samples=n_samples, interventions=intervention)
|
| 371 |
+
|
| 372 |
+
if outcomes is not None:
|
| 373 |
+
samples = {k: v for k, v in samples.items() if k in outcomes}
|
| 374 |
+
|
| 375 |
+
return samples
|
| 376 |
+
|
| 377 |
+
def compare_interventions(
|
| 378 |
+
self,
|
| 379 |
+
interventions: List[Dict[str, float]],
|
| 380 |
+
outcome: str,
|
| 381 |
+
n_samples: int = 1000
|
| 382 |
+
) -> Dict[str, Dict[str, float]]:
|
| 383 |
+
"""
|
| 384 |
+
Compare multiple interventions.
|
| 385 |
+
|
| 386 |
+
Parameters
|
| 387 |
+
----------
|
| 388 |
+
interventions : List[dict]
|
| 389 |
+
List of interventions to compare
|
| 390 |
+
outcome : str
|
| 391 |
+
Outcome variable to compare
|
| 392 |
+
n_samples : int
|
| 393 |
+
Number of samples per intervention
|
| 394 |
+
|
| 395 |
+
Returns
|
| 396 |
+
-------
|
| 397 |
+
dict
|
| 398 |
+
Comparison results
|
| 399 |
+
"""
|
| 400 |
+
results = {}
|
| 401 |
+
|
| 402 |
+
for i, intervention in enumerate(interventions):
|
| 403 |
+
samples = self.simulate_intervention(intervention, n_samples, [outcome])
|
| 404 |
+
outcome_samples = samples[outcome]
|
| 405 |
+
|
| 406 |
+
results[f"intervention_{i}"] = {
|
| 407 |
+
'intervention': intervention,
|
| 408 |
+
'mean': np.mean(outcome_samples),
|
| 409 |
+
'std': np.std(outcome_samples),
|
| 410 |
+
'median': np.median(outcome_samples),
|
| 411 |
+
'q25': np.percentile(outcome_samples, 25),
|
| 412 |
+
'q75': np.percentile(outcome_samples, 75)
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
return results
|
| 416 |
+
|
| 417 |
+
def optimal_intervention(
|
| 418 |
+
self,
|
| 419 |
+
target_var: str,
|
| 420 |
+
intervention_vars: List[str],
|
| 421 |
+
intervention_ranges: Dict[str, Tuple[float, float]],
|
| 422 |
+
objective: str = 'maximize',
|
| 423 |
+
n_trials: int = 100,
|
| 424 |
+
n_samples: int = 1000
|
| 425 |
+
) -> Dict[str, Any]:
|
| 426 |
+
"""
|
| 427 |
+
Find optimal intervention to achieve target.
|
| 428 |
+
|
| 429 |
+
Parameters
|
| 430 |
+
----------
|
| 431 |
+
target_var : str
|
| 432 |
+
Target variable to optimize
|
| 433 |
+
intervention_vars : List[str]
|
| 434 |
+
Variables that can be intervened on
|
| 435 |
+
intervention_ranges : dict
|
| 436 |
+
Ranges for each intervention variable
|
| 437 |
+
objective : str
|
| 438 |
+
'maximize' or 'minimize'
|
| 439 |
+
n_trials : int
|
| 440 |
+
Number of random trials
|
| 441 |
+
n_samples : int
|
| 442 |
+
Samples per trial
|
| 443 |
+
|
| 444 |
+
Returns
|
| 445 |
+
-------
|
| 446 |
+
dict
|
| 447 |
+
Optimal intervention and results
|
| 448 |
+
"""
|
| 449 |
+
best_intervention = None
|
| 450 |
+
best_value = float('-inf') if objective == 'maximize' else float('inf')
|
| 451 |
+
|
| 452 |
+
for _ in range(n_trials):
|
| 453 |
+
# Sample random intervention
|
| 454 |
+
intervention = {}
|
| 455 |
+
for var in intervention_vars:
|
| 456 |
+
low, high = intervention_ranges[var]
|
| 457 |
+
intervention[var] = np.random.uniform(low, high)
|
| 458 |
+
|
| 459 |
+
# Simulate
|
| 460 |
+
samples = self.simulate_intervention(intervention, n_samples, [target_var])
|
| 461 |
+
mean_value = np.mean(samples[target_var])
|
| 462 |
+
|
| 463 |
+
# Update best
|
| 464 |
+
if objective == 'maximize':
|
| 465 |
+
if mean_value > best_value:
|
| 466 |
+
best_value = mean_value
|
| 467 |
+
best_intervention = intervention
|
| 468 |
+
else:
|
| 469 |
+
if mean_value < best_value:
|
| 470 |
+
best_value = mean_value
|
| 471 |
+
best_intervention = intervention
|
| 472 |
+
|
| 473 |
+
return {
|
| 474 |
+
'optimal_intervention': best_intervention,
|
| 475 |
+
'optimal_value': best_value,
|
| 476 |
+
'objective': objective
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
def counterfactual_analysis(
|
| 480 |
+
self,
|
| 481 |
+
observed: Dict[str, float],
|
| 482 |
+
intervention: Dict[str, float],
|
| 483 |
+
outcome: str
|
| 484 |
+
) -> Dict[str, float]:
|
| 485 |
+
"""
|
| 486 |
+
Perform counterfactual analysis.
|
| 487 |
+
|
| 488 |
+
"Given that we observed X, what would have happened if we had done Y?"
|
| 489 |
+
|
| 490 |
+
Parameters
|
| 491 |
+
----------
|
| 492 |
+
observed : dict
|
| 493 |
+
Observed values
|
| 494 |
+
intervention : dict
|
| 495 |
+
Counterfactual intervention
|
| 496 |
+
outcome : str
|
| 497 |
+
Outcome variable
|
| 498 |
+
|
| 499 |
+
Returns
|
| 500 |
+
-------
|
| 501 |
+
dict
|
| 502 |
+
Counterfactual results
|
| 503 |
+
"""
|
| 504 |
+
counterfactual = self.scm.compute_counterfactual(observed, intervention)
|
| 505 |
+
|
| 506 |
+
return {
|
| 507 |
+
'observed_outcome': observed.get(outcome, None),
|
| 508 |
+
'counterfactual_outcome': counterfactual.get(outcome, None),
|
| 509 |
+
'effect': counterfactual.get(outcome, 0) - observed.get(outcome, 0)
|
| 510 |
+
}
|
geobot/inference/particle_filter.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sequential Monte Carlo (SMC) and Particle Filtering
|
| 3 |
+
|
| 4 |
+
Implements advanced particle filtering algorithms for:
|
| 5 |
+
- Recursive Bayesian inference on latent states
|
| 6 |
+
- High-dimensional posterior computation
|
| 7 |
+
- Nonlinear/non-Gaussian state estimation
|
| 8 |
+
- Degeneracy handling through resampling
|
| 9 |
+
|
| 10 |
+
Methods:
|
| 11 |
+
- Bootstrap particle filter
|
| 12 |
+
- Auxiliary particle filter
|
| 13 |
+
- Rao-Blackwellized particle filter
|
| 14 |
+
- Systematic resampling, stratified resampling
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from typing import Callable, Optional, Tuple, Dict, Any
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from scipy.stats import multivariate_normal
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ParticleState:
|
| 25 |
+
"""
|
| 26 |
+
Represents a particle filter state.
|
| 27 |
+
|
| 28 |
+
Attributes
|
| 29 |
+
----------
|
| 30 |
+
particles : np.ndarray, shape (n_particles, state_dim)
|
| 31 |
+
Particle positions
|
| 32 |
+
weights : np.ndarray, shape (n_particles,)
|
| 33 |
+
Normalized particle weights
|
| 34 |
+
log_weights : np.ndarray, shape (n_particles,)
|
| 35 |
+
Log weights (for numerical stability)
|
| 36 |
+
ess : float
|
| 37 |
+
Effective sample size
|
| 38 |
+
"""
|
| 39 |
+
particles: np.ndarray
|
| 40 |
+
weights: np.ndarray
|
| 41 |
+
log_weights: np.ndarray
|
| 42 |
+
ess: float
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SequentialMonteCarlo:
|
| 46 |
+
"""
|
| 47 |
+
Sequential Monte Carlo (particle filter) for recursive Bayesian inference.
|
| 48 |
+
|
| 49 |
+
Performs filtering on nonlinear/non-Gaussian state-space models:
|
| 50 |
+
x_t ~ p(x_t | x_{t-1}) (dynamics)
|
| 51 |
+
y_t ~ p(y_t | x_t) (observation)
|
| 52 |
+
|
| 53 |
+
Maintains posterior approximation p(x_t | y_{1:t}) via weighted particles.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
n_particles: int,
|
| 59 |
+
state_dim: int,
|
| 60 |
+
dynamics_fn: Callable,
|
| 61 |
+
observation_fn: Callable,
|
| 62 |
+
dynamics_noise_fn: Optional[Callable] = None,
|
| 63 |
+
observation_noise_fn: Optional[Callable] = None,
|
| 64 |
+
resample_threshold: float = 0.5
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Initialize Sequential Monte Carlo filter.
|
| 68 |
+
|
| 69 |
+
Parameters
|
| 70 |
+
----------
|
| 71 |
+
n_particles : int
|
| 72 |
+
Number of particles
|
| 73 |
+
state_dim : int
|
| 74 |
+
Dimension of state space
|
| 75 |
+
dynamics_fn : callable
|
| 76 |
+
State transition function: x_t = f(x_{t-1}, noise)
|
| 77 |
+
observation_fn : callable
|
| 78 |
+
Observation likelihood: p(y_t | x_t)
|
| 79 |
+
dynamics_noise_fn : callable, optional
|
| 80 |
+
Dynamics noise sampler
|
| 81 |
+
observation_noise_fn : callable, optional
|
| 82 |
+
Observation noise sampler
|
| 83 |
+
resample_threshold : float
|
| 84 |
+
ESS threshold for resampling (as fraction of n_particles)
|
| 85 |
+
"""
|
| 86 |
+
self.n_particles = n_particles
|
| 87 |
+
self.state_dim = state_dim
|
| 88 |
+
self.dynamics_fn = dynamics_fn
|
| 89 |
+
self.observation_fn = observation_fn
|
| 90 |
+
self.dynamics_noise_fn = dynamics_noise_fn
|
| 91 |
+
self.observation_noise_fn = observation_noise_fn
|
| 92 |
+
self.resample_threshold = resample_threshold
|
| 93 |
+
|
| 94 |
+
# Initialize particles uniformly (or from prior)
|
| 95 |
+
self.particles = np.random.randn(n_particles, state_dim)
|
| 96 |
+
self.weights = np.ones(n_particles) / n_particles
|
| 97 |
+
self.log_weights = np.log(self.weights)
|
| 98 |
+
|
| 99 |
+
# History
|
| 100 |
+
self.history = []
|
| 101 |
+
|
| 102 |
+
def initialize_from_prior(self, prior_sampler: Callable) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Initialize particles from prior distribution.
|
| 105 |
+
|
| 106 |
+
Parameters
|
| 107 |
+
----------
|
| 108 |
+
prior_sampler : callable
|
| 109 |
+
Function that samples from prior: x ~ p(x_0)
|
| 110 |
+
"""
|
| 111 |
+
self.particles = np.array([prior_sampler() for _ in range(self.n_particles)])
|
| 112 |
+
self.weights = np.ones(self.n_particles) / self.n_particles
|
| 113 |
+
self.log_weights = np.log(self.weights)
|
| 114 |
+
|
| 115 |
+
def predict(self) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Prediction step: propagate particles through dynamics.
|
| 118 |
+
|
| 119 |
+
x_t^i ~ p(x_t | x_{t-1}^i)
|
| 120 |
+
"""
|
| 121 |
+
new_particles = np.zeros_like(self.particles)
|
| 122 |
+
|
| 123 |
+
for i in range(self.n_particles):
|
| 124 |
+
# Sample noise
|
| 125 |
+
if self.dynamics_noise_fn:
|
| 126 |
+
noise = self.dynamics_noise_fn()
|
| 127 |
+
else:
|
| 128 |
+
noise = np.random.randn(self.state_dim) * 0.1
|
| 129 |
+
|
| 130 |
+
# Propagate particle
|
| 131 |
+
new_particles[i] = self.dynamics_fn(self.particles[i], noise)
|
| 132 |
+
|
| 133 |
+
self.particles = new_particles
|
| 134 |
+
|
| 135 |
+
def update(self, observation: np.ndarray) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Update step: reweight particles based on observation likelihood.
|
| 138 |
+
|
| 139 |
+
w_t^i ∝ p(y_t | x_t^i) w_{t-1}^i
|
| 140 |
+
|
| 141 |
+
Parameters
|
| 142 |
+
----------
|
| 143 |
+
observation : np.ndarray
|
| 144 |
+
Observation y_t
|
| 145 |
+
"""
|
| 146 |
+
# Compute log-likelihoods
|
| 147 |
+
log_likelihoods = np.zeros(self.n_particles)
|
| 148 |
+
|
| 149 |
+
for i in range(self.n_particles):
|
| 150 |
+
log_likelihoods[i] = self.observation_fn(observation, self.particles[i])
|
| 151 |
+
|
| 152 |
+
# Update log-weights
|
| 153 |
+
self.log_weights = self.log_weights + log_likelihoods
|
| 154 |
+
|
| 155 |
+
# Normalize weights (in log space for stability)
|
| 156 |
+
max_log_weight = np.max(self.log_weights)
|
| 157 |
+
self.log_weights = self.log_weights - max_log_weight
|
| 158 |
+
self.weights = np.exp(self.log_weights)
|
| 159 |
+
self.weights = self.weights / np.sum(self.weights)
|
| 160 |
+
self.log_weights = np.log(self.weights)
|
| 161 |
+
|
| 162 |
+
def compute_ess(self) -> float:
|
| 163 |
+
"""
|
| 164 |
+
Compute effective sample size (ESS).
|
| 165 |
+
|
| 166 |
+
ESS = 1 / sum(w_i^2)
|
| 167 |
+
|
| 168 |
+
Returns
|
| 169 |
+
-------
|
| 170 |
+
float
|
| 171 |
+
Effective sample size
|
| 172 |
+
"""
|
| 173 |
+
return 1.0 / np.sum(self.weights ** 2)
|
| 174 |
+
|
| 175 |
+
def resample(self, method: str = 'systematic') -> None:
|
| 176 |
+
"""
|
| 177 |
+
Resample particles to combat degeneracy.
|
| 178 |
+
|
| 179 |
+
Parameters
|
| 180 |
+
----------
|
| 181 |
+
method : str
|
| 182 |
+
Resampling method ('systematic', 'stratified', 'multinomial')
|
| 183 |
+
"""
|
| 184 |
+
if method == 'systematic':
|
| 185 |
+
indices = self._systematic_resample()
|
| 186 |
+
elif method == 'stratified':
|
| 187 |
+
indices = self._stratified_resample()
|
| 188 |
+
elif method == 'multinomial':
|
| 189 |
+
indices = np.random.choice(
|
| 190 |
+
self.n_particles,
|
| 191 |
+
size=self.n_particles,
|
| 192 |
+
p=self.weights
|
| 193 |
+
)
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(f"Unknown resampling method: {method}")
|
| 196 |
+
|
| 197 |
+
# Resample particles
|
| 198 |
+
self.particles = self.particles[indices]
|
| 199 |
+
|
| 200 |
+
# Reset weights to uniform
|
| 201 |
+
self.weights = np.ones(self.n_particles) / self.n_particles
|
| 202 |
+
self.log_weights = np.log(self.weights)
|
| 203 |
+
|
| 204 |
+
def _systematic_resample(self) -> np.ndarray:
|
| 205 |
+
"""
|
| 206 |
+
Systematic resampling (low variance).
|
| 207 |
+
|
| 208 |
+
Returns
|
| 209 |
+
-------
|
| 210 |
+
np.ndarray
|
| 211 |
+
Resampled indices
|
| 212 |
+
"""
|
| 213 |
+
positions = (np.arange(self.n_particles) + np.random.uniform()) / self.n_particles
|
| 214 |
+
indices = np.zeros(self.n_particles, dtype=int)
|
| 215 |
+
cumulative_sum = np.cumsum(self.weights)
|
| 216 |
+
|
| 217 |
+
i, j = 0, 0
|
| 218 |
+
while i < self.n_particles:
|
| 219 |
+
if positions[i] < cumulative_sum[j]:
|
| 220 |
+
indices[i] = j
|
| 221 |
+
i += 1
|
| 222 |
+
else:
|
| 223 |
+
j += 1
|
| 224 |
+
|
| 225 |
+
return indices
|
| 226 |
+
|
| 227 |
+
def _stratified_resample(self) -> np.ndarray:
|
| 228 |
+
"""
|
| 229 |
+
Stratified resampling.
|
| 230 |
+
|
| 231 |
+
Returns
|
| 232 |
+
-------
|
| 233 |
+
np.ndarray
|
| 234 |
+
Resampled indices
|
| 235 |
+
"""
|
| 236 |
+
positions = (np.arange(self.n_particles) + np.random.uniform(size=self.n_particles)) / self.n_particles
|
| 237 |
+
indices = np.zeros(self.n_particles, dtype=int)
|
| 238 |
+
cumulative_sum = np.cumsum(self.weights)
|
| 239 |
+
|
| 240 |
+
i, j = 0, 0
|
| 241 |
+
while i < self.n_particles:
|
| 242 |
+
if positions[i] < cumulative_sum[j]:
|
| 243 |
+
indices[i] = j
|
| 244 |
+
i += 1
|
| 245 |
+
else:
|
| 246 |
+
j += 1
|
| 247 |
+
|
| 248 |
+
return indices
|
| 249 |
+
|
| 250 |
+
def filter_step(self, observation: np.ndarray, resample: bool = True) -> ParticleState:
|
| 251 |
+
"""
|
| 252 |
+
Single filtering step: predict + update + (optional) resample.
|
| 253 |
+
|
| 254 |
+
Parameters
|
| 255 |
+
----------
|
| 256 |
+
observation : np.ndarray
|
| 257 |
+
Observation at current time
|
| 258 |
+
resample : bool
|
| 259 |
+
Whether to check ESS and resample if needed
|
| 260 |
+
|
| 261 |
+
Returns
|
| 262 |
+
-------
|
| 263 |
+
ParticleState
|
| 264 |
+
Current particle filter state
|
| 265 |
+
"""
|
| 266 |
+
# Predict
|
| 267 |
+
self.predict()
|
| 268 |
+
|
| 269 |
+
# Update
|
| 270 |
+
self.update(observation)
|
| 271 |
+
|
| 272 |
+
# Compute ESS
|
| 273 |
+
ess = self.compute_ess()
|
| 274 |
+
|
| 275 |
+
# Resample if ESS too low
|
| 276 |
+
if resample and ess < self.resample_threshold * self.n_particles:
|
| 277 |
+
self.resample(method='systematic')
|
| 278 |
+
ess = self.n_particles # After resampling, ESS = N
|
| 279 |
+
|
| 280 |
+
# Save state
|
| 281 |
+
state = ParticleState(
|
| 282 |
+
particles=self.particles.copy(),
|
| 283 |
+
weights=self.weights.copy(),
|
| 284 |
+
log_weights=self.log_weights.copy(),
|
| 285 |
+
ess=ess
|
| 286 |
+
)
|
| 287 |
+
self.history.append(state)
|
| 288 |
+
|
| 289 |
+
return state
|
| 290 |
+
|
| 291 |
+
def filter(self, observations: np.ndarray) -> list:
|
| 292 |
+
"""
|
| 293 |
+
Run particle filter on sequence of observations.
|
| 294 |
+
|
| 295 |
+
Parameters
|
| 296 |
+
----------
|
| 297 |
+
observations : np.ndarray, shape (n_timesteps, obs_dim)
|
| 298 |
+
Sequence of observations
|
| 299 |
+
|
| 300 |
+
Returns
|
| 301 |
+
-------
|
| 302 |
+
list
|
| 303 |
+
List of ParticleState objects
|
| 304 |
+
"""
|
| 305 |
+
states = []
|
| 306 |
+
for obs in observations:
|
| 307 |
+
state = self.filter_step(obs)
|
| 308 |
+
states.append(state)
|
| 309 |
+
return states
|
| 310 |
+
|
| 311 |
+
def get_state_estimate(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 312 |
+
"""
|
| 313 |
+
Get posterior mean and covariance estimate.
|
| 314 |
+
|
| 315 |
+
Returns
|
| 316 |
+
-------
|
| 317 |
+
tuple
|
| 318 |
+
(mean, covariance)
|
| 319 |
+
"""
|
| 320 |
+
mean = np.average(self.particles, weights=self.weights, axis=0)
|
| 321 |
+
|
| 322 |
+
# Weighted covariance
|
| 323 |
+
diff = self.particles - mean
|
| 324 |
+
cov = np.dot(self.weights * diff.T, diff)
|
| 325 |
+
|
| 326 |
+
return mean, cov
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class AuxiliaryParticleFilter(SequentialMonteCarlo):
|
| 330 |
+
"""
|
| 331 |
+
Auxiliary Particle Filter.
|
| 332 |
+
|
| 333 |
+
Improves importance distribution by looking ahead at next observation.
|
| 334 |
+
Uses auxiliary variables to guide particle propagation.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def __init__(self, *args, look_ahead_fn: Optional[Callable] = None, **kwargs):
|
| 338 |
+
"""
|
| 339 |
+
Initialize auxiliary particle filter.
|
| 340 |
+
|
| 341 |
+
Parameters
|
| 342 |
+
----------
|
| 343 |
+
look_ahead_fn : callable, optional
|
| 344 |
+
Function to compute look-ahead weights: μ_t^i = p(y_t | m_t^i)
|
| 345 |
+
where m_t^i is a prediction of x_t from x_{t-1}^i
|
| 346 |
+
"""
|
| 347 |
+
super().__init__(*args, **kwargs)
|
| 348 |
+
self.look_ahead_fn = look_ahead_fn
|
| 349 |
+
|
| 350 |
+
def filter_step(self, observation: np.ndarray, resample: bool = True) -> ParticleState:
|
| 351 |
+
"""
|
| 352 |
+
Auxiliary particle filter step.
|
| 353 |
+
|
| 354 |
+
Parameters
|
| 355 |
+
----------
|
| 356 |
+
observation : np.ndarray
|
| 357 |
+
Current observation
|
| 358 |
+
resample : bool
|
| 359 |
+
Whether to resample
|
| 360 |
+
|
| 361 |
+
Returns
|
| 362 |
+
-------
|
| 363 |
+
ParticleState
|
| 364 |
+
Filter state
|
| 365 |
+
"""
|
| 366 |
+
# Step 1: Compute auxiliary weights (look-ahead)
|
| 367 |
+
if self.look_ahead_fn:
|
| 368 |
+
aux_weights = np.zeros(self.n_particles)
|
| 369 |
+
for i in range(self.n_particles):
|
| 370 |
+
# Predict particle position
|
| 371 |
+
predicted = self.dynamics_fn(self.particles[i], np.zeros(self.state_dim))
|
| 372 |
+
# Compute look-ahead likelihood
|
| 373 |
+
aux_weights[i] = np.exp(self.observation_fn(observation, predicted))
|
| 374 |
+
|
| 375 |
+
# Combine with current weights
|
| 376 |
+
aux_weights = self.weights * aux_weights
|
| 377 |
+
aux_weights = aux_weights / np.sum(aux_weights)
|
| 378 |
+
else:
|
| 379 |
+
aux_weights = self.weights
|
| 380 |
+
|
| 381 |
+
# Step 2: Resample using auxiliary weights
|
| 382 |
+
indices = np.random.choice(self.n_particles, size=self.n_particles, p=aux_weights)
|
| 383 |
+
self.particles = self.particles[indices]
|
| 384 |
+
selected_weights = self.weights[indices]
|
| 385 |
+
selected_aux_weights = aux_weights[indices]
|
| 386 |
+
|
| 387 |
+
# Step 3: Propagate
|
| 388 |
+
self.predict()
|
| 389 |
+
|
| 390 |
+
# Step 4: Update with importance weights
|
| 391 |
+
self.update(observation)
|
| 392 |
+
|
| 393 |
+
# Adjust weights for auxiliary sampling
|
| 394 |
+
self.weights = self.weights * selected_weights / (selected_aux_weights + 1e-10)
|
| 395 |
+
self.weights = self.weights / np.sum(self.weights)
|
| 396 |
+
self.log_weights = np.log(self.weights)
|
| 397 |
+
|
| 398 |
+
# ESS and resampling
|
| 399 |
+
ess = self.compute_ess()
|
| 400 |
+
if resample and ess < self.resample_threshold * self.n_particles:
|
| 401 |
+
self.resample()
|
| 402 |
+
ess = self.n_particles
|
| 403 |
+
|
| 404 |
+
state = ParticleState(
|
| 405 |
+
particles=self.particles.copy(),
|
| 406 |
+
weights=self.weights.copy(),
|
| 407 |
+
log_weights=self.log_weights.copy(),
|
| 408 |
+
ess=ess
|
| 409 |
+
)
|
| 410 |
+
self.history.append(state)
|
| 411 |
+
|
| 412 |
+
return state
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class RaoBlackwellizedParticleFilter:
|
| 416 |
+
"""
|
| 417 |
+
Rao-Blackwellized Particle Filter (RBPF).
|
| 418 |
+
|
| 419 |
+
For models with linear-Gaussian substructure:
|
| 420 |
+
- Part of state updated with Kalman filter (exact)
|
| 421 |
+
- Remaining part updated with particle filter
|
| 422 |
+
|
| 423 |
+
Reduces variance by marginalizing out linear components.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
def __init__(
|
| 427 |
+
self,
|
| 428 |
+
n_particles: int,
|
| 429 |
+
nonlinear_dim: int,
|
| 430 |
+
linear_dim: int,
|
| 431 |
+
nonlinear_dynamics_fn: Callable,
|
| 432 |
+
linear_dynamics_fn: Callable,
|
| 433 |
+
observation_fn: Callable,
|
| 434 |
+
F_linear: np.ndarray,
|
| 435 |
+
H_linear: np.ndarray,
|
| 436 |
+
Q_linear: np.ndarray,
|
| 437 |
+
R: np.ndarray
|
| 438 |
+
):
|
| 439 |
+
"""
|
| 440 |
+
Initialize Rao-Blackwellized particle filter.
|
| 441 |
+
|
| 442 |
+
Parameters
|
| 443 |
+
----------
|
| 444 |
+
n_particles : int
|
| 445 |
+
Number of particles for nonlinear part
|
| 446 |
+
nonlinear_dim : int
|
| 447 |
+
Dimension of nonlinear state
|
| 448 |
+
linear_dim : int
|
| 449 |
+
Dimension of linear state
|
| 450 |
+
nonlinear_dynamics_fn : callable
|
| 451 |
+
Nonlinear state dynamics
|
| 452 |
+
linear_dynamics_fn : callable
|
| 453 |
+
Linear state dynamics (conditioned on nonlinear state)
|
| 454 |
+
observation_fn : callable
|
| 455 |
+
Observation likelihood
|
| 456 |
+
F_linear : np.ndarray
|
| 457 |
+
Linear dynamics matrix
|
| 458 |
+
H_linear : np.ndarray
|
| 459 |
+
Linear observation matrix
|
| 460 |
+
Q_linear : np.ndarray
|
| 461 |
+
Linear process noise covariance
|
| 462 |
+
R : np.ndarray
|
| 463 |
+
Observation noise covariance
|
| 464 |
+
"""
|
| 465 |
+
self.n_particles = n_particles
|
| 466 |
+
self.nonlinear_dim = nonlinear_dim
|
| 467 |
+
self.linear_dim = linear_dim
|
| 468 |
+
|
| 469 |
+
self.nonlinear_dynamics_fn = nonlinear_dynamics_fn
|
| 470 |
+
self.linear_dynamics_fn = linear_dynamics_fn
|
| 471 |
+
self.observation_fn = observation_fn
|
| 472 |
+
|
| 473 |
+
# Linear substructure parameters (for Kalman filter)
|
| 474 |
+
self.F_linear = F_linear
|
| 475 |
+
self.H_linear = H_linear
|
| 476 |
+
self.Q_linear = Q_linear
|
| 477 |
+
self.R = R
|
| 478 |
+
|
| 479 |
+
# Initialize particles (nonlinear part)
|
| 480 |
+
self.nonlinear_particles = np.random.randn(n_particles, nonlinear_dim)
|
| 481 |
+
self.weights = np.ones(n_particles) / n_particles
|
| 482 |
+
|
| 483 |
+
# Initialize Kalman filters (one per particle)
|
| 484 |
+
self.linear_means = [np.zeros(linear_dim) for _ in range(n_particles)]
|
| 485 |
+
self.linear_covs = [np.eye(linear_dim) for _ in range(n_particles)]
|
| 486 |
+
|
| 487 |
+
def filter_step(self, observation: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 488 |
+
"""
|
| 489 |
+
RBPF filtering step.
|
| 490 |
+
|
| 491 |
+
Parameters
|
| 492 |
+
----------
|
| 493 |
+
observation : np.ndarray
|
| 494 |
+
Current observation
|
| 495 |
+
|
| 496 |
+
Returns
|
| 497 |
+
-------
|
| 498 |
+
tuple
|
| 499 |
+
(nonlinear_estimate, linear_estimate)
|
| 500 |
+
"""
|
| 501 |
+
# Step 1: Propagate nonlinear particles
|
| 502 |
+
new_nonlinear_particles = np.zeros_like(self.nonlinear_particles)
|
| 503 |
+
for i in range(self.n_particles):
|
| 504 |
+
noise = np.random.randn(self.nonlinear_dim) * 0.1
|
| 505 |
+
new_nonlinear_particles[i] = self.nonlinear_dynamics_fn(
|
| 506 |
+
self.nonlinear_particles[i], noise
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
# Step 2: Update linear state with Kalman filter (per particle)
|
| 510 |
+
new_linear_means = []
|
| 511 |
+
new_linear_covs = []
|
| 512 |
+
log_likelihoods = np.zeros(self.n_particles)
|
| 513 |
+
|
| 514 |
+
for i in range(self.n_particles):
|
| 515 |
+
# Kalman prediction
|
| 516 |
+
m_pred = self.F_linear @ self.linear_means[i]
|
| 517 |
+
P_pred = self.F_linear @ self.linear_covs[i] @ self.F_linear.T + self.Q_linear
|
| 518 |
+
|
| 519 |
+
# Kalman update
|
| 520 |
+
innovation = observation - self.H_linear @ m_pred
|
| 521 |
+
S = self.H_linear @ P_pred @ self.H_linear.T + self.R
|
| 522 |
+
K = P_pred @ self.H_linear.T @ np.linalg.inv(S)
|
| 523 |
+
|
| 524 |
+
m_new = m_pred + K @ innovation
|
| 525 |
+
P_new = (np.eye(self.linear_dim) - K @ self.H_linear) @ P_pred
|
| 526 |
+
|
| 527 |
+
new_linear_means.append(m_new)
|
| 528 |
+
new_linear_covs.append(P_new)
|
| 529 |
+
|
| 530 |
+
# Log-likelihood
|
| 531 |
+
log_likelihoods[i] = multivariate_normal.logpdf(innovation, mean=np.zeros_like(innovation), cov=S)
|
| 532 |
+
|
| 533 |
+
# Step 3: Update weights
|
| 534 |
+
max_ll = np.max(log_likelihoods)
|
| 535 |
+
weights = np.exp(log_likelihoods - max_ll)
|
| 536 |
+
self.weights = self.weights * weights
|
| 537 |
+
self.weights = self.weights / np.sum(self.weights)
|
| 538 |
+
|
| 539 |
+
# Step 4: Resample if needed
|
| 540 |
+
ess = 1.0 / np.sum(self.weights ** 2)
|
| 541 |
+
if ess < 0.5 * self.n_particles:
|
| 542 |
+
indices = np.random.choice(self.n_particles, size=self.n_particles, p=self.weights)
|
| 543 |
+
new_nonlinear_particles = new_nonlinear_particles[indices]
|
| 544 |
+
new_linear_means = [new_linear_means[i] for i in indices]
|
| 545 |
+
new_linear_covs = [new_linear_covs[i] for i in indices]
|
| 546 |
+
self.weights = np.ones(self.n_particles) / self.n_particles
|
| 547 |
+
|
| 548 |
+
# Update state
|
| 549 |
+
self.nonlinear_particles = new_nonlinear_particles
|
| 550 |
+
self.linear_means = new_linear_means
|
| 551 |
+
self.linear_covs = new_linear_covs
|
| 552 |
+
|
| 553 |
+
# Estimates
|
| 554 |
+
nonlinear_estimate = np.average(self.nonlinear_particles, weights=self.weights, axis=0)
|
| 555 |
+
linear_estimate = np.average(new_linear_means, weights=self.weights, axis=0)
|
| 556 |
+
|
| 557 |
+
return nonlinear_estimate, linear_estimate
|
geobot/inference/variational_inference.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Variational Inference (VI) Engine
|
| 3 |
+
|
| 4 |
+
Implements scalable approximate Bayesian inference via optimization:
|
| 5 |
+
- Mean-field variational inference
|
| 6 |
+
- Automatic Differentiation Variational Inference (ADVI)
|
| 7 |
+
- Evidence Lower Bound (ELBO) optimization
|
| 8 |
+
- Coordinate ascent variational inference (CAVI)
|
| 9 |
+
|
| 10 |
+
Provides high-dimensional posterior approximation when MCMC is intractable.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from typing import Callable, Dict, List, Optional, Tuple, Any
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from scipy.stats import norm, multivariate_normal
|
| 17 |
+
from scipy.optimize import minimize
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class VariationalDistribution:
|
| 22 |
+
"""
|
| 23 |
+
Parametric variational distribution q(z | λ).
|
| 24 |
+
|
| 25 |
+
Attributes
|
| 26 |
+
----------
|
| 27 |
+
family : str
|
| 28 |
+
Distribution family ('normal', 'multivariate_normal')
|
| 29 |
+
parameters : dict
|
| 30 |
+
Distribution parameters
|
| 31 |
+
"""
|
| 32 |
+
family: str
|
| 33 |
+
parameters: Dict[str, np.ndarray]
|
| 34 |
+
|
| 35 |
+
def sample(self, n_samples: int = 1) -> np.ndarray:
|
| 36 |
+
"""Sample from variational distribution."""
|
| 37 |
+
if self.family == 'normal':
|
| 38 |
+
mu = self.parameters['mu']
|
| 39 |
+
sigma = self.parameters['sigma']
|
| 40 |
+
return np.random.normal(mu, sigma, size=(n_samples, len(mu)))
|
| 41 |
+
elif self.family == 'multivariate_normal':
|
| 42 |
+
mu = self.parameters['mu']
|
| 43 |
+
cov = self.parameters['cov']
|
| 44 |
+
return np.random.multivariate_normal(mu, cov, size=n_samples)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Unknown family: {self.family}")
|
| 47 |
+
|
| 48 |
+
def log_prob(self, z: np.ndarray) -> np.ndarray:
|
| 49 |
+
"""Compute log probability."""
|
| 50 |
+
if self.family == 'normal':
|
| 51 |
+
mu = self.parameters['mu']
|
| 52 |
+
sigma = self.parameters['sigma']
|
| 53 |
+
return np.sum(norm.logpdf(z, loc=mu, scale=sigma), axis=-1)
|
| 54 |
+
elif self.family == 'multivariate_normal':
|
| 55 |
+
mu = self.parameters['mu']
|
| 56 |
+
cov = self.parameters['cov']
|
| 57 |
+
return multivariate_normal.logpdf(z, mean=mu, cov=cov)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown family: {self.family}")
|
| 60 |
+
|
| 61 |
+
def entropy(self) -> float:
|
| 62 |
+
"""Compute entropy H[q]."""
|
| 63 |
+
if self.family == 'normal':
|
| 64 |
+
sigma = self.parameters['sigma']
|
| 65 |
+
# H = 0.5 * log(2πeσ²)
|
| 66 |
+
return 0.5 * np.sum(np.log(2 * np.pi * np.e * sigma**2))
|
| 67 |
+
elif self.family == 'multivariate_normal':
|
| 68 |
+
cov = self.parameters['cov']
|
| 69 |
+
d = len(cov)
|
| 70 |
+
# H = 0.5 * log((2πe)^d |Σ|)
|
| 71 |
+
sign, logdet = np.linalg.slogdet(cov)
|
| 72 |
+
return 0.5 * (d * np.log(2 * np.pi * np.e) + logdet)
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Unknown family: {self.family}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class VariationalInference:
|
| 78 |
+
"""
|
| 79 |
+
Variational Inference engine.
|
| 80 |
+
|
| 81 |
+
Approximates posterior p(z|x) with variational distribution q(z|λ)
|
| 82 |
+
by maximizing Evidence Lower Bound (ELBO):
|
| 83 |
+
|
| 84 |
+
ELBO(λ) = E_q[log p(x,z)] - E_q[log q(z|λ)]
|
| 85 |
+
= E_q[log p(x|z)] + E_q[log p(z)] - E_q[log q(z|λ)]
|
| 86 |
+
|
| 87 |
+
Equivalently: minimize KL(q(z|λ) || p(z|x))
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
log_joint: Callable,
|
| 93 |
+
variational_family: str = 'normal',
|
| 94 |
+
n_samples: int = 100
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Initialize variational inference.
|
| 98 |
+
|
| 99 |
+
Parameters
|
| 100 |
+
----------
|
| 101 |
+
log_joint : callable
|
| 102 |
+
Log joint probability: log p(x, z)
|
| 103 |
+
variational_family : str
|
| 104 |
+
Variational family ('normal', 'multivariate_normal')
|
| 105 |
+
n_samples : int
|
| 106 |
+
Number of Monte Carlo samples for ELBO estimation
|
| 107 |
+
"""
|
| 108 |
+
self.log_joint = log_joint
|
| 109 |
+
self.variational_family = variational_family
|
| 110 |
+
self.n_samples = n_samples
|
| 111 |
+
self.q = None
|
| 112 |
+
|
| 113 |
+
def elbo(
|
| 114 |
+
self,
|
| 115 |
+
variational_params: np.ndarray,
|
| 116 |
+
param_shapes: Dict[str, Tuple],
|
| 117 |
+
observed_data: Any
|
| 118 |
+
) -> float:
|
| 119 |
+
"""
|
| 120 |
+
Compute Evidence Lower Bound (ELBO).
|
| 121 |
+
|
| 122 |
+
ELBO = E_q[log p(x,z)] - E_q[log q(z)]
|
| 123 |
+
|
| 124 |
+
Parameters
|
| 125 |
+
----------
|
| 126 |
+
variational_params : np.ndarray
|
| 127 |
+
Flattened variational parameters
|
| 128 |
+
param_shapes : dict
|
| 129 |
+
Shapes of each parameter
|
| 130 |
+
observed_data : any
|
| 131 |
+
Observed data x
|
| 132 |
+
|
| 133 |
+
Returns
|
| 134 |
+
-------
|
| 135 |
+
float
|
| 136 |
+
ELBO value
|
| 137 |
+
"""
|
| 138 |
+
# Unpack parameters
|
| 139 |
+
params = self._unpack_params(variational_params, param_shapes)
|
| 140 |
+
|
| 141 |
+
# Create variational distribution
|
| 142 |
+
q = VariationalDistribution(self.variational_family, params)
|
| 143 |
+
|
| 144 |
+
# Sample from q
|
| 145 |
+
z_samples = q.sample(self.n_samples)
|
| 146 |
+
|
| 147 |
+
# Compute E_q[log p(x, z)]
|
| 148 |
+
log_joint_vals = np.array([self.log_joint(z, observed_data) for z in z_samples])
|
| 149 |
+
expected_log_joint = np.mean(log_joint_vals)
|
| 150 |
+
|
| 151 |
+
# Compute E_q[log q(z)]
|
| 152 |
+
log_q_vals = q.log_prob(z_samples)
|
| 153 |
+
expected_log_q = np.mean(log_q_vals)
|
| 154 |
+
|
| 155 |
+
# ELBO
|
| 156 |
+
elbo_val = expected_log_joint - expected_log_q
|
| 157 |
+
|
| 158 |
+
return elbo_val
|
| 159 |
+
|
| 160 |
+
def neg_elbo(self, variational_params: np.ndarray, param_shapes: Dict, observed_data: Any) -> float:
|
| 161 |
+
"""Negative ELBO for minimization."""
|
| 162 |
+
return -self.elbo(variational_params, param_shapes, observed_data)
|
| 163 |
+
|
| 164 |
+
def fit(
|
| 165 |
+
self,
|
| 166 |
+
observed_data: Any,
|
| 167 |
+
init_params: Dict[str, np.ndarray],
|
| 168 |
+
max_iter: int = 1000,
|
| 169 |
+
method: str = 'L-BFGS-B'
|
| 170 |
+
) -> VariationalDistribution:
|
| 171 |
+
"""
|
| 172 |
+
Fit variational distribution via ELBO optimization.
|
| 173 |
+
|
| 174 |
+
Parameters
|
| 175 |
+
----------
|
| 176 |
+
observed_data : any
|
| 177 |
+
Observed data
|
| 178 |
+
init_params : dict
|
| 179 |
+
Initial variational parameters
|
| 180 |
+
max_iter : int
|
| 181 |
+
Maximum optimization iterations
|
| 182 |
+
method : str
|
| 183 |
+
Optimization method
|
| 184 |
+
|
| 185 |
+
Returns
|
| 186 |
+
-------
|
| 187 |
+
VariationalDistribution
|
| 188 |
+
Optimized variational distribution
|
| 189 |
+
"""
|
| 190 |
+
# Pack initial parameters
|
| 191 |
+
flat_params, param_shapes = self._pack_params(init_params)
|
| 192 |
+
|
| 193 |
+
# Optimize
|
| 194 |
+
result = minimize(
|
| 195 |
+
fun=self.neg_elbo,
|
| 196 |
+
x0=flat_params,
|
| 197 |
+
args=(param_shapes, observed_data),
|
| 198 |
+
method=method,
|
| 199 |
+
options={'maxiter': max_iter, 'disp': True}
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Unpack optimized parameters
|
| 203 |
+
opt_params = self._unpack_params(result.x, param_shapes)
|
| 204 |
+
|
| 205 |
+
# Create variational distribution
|
| 206 |
+
self.q = VariationalDistribution(self.variational_family, opt_params)
|
| 207 |
+
|
| 208 |
+
return self.q
|
| 209 |
+
|
| 210 |
+
def _pack_params(self, params: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Dict]:
|
| 211 |
+
"""Pack parameters into flat array."""
|
| 212 |
+
flat = []
|
| 213 |
+
shapes = {}
|
| 214 |
+
for key, val in params.items():
|
| 215 |
+
flat.append(val.flatten())
|
| 216 |
+
shapes[key] = val.shape
|
| 217 |
+
return np.concatenate(flat), shapes
|
| 218 |
+
|
| 219 |
+
def _unpack_params(self, flat: np.ndarray, shapes: Dict) -> Dict[str, np.ndarray]:
|
| 220 |
+
"""Unpack flat array into parameters."""
|
| 221 |
+
params = {}
|
| 222 |
+
idx = 0
|
| 223 |
+
for key, shape in shapes.items():
|
| 224 |
+
size = np.prod(shape)
|
| 225 |
+
params[key] = flat[idx:idx+size].reshape(shape)
|
| 226 |
+
idx += size
|
| 227 |
+
return params
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class MeanFieldVI(VariationalInference):
|
| 231 |
+
"""
|
| 232 |
+
Mean-Field Variational Inference.
|
| 233 |
+
|
| 234 |
+
Assumes variational distribution factorizes:
|
| 235 |
+
q(z | λ) = ∏_i q_i(z_i | λ_i)
|
| 236 |
+
|
| 237 |
+
Uses coordinate ascent variational inference (CAVI) to optimize
|
| 238 |
+
each factor in turn.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
log_joint: Callable,
|
| 244 |
+
factor_families: List[str],
|
| 245 |
+
n_samples: int = 100
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Initialize mean-field VI.
|
| 249 |
+
|
| 250 |
+
Parameters
|
| 251 |
+
----------
|
| 252 |
+
log_joint : callable
|
| 253 |
+
Log joint probability
|
| 254 |
+
factor_families : list
|
| 255 |
+
Distribution family for each factor
|
| 256 |
+
n_samples : int
|
| 257 |
+
Number of samples for ELBO
|
| 258 |
+
"""
|
| 259 |
+
super().__init__(log_joint, 'mean_field', n_samples)
|
| 260 |
+
self.factor_families = factor_families
|
| 261 |
+
self.n_factors = len(factor_families)
|
| 262 |
+
|
| 263 |
+
def fit_cavi(
|
| 264 |
+
self,
|
| 265 |
+
observed_data: Any,
|
| 266 |
+
init_params: List[Dict[str, np.ndarray]],
|
| 267 |
+
max_iter: int = 100,
|
| 268 |
+
tol: float = 1e-4
|
| 269 |
+
) -> List[VariationalDistribution]:
|
| 270 |
+
"""
|
| 271 |
+
Fit using Coordinate Ascent Variational Inference (CAVI).
|
| 272 |
+
|
| 273 |
+
Parameters
|
| 274 |
+
----------
|
| 275 |
+
observed_data : any
|
| 276 |
+
Observed data
|
| 277 |
+
init_params : list
|
| 278 |
+
Initial parameters for each factor
|
| 279 |
+
max_iter : int
|
| 280 |
+
Maximum CAVI iterations
|
| 281 |
+
tol : float
|
| 282 |
+
Convergence tolerance
|
| 283 |
+
|
| 284 |
+
Returns
|
| 285 |
+
-------
|
| 286 |
+
list
|
| 287 |
+
List of optimized factor distributions
|
| 288 |
+
"""
|
| 289 |
+
# Initialize factors
|
| 290 |
+
factors = [
|
| 291 |
+
VariationalDistribution(family, params)
|
| 292 |
+
for family, params in zip(self.factor_families, init_params)
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
prev_elbo = -np.inf
|
| 296 |
+
|
| 297 |
+
for iteration in range(max_iter):
|
| 298 |
+
# Update each factor in turn
|
| 299 |
+
for i in range(self.n_factors):
|
| 300 |
+
# Update factor i holding others fixed
|
| 301 |
+
factors[i] = self._update_factor(i, factors, observed_data)
|
| 302 |
+
|
| 303 |
+
# Compute ELBO
|
| 304 |
+
current_elbo = self._compute_mean_field_elbo(factors, observed_data)
|
| 305 |
+
|
| 306 |
+
# Check convergence
|
| 307 |
+
if abs(current_elbo - prev_elbo) < tol:
|
| 308 |
+
print(f"CAVI converged at iteration {iteration}")
|
| 309 |
+
break
|
| 310 |
+
|
| 311 |
+
prev_elbo = current_elbo
|
| 312 |
+
|
| 313 |
+
if iteration % 10 == 0:
|
| 314 |
+
print(f"Iteration {iteration}, ELBO: {current_elbo:.4f}")
|
| 315 |
+
|
| 316 |
+
self.factors = factors
|
| 317 |
+
return factors
|
| 318 |
+
|
| 319 |
+
def _update_factor(
|
| 320 |
+
self,
|
| 321 |
+
factor_idx: int,
|
| 322 |
+
factors: List[VariationalDistribution],
|
| 323 |
+
observed_data: Any
|
| 324 |
+
) -> VariationalDistribution:
|
| 325 |
+
"""
|
| 326 |
+
Update a single factor via optimization.
|
| 327 |
+
|
| 328 |
+
Parameters
|
| 329 |
+
----------
|
| 330 |
+
factor_idx : int
|
| 331 |
+
Index of factor to update
|
| 332 |
+
factors : list
|
| 333 |
+
Current factor distributions
|
| 334 |
+
observed_data : any
|
| 335 |
+
Observed data
|
| 336 |
+
|
| 337 |
+
Returns
|
| 338 |
+
-------
|
| 339 |
+
VariationalDistribution
|
| 340 |
+
Updated factor
|
| 341 |
+
"""
|
| 342 |
+
# This is a simplified version - full implementation would compute
|
| 343 |
+
# conditional expectations analytically for conjugate models
|
| 344 |
+
|
| 345 |
+
# For now, use gradient-based optimization
|
| 346 |
+
current_params = factors[factor_idx].parameters
|
| 347 |
+
|
| 348 |
+
def factor_neg_elbo(params_flat):
|
| 349 |
+
# Unpack
|
| 350 |
+
if self.factor_families[factor_idx] == 'normal':
|
| 351 |
+
d = len(params_flat) // 2
|
| 352 |
+
mu = params_flat[:d]
|
| 353 |
+
log_sigma = params_flat[d:]
|
| 354 |
+
sigma = np.exp(log_sigma)
|
| 355 |
+
params = {'mu': mu, 'sigma': sigma}
|
| 356 |
+
else:
|
| 357 |
+
raise NotImplementedError
|
| 358 |
+
|
| 359 |
+
# Create trial factor
|
| 360 |
+
trial_factor = VariationalDistribution(self.factor_families[factor_idx], params)
|
| 361 |
+
|
| 362 |
+
# Replace in factors
|
| 363 |
+
trial_factors = factors.copy()
|
| 364 |
+
trial_factors[factor_idx] = trial_factor
|
| 365 |
+
|
| 366 |
+
# Compute ELBO
|
| 367 |
+
elbo = self._compute_mean_field_elbo(trial_factors, observed_data)
|
| 368 |
+
return -elbo
|
| 369 |
+
|
| 370 |
+
# Pack current params
|
| 371 |
+
if self.factor_families[factor_idx] == 'normal':
|
| 372 |
+
params_flat = np.concatenate([
|
| 373 |
+
current_params['mu'],
|
| 374 |
+
np.log(current_params['sigma'])
|
| 375 |
+
])
|
| 376 |
+
else:
|
| 377 |
+
raise NotImplementedError
|
| 378 |
+
|
| 379 |
+
# Optimize
|
| 380 |
+
result = minimize(factor_neg_elbo, params_flat, method='L-BFGS-B')
|
| 381 |
+
|
| 382 |
+
# Unpack
|
| 383 |
+
if self.factor_families[factor_idx] == 'normal':
|
| 384 |
+
d = len(result.x) // 2
|
| 385 |
+
mu = result.x[:d]
|
| 386 |
+
sigma = np.exp(result.x[d:])
|
| 387 |
+
opt_params = {'mu': mu, 'sigma': sigma}
|
| 388 |
+
else:
|
| 389 |
+
raise NotImplementedError
|
| 390 |
+
|
| 391 |
+
return VariationalDistribution(self.factor_families[factor_idx], opt_params)
|
| 392 |
+
|
| 393 |
+
def _compute_mean_field_elbo(
|
| 394 |
+
self,
|
| 395 |
+
factors: List[VariationalDistribution],
|
| 396 |
+
observed_data: Any
|
| 397 |
+
) -> float:
|
| 398 |
+
"""
|
| 399 |
+
Compute ELBO for mean-field approximation.
|
| 400 |
+
|
| 401 |
+
Parameters
|
| 402 |
+
----------
|
| 403 |
+
factors : list
|
| 404 |
+
Factor distributions
|
| 405 |
+
observed_data : any
|
| 406 |
+
Observed data
|
| 407 |
+
|
| 408 |
+
Returns
|
| 409 |
+
-------
|
| 410 |
+
float
|
| 411 |
+
ELBO
|
| 412 |
+
"""
|
| 413 |
+
# Sample from each factor
|
| 414 |
+
samples = []
|
| 415 |
+
for factor in factors:
|
| 416 |
+
samples.append(factor.sample(self.n_samples))
|
| 417 |
+
|
| 418 |
+
# Combine samples
|
| 419 |
+
z_samples = np.column_stack(samples)
|
| 420 |
+
|
| 421 |
+
# Compute E_q[log p(x, z)]
|
| 422 |
+
log_joint_vals = np.array([self.log_joint(z, observed_data) for z in z_samples])
|
| 423 |
+
expected_log_joint = np.mean(log_joint_vals)
|
| 424 |
+
|
| 425 |
+
# Compute E_q[log q(z)] = sum_i E_q[log q_i(z_i)]
|
| 426 |
+
expected_log_q = 0.0
|
| 427 |
+
for i, factor in enumerate(factors):
|
| 428 |
+
log_q_vals = factor.log_prob(samples[i])
|
| 429 |
+
expected_log_q += np.mean(log_q_vals)
|
| 430 |
+
|
| 431 |
+
return expected_log_joint - expected_log_q
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class ADVI:
|
| 435 |
+
"""
|
| 436 |
+
Automatic Differentiation Variational Inference (ADVI).
|
| 437 |
+
|
| 438 |
+
Transforms constrained latent variables to unconstrained space,
|
| 439 |
+
then performs VI with Gaussian variational family.
|
| 440 |
+
|
| 441 |
+
Uses reparameterization trick for low-variance gradient estimates.
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
def __init__(
|
| 445 |
+
self,
|
| 446 |
+
log_joint: Callable,
|
| 447 |
+
transform_fn: Optional[Callable] = None,
|
| 448 |
+
inverse_transform_fn: Optional[Callable] = None
|
| 449 |
+
):
|
| 450 |
+
"""
|
| 451 |
+
Initialize ADVI.
|
| 452 |
+
|
| 453 |
+
Parameters
|
| 454 |
+
----------
|
| 455 |
+
log_joint : callable
|
| 456 |
+
Log joint in original (possibly constrained) space
|
| 457 |
+
transform_fn : callable, optional
|
| 458 |
+
Transform to unconstrained space
|
| 459 |
+
inverse_transform_fn : callable, optional
|
| 460 |
+
Inverse transform
|
| 461 |
+
"""
|
| 462 |
+
self.log_joint = log_joint
|
| 463 |
+
self.transform_fn = transform_fn or (lambda x: x)
|
| 464 |
+
self.inverse_transform_fn = inverse_transform_fn or (lambda x: x)
|
| 465 |
+
|
| 466 |
+
def fit(
|
| 467 |
+
self,
|
| 468 |
+
observed_data: Any,
|
| 469 |
+
latent_dim: int,
|
| 470 |
+
n_samples: int = 10,
|
| 471 |
+
max_iter: int = 1000,
|
| 472 |
+
learning_rate: float = 0.01
|
| 473 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 474 |
+
"""
|
| 475 |
+
Fit ADVI using gradient ascent on ELBO.
|
| 476 |
+
|
| 477 |
+
Parameters
|
| 478 |
+
----------
|
| 479 |
+
observed_data : any
|
| 480 |
+
Observed data
|
| 481 |
+
latent_dim : int
|
| 482 |
+
Dimension of latent variables
|
| 483 |
+
n_samples : int
|
| 484 |
+
Number of samples for ELBO gradient estimation
|
| 485 |
+
max_iter : int
|
| 486 |
+
Maximum iterations
|
| 487 |
+
learning_rate : float
|
| 488 |
+
Learning rate for gradient ascent
|
| 489 |
+
|
| 490 |
+
Returns
|
| 491 |
+
-------
|
| 492 |
+
tuple
|
| 493 |
+
(mean, log_std) of variational distribution
|
| 494 |
+
"""
|
| 495 |
+
# Initialize variational parameters (Gaussian in unconstrained space)
|
| 496 |
+
mu = np.zeros(latent_dim)
|
| 497 |
+
log_sigma = np.zeros(latent_dim)
|
| 498 |
+
|
| 499 |
+
for iteration in range(max_iter):
|
| 500 |
+
# Sample from standard normal
|
| 501 |
+
epsilon = np.random.randn(n_samples, latent_dim)
|
| 502 |
+
|
| 503 |
+
# Reparameterization: z = μ + σ * ε
|
| 504 |
+
sigma = np.exp(log_sigma)
|
| 505 |
+
z_unconstrained = mu + sigma * epsilon
|
| 506 |
+
|
| 507 |
+
# Transform to constrained space
|
| 508 |
+
z_constrained = np.array([self.inverse_transform_fn(z) for z in z_unconstrained])
|
| 509 |
+
|
| 510 |
+
# Compute log joint
|
| 511 |
+
log_joints = np.array([self.log_joint(z, observed_data) for z in z_constrained])
|
| 512 |
+
|
| 513 |
+
# Compute ELBO (with entropy)
|
| 514 |
+
entropy = 0.5 * np.sum(np.log(2 * np.pi * np.e * sigma**2))
|
| 515 |
+
elbo = np.mean(log_joints) + entropy
|
| 516 |
+
|
| 517 |
+
# Gradient estimates (simplified - would use autograd in practice)
|
| 518 |
+
grad_mu = np.mean((log_joints[:, np.newaxis] - elbo) * (z_unconstrained - mu) / (sigma**2), axis=0)
|
| 519 |
+
grad_log_sigma = np.mean(
|
| 520 |
+
(log_joints[:, np.newaxis] - elbo) * ((z_unconstrained - mu)**2 / sigma**2 - 1),
|
| 521 |
+
axis=0
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Update parameters
|
| 525 |
+
mu = mu + learning_rate * grad_mu
|
| 526 |
+
log_sigma = log_sigma + learning_rate * grad_log_sigma
|
| 527 |
+
|
| 528 |
+
if iteration % 100 == 0:
|
| 529 |
+
print(f"Iteration {iteration}, ELBO: {elbo:.4f}")
|
| 530 |
+
|
| 531 |
+
return mu, log_sigma
|
geobot/ml/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Machine Learning enhancers for GeoBotv1
|
| 3 |
+
|
| 4 |
+
Optional but powerful additions once the causal backbone is built.
|
| 5 |
+
Critical principle: These help discover new relationships but must not replace causality.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .risk_scoring import RiskScorer
|
| 9 |
+
from .feature_discovery import FeatureDiscovery
|
| 10 |
+
from .embedding import GeopoliticalEmbedding
|
| 11 |
+
# GNN imports are optional (require PyTorch)
|
| 12 |
+
try:
|
| 13 |
+
from .graph_neural_networks import (
|
| 14 |
+
CausalGNN,
|
| 15 |
+
GeopoliticalNetworkGNN,
|
| 16 |
+
AttentionGNN,
|
| 17 |
+
MessagePassingCausalGNN,
|
| 18 |
+
GNNTrainer,
|
| 19 |
+
NetworkToGraph
|
| 20 |
+
)
|
| 21 |
+
_has_gnn = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
_has_gnn = False
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"RiskScorer",
|
| 27 |
+
"FeatureDiscovery",
|
| 28 |
+
"GeopoliticalEmbedding",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
if _has_gnn:
|
| 32 |
+
__all__.extend([
|
| 33 |
+
"CausalGNN",
|
| 34 |
+
"GeopoliticalNetworkGNN",
|
| 35 |
+
"AttentionGNN",
|
| 36 |
+
"MessagePassingCausalGNN",
|
| 37 |
+
"GNNTrainer",
|
| 38 |
+
"NetworkToGraph",
|
| 39 |
+
])
|
geobot/ml/embedding.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geopolitical embeddings for text and entities.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GeopoliticalEmbedding:
|
| 10 |
+
"""
|
| 11 |
+
Create embeddings for geopolitical entities and text.
|
| 12 |
+
|
| 13 |
+
Transforms text into risk vectors using NLP models.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'):
|
| 17 |
+
"""
|
| 18 |
+
Initialize embedding model.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
model_name : str
|
| 23 |
+
Name of the embedding model
|
| 24 |
+
"""
|
| 25 |
+
self.model_name = model_name
|
| 26 |
+
self.model = None
|
| 27 |
+
self._load_model()
|
| 28 |
+
|
| 29 |
+
def _load_model(self) -> None:
|
| 30 |
+
"""Load embedding model."""
|
| 31 |
+
try:
|
| 32 |
+
from sentence_transformers import SentenceTransformer
|
| 33 |
+
self.model = SentenceTransformer(self.model_name)
|
| 34 |
+
except ImportError:
|
| 35 |
+
print("sentence-transformers not installed. Embeddings will not be available.")
|
| 36 |
+
self.model = None
|
| 37 |
+
|
| 38 |
+
def encode_text(self, texts: List[str]) -> np.ndarray:
|
| 39 |
+
"""
|
| 40 |
+
Encode texts into vectors.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
texts : list
|
| 45 |
+
List of texts to encode
|
| 46 |
+
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
np.ndarray
|
| 50 |
+
Embeddings
|
| 51 |
+
"""
|
| 52 |
+
if self.model is None:
|
| 53 |
+
raise ValueError("Model not loaded")
|
| 54 |
+
|
| 55 |
+
return self.model.encode(texts)
|
| 56 |
+
|
| 57 |
+
def compute_similarity(self, text1: str, text2: str) -> float:
|
| 58 |
+
"""
|
| 59 |
+
Compute similarity between two texts.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
text1 : str
|
| 64 |
+
First text
|
| 65 |
+
text2 : str
|
| 66 |
+
Second text
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
float
|
| 71 |
+
Cosine similarity
|
| 72 |
+
"""
|
| 73 |
+
embeddings = self.encode_text([text1, text2])
|
| 74 |
+
similarity = np.dot(embeddings[0], embeddings[1]) / \
|
| 75 |
+
(np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
|
| 76 |
+
return float(similarity)
|
geobot/ml/feature_discovery.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feature discovery using Random Forests and other methods.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import Dict, List, Tuple
|
| 8 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FeatureDiscovery:
|
| 13 |
+
"""
|
| 14 |
+
Discover important features and relationships in geopolitical data.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
"""Initialize feature discovery."""
|
| 19 |
+
self.feature_scores = {}
|
| 20 |
+
|
| 21 |
+
def discover_important_features(
|
| 22 |
+
self,
|
| 23 |
+
X: pd.DataFrame,
|
| 24 |
+
y: np.ndarray,
|
| 25 |
+
n_top: int = 10
|
| 26 |
+
) -> List[Tuple[str, float]]:
|
| 27 |
+
"""
|
| 28 |
+
Discover most important features using Random Forest.
|
| 29 |
+
|
| 30 |
+
Parameters
|
| 31 |
+
----------
|
| 32 |
+
X : pd.DataFrame
|
| 33 |
+
Feature matrix
|
| 34 |
+
y : np.ndarray
|
| 35 |
+
Target variable
|
| 36 |
+
n_top : int
|
| 37 |
+
Number of top features to return
|
| 38 |
+
|
| 39 |
+
Returns
|
| 40 |
+
-------
|
| 41 |
+
list
|
| 42 |
+
List of (feature_name, importance_score) tuples
|
| 43 |
+
"""
|
| 44 |
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
| 45 |
+
model.fit(X, y)
|
| 46 |
+
|
| 47 |
+
importance = model.feature_importances_
|
| 48 |
+
self.feature_scores = dict(zip(X.columns, importance))
|
| 49 |
+
|
| 50 |
+
sorted_features = sorted(
|
| 51 |
+
self.feature_scores.items(),
|
| 52 |
+
key=lambda x: x[1],
|
| 53 |
+
reverse=True
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
return sorted_features[:n_top]
|
| 57 |
+
|
| 58 |
+
def discover_latent_factors(
|
| 59 |
+
self,
|
| 60 |
+
X: pd.DataFrame,
|
| 61 |
+
n_components: int = 5
|
| 62 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 63 |
+
"""
|
| 64 |
+
Discover latent factors using PCA.
|
| 65 |
+
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
X : pd.DataFrame
|
| 69 |
+
Feature matrix
|
| 70 |
+
n_components : int
|
| 71 |
+
Number of components
|
| 72 |
+
|
| 73 |
+
Returns
|
| 74 |
+
-------
|
| 75 |
+
tuple
|
| 76 |
+
(transformed_data, explained_variance)
|
| 77 |
+
"""
|
| 78 |
+
pca = PCA(n_components=n_components)
|
| 79 |
+
transformed = pca.fit_transform(X)
|
| 80 |
+
|
| 81 |
+
return transformed, pca.explained_variance_ratio_
|
geobot/ml/graph_neural_networks.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph Neural Networks for Causal Graphs and Geopolitical Networks
|
| 3 |
+
|
| 4 |
+
Implements GNNs for:
|
| 5 |
+
- Alliance and trade network analysis
|
| 6 |
+
- Causal graph representation learning
|
| 7 |
+
- Message passing on DAGs
|
| 8 |
+
- Attention mechanisms for influence propagation
|
| 9 |
+
- Graph classification and regression
|
| 10 |
+
|
| 11 |
+
Respects identifiability and invariance constraints from causal theory.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
from typing import List, Dict, Tuple, Optional, Callable
|
| 16 |
+
import networkx as nx
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
HAS_TORCH = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
HAS_TORCH = False
|
| 25 |
+
print("PyTorch not available. GNN functionality limited.")
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import torch_geometric
|
| 29 |
+
from torch_geometric.nn import GCNConv, GATConv, MessagePassing
|
| 30 |
+
from torch_geometric.data import Data
|
| 31 |
+
HAS_TORCH_GEOMETRIC = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
HAS_TORCH_GEOMETRIC = False
|
| 34 |
+
print("torch_geometric not available. Install with: pip install torch-geometric")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if HAS_TORCH:
|
| 38 |
+
class CausalGNN(nn.Module):
|
| 39 |
+
"""
|
| 40 |
+
Graph Neural Network for causal graphs.
|
| 41 |
+
|
| 42 |
+
Respects causal ordering (topological) and propagates information
|
| 43 |
+
along causal edges.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
node_features: int,
|
| 49 |
+
hidden_dim: int,
|
| 50 |
+
output_dim: int,
|
| 51 |
+
num_layers: int = 2,
|
| 52 |
+
attention: bool = False
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Initialize Causal GNN.
|
| 56 |
+
|
| 57 |
+
Parameters
|
| 58 |
+
----------
|
| 59 |
+
node_features : int
|
| 60 |
+
Dimension of input node features
|
| 61 |
+
hidden_dim : int
|
| 62 |
+
Hidden layer dimension
|
| 63 |
+
output_dim : int
|
| 64 |
+
Output dimension
|
| 65 |
+
num_layers : int
|
| 66 |
+
Number of GNN layers
|
| 67 |
+
attention : bool
|
| 68 |
+
Use attention mechanism (GAT)
|
| 69 |
+
"""
|
| 70 |
+
super(CausalGNN, self).__init__()
|
| 71 |
+
|
| 72 |
+
self.num_layers = num_layers
|
| 73 |
+
self.attention = attention
|
| 74 |
+
|
| 75 |
+
# Input layer
|
| 76 |
+
if attention and HAS_TORCH_GEOMETRIC:
|
| 77 |
+
self.conv1 = GATConv(node_features, hidden_dim, heads=4, concat=True)
|
| 78 |
+
self.convs = nn.ModuleList([
|
| 79 |
+
GATConv(hidden_dim * 4, hidden_dim, heads=4, concat=True)
|
| 80 |
+
for _ in range(num_layers - 2)
|
| 81 |
+
])
|
| 82 |
+
self.conv_final = GATConv(hidden_dim * 4, output_dim, heads=1, concat=False)
|
| 83 |
+
elif HAS_TORCH_GEOMETRIC:
|
| 84 |
+
self.conv1 = GCNConv(node_features, hidden_dim)
|
| 85 |
+
self.convs = nn.ModuleList([
|
| 86 |
+
GCNConv(hidden_dim, hidden_dim)
|
| 87 |
+
for _ in range(num_layers - 2)
|
| 88 |
+
])
|
| 89 |
+
self.conv_final = GCNConv(hidden_dim, output_dim)
|
| 90 |
+
else:
|
| 91 |
+
# Fallback to simple linear layers
|
| 92 |
+
self.linear1 = nn.Linear(node_features, hidden_dim)
|
| 93 |
+
self.linears = nn.ModuleList([
|
| 94 |
+
nn.Linear(hidden_dim, hidden_dim)
|
| 95 |
+
for _ in range(num_layers - 2)
|
| 96 |
+
])
|
| 97 |
+
self.linear_final = nn.Linear(hidden_dim, output_dim)
|
| 98 |
+
|
| 99 |
+
def forward(self, data):
|
| 100 |
+
"""
|
| 101 |
+
Forward pass.
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
data : torch_geometric.data.Data
|
| 106 |
+
Graph data with x (node features) and edge_index
|
| 107 |
+
|
| 108 |
+
Returns
|
| 109 |
+
-------
|
| 110 |
+
torch.Tensor
|
| 111 |
+
Node embeddings
|
| 112 |
+
"""
|
| 113 |
+
if HAS_TORCH_GEOMETRIC:
|
| 114 |
+
x, edge_index = data.x, data.edge_index
|
| 115 |
+
|
| 116 |
+
# First layer
|
| 117 |
+
x = self.conv1(x, edge_index)
|
| 118 |
+
x = F.relu(x)
|
| 119 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
| 120 |
+
|
| 121 |
+
# Hidden layers
|
| 122 |
+
for conv in self.convs:
|
| 123 |
+
x = conv(x, edge_index)
|
| 124 |
+
x = F.relu(x)
|
| 125 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
| 126 |
+
|
| 127 |
+
# Output layer
|
| 128 |
+
x = self.conv_final(x, edge_index)
|
| 129 |
+
|
| 130 |
+
else:
|
| 131 |
+
x = data.x
|
| 132 |
+
|
| 133 |
+
x = self.linear1(x)
|
| 134 |
+
x = F.relu(x)
|
| 135 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
| 136 |
+
|
| 137 |
+
for linear in self.linears:
|
| 138 |
+
x = linear(x)
|
| 139 |
+
x = F.relu(x)
|
| 140 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
| 141 |
+
|
| 142 |
+
x = self.linear_final(x)
|
| 143 |
+
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class GeopoliticalNetworkGNN(nn.Module):
|
| 148 |
+
"""
|
| 149 |
+
GNN for geopolitical networks (alliances, trade, etc.).
|
| 150 |
+
|
| 151 |
+
Models influence propagation and network effects.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
node_features: int,
|
| 157 |
+
edge_features: int,
|
| 158 |
+
hidden_dim: int,
|
| 159 |
+
output_dim: int
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
Initialize geopolitical network GNN.
|
| 163 |
+
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
node_features : int
|
| 167 |
+
Node feature dimension
|
| 168 |
+
edge_features : int
|
| 169 |
+
Edge feature dimension
|
| 170 |
+
hidden_dim : int
|
| 171 |
+
Hidden dimension
|
| 172 |
+
output_dim : int
|
| 173 |
+
Output dimension
|
| 174 |
+
"""
|
| 175 |
+
super(GeopoliticalNetworkGNN, self).__init__()
|
| 176 |
+
|
| 177 |
+
if HAS_TORCH_GEOMETRIC:
|
| 178 |
+
self.conv1 = GCNConv(node_features, hidden_dim)
|
| 179 |
+
self.conv2 = GCNConv(hidden_dim, hidden_dim)
|
| 180 |
+
self.conv3 = GCNConv(hidden_dim, output_dim)
|
| 181 |
+
else:
|
| 182 |
+
self.linear1 = nn.Linear(node_features, hidden_dim)
|
| 183 |
+
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
| 184 |
+
self.linear3 = nn.Linear(hidden_dim, output_dim)
|
| 185 |
+
|
| 186 |
+
# Edge feature processing
|
| 187 |
+
self.edge_mlp = nn.Sequential(
|
| 188 |
+
nn.Linear(edge_features, hidden_dim),
|
| 189 |
+
nn.ReLU(),
|
| 190 |
+
nn.Linear(hidden_dim, hidden_dim)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def forward(self, data):
|
| 194 |
+
"""
|
| 195 |
+
Forward pass.
|
| 196 |
+
|
| 197 |
+
Parameters
|
| 198 |
+
----------
|
| 199 |
+
data : Data
|
| 200 |
+
Graph data
|
| 201 |
+
|
| 202 |
+
Returns
|
| 203 |
+
-------
|
| 204 |
+
torch.Tensor
|
| 205 |
+
Node embeddings
|
| 206 |
+
"""
|
| 207 |
+
if HAS_TORCH_GEOMETRIC:
|
| 208 |
+
x, edge_index = data.x, data.edge_index
|
| 209 |
+
|
| 210 |
+
x = self.conv1(x, edge_index)
|
| 211 |
+
x = F.relu(x)
|
| 212 |
+
|
| 213 |
+
x = self.conv2(x, edge_index)
|
| 214 |
+
x = F.relu(x)
|
| 215 |
+
|
| 216 |
+
x = self.conv3(x, edge_index)
|
| 217 |
+
else:
|
| 218 |
+
x = data.x
|
| 219 |
+
|
| 220 |
+
x = self.linear1(x)
|
| 221 |
+
x = F.relu(x)
|
| 222 |
+
|
| 223 |
+
x = self.linear2(x)
|
| 224 |
+
x = F.relu(x)
|
| 225 |
+
|
| 226 |
+
x = self.linear3(x)
|
| 227 |
+
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class MessagePassingCausalGNN(MessagePassing if HAS_TORCH_GEOMETRIC else nn.Module):
|
| 232 |
+
"""
|
| 233 |
+
Custom message passing for causal graphs.
|
| 234 |
+
|
| 235 |
+
Implements directed message passing that respects causal structure:
|
| 236 |
+
- Messages flow only in direction of causal edges
|
| 237 |
+
- Aggregation respects causal mechanisms
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int):
|
| 241 |
+
"""
|
| 242 |
+
Initialize message passing GNN.
|
| 243 |
+
|
| 244 |
+
Parameters
|
| 245 |
+
----------
|
| 246 |
+
node_dim : int
|
| 247 |
+
Node feature dimension
|
| 248 |
+
edge_dim : int
|
| 249 |
+
Edge feature dimension
|
| 250 |
+
hidden_dim : int
|
| 251 |
+
Hidden dimension
|
| 252 |
+
"""
|
| 253 |
+
if HAS_TORCH_GEOMETRIC:
|
| 254 |
+
super(MessagePassingCausalGNN, self).__init__(aggr='add')
|
| 255 |
+
else:
|
| 256 |
+
super(MessagePassingCausalGNN, self).__init__()
|
| 257 |
+
|
| 258 |
+
self.node_dim = node_dim
|
| 259 |
+
self.edge_dim = edge_dim
|
| 260 |
+
self.hidden_dim = hidden_dim
|
| 261 |
+
|
| 262 |
+
# Message function
|
| 263 |
+
self.message_mlp = nn.Sequential(
|
| 264 |
+
nn.Linear(node_dim + node_dim + edge_dim, hidden_dim),
|
| 265 |
+
nn.ReLU(),
|
| 266 |
+
nn.Linear(hidden_dim, hidden_dim)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Update function
|
| 270 |
+
self.update_mlp = nn.Sequential(
|
| 271 |
+
nn.Linear(node_dim + hidden_dim, hidden_dim),
|
| 272 |
+
nn.ReLU(),
|
| 273 |
+
nn.Linear(hidden_dim, node_dim)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def forward(self, x, edge_index, edge_attr):
|
| 277 |
+
"""
|
| 278 |
+
Forward pass.
|
| 279 |
+
|
| 280 |
+
Parameters
|
| 281 |
+
----------
|
| 282 |
+
x : torch.Tensor
|
| 283 |
+
Node features
|
| 284 |
+
edge_index : torch.Tensor
|
| 285 |
+
Edge indices
|
| 286 |
+
edge_attr : torch.Tensor
|
| 287 |
+
Edge attributes
|
| 288 |
+
|
| 289 |
+
Returns
|
| 290 |
+
-------
|
| 291 |
+
torch.Tensor
|
| 292 |
+
Updated node features
|
| 293 |
+
"""
|
| 294 |
+
if HAS_TORCH_GEOMETRIC:
|
| 295 |
+
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
|
| 296 |
+
else:
|
| 297 |
+
# Fallback implementation
|
| 298 |
+
return x
|
| 299 |
+
|
| 300 |
+
def message(self, x_i, x_j, edge_attr):
|
| 301 |
+
"""
|
| 302 |
+
Construct messages.
|
| 303 |
+
|
| 304 |
+
x_i: target node features
|
| 305 |
+
x_j: source node features
|
| 306 |
+
edge_attr: edge features
|
| 307 |
+
|
| 308 |
+
Returns
|
| 309 |
+
-------
|
| 310 |
+
torch.Tensor
|
| 311 |
+
Messages
|
| 312 |
+
"""
|
| 313 |
+
# Concatenate source, target, and edge features
|
| 314 |
+
msg_input = torch.cat([x_i, x_j, edge_attr], dim=-1)
|
| 315 |
+
return self.message_mlp(msg_input)
|
| 316 |
+
|
| 317 |
+
def update(self, aggr_out, x):
|
| 318 |
+
"""
|
| 319 |
+
Update node features.
|
| 320 |
+
|
| 321 |
+
Parameters
|
| 322 |
+
----------
|
| 323 |
+
aggr_out : torch.Tensor
|
| 324 |
+
Aggregated messages
|
| 325 |
+
x : torch.Tensor
|
| 326 |
+
Current node features
|
| 327 |
+
|
| 328 |
+
Returns
|
| 329 |
+
-------
|
| 330 |
+
torch.Tensor
|
| 331 |
+
Updated node features
|
| 332 |
+
"""
|
| 333 |
+
# Concatenate aggregated messages with current features
|
| 334 |
+
update_input = torch.cat([x, aggr_out], dim=-1)
|
| 335 |
+
return self.update_mlp(update_input)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class AttentionGNN(nn.Module):
|
| 339 |
+
"""
|
| 340 |
+
Graph Attention Network for geopolitical influence.
|
| 341 |
+
|
| 342 |
+
Uses attention to weight importance of different neighbors/allies.
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
def __init__(
|
| 346 |
+
self,
|
| 347 |
+
node_features: int,
|
| 348 |
+
hidden_dim: int,
|
| 349 |
+
output_dim: int,
|
| 350 |
+
num_heads: int = 4
|
| 351 |
+
):
|
| 352 |
+
"""
|
| 353 |
+
Initialize attention GNN.
|
| 354 |
+
|
| 355 |
+
Parameters
|
| 356 |
+
----------
|
| 357 |
+
node_features : int
|
| 358 |
+
Input node feature dimension
|
| 359 |
+
hidden_dim : int
|
| 360 |
+
Hidden dimension
|
| 361 |
+
output_dim : int
|
| 362 |
+
Output dimension
|
| 363 |
+
num_heads : int
|
| 364 |
+
Number of attention heads
|
| 365 |
+
"""
|
| 366 |
+
super(AttentionGNN, self).__init__()
|
| 367 |
+
|
| 368 |
+
if HAS_TORCH_GEOMETRIC:
|
| 369 |
+
self.conv1 = GATConv(node_features, hidden_dim, heads=num_heads, concat=True)
|
| 370 |
+
self.conv2 = GATConv(hidden_dim * num_heads, output_dim, heads=1, concat=False)
|
| 371 |
+
else:
|
| 372 |
+
# Fallback
|
| 373 |
+
self.linear1 = nn.Linear(node_features, hidden_dim * num_heads)
|
| 374 |
+
self.linear2 = nn.Linear(hidden_dim * num_heads, output_dim)
|
| 375 |
+
|
| 376 |
+
def forward(self, data):
|
| 377 |
+
"""
|
| 378 |
+
Forward pass with attention.
|
| 379 |
+
|
| 380 |
+
Parameters
|
| 381 |
+
----------
|
| 382 |
+
data : Data
|
| 383 |
+
Graph data
|
| 384 |
+
|
| 385 |
+
Returns
|
| 386 |
+
-------
|
| 387 |
+
torch.Tensor
|
| 388 |
+
Node embeddings with attention weights
|
| 389 |
+
"""
|
| 390 |
+
if HAS_TORCH_GEOMETRIC:
|
| 391 |
+
x, edge_index = data.x, data.edge_index
|
| 392 |
+
|
| 393 |
+
# First layer with multi-head attention
|
| 394 |
+
x, attention_weights1 = self.conv1(x, edge_index, return_attention_weights=True)
|
| 395 |
+
x = F.elu(x)
|
| 396 |
+
x = F.dropout(x, p=0.3, training=self.training)
|
| 397 |
+
|
| 398 |
+
# Second layer
|
| 399 |
+
x, attention_weights2 = self.conv2(x, edge_index, return_attention_weights=True)
|
| 400 |
+
|
| 401 |
+
return x, (attention_weights1, attention_weights2)
|
| 402 |
+
else:
|
| 403 |
+
x = data.x
|
| 404 |
+
|
| 405 |
+
x = self.linear1(x)
|
| 406 |
+
x = F.elu(x)
|
| 407 |
+
x = F.dropout(x, p=0.3, training=self.training)
|
| 408 |
+
|
| 409 |
+
x = self.linear2(x)
|
| 410 |
+
|
| 411 |
+
return x, None
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class GNNTrainer:
|
| 415 |
+
"""
|
| 416 |
+
Trainer for GNN models.
|
| 417 |
+
|
| 418 |
+
Handles training, validation, and evaluation.
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
model,
|
| 424 |
+
learning_rate: float = 0.01,
|
| 425 |
+
weight_decay: float = 5e-4
|
| 426 |
+
):
|
| 427 |
+
"""
|
| 428 |
+
Initialize GNN trainer.
|
| 429 |
+
|
| 430 |
+
Parameters
|
| 431 |
+
----------
|
| 432 |
+
model : nn.Module
|
| 433 |
+
GNN model
|
| 434 |
+
learning_rate : float
|
| 435 |
+
Learning rate
|
| 436 |
+
weight_decay : float
|
| 437 |
+
Weight decay (L2 regularization)
|
| 438 |
+
"""
|
| 439 |
+
if not HAS_TORCH:
|
| 440 |
+
raise ImportError("PyTorch required for GNN training")
|
| 441 |
+
|
| 442 |
+
self.model = model
|
| 443 |
+
self.optimizer = torch.optim.Adam(
|
| 444 |
+
model.parameters(),
|
| 445 |
+
lr=learning_rate,
|
| 446 |
+
weight_decay=weight_decay
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
def train_step(
|
| 450 |
+
self,
|
| 451 |
+
data,
|
| 452 |
+
labels,
|
| 453 |
+
loss_fn: Callable
|
| 454 |
+
) -> float:
|
| 455 |
+
"""
|
| 456 |
+
Single training step.
|
| 457 |
+
|
| 458 |
+
Parameters
|
| 459 |
+
----------
|
| 460 |
+
data : Data
|
| 461 |
+
Graph data
|
| 462 |
+
labels : torch.Tensor
|
| 463 |
+
Labels
|
| 464 |
+
loss_fn : callable
|
| 465 |
+
Loss function
|
| 466 |
+
|
| 467 |
+
Returns
|
| 468 |
+
-------
|
| 469 |
+
float
|
| 470 |
+
Loss value
|
| 471 |
+
"""
|
| 472 |
+
self.model.train()
|
| 473 |
+
self.optimizer.zero_grad()
|
| 474 |
+
|
| 475 |
+
# Forward pass
|
| 476 |
+
out = self.model(data)
|
| 477 |
+
|
| 478 |
+
# Compute loss
|
| 479 |
+
loss = loss_fn(out, labels)
|
| 480 |
+
|
| 481 |
+
# Backward pass
|
| 482 |
+
loss.backward()
|
| 483 |
+
self.optimizer.step()
|
| 484 |
+
|
| 485 |
+
return loss.item()
|
| 486 |
+
|
| 487 |
+
def evaluate(
|
| 488 |
+
self,
|
| 489 |
+
data,
|
| 490 |
+
labels,
|
| 491 |
+
metric_fn: Callable
|
| 492 |
+
) -> float:
|
| 493 |
+
"""
|
| 494 |
+
Evaluate model.
|
| 495 |
+
|
| 496 |
+
Parameters
|
| 497 |
+
----------
|
| 498 |
+
data : Data
|
| 499 |
+
Graph data
|
| 500 |
+
labels : torch.Tensor
|
| 501 |
+
True labels
|
| 502 |
+
metric_fn : callable
|
| 503 |
+
Evaluation metric
|
| 504 |
+
|
| 505 |
+
Returns
|
| 506 |
+
-------
|
| 507 |
+
float
|
| 508 |
+
Metric value
|
| 509 |
+
"""
|
| 510 |
+
self.model.eval()
|
| 511 |
+
|
| 512 |
+
with torch.no_grad():
|
| 513 |
+
out = self.model(data)
|
| 514 |
+
metric = metric_fn(out, labels)
|
| 515 |
+
|
| 516 |
+
return metric
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class NetworkToGraph:
|
| 520 |
+
"""
|
| 521 |
+
Convert NetworkX graph to PyTorch Geometric format.
|
| 522 |
+
"""
|
| 523 |
+
|
| 524 |
+
@staticmethod
|
| 525 |
+
def convert(
|
| 526 |
+
G: nx.Graph,
|
| 527 |
+
node_features: Optional[Dict] = None,
|
| 528 |
+
edge_features: Optional[Dict] = None
|
| 529 |
+
):
|
| 530 |
+
"""
|
| 531 |
+
Convert NetworkX graph to PyTorch Geometric Data.
|
| 532 |
+
|
| 533 |
+
Parameters
|
| 534 |
+
----------
|
| 535 |
+
G : nx.Graph
|
| 536 |
+
NetworkX graph
|
| 537 |
+
node_features : dict, optional
|
| 538 |
+
Node feature dictionary
|
| 539 |
+
edge_features : dict, optional
|
| 540 |
+
Edge feature dictionary
|
| 541 |
+
|
| 542 |
+
Returns
|
| 543 |
+
-------
|
| 544 |
+
Data
|
| 545 |
+
PyTorch Geometric Data object
|
| 546 |
+
"""
|
| 547 |
+
if not HAS_TORCH or not HAS_TORCH_GEOMETRIC:
|
| 548 |
+
raise ImportError("PyTorch and torch_geometric required")
|
| 549 |
+
|
| 550 |
+
# Node features
|
| 551 |
+
if node_features:
|
| 552 |
+
x = torch.tensor([
|
| 553 |
+
node_features[node]
|
| 554 |
+
for node in G.nodes()
|
| 555 |
+
], dtype=torch.float)
|
| 556 |
+
else:
|
| 557 |
+
# Default: one-hot encoding
|
| 558 |
+
n_nodes = G.number_of_nodes()
|
| 559 |
+
x = torch.eye(n_nodes)
|
| 560 |
+
|
| 561 |
+
# Edge index
|
| 562 |
+
edge_index = torch.tensor(
|
| 563 |
+
list(G.edges()),
|
| 564 |
+
dtype=torch.long
|
| 565 |
+
).t().contiguous()
|
| 566 |
+
|
| 567 |
+
# Edge features
|
| 568 |
+
if edge_features:
|
| 569 |
+
edge_attr = torch.tensor([
|
| 570 |
+
edge_features[(u, v)]
|
| 571 |
+
for u, v in G.edges()
|
| 572 |
+
], dtype=torch.float)
|
| 573 |
+
else:
|
| 574 |
+
edge_attr = None
|
| 575 |
+
|
| 576 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
| 577 |
+
|
| 578 |
+
return data
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def example_geopolitical_gnn():
|
| 582 |
+
"""
|
| 583 |
+
Example: GNN for geopolitical alliance network.
|
| 584 |
+
"""
|
| 585 |
+
if not HAS_TORCH or not HAS_TORCH_GEOMETRIC:
|
| 586 |
+
print("PyTorch and torch_geometric required for GNN examples")
|
| 587 |
+
return
|
| 588 |
+
|
| 589 |
+
# Create example graph (alliance network)
|
| 590 |
+
G = nx.DiGraph()
|
| 591 |
+
countries = ['USA', 'China', 'Russia', 'EU', 'India']
|
| 592 |
+
G.add_nodes_from(countries)
|
| 593 |
+
|
| 594 |
+
# Add alliance edges
|
| 595 |
+
alliances = [
|
| 596 |
+
('USA', 'EU'),
|
| 597 |
+
('USA', 'India'),
|
| 598 |
+
('China', 'Russia'),
|
| 599 |
+
]
|
| 600 |
+
G.add_edges_from(alliances)
|
| 601 |
+
|
| 602 |
+
# Node features (e.g., GDP, military strength, etc.)
|
| 603 |
+
node_features = {
|
| 604 |
+
'USA': [1.0, 0.9, 0.8],
|
| 605 |
+
'China': [0.8, 0.7, 0.9],
|
| 606 |
+
'Russia': [0.5, 0.7, 0.6],
|
| 607 |
+
'EU': [0.9, 0.6, 0.7],
|
| 608 |
+
'India': [0.6, 0.6, 0.7]
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
# Convert to PyTorch Geometric format
|
| 612 |
+
data = NetworkToGraph.convert(G, node_features)
|
| 613 |
+
|
| 614 |
+
# Create model
|
| 615 |
+
model = CausalGNN(
|
| 616 |
+
node_features=3,
|
| 617 |
+
hidden_dim=16,
|
| 618 |
+
output_dim=8,
|
| 619 |
+
num_layers=2
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Forward pass
|
| 623 |
+
embeddings = model(data)
|
| 624 |
+
|
| 625 |
+
print("Node embeddings shape:", embeddings.shape)
|
| 626 |
+
print("Node embeddings:", embeddings)
|
| 627 |
+
|
| 628 |
+
return model, data, embeddings
|
geobot/ml/risk_scoring.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Risk scoring using gradient boosting and ensemble methods.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 8 |
+
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
|
| 9 |
+
from sklearn.model_selection import cross_val_score
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RiskScorer:
|
| 13 |
+
"""
|
| 14 |
+
Nonlinear risk scoring using ensemble methods.
|
| 15 |
+
|
| 16 |
+
Uses Gradient Boosting and Random Forests to discover nonlinear
|
| 17 |
+
patterns in geopolitical risk factors.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, method: str = 'gradient_boosting'):
|
| 21 |
+
"""
|
| 22 |
+
Initialize risk scorer.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
method : str
|
| 27 |
+
Method to use ('gradient_boosting', 'random_forest', 'ensemble')
|
| 28 |
+
"""
|
| 29 |
+
self.method = method
|
| 30 |
+
self.model = None
|
| 31 |
+
self.feature_names = None
|
| 32 |
+
self.is_trained = False
|
| 33 |
+
|
| 34 |
+
def train(
|
| 35 |
+
self,
|
| 36 |
+
X: pd.DataFrame,
|
| 37 |
+
y: np.ndarray,
|
| 38 |
+
n_estimators: int = 100,
|
| 39 |
+
max_depth: int = 5
|
| 40 |
+
) -> Dict[str, Any]:
|
| 41 |
+
"""
|
| 42 |
+
Train risk scoring model.
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
----------
|
| 46 |
+
X : pd.DataFrame
|
| 47 |
+
Feature matrix
|
| 48 |
+
y : np.ndarray
|
| 49 |
+
Risk labels (0 = low risk, 1 = high risk)
|
| 50 |
+
n_estimators : int
|
| 51 |
+
Number of estimators
|
| 52 |
+
max_depth : int
|
| 53 |
+
Maximum tree depth
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
dict
|
| 58 |
+
Training results
|
| 59 |
+
"""
|
| 60 |
+
self.feature_names = X.columns.tolist()
|
| 61 |
+
|
| 62 |
+
if self.method == 'gradient_boosting':
|
| 63 |
+
self.model = GradientBoostingClassifier(
|
| 64 |
+
n_estimators=n_estimators,
|
| 65 |
+
max_depth=max_depth,
|
| 66 |
+
random_state=42
|
| 67 |
+
)
|
| 68 |
+
elif self.method == 'random_forest':
|
| 69 |
+
self.model = RandomForestClassifier(
|
| 70 |
+
n_estimators=n_estimators,
|
| 71 |
+
max_depth=max_depth,
|
| 72 |
+
random_state=42
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Unknown method: {self.method}")
|
| 76 |
+
|
| 77 |
+
# Train
|
| 78 |
+
self.model.fit(X, y)
|
| 79 |
+
self.is_trained = True
|
| 80 |
+
|
| 81 |
+
# Cross-validation score
|
| 82 |
+
cv_scores = cross_val_score(self.model, X, y, cv=5)
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
'cv_mean': cv_scores.mean(),
|
| 86 |
+
'cv_std': cv_scores.std(),
|
| 87 |
+
'feature_importance': self.get_feature_importance()
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def predict_risk(self, X: pd.DataFrame) -> np.ndarray:
|
| 91 |
+
"""
|
| 92 |
+
Predict risk scores.
|
| 93 |
+
|
| 94 |
+
Parameters
|
| 95 |
+
----------
|
| 96 |
+
X : pd.DataFrame
|
| 97 |
+
Features
|
| 98 |
+
|
| 99 |
+
Returns
|
| 100 |
+
-------
|
| 101 |
+
np.ndarray
|
| 102 |
+
Risk probabilities
|
| 103 |
+
"""
|
| 104 |
+
if not self.is_trained:
|
| 105 |
+
raise ValueError("Model not trained yet")
|
| 106 |
+
|
| 107 |
+
return self.model.predict_proba(X)[:, 1]
|
| 108 |
+
|
| 109 |
+
def get_feature_importance(self) -> Dict[str, float]:
|
| 110 |
+
"""
|
| 111 |
+
Get feature importance scores.
|
| 112 |
+
|
| 113 |
+
Returns
|
| 114 |
+
-------
|
| 115 |
+
dict
|
| 116 |
+
Feature importance
|
| 117 |
+
"""
|
| 118 |
+
if not self.is_trained:
|
| 119 |
+
raise ValueError("Model not trained yet")
|
| 120 |
+
|
| 121 |
+
importance = self.model.feature_importances_
|
| 122 |
+
return dict(zip(self.feature_names, importance))
|
| 123 |
+
|
| 124 |
+
def explain_prediction(self, X: pd.DataFrame, index: int) -> Dict[str, Any]:
|
| 125 |
+
"""
|
| 126 |
+
Explain a specific prediction.
|
| 127 |
+
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
X : pd.DataFrame
|
| 131 |
+
Features
|
| 132 |
+
index : int
|
| 133 |
+
Index of sample to explain
|
| 134 |
+
|
| 135 |
+
Returns
|
| 136 |
+
-------
|
| 137 |
+
dict
|
| 138 |
+
Explanation
|
| 139 |
+
"""
|
| 140 |
+
if not self.is_trained:
|
| 141 |
+
raise ValueError("Model not trained yet")
|
| 142 |
+
|
| 143 |
+
sample = X.iloc[index:index+1]
|
| 144 |
+
risk_score = self.predict_risk(sample)[0]
|
| 145 |
+
|
| 146 |
+
# Feature contributions (simplified)
|
| 147 |
+
feature_values = sample.iloc[0].to_dict()
|
| 148 |
+
feature_importance = self.get_feature_importance()
|
| 149 |
+
|
| 150 |
+
contributions = {
|
| 151 |
+
feat: feature_values[feat] * feature_importance[feat]
|
| 152 |
+
for feat in feature_values.keys()
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
'risk_score': risk_score,
|
| 157 |
+
'top_risk_factors': sorted(contributions.items(), key=lambda x: x[1], reverse=True)[:5],
|
| 158 |
+
'feature_values': feature_values
|
| 159 |
+
}
|
geobot/models/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Causal models and structural frameworks for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .causal_graph import CausalGraph, StructuralCausalModel
|
| 6 |
+
from .causal_discovery import CausalDiscovery
|
| 7 |
+
from .quasi_experimental import (
|
| 8 |
+
SyntheticControlMethod,
|
| 9 |
+
DifferenceinDifferences,
|
| 10 |
+
RegressionDiscontinuity,
|
| 11 |
+
InstrumentalVariables,
|
| 12 |
+
SyntheticControlResult,
|
| 13 |
+
DIDResult,
|
| 14 |
+
RDDResult,
|
| 15 |
+
IVResult,
|
| 16 |
+
estimate_treatment_effect_bounds
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"CausalGraph",
|
| 21 |
+
"StructuralCausalModel",
|
| 22 |
+
"CausalDiscovery",
|
| 23 |
+
"SyntheticControlMethod",
|
| 24 |
+
"DifferenceinDifferences",
|
| 25 |
+
"RegressionDiscontinuity",
|
| 26 |
+
"InstrumentalVariables",
|
| 27 |
+
"SyntheticControlResult",
|
| 28 |
+
"DIDResult",
|
| 29 |
+
"RDDResult",
|
| 30 |
+
"IVResult",
|
| 31 |
+
"estimate_treatment_effect_bounds",
|
| 32 |
+
]
|
geobot/models/causal_discovery.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Causal Discovery Module
|
| 3 |
+
|
| 4 |
+
Discover causal relationships from observational data using various algorithms.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from typing import Optional, Dict, List, Tuple
|
| 10 |
+
from .causal_graph import CausalGraph
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CausalDiscovery:
|
| 14 |
+
"""
|
| 15 |
+
Discover causal structures from data.
|
| 16 |
+
|
| 17 |
+
Implements various causal discovery algorithms to learn
|
| 18 |
+
causal graphs from observational data.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, method: str = 'pc'):
|
| 22 |
+
"""
|
| 23 |
+
Initialize causal discovery.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
method : str
|
| 28 |
+
Discovery method ('pc', 'ges', 'lingam')
|
| 29 |
+
"""
|
| 30 |
+
self.method = method
|
| 31 |
+
|
| 32 |
+
def discover_from_data(
|
| 33 |
+
self,
|
| 34 |
+
data: pd.DataFrame,
|
| 35 |
+
alpha: float = 0.05,
|
| 36 |
+
max_cond_vars: int = 3
|
| 37 |
+
) -> CausalGraph:
|
| 38 |
+
"""
|
| 39 |
+
Discover causal graph from data.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
data : pd.DataFrame
|
| 44 |
+
Observational data
|
| 45 |
+
alpha : float
|
| 46 |
+
Significance level for independence tests
|
| 47 |
+
max_cond_vars : int
|
| 48 |
+
Maximum number of conditioning variables
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
CausalGraph
|
| 53 |
+
Discovered causal graph
|
| 54 |
+
"""
|
| 55 |
+
if self.method == 'pc':
|
| 56 |
+
return self._pc_algorithm(data, alpha, max_cond_vars)
|
| 57 |
+
elif self.method == 'ges':
|
| 58 |
+
return self._ges_algorithm(data)
|
| 59 |
+
elif self.method == 'lingam':
|
| 60 |
+
return self._lingam_algorithm(data)
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Unknown method: {self.method}")
|
| 63 |
+
|
| 64 |
+
def _pc_algorithm(
|
| 65 |
+
self,
|
| 66 |
+
data: pd.DataFrame,
|
| 67 |
+
alpha: float,
|
| 68 |
+
max_cond_vars: int
|
| 69 |
+
) -> CausalGraph:
|
| 70 |
+
"""
|
| 71 |
+
PC (Peter-Clark) algorithm for causal discovery.
|
| 72 |
+
|
| 73 |
+
This is a constraint-based algorithm that uses conditional
|
| 74 |
+
independence tests to discover causal structure.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
data : pd.DataFrame
|
| 79 |
+
Observational data
|
| 80 |
+
alpha : float
|
| 81 |
+
Significance level
|
| 82 |
+
max_cond_vars : int
|
| 83 |
+
Maximum conditioning set size
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
CausalGraph
|
| 88 |
+
Discovered graph
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
from pgmpy.estimators import PC
|
| 92 |
+
from pgmpy.independence_tests import ChiSquareTest
|
| 93 |
+
|
| 94 |
+
# PC algorithm
|
| 95 |
+
pc = PC(data=data)
|
| 96 |
+
model = pc.estimate(
|
| 97 |
+
significance_level=alpha,
|
| 98 |
+
max_cond_vars=max_cond_vars
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Convert to CausalGraph
|
| 102 |
+
graph = CausalGraph(name="pc_discovered")
|
| 103 |
+
|
| 104 |
+
# Add nodes
|
| 105 |
+
for node in model.nodes():
|
| 106 |
+
graph.add_node(node)
|
| 107 |
+
|
| 108 |
+
# Add edges
|
| 109 |
+
for edge in model.edges():
|
| 110 |
+
graph.add_edge(edge[0], edge[1])
|
| 111 |
+
|
| 112 |
+
return graph
|
| 113 |
+
|
| 114 |
+
except ImportError:
|
| 115 |
+
print("pgmpy required for PC algorithm")
|
| 116 |
+
return self._simple_correlation_graph(data)
|
| 117 |
+
|
| 118 |
+
def _ges_algorithm(self, data: pd.DataFrame) -> CausalGraph:
|
| 119 |
+
"""
|
| 120 |
+
GES (Greedy Equivalence Search) algorithm.
|
| 121 |
+
|
| 122 |
+
Score-based causal discovery algorithm.
|
| 123 |
+
|
| 124 |
+
Parameters
|
| 125 |
+
----------
|
| 126 |
+
data : pd.DataFrame
|
| 127 |
+
Observational data
|
| 128 |
+
|
| 129 |
+
Returns
|
| 130 |
+
-------
|
| 131 |
+
CausalGraph
|
| 132 |
+
Discovered graph
|
| 133 |
+
"""
|
| 134 |
+
# Placeholder - requires causal-learn or similar
|
| 135 |
+
print("GES algorithm not fully implemented yet")
|
| 136 |
+
return self._simple_correlation_graph(data)
|
| 137 |
+
|
| 138 |
+
def _lingam_algorithm(self, data: pd.DataFrame) -> CausalGraph:
|
| 139 |
+
"""
|
| 140 |
+
LiNGAM (Linear Non-Gaussian Acyclic Model) algorithm.
|
| 141 |
+
|
| 142 |
+
Assumes linear relationships and non-Gaussian noise.
|
| 143 |
+
|
| 144 |
+
Parameters
|
| 145 |
+
----------
|
| 146 |
+
data : pd.DataFrame
|
| 147 |
+
Observational data
|
| 148 |
+
|
| 149 |
+
Returns
|
| 150 |
+
-------
|
| 151 |
+
CausalGraph
|
| 152 |
+
Discovered graph
|
| 153 |
+
"""
|
| 154 |
+
# Placeholder - requires lingam package
|
| 155 |
+
print("LiNGAM algorithm not fully implemented yet")
|
| 156 |
+
return self._simple_correlation_graph(data)
|
| 157 |
+
|
| 158 |
+
def _simple_correlation_graph(
|
| 159 |
+
self,
|
| 160 |
+
data: pd.DataFrame,
|
| 161 |
+
threshold: float = 0.3
|
| 162 |
+
) -> CausalGraph:
|
| 163 |
+
"""
|
| 164 |
+
Create a simple graph based on correlations.
|
| 165 |
+
|
| 166 |
+
This is a fallback method and does NOT imply causation.
|
| 167 |
+
|
| 168 |
+
Parameters
|
| 169 |
+
----------
|
| 170 |
+
data : pd.DataFrame
|
| 171 |
+
Data
|
| 172 |
+
threshold : float
|
| 173 |
+
Correlation threshold
|
| 174 |
+
|
| 175 |
+
Returns
|
| 176 |
+
-------
|
| 177 |
+
CausalGraph
|
| 178 |
+
Correlation-based graph
|
| 179 |
+
"""
|
| 180 |
+
graph = CausalGraph(name="correlation_based")
|
| 181 |
+
|
| 182 |
+
# Add nodes
|
| 183 |
+
for col in data.columns:
|
| 184 |
+
graph.add_node(col)
|
| 185 |
+
|
| 186 |
+
# Add edges based on correlation
|
| 187 |
+
corr_matrix = data.corr()
|
| 188 |
+
|
| 189 |
+
for i, col1 in enumerate(data.columns):
|
| 190 |
+
for j, col2 in enumerate(data.columns):
|
| 191 |
+
if i < j: # Avoid duplicates
|
| 192 |
+
corr = abs(corr_matrix.loc[col1, col2])
|
| 193 |
+
if corr > threshold:
|
| 194 |
+
# Arbitrary direction - this is NOT causal
|
| 195 |
+
try:
|
| 196 |
+
graph.add_edge(
|
| 197 |
+
col1, col2,
|
| 198 |
+
strength=corr,
|
| 199 |
+
confidence=0.5,
|
| 200 |
+
mechanism="correlation (not causal)"
|
| 201 |
+
)
|
| 202 |
+
except ValueError:
|
| 203 |
+
# Would create cycle, try other direction
|
| 204 |
+
try:
|
| 205 |
+
graph.add_edge(
|
| 206 |
+
col2, col1,
|
| 207 |
+
strength=corr,
|
| 208 |
+
confidence=0.5,
|
| 209 |
+
mechanism="correlation (not causal)"
|
| 210 |
+
)
|
| 211 |
+
except ValueError:
|
| 212 |
+
# Both directions create cycles, skip
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
return graph
|
| 216 |
+
|
| 217 |
+
def test_conditional_independence(
|
| 218 |
+
self,
|
| 219 |
+
data: pd.DataFrame,
|
| 220 |
+
X: str,
|
| 221 |
+
Y: str,
|
| 222 |
+
Z: Optional[List[str]] = None,
|
| 223 |
+
method: str = 'fisherz'
|
| 224 |
+
) -> Tuple[float, float]:
|
| 225 |
+
"""
|
| 226 |
+
Test conditional independence X ⊥ Y | Z.
|
| 227 |
+
|
| 228 |
+
Parameters
|
| 229 |
+
----------
|
| 230 |
+
data : pd.DataFrame
|
| 231 |
+
Data
|
| 232 |
+
X : str
|
| 233 |
+
First variable
|
| 234 |
+
Y : str
|
| 235 |
+
Second variable
|
| 236 |
+
Z : List[str], optional
|
| 237 |
+
Conditioning variables
|
| 238 |
+
method : str
|
| 239 |
+
Test method ('fisherz', 'chi_square')
|
| 240 |
+
|
| 241 |
+
Returns
|
| 242 |
+
-------
|
| 243 |
+
tuple
|
| 244 |
+
(test_statistic, p_value)
|
| 245 |
+
"""
|
| 246 |
+
if Z is None:
|
| 247 |
+
Z = []
|
| 248 |
+
|
| 249 |
+
if method == 'fisherz':
|
| 250 |
+
return self._fisherz_test(data, X, Y, Z)
|
| 251 |
+
elif method == 'chi_square':
|
| 252 |
+
return self._chi_square_test(data, X, Y, Z)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unknown test method: {method}")
|
| 255 |
+
|
| 256 |
+
def _fisherz_test(
|
| 257 |
+
self,
|
| 258 |
+
data: pd.DataFrame,
|
| 259 |
+
X: str,
|
| 260 |
+
Y: str,
|
| 261 |
+
Z: List[str]
|
| 262 |
+
) -> Tuple[float, float]:
|
| 263 |
+
"""
|
| 264 |
+
Fisher's Z test for conditional independence.
|
| 265 |
+
|
| 266 |
+
Parameters
|
| 267 |
+
----------
|
| 268 |
+
data : pd.DataFrame
|
| 269 |
+
Data
|
| 270 |
+
X : str
|
| 271 |
+
First variable
|
| 272 |
+
Y : str
|
| 273 |
+
Second variable
|
| 274 |
+
Z : List[str]
|
| 275 |
+
Conditioning variables
|
| 276 |
+
|
| 277 |
+
Returns
|
| 278 |
+
-------
|
| 279 |
+
tuple
|
| 280 |
+
(test_statistic, p_value)
|
| 281 |
+
"""
|
| 282 |
+
from scipy.stats import norm
|
| 283 |
+
|
| 284 |
+
n = len(data)
|
| 285 |
+
|
| 286 |
+
if len(Z) == 0:
|
| 287 |
+
# Unconditional correlation
|
| 288 |
+
corr = data[[X, Y]].corr().loc[X, Y]
|
| 289 |
+
else:
|
| 290 |
+
# Partial correlation
|
| 291 |
+
all_vars = [X, Y] + Z
|
| 292 |
+
corr_matrix = data[all_vars].corr()
|
| 293 |
+
|
| 294 |
+
# Compute partial correlation
|
| 295 |
+
# This is a simplified version
|
| 296 |
+
corr_XY = corr_matrix.loc[X, Y]
|
| 297 |
+
corr = corr_XY # Placeholder
|
| 298 |
+
|
| 299 |
+
# Fisher's Z transformation
|
| 300 |
+
if abs(corr) >= 0.9999:
|
| 301 |
+
corr = 0.9999 * np.sign(corr)
|
| 302 |
+
|
| 303 |
+
z = 0.5 * np.log((1 + corr) / (1 - corr))
|
| 304 |
+
test_stat = np.sqrt(n - len(Z) - 3) * z
|
| 305 |
+
|
| 306 |
+
# Two-tailed p-value
|
| 307 |
+
p_value = 2 * (1 - norm.cdf(abs(test_stat)))
|
| 308 |
+
|
| 309 |
+
return test_stat, p_value
|
| 310 |
+
|
| 311 |
+
def _chi_square_test(
|
| 312 |
+
self,
|
| 313 |
+
data: pd.DataFrame,
|
| 314 |
+
X: str,
|
| 315 |
+
Y: str,
|
| 316 |
+
Z: List[str]
|
| 317 |
+
) -> Tuple[float, float]:
|
| 318 |
+
"""
|
| 319 |
+
Chi-square test for conditional independence.
|
| 320 |
+
|
| 321 |
+
Parameters
|
| 322 |
+
----------
|
| 323 |
+
data : pd.DataFrame
|
| 324 |
+
Data
|
| 325 |
+
X : str
|
| 326 |
+
First variable
|
| 327 |
+
Y : str
|
| 328 |
+
Second variable
|
| 329 |
+
Z : List[str]
|
| 330 |
+
Conditioning variables
|
| 331 |
+
|
| 332 |
+
Returns
|
| 333 |
+
-------
|
| 334 |
+
tuple
|
| 335 |
+
(test_statistic, p_value)
|
| 336 |
+
"""
|
| 337 |
+
from scipy.stats import chi2_contingency
|
| 338 |
+
|
| 339 |
+
if len(Z) == 0:
|
| 340 |
+
# Unconditional test
|
| 341 |
+
contingency_table = pd.crosstab(data[X], data[Y])
|
| 342 |
+
chi2, p_value, dof, expected = chi2_contingency(contingency_table)
|
| 343 |
+
return chi2, p_value
|
| 344 |
+
else:
|
| 345 |
+
# Conditional test - stratify by Z
|
| 346 |
+
# This is simplified
|
| 347 |
+
chi2_sum = 0
|
| 348 |
+
dof_sum = 0
|
| 349 |
+
|
| 350 |
+
for z_value in data[Z[0]].unique():
|
| 351 |
+
subset = data[data[Z[0]] == z_value]
|
| 352 |
+
if len(subset) > 1:
|
| 353 |
+
contingency_table = pd.crosstab(subset[X], subset[Y])
|
| 354 |
+
if contingency_table.shape[0] > 1 and contingency_table.shape[1] > 1:
|
| 355 |
+
chi2, _, dof, _ = chi2_contingency(contingency_table)
|
| 356 |
+
chi2_sum += chi2
|
| 357 |
+
dof_sum += dof
|
| 358 |
+
|
| 359 |
+
# Approximate p-value
|
| 360 |
+
from scipy.stats import chi2
|
| 361 |
+
p_value = 1 - chi2.cdf(chi2_sum, dof_sum) if dof_sum > 0 else 1.0
|
| 362 |
+
|
| 363 |
+
return chi2_sum, p_value
|
geobot/models/causal_graph.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Causal Graph Module - DAG Representation
|
| 3 |
+
|
| 4 |
+
Provides infrastructure for representing and analyzing causal relationships
|
| 5 |
+
in geopolitical systems using Directed Acyclic Graphs (DAGs).
|
| 6 |
+
|
| 7 |
+
This module answers:
|
| 8 |
+
- What causes conflict?
|
| 9 |
+
- What causes collapse?
|
| 10 |
+
- What causes escalation?
|
| 11 |
+
- What causes mobilization?
|
| 12 |
+
- What causes instability?
|
| 13 |
+
|
| 14 |
+
Critical for: Real forecasting of interventions, not just correlation-based guessing.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import networkx as nx
|
| 19 |
+
from typing import Dict, List, Set, Optional, Tuple, Any, Callable
|
| 20 |
+
from dataclasses import dataclass, field
|
| 21 |
+
import json
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class CausalEdge:
|
| 26 |
+
"""
|
| 27 |
+
Represents a causal edge in the graph.
|
| 28 |
+
|
| 29 |
+
Attributes
|
| 30 |
+
----------
|
| 31 |
+
source : str
|
| 32 |
+
Source node (cause)
|
| 33 |
+
target : str
|
| 34 |
+
Target node (effect)
|
| 35 |
+
strength : float
|
| 36 |
+
Strength of causal relationship (-1 to 1)
|
| 37 |
+
confidence : float
|
| 38 |
+
Confidence in this relationship (0 to 1)
|
| 39 |
+
mechanism : str
|
| 40 |
+
Description of causal mechanism
|
| 41 |
+
"""
|
| 42 |
+
source: str
|
| 43 |
+
target: str
|
| 44 |
+
strength: float = 1.0
|
| 45 |
+
confidence: float = 1.0
|
| 46 |
+
mechanism: str = ""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CausalGraph:
|
| 50 |
+
"""
|
| 51 |
+
Directed Acyclic Graph (DAG) for causal relationships.
|
| 52 |
+
|
| 53 |
+
This class provides the foundation for causal inference in geopolitical
|
| 54 |
+
forecasting, ensuring that we understand what actually causes events
|
| 55 |
+
rather than just observing correlations.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, name: str = "geopolitical_dag"):
|
| 59 |
+
"""
|
| 60 |
+
Initialize causal graph.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
name : str
|
| 65 |
+
Name of the causal graph
|
| 66 |
+
"""
|
| 67 |
+
self.name = name
|
| 68 |
+
self.graph = nx.DiGraph()
|
| 69 |
+
self.edges: List[CausalEdge] = []
|
| 70 |
+
self.node_metadata: Dict[str, Dict[str, Any]] = {}
|
| 71 |
+
|
| 72 |
+
def add_node(
|
| 73 |
+
self,
|
| 74 |
+
node: str,
|
| 75 |
+
node_type: str = "variable",
|
| 76 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 77 |
+
) -> None:
|
| 78 |
+
"""
|
| 79 |
+
Add a node to the causal graph.
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
node : str
|
| 84 |
+
Node identifier
|
| 85 |
+
node_type : str
|
| 86 |
+
Type of node ('variable', 'event', 'policy', 'state')
|
| 87 |
+
metadata : dict, optional
|
| 88 |
+
Additional metadata for the node
|
| 89 |
+
"""
|
| 90 |
+
self.graph.add_node(node)
|
| 91 |
+
self.node_metadata[node] = {
|
| 92 |
+
'type': node_type,
|
| 93 |
+
'metadata': metadata or {}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def add_edge(
|
| 97 |
+
self,
|
| 98 |
+
source: str,
|
| 99 |
+
target: str,
|
| 100 |
+
strength: float = 1.0,
|
| 101 |
+
confidence: float = 1.0,
|
| 102 |
+
mechanism: str = ""
|
| 103 |
+
) -> None:
|
| 104 |
+
"""
|
| 105 |
+
Add a causal edge to the graph.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
source : str
|
| 110 |
+
Source node (cause)
|
| 111 |
+
target : str
|
| 112 |
+
Target node (effect)
|
| 113 |
+
strength : float
|
| 114 |
+
Strength of causal relationship
|
| 115 |
+
confidence : float
|
| 116 |
+
Confidence in this relationship
|
| 117 |
+
mechanism : str
|
| 118 |
+
Description of causal mechanism
|
| 119 |
+
"""
|
| 120 |
+
# Check for cycles
|
| 121 |
+
if not self._would_create_cycle(source, target):
|
| 122 |
+
self.graph.add_edge(source, target)
|
| 123 |
+
edge = CausalEdge(source, target, strength, confidence, mechanism)
|
| 124 |
+
self.edges.append(edge)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"Adding edge {source} -> {target} would create a cycle")
|
| 127 |
+
|
| 128 |
+
def remove_edge(self, source: str, target: str) -> None:
|
| 129 |
+
"""
|
| 130 |
+
Remove a causal edge.
|
| 131 |
+
|
| 132 |
+
Parameters
|
| 133 |
+
----------
|
| 134 |
+
source : str
|
| 135 |
+
Source node
|
| 136 |
+
target : str
|
| 137 |
+
Target node
|
| 138 |
+
"""
|
| 139 |
+
if self.graph.has_edge(source, target):
|
| 140 |
+
self.graph.remove_edge(source, target)
|
| 141 |
+
self.edges = [e for e in self.edges if not (e.source == source and e.target == target)]
|
| 142 |
+
|
| 143 |
+
def _would_create_cycle(self, source: str, target: str) -> bool:
|
| 144 |
+
"""
|
| 145 |
+
Check if adding an edge would create a cycle.
|
| 146 |
+
|
| 147 |
+
Parameters
|
| 148 |
+
----------
|
| 149 |
+
source : str
|
| 150 |
+
Source node
|
| 151 |
+
target : str
|
| 152 |
+
Target node
|
| 153 |
+
|
| 154 |
+
Returns
|
| 155 |
+
-------
|
| 156 |
+
bool
|
| 157 |
+
True if edge would create cycle
|
| 158 |
+
"""
|
| 159 |
+
# Add nodes if they don't exist
|
| 160 |
+
if source not in self.graph:
|
| 161 |
+
self.graph.add_node(source)
|
| 162 |
+
if target not in self.graph:
|
| 163 |
+
self.graph.add_node(target)
|
| 164 |
+
|
| 165 |
+
# Temporarily add edge and check for cycles
|
| 166 |
+
self.graph.add_edge(source, target)
|
| 167 |
+
has_cycle = not nx.is_directed_acyclic_graph(self.graph)
|
| 168 |
+
self.graph.remove_edge(source, target)
|
| 169 |
+
|
| 170 |
+
return has_cycle
|
| 171 |
+
|
| 172 |
+
def get_parents(self, node: str) -> List[str]:
|
| 173 |
+
"""
|
| 174 |
+
Get direct parents (causes) of a node.
|
| 175 |
+
|
| 176 |
+
Parameters
|
| 177 |
+
----------
|
| 178 |
+
node : str
|
| 179 |
+
Node identifier
|
| 180 |
+
|
| 181 |
+
Returns
|
| 182 |
+
-------
|
| 183 |
+
List[str]
|
| 184 |
+
List of parent nodes
|
| 185 |
+
"""
|
| 186 |
+
return list(self.graph.predecessors(node))
|
| 187 |
+
|
| 188 |
+
def get_children(self, node: str) -> List[str]:
|
| 189 |
+
"""
|
| 190 |
+
Get direct children (effects) of a node.
|
| 191 |
+
|
| 192 |
+
Parameters
|
| 193 |
+
----------
|
| 194 |
+
node : str
|
| 195 |
+
Node identifier
|
| 196 |
+
|
| 197 |
+
Returns
|
| 198 |
+
-------
|
| 199 |
+
List[str]
|
| 200 |
+
List of child nodes
|
| 201 |
+
"""
|
| 202 |
+
return list(self.graph.successors(node))
|
| 203 |
+
|
| 204 |
+
def get_ancestors(self, node: str) -> Set[str]:
|
| 205 |
+
"""
|
| 206 |
+
Get all ancestors (causes) of a node.
|
| 207 |
+
|
| 208 |
+
Parameters
|
| 209 |
+
----------
|
| 210 |
+
node : str
|
| 211 |
+
Node identifier
|
| 212 |
+
|
| 213 |
+
Returns
|
| 214 |
+
-------
|
| 215 |
+
Set[str]
|
| 216 |
+
Set of ancestor nodes
|
| 217 |
+
"""
|
| 218 |
+
return nx.ancestors(self.graph, node)
|
| 219 |
+
|
| 220 |
+
def get_descendants(self, node: str) -> Set[str]:
|
| 221 |
+
"""
|
| 222 |
+
Get all descendants (effects) of a node.
|
| 223 |
+
|
| 224 |
+
Parameters
|
| 225 |
+
----------
|
| 226 |
+
node : str
|
| 227 |
+
Node identifier
|
| 228 |
+
|
| 229 |
+
Returns
|
| 230 |
+
-------
|
| 231 |
+
Set[str]
|
| 232 |
+
Set of descendant nodes
|
| 233 |
+
"""
|
| 234 |
+
return nx.descendants(self.graph, node)
|
| 235 |
+
|
| 236 |
+
def get_topological_order(self) -> List[str]:
|
| 237 |
+
"""
|
| 238 |
+
Get topological ordering of nodes.
|
| 239 |
+
|
| 240 |
+
This is useful for computing values in causal order.
|
| 241 |
+
|
| 242 |
+
Returns
|
| 243 |
+
-------
|
| 244 |
+
List[str]
|
| 245 |
+
Nodes in topological order
|
| 246 |
+
"""
|
| 247 |
+
return list(nx.topological_sort(self.graph))
|
| 248 |
+
|
| 249 |
+
def is_ancestor(self, node1: str, node2: str) -> bool:
|
| 250 |
+
"""
|
| 251 |
+
Check if node1 is an ancestor of node2.
|
| 252 |
+
|
| 253 |
+
Parameters
|
| 254 |
+
----------
|
| 255 |
+
node1 : str
|
| 256 |
+
Potential ancestor
|
| 257 |
+
node2 : str
|
| 258 |
+
Potential descendant
|
| 259 |
+
|
| 260 |
+
Returns
|
| 261 |
+
-------
|
| 262 |
+
bool
|
| 263 |
+
True if node1 is ancestor of node2
|
| 264 |
+
"""
|
| 265 |
+
return node1 in self.get_ancestors(node2)
|
| 266 |
+
|
| 267 |
+
def is_descendant(self, node1: str, node2: str) -> bool:
|
| 268 |
+
"""
|
| 269 |
+
Check if node1 is a descendant of node2.
|
| 270 |
+
|
| 271 |
+
Parameters
|
| 272 |
+
----------
|
| 273 |
+
node1 : str
|
| 274 |
+
Potential descendant
|
| 275 |
+
node2 : str
|
| 276 |
+
Potential ancestor
|
| 277 |
+
|
| 278 |
+
Returns
|
| 279 |
+
-------
|
| 280 |
+
bool
|
| 281 |
+
True if node1 is descendant of node2
|
| 282 |
+
"""
|
| 283 |
+
return node1 in self.get_descendants(node2)
|
| 284 |
+
|
| 285 |
+
def get_markov_blanket(self, node: str) -> Set[str]:
|
| 286 |
+
"""
|
| 287 |
+
Get Markov blanket of a node.
|
| 288 |
+
|
| 289 |
+
The Markov blanket includes: parents, children, and co-parents
|
| 290 |
+
(other parents of children).
|
| 291 |
+
|
| 292 |
+
Parameters
|
| 293 |
+
----------
|
| 294 |
+
node : str
|
| 295 |
+
Node identifier
|
| 296 |
+
|
| 297 |
+
Returns
|
| 298 |
+
-------
|
| 299 |
+
Set[str]
|
| 300 |
+
Markov blanket nodes
|
| 301 |
+
"""
|
| 302 |
+
parents = set(self.get_parents(node))
|
| 303 |
+
children = set(self.get_children(node))
|
| 304 |
+
|
| 305 |
+
# Get co-parents (parents of children)
|
| 306 |
+
co_parents = set()
|
| 307 |
+
for child in children:
|
| 308 |
+
co_parents.update(self.get_parents(child))
|
| 309 |
+
|
| 310 |
+
co_parents.discard(node)
|
| 311 |
+
|
| 312 |
+
return parents | children | co_parents
|
| 313 |
+
|
| 314 |
+
def d_separated(self, X: Set[str], Y: Set[str], Z: Set[str]) -> bool:
|
| 315 |
+
"""
|
| 316 |
+
Test if X and Y are d-separated given Z.
|
| 317 |
+
|
| 318 |
+
This is fundamental for determining conditional independence.
|
| 319 |
+
|
| 320 |
+
Parameters
|
| 321 |
+
----------
|
| 322 |
+
X : Set[str]
|
| 323 |
+
First set of nodes
|
| 324 |
+
Y : Set[str]
|
| 325 |
+
Second set of nodes
|
| 326 |
+
Z : Set[str]
|
| 327 |
+
Conditioning set
|
| 328 |
+
|
| 329 |
+
Returns
|
| 330 |
+
-------
|
| 331 |
+
bool
|
| 332 |
+
True if X and Y are d-separated given Z
|
| 333 |
+
"""
|
| 334 |
+
return nx.d_separated(self.graph, X, Y, Z)
|
| 335 |
+
|
| 336 |
+
def visualize(self, output_path: Optional[str] = None) -> None:
|
| 337 |
+
"""
|
| 338 |
+
Visualize the causal graph.
|
| 339 |
+
|
| 340 |
+
Parameters
|
| 341 |
+
----------
|
| 342 |
+
output_path : str, optional
|
| 343 |
+
Path to save visualization
|
| 344 |
+
"""
|
| 345 |
+
try:
|
| 346 |
+
import matplotlib.pyplot as plt
|
| 347 |
+
|
| 348 |
+
pos = nx.spring_layout(self.graph)
|
| 349 |
+
plt.figure(figsize=(12, 8))
|
| 350 |
+
|
| 351 |
+
nx.draw(
|
| 352 |
+
self.graph,
|
| 353 |
+
pos,
|
| 354 |
+
with_labels=True,
|
| 355 |
+
node_color='lightblue',
|
| 356 |
+
node_size=3000,
|
| 357 |
+
font_size=10,
|
| 358 |
+
font_weight='bold',
|
| 359 |
+
arrows=True,
|
| 360 |
+
arrowsize=20,
|
| 361 |
+
edge_color='gray'
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
plt.title(f"Causal Graph: {self.name}")
|
| 365 |
+
|
| 366 |
+
if output_path:
|
| 367 |
+
plt.savefig(output_path)
|
| 368 |
+
else:
|
| 369 |
+
plt.show()
|
| 370 |
+
|
| 371 |
+
except ImportError:
|
| 372 |
+
print("Matplotlib required for visualization")
|
| 373 |
+
|
| 374 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 375 |
+
"""
|
| 376 |
+
Convert graph to dictionary representation.
|
| 377 |
+
|
| 378 |
+
Returns
|
| 379 |
+
-------
|
| 380 |
+
dict
|
| 381 |
+
Dictionary representation
|
| 382 |
+
"""
|
| 383 |
+
return {
|
| 384 |
+
'name': self.name,
|
| 385 |
+
'nodes': [
|
| 386 |
+
{'id': node, **self.node_metadata.get(node, {})}
|
| 387 |
+
for node in self.graph.nodes()
|
| 388 |
+
],
|
| 389 |
+
'edges': [
|
| 390 |
+
{
|
| 391 |
+
'source': edge.source,
|
| 392 |
+
'target': edge.target,
|
| 393 |
+
'strength': edge.strength,
|
| 394 |
+
'confidence': edge.confidence,
|
| 395 |
+
'mechanism': edge.mechanism
|
| 396 |
+
}
|
| 397 |
+
for edge in self.edges
|
| 398 |
+
]
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
def to_json(self, path: str) -> None:
|
| 402 |
+
"""
|
| 403 |
+
Save graph to JSON file.
|
| 404 |
+
|
| 405 |
+
Parameters
|
| 406 |
+
----------
|
| 407 |
+
path : str
|
| 408 |
+
Output file path
|
| 409 |
+
"""
|
| 410 |
+
with open(path, 'w') as f:
|
| 411 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 412 |
+
|
| 413 |
+
@classmethod
|
| 414 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'CausalGraph':
|
| 415 |
+
"""
|
| 416 |
+
Load graph from dictionary.
|
| 417 |
+
|
| 418 |
+
Parameters
|
| 419 |
+
----------
|
| 420 |
+
data : dict
|
| 421 |
+
Dictionary representation
|
| 422 |
+
|
| 423 |
+
Returns
|
| 424 |
+
-------
|
| 425 |
+
CausalGraph
|
| 426 |
+
Loaded graph
|
| 427 |
+
"""
|
| 428 |
+
graph = cls(name=data['name'])
|
| 429 |
+
|
| 430 |
+
# Add nodes
|
| 431 |
+
for node_data in data['nodes']:
|
| 432 |
+
graph.add_node(
|
| 433 |
+
node_data['id'],
|
| 434 |
+
node_type=node_data.get('type', 'variable'),
|
| 435 |
+
metadata=node_data.get('metadata', {})
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Add edges
|
| 439 |
+
for edge_data in data['edges']:
|
| 440 |
+
graph.add_edge(
|
| 441 |
+
edge_data['source'],
|
| 442 |
+
edge_data['target'],
|
| 443 |
+
strength=edge_data.get('strength', 1.0),
|
| 444 |
+
confidence=edge_data.get('confidence', 1.0),
|
| 445 |
+
mechanism=edge_data.get('mechanism', '')
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
return graph
|
| 449 |
+
|
| 450 |
+
@classmethod
|
| 451 |
+
def from_json(cls, path: str) -> 'CausalGraph':
|
| 452 |
+
"""
|
| 453 |
+
Load graph from JSON file.
|
| 454 |
+
|
| 455 |
+
Parameters
|
| 456 |
+
----------
|
| 457 |
+
path : str
|
| 458 |
+
Input file path
|
| 459 |
+
|
| 460 |
+
Returns
|
| 461 |
+
-------
|
| 462 |
+
CausalGraph
|
| 463 |
+
Loaded graph
|
| 464 |
+
"""
|
| 465 |
+
with open(path, 'r') as f:
|
| 466 |
+
data = json.load(f)
|
| 467 |
+
return cls.from_dict(data)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class StructuralCausalModel:
|
| 471 |
+
"""
|
| 472 |
+
Structural Causal Model (SCM) with functional equations.
|
| 473 |
+
|
| 474 |
+
An SCM defines how each variable is generated from its parents
|
| 475 |
+
and exogenous noise.
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
def __init__(self, causal_graph: CausalGraph):
|
| 479 |
+
"""
|
| 480 |
+
Initialize structural causal model.
|
| 481 |
+
|
| 482 |
+
Parameters
|
| 483 |
+
----------
|
| 484 |
+
causal_graph : CausalGraph
|
| 485 |
+
Underlying causal graph
|
| 486 |
+
"""
|
| 487 |
+
self.graph = causal_graph
|
| 488 |
+
self.functions: Dict[str, Callable] = {}
|
| 489 |
+
self.noise_distributions: Dict[str, Any] = {}
|
| 490 |
+
|
| 491 |
+
def set_function(
|
| 492 |
+
self,
|
| 493 |
+
node: str,
|
| 494 |
+
function: Callable,
|
| 495 |
+
noise_dist: Optional[Any] = None
|
| 496 |
+
) -> None:
|
| 497 |
+
"""
|
| 498 |
+
Set structural equation for a node.
|
| 499 |
+
|
| 500 |
+
Parameters
|
| 501 |
+
----------
|
| 502 |
+
node : str
|
| 503 |
+
Node identifier
|
| 504 |
+
function : callable
|
| 505 |
+
Function that computes node value from parents
|
| 506 |
+
Signature: f(parent_values, noise) -> value
|
| 507 |
+
noise_dist : optional
|
| 508 |
+
Noise distribution for this variable
|
| 509 |
+
"""
|
| 510 |
+
self.functions[node] = function
|
| 511 |
+
if noise_dist is not None:
|
| 512 |
+
self.noise_distributions[node] = noise_dist
|
| 513 |
+
|
| 514 |
+
def sample(
|
| 515 |
+
self,
|
| 516 |
+
n_samples: int = 1,
|
| 517 |
+
interventions: Optional[Dict[str, float]] = None
|
| 518 |
+
) -> Dict[str, np.ndarray]:
|
| 519 |
+
"""
|
| 520 |
+
Sample from the structural causal model.
|
| 521 |
+
|
| 522 |
+
Parameters
|
| 523 |
+
----------
|
| 524 |
+
n_samples : int
|
| 525 |
+
Number of samples to generate
|
| 526 |
+
interventions : dict, optional
|
| 527 |
+
Dictionary of interventions {node: value}
|
| 528 |
+
|
| 529 |
+
Returns
|
| 530 |
+
-------
|
| 531 |
+
dict
|
| 532 |
+
Dictionary of samples for each variable
|
| 533 |
+
"""
|
| 534 |
+
samples = {node: np.zeros(n_samples) for node in self.graph.graph.nodes()}
|
| 535 |
+
|
| 536 |
+
# Sample in topological order
|
| 537 |
+
for node in self.graph.get_topological_order():
|
| 538 |
+
# Check if this node is intervened upon
|
| 539 |
+
if interventions and node in interventions:
|
| 540 |
+
samples[node] = np.full(n_samples, interventions[node])
|
| 541 |
+
else:
|
| 542 |
+
# Get parent values
|
| 543 |
+
parents = self.graph.get_parents(node)
|
| 544 |
+
parent_values = {p: samples[p] for p in parents}
|
| 545 |
+
|
| 546 |
+
# Sample noise
|
| 547 |
+
if node in self.noise_distributions:
|
| 548 |
+
noise = self.noise_distributions[node].rvs(n_samples)
|
| 549 |
+
else:
|
| 550 |
+
noise = np.zeros(n_samples)
|
| 551 |
+
|
| 552 |
+
# Compute value using structural equation
|
| 553 |
+
if node in self.functions:
|
| 554 |
+
samples[node] = self.functions[node](parent_values, noise)
|
| 555 |
+
else:
|
| 556 |
+
# Default: just use noise
|
| 557 |
+
samples[node] = noise
|
| 558 |
+
|
| 559 |
+
return samples
|
| 560 |
+
|
| 561 |
+
def compute_counterfactual(
|
| 562 |
+
self,
|
| 563 |
+
observed: Dict[str, float],
|
| 564 |
+
interventions: Dict[str, float]
|
| 565 |
+
) -> Dict[str, float]:
|
| 566 |
+
"""
|
| 567 |
+
Compute counterfactual: What would happen if we intervened?
|
| 568 |
+
|
| 569 |
+
Parameters
|
| 570 |
+
----------
|
| 571 |
+
observed : dict
|
| 572 |
+
Observed values
|
| 573 |
+
interventions : dict
|
| 574 |
+
Interventions to apply
|
| 575 |
+
|
| 576 |
+
Returns
|
| 577 |
+
-------
|
| 578 |
+
dict
|
| 579 |
+
Counterfactual values
|
| 580 |
+
"""
|
| 581 |
+
# This is a simplified version
|
| 582 |
+
# Full counterfactual computation requires abduction-action-prediction
|
| 583 |
+
|
| 584 |
+
# For now, we sample with interventions
|
| 585 |
+
samples = self.sample(n_samples=1000, interventions=interventions)
|
| 586 |
+
|
| 587 |
+
# Return means
|
| 588 |
+
return {node: np.mean(values) for node, values in samples.items()}
|
geobot/models/quasi_experimental.py
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quasi-Experimental Causal Inference Methods
|
| 3 |
+
|
| 4 |
+
When randomized experiments are impossible, quasi-experimental designs
|
| 5 |
+
provide credible causal identification under weaker assumptions.
|
| 6 |
+
|
| 7 |
+
Core methods:
|
| 8 |
+
1. Synthetic Control Method (SCM): Construct counterfactual from weighted controls
|
| 9 |
+
2. Difference-in-Differences (DiD): Compare treatment vs control before/after
|
| 10 |
+
3. Regression Discontinuity Design (RDD): Exploit threshold-based treatment assignment
|
| 11 |
+
4. Instrumental Variables (IV): Use exogenous variation to identify causal effects
|
| 12 |
+
5. Causal Forests: Machine learning for heterogeneous treatment effects
|
| 13 |
+
|
| 14 |
+
Applications in geopolitics:
|
| 15 |
+
- SCM: Effect of sanctions on target country (compare to synthetic control)
|
| 16 |
+
- DiD: Impact of regime change (compare neighboring countries before/after)
|
| 17 |
+
- RDD: Effect of election outcomes (winners vs losers near threshold)
|
| 18 |
+
- IV: Effect of trade on conflict (use geographic instruments)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
from scipy import optimize, stats
|
| 23 |
+
from typing import Dict, List, Tuple, Optional, Union
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
import warnings
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SyntheticControlResult:
|
| 30 |
+
"""Results from Synthetic Control Method."""
|
| 31 |
+
weights: np.ndarray # Weights on control units
|
| 32 |
+
treated_outcome: np.ndarray # Actual treated unit outcomes
|
| 33 |
+
synthetic_outcome: np.ndarray # Synthetic control outcomes
|
| 34 |
+
treatment_effect: np.ndarray # Difference (post-treatment)
|
| 35 |
+
pre_treatment_fit: float # RMSPE in pre-treatment period
|
| 36 |
+
control_units: List[str] # Names of control units
|
| 37 |
+
treatment_time: int # Index where treatment starts
|
| 38 |
+
p_value: Optional[float] = None # From permutation test
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class DIDResult:
|
| 43 |
+
"""Results from Difference-in-Differences."""
|
| 44 |
+
att: float # Average Treatment effect on Treated
|
| 45 |
+
se: float # Standard error
|
| 46 |
+
t_stat: float
|
| 47 |
+
p_value: float
|
| 48 |
+
pre_treatment_diff: float # Check parallel trends
|
| 49 |
+
post_treatment_diff: float
|
| 50 |
+
n_treated: int
|
| 51 |
+
n_control: int
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class RDDResult:
|
| 56 |
+
"""Results from Regression Discontinuity Design."""
|
| 57 |
+
treatment_effect: float # Local Average Treatment Effect (LATE)
|
| 58 |
+
se: float
|
| 59 |
+
t_stat: float
|
| 60 |
+
p_value: float
|
| 61 |
+
bandwidth: float
|
| 62 |
+
n_left: int # Observations below cutoff
|
| 63 |
+
n_right: int # Observations above cutoff
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class IVResult:
|
| 68 |
+
"""Results from Instrumental Variables estimation."""
|
| 69 |
+
beta_iv: np.ndarray # IV estimates
|
| 70 |
+
beta_ols: np.ndarray # OLS estimates (for comparison)
|
| 71 |
+
se_iv: np.ndarray # Standard errors
|
| 72 |
+
first_stage_f: float # First stage F-statistic
|
| 73 |
+
weak_instrument: bool # True if F < 10
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class SyntheticControlMethod:
|
| 77 |
+
"""
|
| 78 |
+
Synthetic Control Method (Abadie, Diamond, Hainmueller 2010, 2015)
|
| 79 |
+
|
| 80 |
+
Creates a synthetic version of the treated unit as a weighted average
|
| 81 |
+
of control units to estimate counterfactual outcomes.
|
| 82 |
+
|
| 83 |
+
Key idea: If we can match pre-treatment outcomes and covariates perfectly,
|
| 84 |
+
the synthetic control provides a valid counterfactual.
|
| 85 |
+
|
| 86 |
+
Example:
|
| 87 |
+
>>> # Effect of sanctions on Iran's GDP
|
| 88 |
+
>>> scm = SyntheticControlMethod()
|
| 89 |
+
>>> result = scm.fit(
|
| 90 |
+
... treated_outcome=iran_gdp, # (T,)
|
| 91 |
+
... control_outcomes=other_countries_gdp, # (T, J)
|
| 92 |
+
... treatment_time=20, # Sanctions imposed at t=20
|
| 93 |
+
... treated_covariates=iran_covariates, # (K,)
|
| 94 |
+
... control_covariates=other_covariates # (J, K)
|
| 95 |
+
... )
|
| 96 |
+
>>> print(f"Average treatment effect: {np.mean(result.treatment_effect):.2f}")
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, loss: str = 'l2'):
|
| 100 |
+
"""
|
| 101 |
+
Initialize SCM.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
loss: Loss function for matching ('l2' or 'l1')
|
| 105 |
+
"""
|
| 106 |
+
self.loss = loss
|
| 107 |
+
|
| 108 |
+
def fit(self, treated_outcome: np.ndarray, control_outcomes: np.ndarray,
|
| 109 |
+
treatment_time: int,
|
| 110 |
+
treated_covariates: Optional[np.ndarray] = None,
|
| 111 |
+
control_covariates: Optional[np.ndarray] = None,
|
| 112 |
+
control_names: Optional[List[str]] = None,
|
| 113 |
+
custom_weights: Optional[np.ndarray] = None) -> SyntheticControlResult:
|
| 114 |
+
"""
|
| 115 |
+
Fit synthetic control model.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
treated_outcome: Outcome for treated unit, shape (T,)
|
| 119 |
+
control_outcomes: Outcomes for control units, shape (T, J)
|
| 120 |
+
treatment_time: Time index when treatment begins
|
| 121 |
+
treated_covariates: Covariates for treated unit, shape (K,)
|
| 122 |
+
control_covariates: Covariates for controls, shape (J, K)
|
| 123 |
+
control_names: Names of control units
|
| 124 |
+
custom_weights: Optional custom weights for different predictors
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
SyntheticControlResult with estimated effects
|
| 128 |
+
"""
|
| 129 |
+
T, J = control_outcomes.shape
|
| 130 |
+
|
| 131 |
+
if control_names is None:
|
| 132 |
+
control_names = [f"control_{j}" for j in range(J)]
|
| 133 |
+
|
| 134 |
+
# Pre-treatment period
|
| 135 |
+
Y1_pre = treated_outcome[:treatment_time]
|
| 136 |
+
Y0_pre = control_outcomes[:treatment_time, :]
|
| 137 |
+
|
| 138 |
+
# Construct predictors matrix
|
| 139 |
+
if treated_covariates is not None and control_covariates is not None:
|
| 140 |
+
# Include both outcomes and covariates
|
| 141 |
+
X1 = np.concatenate([Y1_pre, treated_covariates])
|
| 142 |
+
X0 = np.vstack([Y0_pre.T, control_covariates.T]) # Shape: (J, T_pre + K)
|
| 143 |
+
else:
|
| 144 |
+
# Use only pre-treatment outcomes
|
| 145 |
+
X1 = Y1_pre
|
| 146 |
+
X0 = Y0_pre.T # Shape: (J, T_pre)
|
| 147 |
+
|
| 148 |
+
# Find weights that minimize ||X1 - X0 w||
|
| 149 |
+
weights = self._optimize_weights(X1, X0, custom_weights)
|
| 150 |
+
|
| 151 |
+
# Construct synthetic control
|
| 152 |
+
synthetic_outcome = control_outcomes @ weights
|
| 153 |
+
|
| 154 |
+
# Compute treatment effects (post-treatment)
|
| 155 |
+
treatment_effect = np.zeros(T)
|
| 156 |
+
treatment_effect[treatment_time:] = (
|
| 157 |
+
treated_outcome[treatment_time:] - synthetic_outcome[treatment_time:]
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Pre-treatment fit quality
|
| 161 |
+
pre_treatment_fit = np.sqrt(np.mean((Y1_pre - synthetic_outcome[:treatment_time]) ** 2))
|
| 162 |
+
|
| 163 |
+
return SyntheticControlResult(
|
| 164 |
+
weights=weights,
|
| 165 |
+
treated_outcome=treated_outcome,
|
| 166 |
+
synthetic_outcome=synthetic_outcome,
|
| 167 |
+
treatment_effect=treatment_effect,
|
| 168 |
+
pre_treatment_fit=pre_treatment_fit,
|
| 169 |
+
control_units=control_names,
|
| 170 |
+
treatment_time=treatment_time
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _optimize_weights(self, X1: np.ndarray, X0: np.ndarray,
|
| 174 |
+
V: Optional[np.ndarray] = None) -> np.ndarray:
|
| 175 |
+
"""
|
| 176 |
+
Optimize weights to minimize prediction error.
|
| 177 |
+
|
| 178 |
+
min_w ||X1 - X0 w||_V^2
|
| 179 |
+
s.t. w >= 0, sum(w) = 1
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
X1: Target predictors, shape (K,)
|
| 183 |
+
X0: Control predictors, shape (J, K)
|
| 184 |
+
V: Optional weighting matrix
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Optimal weights, shape (J,)
|
| 188 |
+
"""
|
| 189 |
+
J = X0.shape[0]
|
| 190 |
+
|
| 191 |
+
if V is None:
|
| 192 |
+
V = np.eye(len(X1))
|
| 193 |
+
|
| 194 |
+
# Objective function
|
| 195 |
+
def objective(w):
|
| 196 |
+
diff = X1 - X0.T @ w
|
| 197 |
+
return diff.T @ V @ diff
|
| 198 |
+
|
| 199 |
+
# Constraints: w >= 0, sum(w) = 1
|
| 200 |
+
constraints = {'type': 'eq', 'fun': lambda w: np.sum(w) - 1}
|
| 201 |
+
bounds = [(0, 1) for _ in range(J)]
|
| 202 |
+
|
| 203 |
+
# Initial guess: equal weights
|
| 204 |
+
w0 = np.ones(J) / J
|
| 205 |
+
|
| 206 |
+
# Optimize
|
| 207 |
+
result = optimize.minimize(
|
| 208 |
+
objective,
|
| 209 |
+
x0=w0,
|
| 210 |
+
method='SLSQP',
|
| 211 |
+
bounds=bounds,
|
| 212 |
+
constraints=constraints
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if not result.success:
|
| 216 |
+
warnings.warn("Optimization did not fully converge")
|
| 217 |
+
|
| 218 |
+
return result.x
|
| 219 |
+
|
| 220 |
+
def placebo_test(self, treated_outcome: np.ndarray, control_outcomes: np.ndarray,
|
| 221 |
+
treatment_time: int, n_permutations: int = 100) -> float:
|
| 222 |
+
"""
|
| 223 |
+
Conduct placebo test by applying SCM to control units.
|
| 224 |
+
|
| 225 |
+
Tests whether the observed treatment effect is unusually large
|
| 226 |
+
compared to effects from placebo treatments on controls.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
treated_outcome: Treated unit outcome
|
| 230 |
+
control_outcomes: Control units outcomes
|
| 231 |
+
treatment_time: Treatment time
|
| 232 |
+
n_permutations: Number of placebo tests
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
p-value: Proportion of placebos with larger effect
|
| 236 |
+
"""
|
| 237 |
+
# Fit actual SCM
|
| 238 |
+
actual_result = self.fit(treated_outcome, control_outcomes, treatment_time)
|
| 239 |
+
actual_effect = np.abs(np.mean(actual_result.treatment_effect[treatment_time:]))
|
| 240 |
+
|
| 241 |
+
# Run placebo tests
|
| 242 |
+
placebo_effects = []
|
| 243 |
+
J = control_outcomes.shape[1]
|
| 244 |
+
|
| 245 |
+
for j in range(min(J, n_permutations)):
|
| 246 |
+
# Treat control j as if it were treated
|
| 247 |
+
placebo_treated = control_outcomes[:, j]
|
| 248 |
+
placebo_controls = np.delete(control_outcomes, j, axis=1)
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
placebo_result = self.fit(placebo_treated, placebo_controls, treatment_time)
|
| 252 |
+
placebo_effect = np.abs(np.mean(placebo_result.treatment_effect[treatment_time:]))
|
| 253 |
+
placebo_effects.append(placebo_effect)
|
| 254 |
+
except:
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
# p-value: proportion of placebos with larger effect
|
| 258 |
+
placebo_effects = np.array(placebo_effects)
|
| 259 |
+
p_value = np.mean(placebo_effects >= actual_effect)
|
| 260 |
+
|
| 261 |
+
return p_value
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class DifferenceinDifferences:
|
| 265 |
+
"""
|
| 266 |
+
Difference-in-Differences (DiD) Estimation
|
| 267 |
+
|
| 268 |
+
Compares changes over time between treatment and control groups.
|
| 269 |
+
|
| 270 |
+
Model:
|
| 271 |
+
Y_it = β_0 + β_1 * Treated_i + β_2 * Post_t + β_3 * (Treated_i × Post_t) + ε_it
|
| 272 |
+
|
| 273 |
+
where β_3 is the DiD estimate (Average Treatment effect on Treated).
|
| 274 |
+
|
| 275 |
+
Key assumption: Parallel trends (treatment and control would have
|
| 276 |
+
followed same trend absent treatment).
|
| 277 |
+
|
| 278 |
+
Example:
|
| 279 |
+
>>> # Effect of regime change in country A
|
| 280 |
+
>>> did = DifferenceinDifferences()
|
| 281 |
+
>>> result = did.estimate(
|
| 282 |
+
... treated_pre=country_a_gdp_before,
|
| 283 |
+
... treated_post=country_a_gdp_after,
|
| 284 |
+
... control_pre=neighbors_gdp_before,
|
| 285 |
+
... control_post=neighbors_gdp_after
|
| 286 |
+
... )
|
| 287 |
+
>>> print(f"ATT: {result.att:.3f} (p={result.p_value:.3f})")
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
def estimate(self, treated_pre: np.ndarray, treated_post: np.ndarray,
|
| 291 |
+
control_pre: np.ndarray, control_post: np.ndarray,
|
| 292 |
+
cluster_robust: bool = False) -> DIDResult:
|
| 293 |
+
"""
|
| 294 |
+
Estimate DiD effect.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
treated_pre: Treated group pre-treatment, shape (n_treated,)
|
| 298 |
+
treated_post: Treated group post-treatment, shape (n_treated,)
|
| 299 |
+
control_pre: Control group pre-treatment, shape (n_control,)
|
| 300 |
+
control_post: Control group post-treatment, shape (n_control,)
|
| 301 |
+
cluster_robust: Use cluster-robust standard errors
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
DIDResult with ATT estimate
|
| 305 |
+
"""
|
| 306 |
+
# Convert to arrays
|
| 307 |
+
treated_pre = np.asarray(treated_pre)
|
| 308 |
+
treated_post = np.asarray(treated_post)
|
| 309 |
+
control_pre = np.asarray(control_pre)
|
| 310 |
+
control_post = np.asarray(control_post)
|
| 311 |
+
|
| 312 |
+
# Sample sizes
|
| 313 |
+
n_treated = len(treated_pre)
|
| 314 |
+
n_control = len(control_pre)
|
| 315 |
+
|
| 316 |
+
# Mean outcomes
|
| 317 |
+
y_treated_pre = np.mean(treated_pre)
|
| 318 |
+
y_treated_post = np.mean(treated_post)
|
| 319 |
+
y_control_pre = np.mean(control_pre)
|
| 320 |
+
y_control_post = np.mean(control_post)
|
| 321 |
+
|
| 322 |
+
# DiD estimate
|
| 323 |
+
diff_treated = y_treated_post - y_treated_pre
|
| 324 |
+
diff_control = y_control_post - y_control_pre
|
| 325 |
+
att = diff_treated - diff_control
|
| 326 |
+
|
| 327 |
+
# Standard error (assuming homoskedasticity)
|
| 328 |
+
var_treated_pre = np.var(treated_pre, ddof=1)
|
| 329 |
+
var_treated_post = np.var(treated_post, ddof=1)
|
| 330 |
+
var_control_pre = np.var(control_pre, ddof=1)
|
| 331 |
+
var_control_post = np.var(control_post, ddof=1)
|
| 332 |
+
|
| 333 |
+
se = np.sqrt(
|
| 334 |
+
var_treated_post / n_treated +
|
| 335 |
+
var_treated_pre / n_treated +
|
| 336 |
+
var_control_post / n_control +
|
| 337 |
+
var_control_pre / n_control
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Test statistic
|
| 341 |
+
t_stat = att / se
|
| 342 |
+
p_value = 2 * (1 - stats.t.cdf(np.abs(t_stat), df=n_treated + n_control - 2))
|
| 343 |
+
|
| 344 |
+
return DIDResult(
|
| 345 |
+
att=att,
|
| 346 |
+
se=se,
|
| 347 |
+
t_stat=t_stat,
|
| 348 |
+
p_value=p_value,
|
| 349 |
+
pre_treatment_diff=y_treated_pre - y_control_pre,
|
| 350 |
+
post_treatment_diff=y_treated_post - y_control_post,
|
| 351 |
+
n_treated=n_treated,
|
| 352 |
+
n_control=n_control
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
def panel_did(self, panel_data: np.ndarray, treatment_indicator: np.ndarray,
|
| 356 |
+
time_indicator: np.ndarray, unit_ids: np.ndarray) -> DIDResult:
|
| 357 |
+
"""
|
| 358 |
+
Estimate DiD with panel data and fixed effects.
|
| 359 |
+
|
| 360 |
+
Model:
|
| 361 |
+
Y_it = α_i + γ_t + δ * (Treatment_i × Post_t) + ε_it
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
panel_data: Outcome variable, shape (N*T,)
|
| 365 |
+
treatment_indicator: 1 if unit is treated, 0 otherwise, shape (N*T,)
|
| 366 |
+
time_indicator: 1 if post-treatment, 0 if pre, shape (N*T,)
|
| 367 |
+
unit_ids: Unit identifiers, shape (N*T,)
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
DIDResult
|
| 371 |
+
"""
|
| 372 |
+
# Create interaction term
|
| 373 |
+
did_term = treatment_indicator * time_indicator
|
| 374 |
+
|
| 375 |
+
# Demean for fixed effects (within transformation)
|
| 376 |
+
n_obs = len(panel_data)
|
| 377 |
+
unique_units = np.unique(unit_ids)
|
| 378 |
+
unique_times = np.unique(time_indicator)
|
| 379 |
+
|
| 380 |
+
# Demean by unit (removes α_i)
|
| 381 |
+
y_demeaned = np.zeros(n_obs)
|
| 382 |
+
did_demeaned = np.zeros(n_obs)
|
| 383 |
+
|
| 384 |
+
for unit in unique_units:
|
| 385 |
+
mask = unit_ids == unit
|
| 386 |
+
y_demeaned[mask] = panel_data[mask] - np.mean(panel_data[mask])
|
| 387 |
+
did_demeaned[mask] = did_term[mask] - np.mean(did_term[mask])
|
| 388 |
+
|
| 389 |
+
# Regression: y_demeaned ~ did_demeaned (absorbs time FE implicitly)
|
| 390 |
+
# Simple OLS
|
| 391 |
+
att = np.sum(did_demeaned * y_demeaned) / np.sum(did_demeaned ** 2)
|
| 392 |
+
|
| 393 |
+
# Standard error
|
| 394 |
+
residuals = y_demeaned - att * did_demeaned
|
| 395 |
+
rss = np.sum(residuals ** 2)
|
| 396 |
+
se = np.sqrt(rss / (n_obs - 2) / np.sum(did_demeaned ** 2))
|
| 397 |
+
|
| 398 |
+
t_stat = att / se
|
| 399 |
+
p_value = 2 * (1 - stats.t.cdf(np.abs(t_stat), df=n_obs - 2))
|
| 400 |
+
|
| 401 |
+
n_treated = np.sum(treatment_indicator > 0)
|
| 402 |
+
n_control = n_obs - n_treated
|
| 403 |
+
|
| 404 |
+
return DIDResult(
|
| 405 |
+
att=att,
|
| 406 |
+
se=se,
|
| 407 |
+
t_stat=t_stat,
|
| 408 |
+
p_value=p_value,
|
| 409 |
+
pre_treatment_diff=0.0, # Not directly computed
|
| 410 |
+
post_treatment_diff=0.0,
|
| 411 |
+
n_treated=n_treated,
|
| 412 |
+
n_control=n_control
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class RegressionDiscontinuity:
|
| 417 |
+
"""
|
| 418 |
+
Regression Discontinuity Design (RDD)
|
| 419 |
+
|
| 420 |
+
Estimates treatment effects when treatment assignment is determined
|
| 421 |
+
by whether a running variable crosses a threshold.
|
| 422 |
+
|
| 423 |
+
Sharp RDD: Treatment deterministically assigned at cutoff
|
| 424 |
+
Fuzzy RDD: Probability of treatment jumps at cutoff
|
| 425 |
+
|
| 426 |
+
Example: Effect of election victory on policy outcomes
|
| 427 |
+
- Running variable: Vote margin
|
| 428 |
+
- Cutoff: 50%
|
| 429 |
+
- Treatment: Winning election
|
| 430 |
+
|
| 431 |
+
Example:
|
| 432 |
+
>>> # Effect of election victory on military spending
|
| 433 |
+
>>> rdd = RegressionDiscontinuity(cutoff=0.5) # 50% vote share
|
| 434 |
+
>>> result = rdd.estimate_sharp(
|
| 435 |
+
... running_var=vote_share, # Vote percentage
|
| 436 |
+
... outcome=military_spending,
|
| 437 |
+
... bandwidth=0.1 # 10% bandwidth
|
| 438 |
+
... )
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
def __init__(self, cutoff: float = 0.0):
|
| 442 |
+
"""
|
| 443 |
+
Initialize RDD.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
cutoff: Threshold value for treatment assignment
|
| 447 |
+
"""
|
| 448 |
+
self.cutoff = cutoff
|
| 449 |
+
|
| 450 |
+
def estimate_sharp(self, running_var: np.ndarray, outcome: np.ndarray,
|
| 451 |
+
bandwidth: Optional[float] = None,
|
| 452 |
+
kernel: str = 'triangular',
|
| 453 |
+
polynomial_order: int = 1) -> RDDResult:
|
| 454 |
+
"""
|
| 455 |
+
Estimate sharp RDD effect.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
running_var: Running variable (e.g., vote share)
|
| 459 |
+
outcome: Outcome variable
|
| 460 |
+
bandwidth: Bandwidth around cutoff (if None, use data-driven selection)
|
| 461 |
+
kernel: Weighting kernel ('triangular', 'uniform', 'epanechnikov')
|
| 462 |
+
polynomial_order: Order of local polynomial
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
RDDResult with treatment effect estimate
|
| 466 |
+
"""
|
| 467 |
+
running_var = np.asarray(running_var)
|
| 468 |
+
outcome = np.asarray(outcome)
|
| 469 |
+
|
| 470 |
+
# Center running variable at cutoff
|
| 471 |
+
X = running_var - self.cutoff
|
| 472 |
+
|
| 473 |
+
# Select bandwidth if not provided
|
| 474 |
+
if bandwidth is None:
|
| 475 |
+
bandwidth = self._select_bandwidth(X, outcome)
|
| 476 |
+
|
| 477 |
+
# Restrict to bandwidth
|
| 478 |
+
in_bandwidth = np.abs(X) <= bandwidth
|
| 479 |
+
X_bw = X[in_bandwidth]
|
| 480 |
+
Y_bw = outcome[in_bandwidth]
|
| 481 |
+
|
| 482 |
+
# Treatment indicator (above cutoff)
|
| 483 |
+
D = (X_bw >= 0).astype(float)
|
| 484 |
+
|
| 485 |
+
# Create weights
|
| 486 |
+
weights = self._kernel_weights(X_bw, bandwidth, kernel)
|
| 487 |
+
|
| 488 |
+
# Fit local polynomial separately on each side
|
| 489 |
+
# Model: Y = α + β*D + γ*X + δ*(D*X) + higher order terms
|
| 490 |
+
|
| 491 |
+
# Design matrix
|
| 492 |
+
Z = np.column_stack([
|
| 493 |
+
np.ones(len(X_bw)), # Intercept
|
| 494 |
+
D, # Treatment
|
| 495 |
+
X_bw, # Running variable
|
| 496 |
+
D * X_bw # Interaction
|
| 497 |
+
])
|
| 498 |
+
|
| 499 |
+
# Weighted least squares
|
| 500 |
+
W = np.diag(weights)
|
| 501 |
+
try:
|
| 502 |
+
beta = np.linalg.solve(Z.T @ W @ Z, Z.T @ W @ Y_bw)
|
| 503 |
+
except np.linalg.LinAlgError:
|
| 504 |
+
beta = np.linalg.lstsq(Z.T @ W @ Z, Z.T @ W @ Y_bw, rcond=None)[0]
|
| 505 |
+
|
| 506 |
+
# Treatment effect is coefficient on D
|
| 507 |
+
treatment_effect = beta[1]
|
| 508 |
+
|
| 509 |
+
# Standard error (heteroskedasticity-robust)
|
| 510 |
+
residuals = Y_bw - Z @ beta
|
| 511 |
+
meat = Z.T @ W @ np.diag(residuals ** 2) @ W @ Z
|
| 512 |
+
bread_inv = np.linalg.inv(Z.T @ W @ Z)
|
| 513 |
+
vcov = bread_inv @ meat @ bread_inv
|
| 514 |
+
se = np.sqrt(vcov[1, 1])
|
| 515 |
+
|
| 516 |
+
# Test statistic
|
| 517 |
+
t_stat = treatment_effect / se
|
| 518 |
+
n_left = np.sum(X_bw < 0)
|
| 519 |
+
n_right = np.sum(X_bw >= 0)
|
| 520 |
+
df = len(X_bw) - Z.shape[1]
|
| 521 |
+
p_value = 2 * (1 - stats.t.cdf(np.abs(t_stat), df=df))
|
| 522 |
+
|
| 523 |
+
return RDDResult(
|
| 524 |
+
treatment_effect=treatment_effect,
|
| 525 |
+
se=se,
|
| 526 |
+
t_stat=t_stat,
|
| 527 |
+
p_value=p_value,
|
| 528 |
+
bandwidth=bandwidth,
|
| 529 |
+
n_left=n_left,
|
| 530 |
+
n_right=n_right
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
def _select_bandwidth(self, X: np.ndarray, Y: np.ndarray) -> float:
|
| 534 |
+
"""
|
| 535 |
+
Select bandwidth using Imbens-Kalyanaraman method (simplified).
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
X: Centered running variable
|
| 539 |
+
Y: Outcome
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
Optimal bandwidth
|
| 543 |
+
"""
|
| 544 |
+
# Simplified: use rule of thumb
|
| 545 |
+
# h = C * σ * n^{-1/5}
|
| 546 |
+
sigma = np.std(Y)
|
| 547 |
+
n = len(Y)
|
| 548 |
+
bandwidth = 1.06 * sigma * (n ** (-1 / 5))
|
| 549 |
+
|
| 550 |
+
# Ensure reasonable range
|
| 551 |
+
bandwidth = np.clip(bandwidth, 0.1 * np.std(X), 2.0 * np.std(X))
|
| 552 |
+
|
| 553 |
+
return bandwidth
|
| 554 |
+
|
| 555 |
+
def _kernel_weights(self, X: np.ndarray, bandwidth: float, kernel: str) -> np.ndarray:
|
| 556 |
+
"""Compute kernel weights."""
|
| 557 |
+
u = X / bandwidth
|
| 558 |
+
|
| 559 |
+
if kernel == 'triangular':
|
| 560 |
+
weights = np.maximum(1 - np.abs(u), 0)
|
| 561 |
+
elif kernel == 'uniform':
|
| 562 |
+
weights = (np.abs(u) <= 1).astype(float)
|
| 563 |
+
elif kernel == 'epanechnikov':
|
| 564 |
+
weights = np.maximum(0.75 * (1 - u ** 2), 0)
|
| 565 |
+
else:
|
| 566 |
+
weights = np.ones(len(X))
|
| 567 |
+
|
| 568 |
+
return weights
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
class InstrumentalVariables:
|
| 572 |
+
"""
|
| 573 |
+
Instrumental Variables (IV) Estimation
|
| 574 |
+
|
| 575 |
+
Addresses endogeneity (omitted variable bias, reverse causality)
|
| 576 |
+
using exogenous variation from an instrument.
|
| 577 |
+
|
| 578 |
+
Model:
|
| 579 |
+
Y = β_0 + β_1 * X + ε (Structural equation)
|
| 580 |
+
X = γ_0 + γ_1 * Z + η (First stage)
|
| 581 |
+
|
| 582 |
+
where:
|
| 583 |
+
- X: Endogenous variable
|
| 584 |
+
- Z: Instrument (exogenous, correlated with X, affects Y only through X)
|
| 585 |
+
- β_1: Causal effect of X on Y
|
| 586 |
+
|
| 587 |
+
Estimation: Two-Stage Least Squares (2SLS)
|
| 588 |
+
|
| 589 |
+
Example:
|
| 590 |
+
>>> # Effect of trade on conflict (trade is endogenous)
|
| 591 |
+
>>> # Instrument: Geographic distance to major ports
|
| 592 |
+
>>> iv = InstrumentalVariables()
|
| 593 |
+
>>> result = iv.estimate_2sls(
|
| 594 |
+
... outcome=conflict_intensity,
|
| 595 |
+
... endogenous=trade_volume,
|
| 596 |
+
... instrument=distance_to_port,
|
| 597 |
+
... exogenous_controls=other_covariates
|
| 598 |
+
... )
|
| 599 |
+
>>> print(f"IV estimate: {result.beta_iv[0]:.3f}")
|
| 600 |
+
>>> print(f"First stage F: {result.first_stage_f:.1f}")
|
| 601 |
+
"""
|
| 602 |
+
|
| 603 |
+
def estimate_2sls(self, outcome: np.ndarray, endogenous: np.ndarray,
|
| 604 |
+
instrument: np.ndarray,
|
| 605 |
+
exogenous_controls: Optional[np.ndarray] = None) -> IVResult:
|
| 606 |
+
"""
|
| 607 |
+
Two-Stage Least Squares estimation.
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
outcome: Dependent variable Y, shape (n,)
|
| 611 |
+
endogenous: Endogenous variable X, shape (n,) or (n, k)
|
| 612 |
+
instrument: Instrument Z, shape (n,) or (n, m)
|
| 613 |
+
exogenous_controls: Additional exogenous controls, shape (n, p)
|
| 614 |
+
|
| 615 |
+
Returns:
|
| 616 |
+
IVResult with IV estimates
|
| 617 |
+
"""
|
| 618 |
+
outcome = np.asarray(outcome).reshape(-1, 1)
|
| 619 |
+
endogenous = np.atleast_2d(endogenous)
|
| 620 |
+
if endogenous.ndim == 1:
|
| 621 |
+
endogenous = endogenous.reshape(-1, 1)
|
| 622 |
+
instrument = np.atleast_2d(instrument)
|
| 623 |
+
if instrument.ndim == 1:
|
| 624 |
+
instrument = instrument.reshape(-1, 1)
|
| 625 |
+
|
| 626 |
+
n = len(outcome)
|
| 627 |
+
|
| 628 |
+
# Construct design matrices
|
| 629 |
+
if exogenous_controls is not None:
|
| 630 |
+
exogenous_controls = np.atleast_2d(exogenous_controls)
|
| 631 |
+
W = np.column_stack([np.ones((n, 1)), exogenous_controls])
|
| 632 |
+
else:
|
| 633 |
+
W = np.ones((n, 1))
|
| 634 |
+
|
| 635 |
+
# Full instrument matrix: [W, Z]
|
| 636 |
+
Z_full = np.column_stack([W, instrument])
|
| 637 |
+
|
| 638 |
+
# STAGE 1: Regress endogenous on instruments
|
| 639 |
+
# X = Z_full @ γ + residuals
|
| 640 |
+
first_stage_coef = np.linalg.lstsq(Z_full, endogenous, rcond=None)[0]
|
| 641 |
+
X_hat = Z_full @ first_stage_coef # Fitted values
|
| 642 |
+
|
| 643 |
+
# First stage F-statistic
|
| 644 |
+
residuals_first = endogenous - X_hat
|
| 645 |
+
rss_first = np.sum(residuals_first ** 2, axis=0)
|
| 646 |
+
tss_first = np.sum((endogenous - np.mean(endogenous, axis=0)) ** 2, axis=0)
|
| 647 |
+
r_squared_first = 1 - rss_first / tss_first
|
| 648 |
+
|
| 649 |
+
k_instruments = instrument.shape[1]
|
| 650 |
+
k_exogenous = W.shape[1]
|
| 651 |
+
first_stage_f = (r_squared_first / k_instruments) / ((1 - r_squared_first) / (n - k_exogenous - k_instruments))
|
| 652 |
+
first_stage_f = float(np.mean(first_stage_f)) # Average if multiple endogenous
|
| 653 |
+
|
| 654 |
+
# STAGE 2: Regress Y on X_hat and W
|
| 655 |
+
X_full = np.column_stack([W, X_hat])
|
| 656 |
+
beta_iv = np.linalg.lstsq(X_full, outcome, rcond=None)[0]
|
| 657 |
+
|
| 658 |
+
# Standard errors (2SLS requires special formula)
|
| 659 |
+
Y_hat = X_full @ beta_iv
|
| 660 |
+
residuals_second = outcome - Y_hat
|
| 661 |
+
sigma_sq = np.sum(residuals_second ** 2) / (n - X_full.shape[1])
|
| 662 |
+
|
| 663 |
+
# Variance: σ^2 (X_hat' X_hat)^{-1}
|
| 664 |
+
vcov = sigma_sq * np.linalg.inv(X_full.T @ X_full)
|
| 665 |
+
se_iv = np.sqrt(np.diag(vcov)).reshape(-1, 1)
|
| 666 |
+
|
| 667 |
+
# OLS for comparison (biased but often smaller SE)
|
| 668 |
+
X_full_ols = np.column_stack([W, endogenous])
|
| 669 |
+
beta_ols = np.linalg.lstsq(X_full_ols, outcome, rcond=None)[0]
|
| 670 |
+
|
| 671 |
+
# Weak instrument warning
|
| 672 |
+
weak_instrument = first_stage_f < 10
|
| 673 |
+
|
| 674 |
+
return IVResult(
|
| 675 |
+
beta_iv=beta_iv[k_exogenous:, 0], # Exclude intercept/controls
|
| 676 |
+
beta_ols=beta_ols[k_exogenous:, 0],
|
| 677 |
+
se_iv=se_iv[k_exogenous:, 0],
|
| 678 |
+
first_stage_f=first_stage_f,
|
| 679 |
+
weak_instrument=weak_instrument
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def estimate_treatment_effect_bounds(outcome_treated: np.ndarray,
|
| 684 |
+
outcome_control: np.ndarray,
|
| 685 |
+
selection_probability: float = 0.5) -> Tuple[float, float]:
|
| 686 |
+
"""
|
| 687 |
+
Estimate bounds on treatment effect under selection on unobservables.
|
| 688 |
+
|
| 689 |
+
When treatment assignment is not random, the true effect lies within bounds.
|
| 690 |
+
This implements Manski bounds (worst-case bounds).
|
| 691 |
+
|
| 692 |
+
Args:
|
| 693 |
+
outcome_treated: Outcomes for treated group
|
| 694 |
+
outcome_control: Outcomes for control group
|
| 695 |
+
selection_probability: P(Treatment | unobservables)
|
| 696 |
+
|
| 697 |
+
Returns:
|
| 698 |
+
(lower_bound, upper_bound) on average treatment effect
|
| 699 |
+
"""
|
| 700 |
+
# Observed means
|
| 701 |
+
y_treated = np.mean(outcome_treated)
|
| 702 |
+
y_control = np.mean(outcome_control)
|
| 703 |
+
|
| 704 |
+
# Range of outcomes
|
| 705 |
+
y_min = min(np.min(outcome_treated), np.min(outcome_control))
|
| 706 |
+
y_max = max(np.max(outcome_treated), np.max(outcome_control))
|
| 707 |
+
|
| 708 |
+
# Worst-case bounds
|
| 709 |
+
# Lower bound: assume best outcomes for control in unobserved potential outcomes
|
| 710 |
+
lower_bound = y_treated - y_max
|
| 711 |
+
|
| 712 |
+
# Upper bound: assume worst outcomes for control in unobserved potential outcomes
|
| 713 |
+
upper_bound = y_treated - y_min
|
| 714 |
+
|
| 715 |
+
return (lower_bound, upper_bound)
|
geobot/simulation/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simulation engines for GeoBotv1
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .monte_carlo import MonteCarloEngine, ShockSimulator
|
| 6 |
+
from .agent_based import AgentBasedModel, GeopoliticalAgent
|
| 7 |
+
from .sde_solver import (
|
| 8 |
+
EulerMaruyama,
|
| 9 |
+
Milstein,
|
| 10 |
+
StochasticRungeKutta,
|
| 11 |
+
JumpDiffusionProcess,
|
| 12 |
+
GeopoliticalSDE,
|
| 13 |
+
ornstein_uhlenbeck_process
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# Hawkes processes (wrapper for timeseries.point_processes)
|
| 17 |
+
from . import hawkes
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"MonteCarloEngine",
|
| 21 |
+
"ShockSimulator",
|
| 22 |
+
"AgentBasedModel",
|
| 23 |
+
"GeopoliticalAgent",
|
| 24 |
+
"EulerMaruyama",
|
| 25 |
+
"Milstein",
|
| 26 |
+
"StochasticRungeKutta",
|
| 27 |
+
"JumpDiffusionProcess",
|
| 28 |
+
"GeopoliticalSDE",
|
| 29 |
+
"ornstein_uhlenbeck_process",
|
| 30 |
+
"hawkes",
|
| 31 |
+
]
|
geobot/simulation/agent_based.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent-Based Modeling for Geopolitical Simulation
|
| 3 |
+
|
| 4 |
+
Models individual actors (states, organizations, leaders) and their interactions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AgentType(Enum):
|
| 14 |
+
"""Types of geopolitical agents."""
|
| 15 |
+
STATE = "state"
|
| 16 |
+
ORGANIZATION = "organization"
|
| 17 |
+
LEADER = "leader"
|
| 18 |
+
ALLIANCE = "alliance"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AgentState:
|
| 23 |
+
"""State variables for an agent."""
|
| 24 |
+
position: np.ndarray # Position in feature space
|
| 25 |
+
resources: float = 1.0
|
| 26 |
+
power: float = 1.0
|
| 27 |
+
hostility: float = 0.0
|
| 28 |
+
cooperation: float = 0.5
|
| 29 |
+
stability: float = 1.0
|
| 30 |
+
custom: Dict[str, float] = field(default_factory=dict)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class GeopoliticalAgent:
|
| 34 |
+
"""
|
| 35 |
+
Represents a geopolitical actor.
|
| 36 |
+
|
| 37 |
+
Agents have internal state, decision-making logic, and
|
| 38 |
+
interact with other agents and the environment.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
agent_id: str,
|
| 44 |
+
agent_type: AgentType,
|
| 45 |
+
initial_state: AgentState
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Initialize agent.
|
| 49 |
+
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
agent_id : str
|
| 53 |
+
Unique agent identifier
|
| 54 |
+
agent_type : AgentType
|
| 55 |
+
Type of agent
|
| 56 |
+
initial_state : AgentState
|
| 57 |
+
Initial state
|
| 58 |
+
"""
|
| 59 |
+
self.agent_id = agent_id
|
| 60 |
+
self.agent_type = agent_type
|
| 61 |
+
self.state = initial_state
|
| 62 |
+
self.history: List[AgentState] = [initial_state]
|
| 63 |
+
self.relationships: Dict[str, float] = {} # {agent_id: relationship_strength}
|
| 64 |
+
|
| 65 |
+
def update_state(self, **kwargs) -> None:
|
| 66 |
+
"""Update agent state variables."""
|
| 67 |
+
for key, value in kwargs.items():
|
| 68 |
+
if hasattr(self.state, key):
|
| 69 |
+
setattr(self.state, key, value)
|
| 70 |
+
else:
|
| 71 |
+
self.state.custom[key] = value
|
| 72 |
+
|
| 73 |
+
def decide_action(
|
| 74 |
+
self,
|
| 75 |
+
environment: Dict[str, Any],
|
| 76 |
+
other_agents: List['GeopoliticalAgent']
|
| 77 |
+
) -> Dict[str, Any]:
|
| 78 |
+
"""
|
| 79 |
+
Decide on action based on current state and environment.
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
environment : dict
|
| 84 |
+
Environmental factors
|
| 85 |
+
other_agents : list
|
| 86 |
+
Other agents in the system
|
| 87 |
+
|
| 88 |
+
Returns
|
| 89 |
+
-------
|
| 90 |
+
dict
|
| 91 |
+
Chosen action
|
| 92 |
+
"""
|
| 93 |
+
# Simple decision logic (can be made more sophisticated)
|
| 94 |
+
action = {
|
| 95 |
+
'type': 'none',
|
| 96 |
+
'target': None,
|
| 97 |
+
'intensity': 0.0
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Check for threats
|
| 101 |
+
threats = [
|
| 102 |
+
agent for agent in other_agents
|
| 103 |
+
if self.relationships.get(agent.agent_id, 0) < -0.5
|
| 104 |
+
and agent.state.power > self.state.power * 0.8
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
if threats and self.state.hostility > 0.5:
|
| 108 |
+
# Consider conflict
|
| 109 |
+
action = {
|
| 110 |
+
'type': 'escalate',
|
| 111 |
+
'target': threats[0].agent_id,
|
| 112 |
+
'intensity': self.state.hostility
|
| 113 |
+
}
|
| 114 |
+
elif self.state.cooperation > 0.7:
|
| 115 |
+
# Seek cooperation
|
| 116 |
+
potential_partners = [
|
| 117 |
+
agent for agent in other_agents
|
| 118 |
+
if self.relationships.get(agent.agent_id, 0) > 0.3
|
| 119 |
+
]
|
| 120 |
+
if potential_partners:
|
| 121 |
+
action = {
|
| 122 |
+
'type': 'cooperate',
|
| 123 |
+
'target': potential_partners[0].agent_id,
|
| 124 |
+
'intensity': self.state.cooperation
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
return action
|
| 128 |
+
|
| 129 |
+
def interact(self, other_agent: 'GeopoliticalAgent', action: Dict[str, Any]) -> None:
|
| 130 |
+
"""
|
| 131 |
+
Interact with another agent.
|
| 132 |
+
|
| 133 |
+
Parameters
|
| 134 |
+
----------
|
| 135 |
+
other_agent : GeopoliticalAgent
|
| 136 |
+
Other agent
|
| 137 |
+
action : dict
|
| 138 |
+
Action to perform
|
| 139 |
+
"""
|
| 140 |
+
action_type = action['type']
|
| 141 |
+
intensity = action['intensity']
|
| 142 |
+
|
| 143 |
+
if action_type == 'cooperate':
|
| 144 |
+
# Strengthen relationship
|
| 145 |
+
current_rel = self.relationships.get(other_agent.agent_id, 0)
|
| 146 |
+
self.relationships[other_agent.agent_id] = min(1.0, current_rel + 0.1 * intensity)
|
| 147 |
+
|
| 148 |
+
# Mutual benefit
|
| 149 |
+
self.state.resources += 0.05 * intensity
|
| 150 |
+
other_agent.state.resources += 0.05 * intensity
|
| 151 |
+
|
| 152 |
+
elif action_type == 'escalate':
|
| 153 |
+
# Weaken relationship
|
| 154 |
+
current_rel = self.relationships.get(other_agent.agent_id, 0)
|
| 155 |
+
self.relationships[other_agent.agent_id] = max(-1.0, current_rel - 0.2 * intensity)
|
| 156 |
+
|
| 157 |
+
# Conflict effects
|
| 158 |
+
power_ratio = self.state.power / (other_agent.state.power + 1e-6)
|
| 159 |
+
if power_ratio > 1:
|
| 160 |
+
self.state.resources += 0.1 * intensity
|
| 161 |
+
other_agent.state.resources -= 0.15 * intensity
|
| 162 |
+
other_agent.state.stability -= 0.1 * intensity
|
| 163 |
+
else:
|
| 164 |
+
self.state.resources -= 0.1 * intensity
|
| 165 |
+
self.state.stability -= 0.05 * intensity
|
| 166 |
+
|
| 167 |
+
def save_state(self) -> None:
|
| 168 |
+
"""Save current state to history."""
|
| 169 |
+
self.history.append(AgentState(
|
| 170 |
+
position=self.state.position.copy(),
|
| 171 |
+
resources=self.state.resources,
|
| 172 |
+
power=self.state.power,
|
| 173 |
+
hostility=self.state.hostility,
|
| 174 |
+
cooperation=self.state.cooperation,
|
| 175 |
+
stability=self.state.stability,
|
| 176 |
+
custom=self.state.custom.copy()
|
| 177 |
+
))
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class AgentBasedModel:
|
| 181 |
+
"""
|
| 182 |
+
Agent-based model for geopolitical simulation.
|
| 183 |
+
|
| 184 |
+
Manages multiple agents and their interactions over time.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self):
|
| 188 |
+
"""Initialize agent-based model."""
|
| 189 |
+
self.agents: Dict[str, GeopoliticalAgent] = {}
|
| 190 |
+
self.environment: Dict[str, Any] = {}
|
| 191 |
+
self.time: int = 0
|
| 192 |
+
|
| 193 |
+
def add_agent(self, agent: GeopoliticalAgent) -> None:
|
| 194 |
+
"""
|
| 195 |
+
Add agent to model.
|
| 196 |
+
|
| 197 |
+
Parameters
|
| 198 |
+
----------
|
| 199 |
+
agent : GeopoliticalAgent
|
| 200 |
+
Agent to add
|
| 201 |
+
"""
|
| 202 |
+
self.agents[agent.agent_id] = agent
|
| 203 |
+
|
| 204 |
+
def remove_agent(self, agent_id: str) -> None:
|
| 205 |
+
"""
|
| 206 |
+
Remove agent from model.
|
| 207 |
+
|
| 208 |
+
Parameters
|
| 209 |
+
----------
|
| 210 |
+
agent_id : str
|
| 211 |
+
Agent ID to remove
|
| 212 |
+
"""
|
| 213 |
+
if agent_id in self.agents:
|
| 214 |
+
del self.agents[agent_id]
|
| 215 |
+
|
| 216 |
+
def set_environment(self, **kwargs) -> None:
|
| 217 |
+
"""Set environmental variables."""
|
| 218 |
+
self.environment.update(kwargs)
|
| 219 |
+
|
| 220 |
+
def step(self) -> None:
|
| 221 |
+
"""
|
| 222 |
+
Execute one time step of simulation.
|
| 223 |
+
|
| 224 |
+
All agents make decisions and interact.
|
| 225 |
+
"""
|
| 226 |
+
self.time += 1
|
| 227 |
+
|
| 228 |
+
# Phase 1: All agents decide actions
|
| 229 |
+
actions = {}
|
| 230 |
+
other_agents_list = list(self.agents.values())
|
| 231 |
+
|
| 232 |
+
for agent in self.agents.values():
|
| 233 |
+
action = agent.decide_action(self.environment, other_agents_list)
|
| 234 |
+
actions[agent.agent_id] = action
|
| 235 |
+
|
| 236 |
+
# Phase 2: Execute actions
|
| 237 |
+
for agent_id, action in actions.items():
|
| 238 |
+
agent = self.agents[agent_id]
|
| 239 |
+
|
| 240 |
+
if action['type'] != 'none' and action['target'] is not None:
|
| 241 |
+
if action['target'] in self.agents:
|
| 242 |
+
target = self.agents[action['target']]
|
| 243 |
+
agent.interact(target, action)
|
| 244 |
+
|
| 245 |
+
# Phase 3: Environmental updates
|
| 246 |
+
for agent in self.agents.values():
|
| 247 |
+
# Resource growth
|
| 248 |
+
agent.state.resources *= (1 + 0.01 * agent.state.stability)
|
| 249 |
+
|
| 250 |
+
# Power calculation
|
| 251 |
+
agent.state.power = agent.state.resources * agent.state.stability
|
| 252 |
+
|
| 253 |
+
# Add noise
|
| 254 |
+
agent.state.hostility += np.random.normal(0, 0.05)
|
| 255 |
+
agent.state.hostility = np.clip(agent.state.hostility, 0, 1)
|
| 256 |
+
|
| 257 |
+
agent.state.cooperation += np.random.normal(0, 0.05)
|
| 258 |
+
agent.state.cooperation = np.clip(agent.state.cooperation, 0, 1)
|
| 259 |
+
|
| 260 |
+
# Save state
|
| 261 |
+
agent.save_state()
|
| 262 |
+
|
| 263 |
+
def run(self, n_steps: int) -> None:
|
| 264 |
+
"""
|
| 265 |
+
Run simulation for multiple steps.
|
| 266 |
+
|
| 267 |
+
Parameters
|
| 268 |
+
----------
|
| 269 |
+
n_steps : int
|
| 270 |
+
Number of time steps
|
| 271 |
+
"""
|
| 272 |
+
for _ in range(n_steps):
|
| 273 |
+
self.step()
|
| 274 |
+
|
| 275 |
+
def get_agent_trajectories(self, agent_id: str) -> Dict[str, List[float]]:
|
| 276 |
+
"""
|
| 277 |
+
Get historical trajectories for an agent.
|
| 278 |
+
|
| 279 |
+
Parameters
|
| 280 |
+
----------
|
| 281 |
+
agent_id : str
|
| 282 |
+
Agent ID
|
| 283 |
+
|
| 284 |
+
Returns
|
| 285 |
+
-------
|
| 286 |
+
dict
|
| 287 |
+
Trajectories of state variables
|
| 288 |
+
"""
|
| 289 |
+
if agent_id not in self.agents:
|
| 290 |
+
return {}
|
| 291 |
+
|
| 292 |
+
agent = self.agents[agent_id]
|
| 293 |
+
history = agent.history
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
'resources': [s.resources for s in history],
|
| 297 |
+
'power': [s.power for s in history],
|
| 298 |
+
'hostility': [s.hostility for s in history],
|
| 299 |
+
'cooperation': [s.cooperation for s in history],
|
| 300 |
+
'stability': [s.stability for s in history]
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
def get_system_state(self) -> Dict[str, Any]:
|
| 304 |
+
"""
|
| 305 |
+
Get current state of entire system.
|
| 306 |
+
|
| 307 |
+
Returns
|
| 308 |
+
-------
|
| 309 |
+
dict
|
| 310 |
+
System state
|
| 311 |
+
"""
|
| 312 |
+
return {
|
| 313 |
+
'time': self.time,
|
| 314 |
+
'n_agents': len(self.agents),
|
| 315 |
+
'total_resources': sum(a.state.resources for a in self.agents.values()),
|
| 316 |
+
'mean_hostility': np.mean([a.state.hostility for a in self.agents.values()]),
|
| 317 |
+
'mean_cooperation': np.mean([a.state.cooperation for a in self.agents.values()]),
|
| 318 |
+
'mean_stability': np.mean([a.state.stability for a in self.agents.values()])
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
def analyze_network(self) -> Dict[str, Any]:
|
| 322 |
+
"""
|
| 323 |
+
Analyze the network of relationships.
|
| 324 |
+
|
| 325 |
+
Returns
|
| 326 |
+
-------
|
| 327 |
+
dict
|
| 328 |
+
Network metrics
|
| 329 |
+
"""
|
| 330 |
+
import networkx as nx
|
| 331 |
+
|
| 332 |
+
# Build network
|
| 333 |
+
G = nx.Graph()
|
| 334 |
+
for agent in self.agents.values():
|
| 335 |
+
G.add_node(agent.agent_id)
|
| 336 |
+
|
| 337 |
+
for agent in self.agents.values():
|
| 338 |
+
for other_id, strength in agent.relationships.items():
|
| 339 |
+
if strength > 0.1: # Only positive relationships
|
| 340 |
+
G.add_edge(agent.agent_id, other_id, weight=strength)
|
| 341 |
+
|
| 342 |
+
# Compute metrics
|
| 343 |
+
return {
|
| 344 |
+
'n_nodes': G.number_of_nodes(),
|
| 345 |
+
'n_edges': G.number_of_edges(),
|
| 346 |
+
'density': nx.density(G),
|
| 347 |
+
'average_clustering': nx.average_clustering(G) if G.number_of_edges() > 0 else 0,
|
| 348 |
+
'connected_components': nx.number_connected_components(G)
|
| 349 |
+
}
|