import logging import os import numpy as np import torch import torch.nn.functional as F from onsets_and_frames.constants import ( DTW_FACTOR, HOP_LENGTH, MAX_MIDI, MIN_MIDI, N_KEYS, ) def cycle(iterable): while True: for item in iterable: yield item def shift_label(label, shift): if shift == 0: return label assert len(label.shape) == 2 t, p = label.shape keys, instruments = N_KEYS, p // N_KEYS label_zero_pad = torch.zeros(t, instruments, abs(shift), dtype=label.dtype) label = label.reshape(t, instruments, keys) to_cat = ( (label_zero_pad, label[:, :, :-shift]) if shift > 0 else (label[:, :, -shift:], label_zero_pad) ) label = torch.cat(to_cat, dim=-1) return label.reshape(t, p) def get_peaks(notes, win_size, gpu=False): constraints = [] notes = notes.cpu() for i in range(1, win_size + 1): forward = torch.roll(notes, i, 0) forward[:i, ...] = 0 # assume time axis is 0 backward = torch.roll(notes, -i, 0) backward[-i:, ...] = 0 constraints.extend([forward, backward]) res = torch.ones(notes.shape, dtype=bool) for elem in constraints: res = res & (notes >= elem) return res if not gpu else res.cuda() def get_peaks_numpy(notes, win_size): """ Detect peaks in a NumPy array based on a window size. Args: notes (np.ndarray): Input array, shape (frames, ...). win_size (int): Window size for detecting peaks. Returns: np.ndarray: Boolean array indicating peaks, same shape as `notes`. """ # Initialize constraints constraints = [] notes = np.array(notes) # Ensure input is a NumPy array for i in range(1, win_size + 1): # Roll array forward and backward forward = np.roll(notes, i, axis=0) backward = np.roll(notes, -i, axis=0) # Zero out invalid regions forward[:i, ...] = 0 backward[-i:, ...] = 0 constraints.extend([forward, backward]) # Initialize result with all True res = np.ones_like(notes, dtype=bool) # Apply constraints for elem in constraints: res &= notes >= elem return res def get_diff(notes, offset=True): rolled = np.roll(notes, 1, axis=0) rolled[0, ...] = 0 return (rolled & (~notes)) if offset else (notes & (~rolled)) def compress_across_octave(notes): keys = MAX_MIDI - MIN_MIDI + 1 time, instruments = notes.shape[0], notes.shape[1] // keys notes_reshaped = notes.reshape((time, instruments, keys)) notes_reshaped = notes_reshaped.max(axis=1) octaves = keys // 12 res = np.zeros((time, 12), dtype=np.uint8) for i in range(octaves): curr_octave = notes_reshaped[:, i * 12 : (i + 1) * 12] res = np.maximum(res, curr_octave) return res def compress_time(notes, factor): t, p = notes.shape res = np.zeros((t // factor, p), dtype=notes.dtype) for i in range(t // factor): res[i, :] = notes[i * factor : (i + 1) * factor, :].max(axis=0) return res def get_matches(index1, index2): matches = {} for i1, i2 in zip(index1, index2): # matches[i1] = matches.get(i1, []) + [i2] if i1 not in matches: matches[i1] = [] matches[i1].append(i2) return matches """ Extend a temporal range to WINDOW_SIZE_SRC if it is shorter than that. WINDOW_SIZE_SRC defaults to 28 frames for 256 hop length (assuming DTW_FACTOR=3), which is ~0.5 second. """ def get_margin( t_sources, max_len, WINDOW_SIZE_SRC=11 * (512 // HOP_LENGTH) + 2 * DTW_FACTOR ): margin = max(0, (WINDOW_SIZE_SRC - len(t_sources)) // 2) t_sources_left = list(range(max(t_sources[0] - margin, 0), t_sources[0])) t_sources_right = list( range(t_sources[-1], min(t_sources[-1] + margin, max_len - 1)) ) t_sources_extended = t_sources_left + t_sources + t_sources_right return t_sources_extended def get_inactive_instruments(target_onsets, T): keys = MAX_MIDI - MIN_MIDI + 1 time, instruments = target_onsets.shape[0], target_onsets.shape[1] // keys notes_reshaped = target_onsets.reshape((time, instruments, keys)) active_instruments = notes_reshaped.max(axis=(0, 2)) res = np.zeros((T, instruments, keys), dtype=bool) for ins in range(instruments): if active_instruments[ins] == 0: res[:, ins, :] = 1 return res.reshape((T, instruments * keys)), active_instruments def max_inst(probs, threshold_vec=None): if threshold_vec is None: threshold_vec = 0.5 if probs.shape[-1] == N_KEYS or probs.shape[-1] == N_KEYS * 2: # there is only pitch return probs keys = MAX_MIDI - MIN_MIDI + 1 instruments = probs.shape[1] // keys time = len(probs) probs = probs.reshape((time, instruments, keys)) notes = probs.max(axis=1) >= threshold_vec max_instruments = np.argmax(probs[:, :-1, :], axis=1) res = np.zeros(probs.shape, dtype=np.uint8) for t, p in zip(*(notes.nonzero())): res[t, max_instruments[t, p], p] = 1 res[t, -1, p] = 1 return res.reshape((time, instruments * keys)) # Define the smoothing function (operates on CPU) def smooth_labels(onset_tensor): """ Smooths onset labels using a triangular kernel with 1D convolution along the time axis. Args: onset_tensor (torch.Tensor): A (T, F) tensor where T = time steps and F = pitches. Returns: torch.Tensor: Smoothed onset tensor with the same shape (T, F). """ # Define the triangular smoothing kernel # kernel = torch.tensor([0.2, 0.4, 0.6, 0.8, 1, 0.8, 0.6, 0.4, 0.2], # dtype=onset_tensor.dtype).view(1, 1, -1) # kernel = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1], # dtype=onset_tensor.dtype).view(1, 1, -1) kernel = torch.tensor([0.33, 0.67, 1, 0.67, 0.33], dtype=onset_tensor.dtype).view( 1, 1, -1 ) onset_tensor = onset_tensor.T.unsqueeze(1) # Now shape is (F, 1, T) # Use 'same' padding so that the output has the same time dimension as the input. padding = kernel.shape[-1] // 2 smoothed = F.conv1d(onset_tensor, kernel, padding=padding) # Reshape back to original shape (T, F) return smoothed.squeeze(1).T def initialize_logging_system(logdir): """Initialize the logging system once with named loggers for train and dataset.""" log_file = os.path.join(logdir, "training.log") # Create formatter formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # File handler (shared by all loggers) file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) # Console handler (shared by all loggers) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(formatter) # Create train logger train_logger = logging.getLogger("train") train_logger.setLevel(logging.INFO) train_logger.handlers.clear() train_logger.addHandler(file_handler) train_logger.addHandler(console_handler) # Create dataset logger dataset_logger = logging.getLogger("dataset") dataset_logger.setLevel(logging.INFO) dataset_logger.handlers.clear() dataset_logger.addHandler(file_handler) dataset_logger.addHandler(console_handler) return train_logger, dataset_logger def get_logger(name): """Get a named logger. Call initialize_logging_system first.""" return logging.getLogger(name)