File size: 39,296 Bytes
4deea85
b9fa3b7
4deea85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
###########################################################################################################################################
#||||- - - |8.19.2025| - - -                                ||   MÖBIUS MARKOV   ||                             - - - |1990two| - - -|||| #
###########################################################################################################################################
"""
Mathematical Foundation & Conceptual Documentation
-------------------------------------------------

CORE PRINCIPLE:
Combines Möbius transformations (complex analysis) with Markov chains to create
probabilistic systems evolving in dynamically warped non-Euclidean state spaces.
The geometry of the state space continuously adapts based on the system's evolution,
enabling rich, non-linear dynamics impossible in traditional Euclidean spaces.

MATHEMATICAL FOUNDATION:
=======================

1. MÖBIUS TRANSFORMATIONS:
   f(z) = (az + b)/(cz + d)
   
   Where:
   - z ∈ ℂ: complex state variable
   - a,b,c,d ∈ ℂ: complex parameters with ad - bc ≠ 0
   - f: ℂ ∪ {∞} → ℂ ∪ {∞} (extended complex plane)
   
   Properties:
   - Conformal mapping (preserves angles)
   - Maps circles/lines to circles/lines
   - Group structure under composition
   - Inverse: f⁻¹(w) = (dw - b)/(-cw + a)

2. COMPLEX STATE SPACE:
   State positions: z₁, z₂, ..., zₙ ∈ ℂ
   
   Transformed positions: w_i = f(z_i) = (az_i + b)/(cz_i + d)
   
   Distance in transformed space: d(w_i, w_j) = |w_i - w_j|

3. MARKOV TRANSITION PROBABILITIES:
   P(i → j) = softmax(β · K(d(w_i, w_j)) + θ_ij)
   
   Where:
   - K(d): distance kernel (Gaussian, inverse, linear)
   - β: distance scaling parameter
   - θ_ij: base transition logits
   - Transformed distances create non-Euclidean transition structure

4. ADAPTIVE GEOMETRY EVOLUTION:
   ∂(a,b,c,d)/∂t = η · G(x_t, E_t)
   
   Where:
   - G: geometry evolution function
   - x_t: current state distribution
   - E_t: embedded state features
   - η: geometry learning rate
   
   The Möbius parameters evolve based on system state.

5. KERNEL FUNCTIONS:
   Gaussian: K(d) = exp(-d²/(2σ²))
   Inverse: K(d) = 1/(d^α + ε)
   Linear: K(d) = max(0, 1 - d)
   
   Different kernels create different transition locality structures.

CONCEPTUAL REASONING:
====================

WHY MÖBIUS + MARKOV?
- Standard Markov chains assume fixed, Euclidean state spaces
- Real systems often have curved, adaptive state geometries
- Möbius transformations provide rich geometric transformations
- Complex analysis offers elegant mathematical framework
- Dynamic geometry enables meta-learning of state representations

KEY INNOVATIONS:
1. **Dynamic Non-Euclidean Geometry**: State space warps over time
2. **Complex State Representations**: Rich 2D embedding in complex plane
3. **Conformal Invariance**: Angle-preserving transformations maintain local structure
4. **Learnable Geometry**: Möbius parameters adapt to data
5. **Multi-Scale Dynamics**: Both local transitions and global geometry evolve

APPLICATIONS:
- Dynamical systems with changing phase spaces
- Neural representations learning geometric structure
- Sequential data with non-stationary transition patterns
- Robotics in environments with changing topology
- Financial modeling with regime changes

COMPLEXITY ANALYSIS:
- Möbius Transform: O(n) for n states
- Distance Computation: O(n²) for all pairs
- Markov Step: O(n²) for transition matrix
- Geometry Evolution: O(1) for parameter updates
- Memory: O(n²) for transition probabilities

BIOLOGICAL INSPIRATION:
- Neural manifold learning in cortical representations
- Synaptic plasticity reshaping connectivity patterns
- Developmental changes in brain network topology
- Spatial navigation with changing environmental maps
- Memory consolidation through representational geometry changes
"""

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

SAFE_MIN = -1e6
SAFE_MAX = 1e6
EPS = 1e-8

#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𝔦 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#

def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
    zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
    maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype)
    tensor = torch.where(torch.isnan(tensor), zero, tensor)
    tensor = torch.where(torch.isinf(tensor), maxv, tensor)
    return torch.clamp(tensor, min_val, max_val)

