Spaces:
Running
Running
| import random | |
| from typing import Optional | |
| from ...environment import VerifiableEnvironment | |
| class CantorExpansion_Environment(VerifiableEnvironment) : # Source : https://www.luogu.com.cn/problem/P3477 | |
| prompt_template = \ | |
| r"""Given a sequence of integers: {A} | |
| Please count the number of distinct permutations of this sequence that are **lexicographically smaller** than the original sequence. Output a single integer — the number of such permutations modulo {MOD}. | |
| Note: Permutations that only differ by the positions of equal elements are considered the **same**.""" | |
| def __init__(self, | |
| max_MOD : int = 100000, | |
| wrong_format : float = -1.0, wrong_range : float = -0.5, correct_answer : float = +1.0, wrong_answer : float = 0.0, | |
| **kwargs) : | |
| """ | |
| Initialize the CantorExpansion_Environment instance. | |
| """ | |
| super().__init__(**kwargs) | |
| self.max_MOD = max_MOD | |
| self.rewards = { | |
| "wrong_format" : wrong_format, | |
| "wrong_range" : wrong_range, | |
| "correct_answer" : correct_answer, | |
| "wrong_answer" : wrong_answer, | |
| } | |
| def _generate(self) -> None : | |
| assert "N" in self.parameter, "N is required in parameter" | |
| N = self.parameter["N"] | |
| assert N >= 3, "N should be greater than or equal to 3" | |
| M = random.randint(2, N) | |
| A = self.parameter["A"] = [random.randint(1, M) for _ in range(N)] | |
| MOD = self.parameter["MOD"] = random.randint(2, self.max_MOD) | |
| M = max(A) | |
| # 1. Factor MOD into its prime factors and compute phi(MOD) | |
| ph = MOD | |
| nt = MOD | |
| p_list = [] | |
| i = 2 | |
| while i * i <= nt: | |
| if nt % i == 0: | |
| p_list.append(i) | |
| ph = ph // i * (i - 1) | |
| while nt % i == 0: | |
| nt //= i | |
| i += 1 | |
| if nt > 1: | |
| p_list.append(nt) | |
| ph = ph // nt * (nt - 1) | |
| pc = len(p_list) | |
| # 2. Fenwick tree (BIT) for counting how many of the suffix elements are <= a given value | |
| T = [0] * (M + 1) | |
| def bit_add(x): | |
| while x <= M: | |
| T[x] += 1 | |
| x += x & -x | |
| def bit_sum(x): | |
| s = 0 | |
| while x > 0: | |
| s += T[x] | |
| x -= x & -x | |
| return s | |
| # 3. Arrays to track multiplicative state modulo MOD | |
| iv = [0] * (N + 2) # iv[k] = modular inverse of k (for k co-prime to MOD), filled on the fly | |
| iv[1] = 1 | |
| tp = [0] * pc # exponent counts for each prime in p_list | |
| tc = 1 # current product of all co-prime parts mod MOD | |
| cnt = [0] * (M + 1) # how many times each value appears in the suffix | |
| ans = 0 | |
| # Seed with the last element in the permutation | |
| bit_add(A[N-1]) | |
| cnt[A[N-1]] += 1 | |
| # Process positions from right to left | |
| for idx in range(N - 2, -1, -1): | |
| # w = how many suffix elements are strictly smaller than A[idx] | |
| w = bit_sum(A[idx] - 1) | |
| # 1) Multiply in the next factorial factor: (suffix length)! | |
| k = (N - 1) - idx | |
| tmp = k | |
| for j, pj in enumerate(p_list): | |
| while tmp % pj == 0: | |
| tmp //= pj | |
| tp[j] += 1 | |
| tc = tc * tmp % MOD | |
| # 2) Add this element into the BIT and update its count | |
| bit_add(A[idx]) | |
| iv[k + 1] = pow(k + 1, ph - 1, MOD) # inverse of k+1, co-prime part only used later | |
| cnt[A[idx]] += 1 | |
| # 3) Divide out the new multiplicity factorial factor | |
| tmp = cnt[A[idx]] | |
| for j, pj in enumerate(p_list): | |
| while tmp % pj == 0: | |
| tmp //= pj | |
| tp[j] -= 1 | |
| tc = tc * iv[tmp] % MOD | |
| # 4) If there are smaller choices w, add w * (remaining permutations) to the rank | |
| if w > 0: | |
| # multiply by w | |
| tmp = w | |
| for j, pj in enumerate(p_list): | |
| while tmp % pj == 0: | |
| tmp //= pj | |
| tp[j] += 1 | |
| tc = tc * tmp % MOD | |
| # compute the current value = tc * ∏ p_i^{tp_i} mod MOD | |
| cur = tc | |
| for j, pj in enumerate(p_list): | |
| if tp[j]: | |
| cur = cur * pow(pj, tp[j], MOD) % MOD | |
| ans = (ans + cur) % MOD | |
| # divide back by w to restore state | |
| tmp = w | |
| for j, pj in enumerate(p_list): | |
| while tmp % pj == 0: | |
| tmp //= pj | |
| tp[j] -= 1 | |
| tc = tc * iv[tmp] % MOD | |
| self.parameter["reference_answer"] = ans % MOD | |
| def _prompt_generate(self) -> str : | |
| return self.prompt_template.format(A = ", ".join(map(str, self.parameter["A"])), MOD = self.parameter["MOD"]) | |
| def _process(self, answer : Optional[str]) -> Optional[int] : | |
| if answer is not None : | |
| answer = answer.strip() | |
| try : | |
| int_answer = int(answer) | |
| return int_answer | |
| except ValueError : | |
| return None | |
| else : | |
| return None | |
| def scorer(self, output : str) -> float : | |
| processed_result = self.processor(output) | |
| if processed_result is not None : | |
| if not (0 <= processed_result < self.parameter["MOD"]) : | |
| return self.rewards["wrong_range"] | |
| if processed_result == self.parameter["reference_answer"] : | |
| return self.rewards["correct_answer"] | |
| else : | |
| return self.rewards["wrong_answer"] | |
| else : | |
| return self.rewards["wrong_format"] |