BeatNet / scripts /patch_madmom.py
ellagranger's picture
Docker output fix & BeatNet fix
384d531
# scripts/patch_madmom.py
import re
import site
from pathlib import Path
FUTURE_LINE = "from __future__ import absolute_import, division, print_function"
ALIASES = [
(re.compile(r"\bnp\.float\b"), "float"),
(re.compile(r"\bnp\.int\b"), "int"),
(re.compile(r"\bnp\.bool\b"), "bool"),
(re.compile(r"\bnp\.object\b"), "object"),
(re.compile(r"\bnp\.complex\b"), "complex"),
]
def find_madmom_root() -> Path:
for p in site.getsitepackages():
cand = Path(p) / "madmom"
if cand.exists():
return cand
raise RuntimeError("madmom package not found in site-packages")
def patch_processors(madmom_root: Path) -> None:
proc = madmom_root / "processors.py"
if not proc.exists():
return
s = proc.read_text(encoding="utf-8")
s2 = s.replace(
"from collections import MutableSequence",
"from collections.abc import MutableSequence",
)
if s2 != s:
proc.write_text(s2, encoding="utf-8")
def patch_numpy_aliases(madmom_root: Path) -> None:
patched = 0
for f in madmom_root.rglob("*.py"):
try:
s = f.read_text(encoding="utf-8")
except Exception:
continue
s2 = s
for pat, repl in ALIASES:
s2 = pat.sub(repl, s2)
if s2 != s:
f.write_text(s2, encoding="utf-8")
patched += 1
print(f"Patched NumPy aliases in {patched} madmom files")
def patch_utils_future_safe(madmom_root: Path) -> None:
u = madmom_root / "utils" / "__init__.py"
if not u.exists():
raise RuntimeError("madmom/utils/__init__.py not found")
s = u.read_text(encoding="utf-8")
# Remove any existing future import (we'll reinsert correctly)
s = re.sub(
r"^\s*from __future__ import absolute_import,\s*division,\s*print_function\s*\n",
"",
s,
flags=re.M,
)
lines = s.splitlines(True)
# Find insertion point after shebang/encoding/docstring
i = 0
if i < len(lines) and lines[i].startswith("#!"):
i += 1
if i < len(lines) and re.match(r"^#.*coding[:=]\s*[-\w.]+", lines[i]):
i += 1
if i < len(lines) and re.match(r'^\s*[ruRU]?[\'"]{3}', lines[i]):
quote = lines[i].lstrip()[0:3] # ''' or """
i += 1
while i < len(lines) and quote not in lines[i]:
i += 1
if i < len(lines):
i += 1
# Insert future import at the right place
lines.insert(i, FUTURE_LINE + "\n")
s = "".join(lines)
compat = """
# --- Compatibility for Python 3 ---
try:
basestring
except NameError:
basestring = str
try:
integer
except NameError:
integer = int
# ---------------------------------
"""
if "basestring = str" not in s:
s = s.replace(FUTURE_LINE + "\n", FUTURE_LINE + "\n" + compat)
s = re.sub(r"string_types\s*=\s*basestring", "string_types = (str,)", s)
s = re.sub(r"integer_types\s*=\s*\(int,\s*integer\)", "integer_types = (int,)", s)
u.write_text(s, encoding="utf-8")
def main() -> None:
madmom_root = find_madmom_root()
patch_processors(madmom_root)
# patch_numpy_aliases(madmom_root)
patch_utils_future_safe(madmom_root)
print("madmom patched OK:", madmom_root)
if __name__ == "__main__":
main()