def safe_complex_division(numerator, denominator, eps=EPS):
    """Safe complex division with numerical stability.
    
    Performs complex division z₁/z₂ = (z₁ * conj(z₂))/(|z₂|²) with
    proper handling of near-zero denominators to prevent numerical instability.
    
    Mathematical Details:
    - z₁/z₂ = (z₁ * z̄₂)/(z₂ * z̄₂) = (z₁ * z̄₂)/|z₂|²
    - |z₂|² = Re(z₂ * z̄₂) clamped to minimum value
    
    Args:
        numerator: Complex numerator tensor
        denominator: Complex denominator tensor  
        eps: Minimum denominator magnitude
        
    Returns:
        Result of safe complex division
    """
    denominator_conj = torch.conj(denominator)
    norm_sq = torch.real(denominator * denominator_conj)
    norm_sq = torch.clamp(norm_sq, min=eps)
    return (numerator * denominator_conj) / norm_sq

###########################################################################################################################################
####################################################- - -   MÖBIUS TRANSFORM   - - -#######################################################

class MobiusTransform(nn.Module):
    """Learnable Möbius transformation for complex plane warping.
    
    Implements a Möbius transformation f(z) = (az + b)/(cz + d) with
    learnable complex parameters. Provides both forward and inverse
    transformations with automatic parameter normalization for stability.
    
    Mathematical Properties:
    - Conformal mapping (preserves angles locally)
    - Maps circles and lines to circles and lines
    - Determinant constraint: ad - bc ≠ 0 for invertibility
    - Group structure: composition of Möbius transforms is Möbius
    
    The transformation can be initialized as identity or with random parameters.
    Parameters are stored as [real, imaginary] pairs and converted to complex
    tensors during computation.
    """
    def __init__(self, learnable=True, init_identity=True):
        super().__init__()
        self.learnable = learnable
        
        if init_identity:
            # Initialize as identity transform: f(z) = z (a=1, b=0, c=0, d=1)
            a_init, b_init, c_init, d_init = 1.0, 0.0, 0.0, 1.0
        else:
            # Random initialization (ensuring ad - bc ≠ 0)
            a_init, d_init = 1.0, 1.0
            b_init, c_init = 0.1, 0.1
        
        if learnable:
            # Möbius parameters: f(z) = (az + b)/(cz + d)
            # Stored as [real, imaginary] pairs for each complex parameter
            self.a = nn.Parameter(torch.tensor([a_init, 0.0]))
            self.b = nn.Parameter(torch.tensor([b_init, 0.0]))
            self.c = nn.Parameter(torch.tensor([c_init, 0.0]))
            self.d = nn.Parameter(torch.tensor([d_init, 0.0]))
        else:
            # Fixed parameters (non-learnable)
            self.register_buffer('a', torch.tensor([a_init, 0.0]))
            self.register_buffer('b', torch.tensor([b_init, 0.0]))
            self.register_buffer('c', torch.tensor([c_init, 0.0]))
            self.register_buffer('d', torch.tensor([d_init, 0.0]))
    
    def to_complex(self, param):
        """Convert [real, imag] tensor to complex tensor.
        
        Args:
            param: Tensor [real_part, imaginary_part]
            
        Returns:
            Complex tensor with real and imaginary components
        """
        return torch.complex(param[0], param[1])
    
    def get_determinant(self):
        """Compute transformation determinant ad - bc.
        
        The determinant must be non-zero for the transformation to be invertible.
        This is automatically enforced by parameter normalization.
        
        Mathematical Details:
        - det(M) = ad - bc where M = [[a,b],[c,d]]
        - det ≠ 0 ensures bijective mapping
        - |det| measures area scaling factor
        
        Returns:
            Complex determinant value
        """
        a_complex = self.to_complex(self.a)
        b_complex = self.to_complex(self.b)
        c_complex = self.to_complex(self.c)
        d_complex = self.to_complex(self.d)
        
        det = a_complex * d_complex - b_complex * c_complex
        return det
    
    def normalize_parameters(self):
        """Ensure parameter validity and numerical stability.
        
        Enforces constraints:
        1. Determinant magnitude > ε (invertibility)
        2. Parameter magnitudes bounded (numerical stability)
        3. Reset to identity if determinant too small
        
        Called automatically before each transformation to maintain stability.
        """
        if self.learnable:
            with torch.no_grad():
                det = torch.abs(self.get_determinant())
                if det < EPS:
                    # Reset to identity if determinant too small
                    one = torch.tensor([1.0, 0.0], device=self.a.device, dtype=self.a.dtype)
                    self.a.copy_(one)
                    self.d.copy_(one)
                    self.b.mul_(0.1)
                    self.c.mul_(0.1)
                
                # Clamp parameter magnitudes for numerical stability
                for p in (self.a, self.b, self.c, self.d):
                    p.clamp_(-10.0, 10.0)
    
    def transform(self, z):
        """Apply Möbius transform: f(z) = (az + b)/(cz + d).
        
        Performs the forward Möbius transformation on complex input.
        Automatically normalizes parameters before computation for stability.
        
        Mathematical Details:
        - Numerator: az + b (complex multiplication and addition)
        - Denominator: cz + d (complex multiplication and addition)  
        - Division: safe complex division with stability checks
        
        Args:
            z: Complex input tensor [..., any_shape]
            
        Returns:
            Transformed complex tensor of same shape
        """
        self.normalize_parameters()
        
        a_complex = self.to_complex(self.a)
        b_complex = self.to_complex(self.b)
        c_complex = self.to_complex(self.c)
        d_complex = self.to_complex(self.d)
        
        # Numerator: az + b
        numerator = a_complex * z + b_complex
        
        # Denominator: cz + d
        denominator = c_complex * z + d_complex
        
        # Safe complex division
        transformed = safe_complex_division(numerator, denominator)
        
        return transformed
    
    def inverse_transform(self, w):
        """Apply inverse Möbius transform: f⁻¹(w) = (dw - b)/(-cw + a).
        
        Computes the inverse transformation to map from transformed space
        back to original space. Uses the standard Möbius inverse formula.
        
        Mathematical Details:
        - For f(z) = (az + b)/(cz + d)
        - f⁻¹(w) = (dw - b)/(-cw + a)
        - Inverse determinant: det(f⁻¹) = 1/det(f)
        
        Args:
            w: Complex input in transformed space
            
        Returns:
            Original complex values before transformation
        """
        self.normalize_parameters()
        
        a_complex = self.to_complex(self.a)
        b_complex = self.to_complex(self.b)
        c_complex = self.to_complex(self.c)
        d_complex = self.to_complex(self.d)
        
        # Inverse: f⁻¹(w) = (dw - b)/(-cw + a)
        numerator = d_complex * w - b_complex
        denominator = -c_complex * w + a_complex
        
        return safe_complex_division(numerator, denominator)
    
    def get_transform_info(self):
        """Get comprehensive information about current transformation.
        
        Returns diagnostic information including determinant, identity check,
        and current parameter values for analysis and debugging.
        
        Returns:
            Dictionary with transformation properties
        """
        det = self.get_determinant()
        one = torch.tensor(1.0, device=det.device, dtype=det.real.dtype)
        return {
            'determinant': det,
            'is_identity': torch.allclose(torch.abs(det), one, atol=1e-6),
            'parameters': {
                'a': self.to_complex(self.a), 
                'b': self.to_complex(self.b),
                'c': self.to_complex(self.c), 
                'd': self.to_complex(self.d)
            }
        }

