File size: 7,731 Bytes
05d6e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F

from onsets_and_frames.constants import (
    DTW_FACTOR,
    HOP_LENGTH,
    MAX_MIDI,
    MIN_MIDI,
    N_KEYS,
)


def cycle(iterable):
    while True:
        for item in iterable:
            yield item


def shift_label(label, shift):
    if shift == 0:
        return label
    assert len(label.shape) == 2
    t, p = label.shape
    keys, instruments = N_KEYS, p // N_KEYS
    label_zero_pad = torch.zeros(t, instruments, abs(shift), dtype=label.dtype)
    label = label.reshape(t, instruments, keys)
    to_cat = (
        (label_zero_pad, label[:, :, :-shift])
        if shift > 0
        else (label[:, :, -shift:], label_zero_pad)
    )
    label = torch.cat(to_cat, dim=-1)
    return label.reshape(t, p)


def get_peaks(notes, win_size, gpu=False):
    constraints = []
    notes = notes.cpu()
    for i in range(1, win_size + 1):
        forward = torch.roll(notes, i, 0)
        forward[:i, ...] = 0  # assume time axis is 0
        backward = torch.roll(notes, -i, 0)
        backward[-i:, ...] = 0
        constraints.extend([forward, backward])
    res = torch.ones(notes.shape, dtype=bool)
    for elem in constraints:
        res = res & (notes >= elem)
    return res if not gpu else res.cuda()


def get_peaks_numpy(notes, win_size):
    """
    Detect peaks in a NumPy array based on a window size.

    Args:
        notes (np.ndarray): Input array, shape (frames, ...).
        win_size (int): Window size for detecting peaks.

    Returns:
        np.ndarray: Boolean array indicating peaks, same shape as `notes`.
    """
    # Initialize constraints
    constraints = []
    notes = np.array(notes)  # Ensure input is a NumPy array

    for i in range(1, win_size + 1):
        # Roll array forward and backward
        forward = np.roll(notes, i, axis=0)
        backward = np.roll(notes, -i, axis=0)

        # Zero out invalid regions
        forward[:i, ...] = 0
        backward[-i:, ...] = 0

        constraints.extend([forward, backward])

    # Initialize result with all True
    res = np.ones_like(notes, dtype=bool)

    # Apply constraints
    for elem in constraints:
        res &= notes >= elem

    return res


def get_diff(notes, offset=True):
    rolled = np.roll(notes, 1, axis=0)
    rolled[0, ...] = 0
    return (rolled & (~notes)) if offset else (notes & (~rolled))


def compress_across_octave(notes):
    keys = MAX_MIDI - MIN_MIDI + 1
    time, instruments = notes.shape[0], notes.shape[1] // keys
    notes_reshaped = notes.reshape((time, instruments, keys))
    notes_reshaped = notes_reshaped.max(axis=1)
    octaves = keys // 12
    res = np.zeros((time, 12), dtype=np.uint8)
    for i in range(octaves):
        curr_octave = notes_reshaped[:, i * 12 : (i + 1) * 12]
        res = np.maximum(res, curr_octave)
    return res


