ZhiyuanZeng's picture
Upload folder using huggingface_hub
3bf8430 verified
import random
from typing import Optional
from ...environment import VerifiableEnvironment
class CantorExpansion_Environment(VerifiableEnvironment) : # Source : https://www.luogu.com.cn/problem/P3477
prompt_template = \
r"""Given a sequence of integers: {A}
Please count the number of distinct permutations of this sequence that are **lexicographically smaller** than the original sequence. Output a single integer — the number of such permutations modulo {MOD}.
Note: Permutations that only differ by the positions of equal elements are considered the **same**."""
def __init__(self,
max_MOD : int = 100000,
wrong_format : float = -1.0, wrong_range : float = -0.5, correct_answer : float = +1.0, wrong_answer : float = 0.0,
**kwargs) :
"""
Initialize the CantorExpansion_Environment instance.
"""
super().__init__(**kwargs)
self.max_MOD = max_MOD
self.rewards = {
"wrong_format" : wrong_format,
"wrong_range" : wrong_range,
"correct_answer" : correct_answer,
"wrong_answer" : wrong_answer,
}
def _generate(self) -> None :
assert "N" in self.parameter, "N is required in parameter"
N = self.parameter["N"]
assert N >= 3, "N should be greater than or equal to 3"
M = random.randint(2, N)
A = self.parameter["A"] = [random.randint(1, M) for _ in range(N)]
MOD = self.parameter["MOD"] = random.randint(2, self.max_MOD)
M = max(A)
# 1. Factor MOD into its prime factors and compute phi(MOD)
ph = MOD
nt = MOD
p_list = []
i = 2
while i * i <= nt:
if nt % i == 0:
p_list.append(i)
ph = ph // i * (i - 1)
while nt % i == 0:
nt //= i
i += 1
if nt > 1:
p_list.append(nt)
ph = ph // nt * (nt - 1)
pc = len(p_list)
# 2. Fenwick tree (BIT) for counting how many of the suffix elements are <= a given value
T = [0] * (M + 1)
def bit_add(x):
while x <= M:
T[x] += 1
x += x & -x
def bit_sum(x):
s = 0
while x > 0:
s += T[x]
x -= x & -x
return s
# 3. Arrays to track multiplicative state modulo MOD
iv = [0] * (N + 2) # iv[k] = modular inverse of k (for k co-prime to MOD), filled on the fly
iv[1] = 1
tp = [0] * pc # exponent counts for each prime in p_list
tc = 1 # current product of all co-prime parts mod MOD
cnt = [0] * (M + 1) # how many times each value appears in the suffix
ans = 0
# Seed with the last element in the permutation
bit_add(A[N-1])
cnt[A[N-1]] += 1
# Process positions from right to left
for idx in range(N - 2, -1, -1):
# w = how many suffix elements are strictly smaller than A[idx]
w = bit_sum(A[idx] - 1)
# 1) Multiply in the next factorial factor: (suffix length)!
k = (N - 1) - idx
tmp = k
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] += 1
tc = tc * tmp % MOD
# 2) Add this element into the BIT and update its count
bit_add(A[idx])
iv[k + 1] = pow(k + 1, ph - 1, MOD) # inverse of k+1, co-prime part only used later
cnt[A[idx]] += 1
# 3) Divide out the new multiplicity factorial factor
tmp = cnt[A[idx]]
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] -= 1
tc = tc * iv[tmp] % MOD
# 4) If there are smaller choices w, add w * (remaining permutations) to the rank
if w > 0:
# multiply by w
tmp = w
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] += 1
tc = tc * tmp % MOD
# compute the current value = tc * ∏ p_i^{tp_i} mod MOD
cur = tc
for j, pj in enumerate(p_list):
if tp[j]:
cur = cur * pow(pj, tp[j], MOD) % MOD
ans = (ans + cur) % MOD
# divide back by w to restore state
tmp = w
for j, pj in enumerate(p_list):
while tmp % pj == 0:
tmp //= pj
tp[j] -= 1
tc = tc * iv[tmp] % MOD
self.parameter["reference_answer"] = ans % MOD
def _prompt_generate(self) -> str :
return self.prompt_template.format(A = ", ".join(map(str, self.parameter["A"])), MOD = self.parameter["MOD"])
def _process(self, answer : Optional[str]) -> Optional[int] :
if answer is not None :
answer = answer.strip()
try :
int_answer = int(answer)
return int_answer
except ValueError :
return None
else :
return None
def scorer(self, output : str) -> float :
processed_result = self.processor(output)
if processed_result is not None :
if not (0 <= processed_result < self.parameter["MOD"]) :
return self.rewards["wrong_range"]
if processed_result == self.parameter["reference_answer"] :
return self.rewards["correct_answer"]
else :
return self.rewards["wrong_answer"]
else :
return self.rewards["wrong_format"]