###########################################################################################################################################
#############################################- - -   COMPLEX STATE MARKOV CHAIN   - - -####################################################

class ComplexStateMarkovChain(nn.Module):
    """Markov chain with complex state positions and distance-based transitions.
    
    Implements a Markov chain where states are positioned in the complex plane
    and transition probabilities depend on distances in the (potentially transformed)
    complex space. Supports multiple distance kernels and learnable parameters.
    
    Mathematical Framework:
    - State positions: z₁, z₂, ..., zₙ ∈ ℂ (learnable)
    - Transformed positions: w_i = f(z_i) via Möbius transform
    - Distance-based transitions: P(i→j) ∝ K(|w_i - w_j|)
    - Multiple kernel options: Gaussian, inverse power, linear
    
    The chain combines base transition logits with distance-based terms,
    allowing both learned and geometric transition structure.
    """
    def __init__(self, num_states, state_embedding_dim=64, distance_kernel='gaussian'):
        super().__init__()
        self.num_states = num_states
        self.state_embedding_dim = state_embedding_dim
        self.distance_kernel = distance_kernel
        
        # Complex state positions (learnable parameters)
        # Each state has a position in the complex plane
        self.state_positions = nn.Parameter(
            torch.complex(
                torch.randn(num_states) * 2.0,  # Real parts
                torch.randn(num_states) * 2.0   # Imaginary parts
            )
        )
        
        # State feature embeddings for additional representation
        self.state_embeddings = nn.Parameter(torch.randn(num_states, state_embedding_dim) * 0.1)
        
        # Learnable transition parameters
        self.base_transition_logits = nn.Parameter(torch.randn(num_states, num_states) * 0.1)
        self.distance_scale = nn.Parameter(torch.tensor(1.0))
        self.distance_bias = nn.Parameter(torch.tensor(0.0))
        
        # Kernel-specific parameters
        if distance_kernel == 'gaussian':
            self.kernel_width = nn.Parameter(torch.tensor(1.0))
        elif distance_kernel == 'inverse':
            self.kernel_power = nn.Parameter(torch.tensor(1.0))
        
    def compute_transformed_distances(self, mobius_transform):
        """Compute pairwise distances between states in transformed space.
        
        Applies the Möbius transformation to all state positions and computes
        the Euclidean distances in the transformed complex plane.
        
        Mathematical Details:
        - Transform: w_i = f(z_i) for each state position z_i
        - Distance: d(i,j) = |w_i - w_j| = |f(z_i) - f(z_j)|
        - Results in [num_states, num_states] distance matrix
        
        Args:
            mobius_transform: MobiusTransform instance
            
        Returns:
            Tuple of (distance_matrix, transformed_positions)
        """
        # Transform all state positions through Möbius map
        transformed_positions = mobius_transform.transform(self.state_positions)
        
        # Compute pairwise distances in transformed space
        pos_i = transformed_positions.unsqueeze(0)  # [1, num_states]
        pos_j = transformed_positions.unsqueeze(1)  # [num_states, 1]
        
        # Complex distance: |w_i - w_j|
        complex_diff = pos_i - pos_j
        distances = torch.abs(complex_diff)
        
        return distances, transformed_positions
    
    def distance_to_probability(self, distances):
        """Convert distances to probability contributions via kernel function.
        
        Applies the selected kernel function to transform geometric distances
        into probability weights. Different kernels create different locality
        structures in the transition probabilities.
        
        Kernel Options:
        - Gaussian: K(d) = exp(-d²/(2σ²)) - smooth, localized
        - Inverse: K(d) = 1/(d^α + ε) - power-law decay
        - Linear: K(d) = max(0, 1-d) - linear decay with cutoff
        
        Args:
            distances: Distance matrix [num_states, num_states]
            
        Returns:
            Probability contribution matrix [num_states, num_states]
        """
        distances = torch.clamp(distances, min=EPS)
        
        if self.distance_kernel == 'gaussian':
            width = torch.clamp(self.kernel_width, min=0.1, max=10.0)
            prob_contrib = torch.exp(-distances**2 / (2 * width**2))
        elif self.distance_kernel == 'inverse':
            power = torch.clamp(self.kernel_power, min=0.5, max=3.0)
            prob_contrib = 1.0 / (distances**power + EPS)
        else:
            # Linear kernel with cutoff
            prob_contrib = torch.clamp(1.0 - distances, min=0.0)
        
        return prob_contrib
    
    def compute_transition_matrix(self, mobius_transform):
        """Compute full transition matrix combining geometry and learned weights.
        
        Creates the Markov transition matrix by combining:
        1. Base transition logits (learned affinities)
        2. Distance-based contributions (geometric structure)
        3. Scaling and bias parameters
        4. Softmax normalization for valid probabilities
        
        Mathematical Details:
        - Base logits: θᵢⱼ (learned transition preferences)
        - Distance terms: β·K(d(wᵢ,wⱼ)) + γ (scaled kernel)
        - Combined: logits = θᵢⱼ + β·K(d(wᵢ,wⱼ)) + γ
        - Probabilities: P(i→j) = softmax_j(logits_i)
        
        Args:
            mobius_transform: MobiusTransform for state space warping
            
        Returns:
            Tuple of (transition_matrix, transformed_positions)
        """
        # Get transformed distances
        distances, transformed_positions = self.compute_transformed_distances(mobius_transform)
        
        # Convert distances to probability contributions
        distance_contrib = self.distance_to_probability(distances)
        
        # Scale and bias distance contributions
        scale = torch.clamp(self.distance_scale, min=0.1, max=10.0)
        bias = torch.clamp(self.distance_bias, min=-5.0, max=5.0)
        scaled_distance = scale * distance_contrib + bias
        
        # Combine with base transition logits
        transition_logits = self.base_transition_logits + scaled_distance
        
        # Add small diagonal bias for numerical stability
        transition_logits = transition_logits + torch.eye(self.num_states, device=transition_logits.device) * 0.05
        
        # Convert to valid probability matrix (row-stochastic)
        transition_matrix = F.softmax(transition_logits, dim=1)
        
        return transition_matrix, transformed_positions
    
    def forward(self, initial_state, num_steps, mobius_transform):
        """Execute Markov chain evolution in transformed space.
        
        Runs the Markov chain for the specified number of steps using
        transition probabilities computed in the Möbius-transformed space.
        Records complete trajectory for analysis.
        
        Mathematical Process:
        1. Compute transition matrix P in transformed space
        2. Evolve state: s_{t+1} = s_t · P
        3. Track most likely state positions over time
        4. Return complete trajectory and final state
        
        Args:
            initial_state: Initial state distribution [batch_size, num_states]
            num_steps: Number of Markov steps to execute
            mobius_transform: MobiusTransform for space warping
            
        Returns:
            Dictionary containing trajectory, final state, and diagnostics
        """
        batch_size = initial_state.shape[0] if initial_state.dim() > 1 else 1
        
        if initial_state.dim() == 1:
            current_state = initial_state.unsqueeze(0)
        else:
            current_state = initial_state
        
        # Get transition matrix in transformed space
        transition_matrix, transformed_positions = self.compute_transition_matrix(mobius_transform)
        
        # Store complete trajectory
        trajectory = [current_state.clone()]
        state_positions = [transformed_positions[current_state.argmax(dim=-1)]]
        
        # Execute Markov chain evolution
        for step in range(num_steps):
            # Matrix multiplication for batch processing
            current_state = torch.matmul(current_state, transition_matrix)
            trajectory.append(current_state.clone())
            
            # Track most likely state positions in transformed space
            most_likely_states = current_state.argmax(dim=-1)
            state_positions.append(transformed_positions[most_likely_states])
        
        return {
            'trajectory': torch.stack(trajectory),
            'final_state': current_state,
            'state_positions': torch.stack(state_positions),
            'transition_matrix': transition_matrix,
            'transformed_positions': transformed_positions
        }

