Spaces:
Running
Running
File size: 5,931 Bytes
3bf8430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import random
from typing import Optional
from ...environment import VerifiableEnvironment
class CantorExpansion_Environment(VerifiableEnvironment) : # Source : https://www.luogu.com.cn/problem/P3477
prompt_template = \
r"""Given a sequence of integers: {A}
Please count the number of distinct permutations of this sequence that are **lexicographically smaller** than the original sequence. Output a single integer — the number of such permutations modulo {MOD}.
Note: Permutations that only differ by the positions of equal elements are considered the **same**."""
def __init__(self,
max_MOD : int = 100000,
wrong_format : float = -1.0, wrong_range : float = -0.5, correct_answer : float = +1.0, wrong_answer : float = 0.0,
**kwargs) :
"""
Initialize the CantorExpansion_Environment instance.
"""
super().__init__(**kwargs)
self.max_MOD = max_MOD
self.rewards = {
"wrong_format" : wrong_format,
"wrong_range" : wrong_range,
"correct_answer" : correct_answer,
"wrong_answer" : wrong_answer,
}
def _generate(self) -> None :
assert "N" in self.parameter, "N is required in parameter"
N = self.parameter["N"]
assert N >= 3, "N should be greater than or equal to 3"
M = random.randint(2, N)
A = self.parameter["A"] = [random.randint(1, M) for _ in range(N)]
MOD = self.parameter["MOD"] = random.randint(2, self.max_MOD)
M = max(A)
# 1. Factor MOD into its prime factors and compute phi(MOD)
ph = MOD
nt = MOD
p_list = []
i = 2
while i * i <= nt:
if nt % i == 0:
p_list.append(i)
ph = ph // i * (i - 1)
while nt % i == 0:
nt //= i
i += 1
if nt > 1:
p_list.append(nt)
ph = ph // nt * (nt - 1)
pc = len(p_list)
# 2. Fenwick tree (BIT) for counting how many of the suffix elements are <= a given value
T = [0] * (M + 1)
def bit_add(x):
while x <= M:
T[x] += 1
x += x & -x
def bit_sum(x):
s = 0
while x > 0:
s += T[x]
x -= x & -x
return s
# 3. Arrays to track multiplicative state modulo MOD
iv = [0] * (N + 2) # iv[k] = modular inverse of k (for k co-prime to MOD), filled on the fly
iv[1] = 1
tp = [0] * pc # exponent counts for each prime in p_list
tc = 1 # current product of all co-prime parts mod MOD
cnt = [0] * (M + 1) # how many times each value appears in the suffix
ans = 0
# Seed with the last element in the permutation
bit_add(A[N-1])
cnt[A[N-1]] += 1
# Process positions from right to left
for idx in range(N - 2, -1, -1):
# w = how many suffix elements are strictly smaller than A[idx]
w = bit_sum(A[idx] - 1)
# 1) Multiply in the next factorial factor: (suffix length)!
k = (N - 1) - idx
tmp = k
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] += 1
tc = tc * tmp % MOD
# 2) Add this element into the BIT and update its count
bit_add(A[idx])
iv[k + 1] = pow(k + 1, ph - 1, MOD) # inverse of k+1, co-prime part only used later
cnt[A[idx]] += 1
# 3) Divide out the new multiplicity factorial factor
tmp = cnt[A[idx]]
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] -= 1
tc = tc * iv[tmp] % MOD
# 4) If there are smaller choices w, add w * (remaining permutations) to the rank
if w > 0:
# multiply by w
tmp = w
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] += 1
tc = tc * tmp % MOD
# compute the current value = tc * ∏ p_i^{tp_i} mod MOD
cur = tc
for j, pj in enumerate(p_list):
if tp[j]:
cur = cur * pow(pj, tp[j], MOD) % MOD
ans = (ans + cur) % MOD
# divide back by w to restore state
tmp = w
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] -= 1
tc = tc * iv[tmp] % MOD
self.parameter["reference_answer"] = ans % MOD
def _prompt_generate(self) -> str :
return self.prompt_template.format(A = ", ".join(map(str, self.parameter["A"])), MOD = self.parameter["MOD"])
def _process(self, answer : Optional[str]) -> Optional[int] :
if answer is not None :
answer = answer.strip()
try :
int_answer = int(answer)
return int_answer
except ValueError :
return None
else :
return None
def scorer(self, output : str) -> float :
processed_result = self.processor(output)
if processed_result is not None :
if not (0 <= processed_result < self.parameter["MOD"]) :
return self.rewards["wrong_range"]
if processed_result == self.parameter["reference_answer"] :
return self.rewards["correct_answer"]
else :
return self.rewards["wrong_answer"]
else :
return self.rewards["wrong_format"] |