File size: 5,931 Bytes
3bf8430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import random
from typing import Optional
from ...environment import VerifiableEnvironment


class CantorExpansion_Environment(VerifiableEnvironment) : # Source : https://www.luogu.com.cn/problem/P3477
    prompt_template = \
r"""Given a sequence of integers: {A}

Please count the number of distinct permutations of this sequence that are **lexicographically smaller** than the original sequence. Output a single integer — the number of such permutations modulo {MOD}.
Note: Permutations that only differ by the positions of equal elements are considered the **same**."""

    def __init__(self,
                 max_MOD : int = 100000,
                 wrong_format : float = -1.0, wrong_range : float = -0.5, correct_answer : float = +1.0, wrong_answer : float = 0.0,
                 **kwargs) :
        """
        Initialize the CantorExpansion_Environment instance.
        """
        super().__init__(**kwargs)

        self.max_MOD = max_MOD
        self.rewards = {
            "wrong_format" : wrong_format,
            "wrong_range" : wrong_range,
            "correct_answer" : correct_answer,
            "wrong_answer" : wrong_answer,
        }
    

    def _generate(self) -> None :
        assert "N" in self.parameter, "N is required in parameter"
        N = self.parameter["N"]
        assert N >= 3, "N should be greater than or equal to 3"

        M = random.randint(2, N)
        A = self.parameter["A"] = [random.randint(1, M) for _ in range(N)]
        MOD = self.parameter["MOD"] = random.randint(2, self.max_MOD)


        M = max(A)

        # 1. Factor MOD into its prime factors and compute phi(MOD)
        ph = MOD
        nt = MOD
        p_list = []
        i = 2
        while i * i <= nt:
            if nt % i == 0:
                p_list.append(i)
                ph = ph // i * (i - 1)
                while nt % i == 0:
                    nt //= i
            i += 1
        if nt > 1:
            p_list.append(nt)
            ph = ph // nt * (nt - 1)
        pc = len(p_list)

        # 2. Fenwick tree (BIT) for counting how many of the suffix elements are <= a given value
        T = [0] * (M + 1)
        def bit_add(x):
            while x <= M:
                T[x] += 1
                x += x & -x
        def bit_sum(x):
            s = 0
            while x > 0:
                s += T[x]
                x -= x & -x
            return s

        # 3. Arrays to track multiplicative state modulo MOD
        iv = [0] * (N + 2)    # iv[k] = modular inverse of k (for k co-prime to MOD), filled on the fly
        iv[1] = 1
        tp = [0] * pc         # exponent counts for each prime in p_list
        tc = 1                # current product of all co-prime parts mod MOD
        cnt = [0] * (M + 1)   # how many times each value appears in the suffix

        ans = 0

        # Seed with the last element in the permutation
        bit_add(A[N-1])
        cnt[A[N-1]] += 1

        # Process positions from right to left
        for idx in range(N - 2, -1, -1):
            # w = how many suffix elements are strictly smaller than A[idx]
            w = bit_sum(A[idx] - 1)

            # 1) Multiply in the next factorial factor: (suffix length)!
            k = (N - 1) - idx
            tmp = k
            for j, pj in enumerate(p_list):
                while tmp % pj == 0:
                    tmp //= pj
                    tp[j] += 1
            tc = tc * tmp % MOD

            # 2) Add this element into the BIT and update its count
            bit_add(A[idx])
            iv[k + 1] = pow(k + 1, ph - 1, MOD)  # inverse of k+1, co-prime part only used later
            cnt[A[idx]] += 1

            # 3) Divide out the new multiplicity factorial factor
            tmp = cnt[A[idx]]
            for j, pj in enumerate(p_list):
                while tmp % pj == 0:
                    tmp //= pj
                    tp[j] -= 1
            tc = tc * iv[tmp] % MOD

            # 4) If there are smaller choices w, add w * (remaining permutations) to the rank
            if w > 0:
                # multiply by w
                tmp = w
                for j, pj in enumerate(p_list):
                    while tmp % pj == 0:
                        tmp //= pj
                        tp[j] += 1
                tc = tc * tmp % MOD

                # compute the current value = tc * ∏ p_i^{tp_i} mod MOD
                cur = tc
                for j, pj in enumerate(p_list):
                    if tp[j]:
                        cur = cur * pow(pj, tp[j], MOD) % MOD
                ans = (ans + cur) % MOD

                # divide back by w to restore state
                tmp = w
                for j, pj in enumerate(p_list):
                    while tmp % pj == 0:
                        tmp //= pj
                        tp[j] -= 1
                tc = tc * iv[tmp] % MOD

        self.parameter["reference_answer"] = ans % MOD
    

    def _prompt_generate(self) -> str :
        return self.prompt_template.format(A = ", ".join(map(str, self.parameter["A"])), MOD = self.parameter["MOD"])
    

    def _process(self, answer : Optional[str]) -> Optional[int] :
        if answer is not None :
            answer = answer.strip()
            try :
                int_answer = int(answer)
                return int_answer
            except ValueError :
                return None
        else :
            return None
    

    def scorer(self, output : str) -> float :
        processed_result = self.processor(output)
        if processed_result is not None :
            if not (0 <= processed_result < self.parameter["MOD"]) :
                return self.rewards["wrong_range"]
            if processed_result == self.parameter["reference_answer"] :
                return self.rewards["correct_answer"]
            else :
                return self.rewards["wrong_answer"]
        else :
            return self.rewards["wrong_format"]