###########################################################################################################################################
############################################- - -   MÖBIUS MARKOV SYSTEM   - - -###########################################################

class MobiusMarkovSystem(nn.Module):
    """Complete system integrating Möbius transformations with Markov dynamics.
    
    Implements the full Möbius-Markov architecture where:
    1. Complex state positions define Markov chain geometry
    2. Möbius transformations dynamically warp the state space
    3. Transformation parameters evolve based on system state
    4. State transitions depend on transformed distances
    
    This creates a meta-learning system where the geometry of state space
    adapts based on the system's evolutionary trajectory, enabling
    discovery of optimal representational geometries for different tasks.
    
    Architecture Components:
    - MobiusTransform: Learnable complex plane warping
    - ComplexStateMarkovChain: Distance-based probabilistic transitions
    - Evolution dynamics: State-dependent geometry adaptation
    - Encoder/decoder: Interface with external representations
    """
    def __init__(self, num_states, state_embedding_dim=64, evolution_steps=10):
        super().__init__()
        self.num_states = num_states
        self.evolution_steps = evolution_steps
        
        # Core components
        self.mobius_transform = MobiusTransform(learnable=True, init_identity=True)
        self.markov_chain = ComplexStateMarkovChain(num_states, state_embedding_dim)
        
        # Evolution dynamics for Möbius parameters
        # Maps state embeddings to parameter updates
        self.mobius_evolution = nn.Sequential(
            nn.Linear(state_embedding_dim, state_embedding_dim),
            nn.Tanh(),
            nn.Linear(state_embedding_dim, 8),  # 4 complex parameters = 8 real values
        )
        
        # State encoder/decoder for external interface
        self.state_encoder = nn.Sequential(
            nn.Linear(num_states, state_embedding_dim),
            nn.LayerNorm(state_embedding_dim),
            nn.ReLU(),
            nn.Linear(state_embedding_dim, state_embedding_dim)
        )
        
        self.state_decoder = nn.Sequential(
            nn.Linear(state_embedding_dim, state_embedding_dim),
            nn.ReLU(),
            nn.Linear(state_embedding_dim, num_states),
            nn.Softmax(dim=-1)
        )
        
        # Geometric evolution controller
        self.geometry_controller = nn.Parameter(torch.tensor(0.1))
        
    def evolve_mobius_parameters(self, state_embedding):
        """Evolve Möbius transformation parameters based on current state.
        
        Updates the complex parameters (a,b,c,d) of the Möbius transformation
        based on the current state embedding. This enables the geometry to
        adapt dynamically as the system evolves.
        
        Mathematical Details:
        - Parameter updates: Δp = η·G(embedding)
        - G: neural network mapping embeddings to parameter changes
        - η: learnable evolution rate
        - Updates applied directly to parameter tensors
        
        The evolution is designed to be conservative (small steps) to
        maintain stability while allowing adaptive geometry.
        
        Args:
            state_embedding: Current state representation [embedding_dim]
        """
        # Generate parameter evolution signal
        evolution_signal = self.mobius_evolution(state_embedding)
        evolution_rate = torch.clamp(self.geometry_controller, 0.01, 1.0)
        
        if self.mobius_transform.learnable:
            with torch.no_grad():
                # Reshape evolution signal to 4 complex parameters (8 real values)
                updates = (evolution_signal.view(4, 2) * evolution_rate * 0.01)\
                            .to(device=self.mobius_transform.a.device, dtype=self.mobius_transform.a.dtype)
                
                # Apply updates to Möbius parameters
                self.mobius_transform.a.add_(updates[0])
                self.mobius_transform.b.add_(updates[1])
                self.mobius_transform.c.add_(updates[2])
                self.mobius_transform.d.add_(updates[3])
                
                # Ensure parameters remain valid
                self.mobius_transform.normalize_parameters()
    
    def forward(self, initial_state, return_full_trajectory=False):
        """Execute complete Möbius-Markov evolution cycle.
        
        Implements the full system dynamics:
        1. Encode initial state to embedding space
        2. Iteratively evolve: geometry → Markov step → state update
        3. Adapt Möbius parameters based on current state
        4. Generate final prediction with learned decoder
        
        This creates a feedback loop where the geometry of state space
        continuously adapts based on the system's trajectory, enabling
        meta-learning of optimal spatial representations.
        
        Args:
            initial_state: Initial state distribution [batch_size, num_states]
            return_full_trajectory: Whether to return complete evolution history
            
        Returns:
            Dictionary containing final state, prediction, and optional trajectory
        """
        # Encode initial state to embedding space
        state_embedding = self.state_encoder(initial_state)
        
        # Store complete evolution history if requested
        evolution_history = {
            'states': [],
            'geometries': [],
            'transition_matrices': [],
            'transformed_positions': []
        }
        
        current_state = initial_state
        
        # Multi-step evolution loop
        for step in range(self.evolution_steps):
            # Encode current state
            state_embedding = self.state_encoder(current_state)
            
            # Evolve Möbius parameters based on current state
            self.evolve_mobius_parameters(state_embedding.mean(dim=0))
            
            # Execute one Markov step in current transformed space
            markov_output = self.markov_chain.forward(
                current_state, 
                num_steps=1, 
                mobius_transform=self.mobius_transform
            )
            
            # Update current state distribution
            current_state = markov_output['final_state']
            
            # Store evolution history
            if return_full_trajectory:
                evolution_history['states'].append(current_state.clone())
                evolution_history['geometries'].append(self.mobius_transform.get_transform_info())
                evolution_history['transition_matrices'].append(markov_output['transition_matrix'])
                evolution_history['transformed_positions'].append(markov_output['transformed_positions'])
        
        # Generate final prediction using learned decoder
        final_embedding = self.state_encoder(current_state)
        final_prediction = self.state_decoder(final_embedding)
        
        output = {
            'final_state': current_state,
            'final_prediction': final_prediction,
            'final_embedding': final_embedding,
            'final_geometry': self.mobius_transform.get_transform_info()
        }
        
        if return_full_trajectory:
            output['evolution_history'] = evolution_history
        
        return output
    
    def predict_sequence(self, initial_state, sequence_length):
        """Generate sequence of predictions through iterative evolution.
        
        Uses the system as a generative model by iteratively applying
        the evolution process and collecting predictions at each step.
        
        Args:
            initial_state: Starting state distribution
            sequence_length: Number of prediction steps
            
        Returns:
            Tensor of predictions [sequence_length, batch_size, num_states]
        """
        predictions = []
        current_state = initial_state
        
        for _ in range(sequence_length):
            output = self.forward(current_state)
            predictions.append(output['final_prediction'])
            current_state = output['final_state']
        
        return torch.stack(predictions)
    
    def get_system_info(self):
        """Get comprehensive system state information.
        
        Returns diagnostic information about the current system state
        including geometry parameters, state positions, and evolution settings.
        
        Returns:
            Dictionary with complete system diagnostics
        """
        return {
            'num_states': self.num_states,
            'evolution_steps': self.evolution_steps,
            'current_geometry': self.mobius_transform.get_transform_info(),
            'state_positions': self.markov_chain.state_positions,
            'geometry_evolution_rate': self.geometry_controller.item()
        }