def compress_time(notes, factor):
    t, p = notes.shape
    res = np.zeros((t // factor, p), dtype=notes.dtype)
    for i in range(t // factor):
        res[i, :] = notes[i * factor : (i + 1) * factor, :].max(axis=0)
    return res


def get_matches(index1, index2):
    matches = {}
    for i1, i2 in zip(index1, index2):
        # matches[i1] = matches.get(i1, []) + [i2]
        if i1 not in matches:
            matches[i1] = []
        matches[i1].append(i2)
    return matches


"""
Extend a temporal range to WINDOW_SIZE_SRC if it is shorter than that.
WINDOW_SIZE_SRC defaults to 28 frames for 256 hop length (assuming DTW_FACTOR=3), which is ~0.5 second.
"""


def get_margin(
    t_sources, max_len, WINDOW_SIZE_SRC=11 * (512 // HOP_LENGTH) + 2 * DTW_FACTOR
):
    margin = max(0, (WINDOW_SIZE_SRC - len(t_sources)) // 2)
    t_sources_left = list(range(max(t_sources[0] - margin, 0), t_sources[0]))
    t_sources_right = list(
        range(t_sources[-1], min(t_sources[-1] + margin, max_len - 1))
    )
    t_sources_extended = t_sources_left + t_sources + t_sources_right
    return t_sources_extended


def get_inactive_instruments(target_onsets, T):
    keys = MAX_MIDI - MIN_MIDI + 1
    time, instruments = target_onsets.shape[0], target_onsets.shape[1] // keys
    notes_reshaped = target_onsets.reshape((time, instruments, keys))
    active_instruments = notes_reshaped.max(axis=(0, 2))
    res = np.zeros((T, instruments, keys), dtype=bool)
    for ins in range(instruments):
        if active_instruments[ins] == 0:
            res[:, ins, :] = 1
    return res.reshape((T, instruments * keys)), active_instruments


def max_inst(probs, threshold_vec=None):
    if threshold_vec is None:
        threshold_vec = 0.5
    if probs.shape[-1] == N_KEYS or probs.shape[-1] == N_KEYS * 2:
        # there is only pitch
        return probs
    keys = MAX_MIDI - MIN_MIDI + 1
    instruments = probs.shape[1] // keys
    time = len(probs)
    probs = probs.reshape((time, instruments, keys))
    notes = probs.max(axis=1) >= threshold_vec
    max_instruments = np.argmax(probs[:, :-1, :], axis=1)
    res = np.zeros(probs.shape, dtype=np.uint8)
    for t, p in zip(*(notes.nonzero())):
        res[t, max_instruments[t, p], p] = 1
        res[t, -1, p] = 1
    return res.reshape((time, instruments * keys))


# Define the smoothing function (operates on CPU)
def smooth_labels(onset_tensor):
    """
    Smooths onset labels using a triangular kernel with 1D convolution along the time axis.

    Args:
        onset_tensor (torch.Tensor): A (T, F) tensor where T = time steps and F = pitches.

    Returns:
        torch.Tensor: Smoothed onset tensor with the same shape (T, F).
    """
    # Define the triangular smoothing kernel
    # kernel = torch.tensor([0.2, 0.4, 0.6, 0.8, 1, 0.8, 0.6, 0.4, 0.2],
    #                       dtype=onset_tensor.dtype).view(1, 1, -1)
    # kernel = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5,  0.6, 0.7,  0.8, 0.9,  1, 0.9,  0.8, 0.7,  0.6, 0.5,  0.4, 0.3,  0.2, 0.1],
    #                       dtype=onset_tensor.dtype).view(1, 1, -1)
    kernel = torch.tensor([0.33, 0.67, 1, 0.67, 0.33], dtype=onset_tensor.dtype).view(
        1, 1, -1
    )

    onset_tensor = onset_tensor.T.unsqueeze(1)  # Now shape is (F, 1, T)

    # Use 'same' padding so that the output has the same time dimension as the input.
    padding = kernel.shape[-1] // 2
    smoothed = F.conv1d(onset_tensor, kernel, padding=padding)

    # Reshape back to original shape (T, F)
    return smoothed.squeeze(1).T


def initialize_logging_system(logdir):
    """Initialize the logging system once with named loggers for train and dataset."""
    log_file = os.path.join(logdir, "training.log")

    # Create formatter
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )

    # File handler (shared by all loggers)
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)

    # Console handler (shared by all loggers)
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)

    # Create train logger
    train_logger = logging.getLogger("train")
    train_logger.setLevel(logging.INFO)
    train_logger.handlers.clear()
    train_logger.addHandler(file_handler)
    train_logger.addHandler(console_handler)

    # Create dataset logger
    dataset_logger = logging.getLogger("dataset")
    dataset_logger.setLevel(logging.INFO)
    dataset_logger.handlers.clear()
    dataset_logger.addHandler(file_handler)
    dataset_logger.addHandler(console_handler)

    return train_logger, dataset_logger


def get_logger(name):
    """Get a named logger. Call initialize_logging_system first."""
    return logging.getLogger(name)