###########################################################################################################################################
###################################################- - -   DEMO AND TESTING   - - -########################################################

def test_mobius_markov():
    """Comprehensive test of Möbius-Markov system functionality."""
    print("Testing Möbius Markov - Non-Euclidean Probabilistic Systems")
    print("=" * 75)
    
    # Create Möbius-Markov system
    num_states = 8
    system = MobiusMarkovSystem(
        num_states=num_states,
        state_embedding_dim=32,
        evolution_steps=5
    )
    
    print(f"Created Möbius-Markov System:")
    print(f"  - Number of states: {num_states}")
    print(f"  - Evolution steps: {system.evolution_steps}")
    print(f"  - State embedding dimension: 32")
    print(f"  - Complex state space with adaptive geometry")
    
    # Create initial state distribution
    batch_size = 4
    initial_state = torch.zeros(batch_size, num_states)
    initial_state[:, 0] = 1.0  # Start in state 0
    
    print(f"\nTesting with batch size: {batch_size}")
    print("Initial state: All samples start in state 0")
    
    # Test Möbius transformation components
    print("\nTesting Möbius transformation...")
    mobius = system.mobius_transform
    test_complex = torch.complex(torch.randn(5), torch.randn(5))
    transformed = mobius.transform(test_complex)
    inverse_transformed = mobius.inverse_transform(transformed)
    
    # Check invertibility
    reconstruction_error = torch.mean(torch.abs(test_complex - inverse_transformed))
    print(f"  - Transformation invertibility error: {reconstruction_error:.6f}")
    
    # Test geometry info
    geometry_info = mobius.get_transform_info()
    det_magnitude = torch.abs(geometry_info['determinant'])
    print(f"  - Determinant magnitude: {det_magnitude:.4f}")
    print(f"  - Is identity: {geometry_info['is_identity']}")
    
    # Test forward evolution
    print("\nExecuting Möbius-Markov evolution...")
    output = system(initial_state, return_full_trajectory=True)
    
    print("Evolution results:")
    print(f"  - Final state shape: {output['final_state'].shape}")
    print(f"  - Final prediction shape: {output['final_prediction'].shape}")
    
    # Analyze geometric evolution
    print("\nGeometric evolution analysis:")
    history = output['evolution_history']
    
    for step in range(min(3, len(history['geometries']))):
        geometry = history['geometries'][step]
        det = torch.abs(geometry['determinant'])
        print(f"  Step {step+1}: Determinant magnitude = {det:.3f}")
    
    # Test state transitions and trajectory
    print("\nState transition analysis:")
    final_states = output['final_state']
    for i in range(min(batch_size, 3)):
        most_likely = final_states[i].argmax().item()
        confidence = final_states[i].max().item()
        print(f"  Sample {i+1}: Most likely state = {most_likely}, Confidence = {confidence:.3f}")
    
    # Test complex state positions
    print("\nComplex state space analysis:")
    state_positions = system.markov_chain.state_positions
    print(f"  - State positions in complex plane:")
    for i in range(min(4, num_states)):
        pos = state_positions[i]
        real, imag = pos.real.item(), pos.imag.item()
        magnitude = torch.abs(pos).item()
        print(f"    State {i}: {real:.3f} + {imag:.3f}i (|z| = {magnitude:.3f})")
    
    # Test sequence prediction
    print("\n Testing sequence prediction...")
    sequence_length = 3
    sequence = system.predict_sequence(initial_state[:1], sequence_length)
    
    print(f"Generated sequence of length {sequence_length}:")
    for t in range(sequence_length):
        most_likely = sequence[t, 0].argmax().item()
        confidence = sequence[t, 0].max().item()
        print(f"  Time {t+1}: State {most_likely} (confidence: {confidence:.3f})")
    
    # System information and diagnostics
    info = system.get_system_info()
    print(f"\nSystem diagnostics:")
    print(f"  - Geometry evolution rate: {info['geometry_evolution_rate']:.4f}")
    print(f"  - Current determinant: {torch.abs(info['current_geometry']['determinant']):.3f}")
    
    # Test adaptive behavior with different inputs
    print("\nTesting adaptive geometry...")
    
    # Concentrated initial state
    concentrated_state = torch.zeros(1, num_states)
    concentrated_state[0, 0] = 1.0
    conc_output = system(concentrated_state)
    
    # Uniform initial state  
    uniform_state = torch.ones(1, num_states) / num_states
    uniform_output = system(uniform_state)
    
    conc_det = torch.abs(conc_output['final_geometry']['determinant'])
    uniform_det = torch.abs(uniform_output['final_geometry']['determinant'])
    
    print(f"  - Concentrated input → final determinant: {conc_det:.4f}")
    print(f"  - Uniform input → final determinant: {uniform_det:.4f}")
    print(f"  - Geometry adaptation difference: {abs(conc_det - uniform_det):.4f}")
    
    print("\n Möbius-Markov test completed!")
    print("✓ Non-Euclidean state space with dynamic geometry")
    print("✓ Markov transitions in continuously warped space")  
    print("✓ Learnable Möbius transformations with invertibility")
    print("✓ State-dependent geometric evolution")
    print("✓ Complex plane representations and distance-based transitions")
    print("✓ Adaptive spatial structure for different input patterns")
    
    return True

def visualization_demo():
    """Demonstrate geometric transformations and state evolution."""
    print("\n" + "="*60)
    print(" GEOMETRIC TRANSFORMATION DEMO")
    print("="*60)
    
    # Create simple system for clear visualization
    system = MobiusMarkovSystem(num_states=6, evolution_steps=3)
    
    # Get initial state positions in complex plane
    initial_positions = system.markov_chain.state_positions.detach()
    print("Initial state positions (complex plane):")
    for i, pos in enumerate(initial_positions):
        real, imag = pos.real.item(), pos.imag.item()
        magnitude = torch.abs(pos).item()
        angle = torch.angle(pos).item() * 180 / math.pi
        print(f"  State {i}: {real:.3f} + {imag:.3f}i (r={magnitude:.3f}, θ={angle:.1f}°)")
    
    # Apply several geometric transformations
    print("\nApplying geometric transformations...")
    
    test_state = torch.zeros(1, 6)
    test_state[0, 0] = 1.0
    
    for step in range(3):
        # Get current geometry
        geometry = system.mobius_transform.get_transform_info()
        
        # Transform state positions to show warping
        transformed_pos = system.mobius_transform.transform(initial_positions)
        
        print(f"\nStep {step+1}:")
        print(f"  Transform determinant: {torch.abs(geometry['determinant']):.3f}")
        
        # Show how first few states are transformed
        for i in range(min(3, len(transformed_pos))):
            orig = initial_positions[i]
            trans = transformed_pos[i]
            print(f"    State {i}: {orig.real:.2f}+{orig.imag:.2f}i → {trans.real:.2f}+{trans.imag:.2f}i")
                
        # Evolve the system one step
        output = system(test_state)
        test_state = output['final_state']
    
    print("\n Geometric evolution creates rich, non-Euclidean probabilistic dynamics!")
    print("   State space continuously warps based on system trajectory")
    print("   Distance-based transitions adapt to transformed geometry")

if __name__ == "__main__":
    test_mobius_markov()
    visualization_demo()

###########################################################################################################################################
###########################################################################################################################################