Spaces:
Running
Running
| import tqdm | |
| import os | |
| import textwrap | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from torch.utils.data import DataLoader | |
| import warnings | |
| from torch.utils.data import Dataset | |
| from sklearn.preprocessing import StandardScaler | |
| warnings.filterwarnings("ignore") | |
| from .time_rcd.dataset import ChatTSTimeRCDPretrainDataset | |
| from .time_rcd.TimeRCD_pretrain_multi import TimeSeriesPretrainModel, create_random_mask, collate_fn, test_collate_fn | |
| from .time_rcd.time_rcd_config import TimeRCDConfig, default_config | |
| from utils.dataset import TimeRCDDataset | |
| class TimeRCDPretrainTester: | |
| """Tester class for visualizing pretrained model results.""" | |
| def __init__(self, checkpoint_path: str, config: TimeRCDConfig): | |
| self.config = config | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.win_size = config.win_size | |
| self.batch_size = config.batch_size | |
| # Load model | |
| self.model = TimeSeriesPretrainModel(config).to(self.device) | |
| self.load_checkpoint(checkpoint_path) | |
| self.model.eval() | |
| print(f"Model loaded on device: {self.device}") | |
| def load_checkpoint(self, checkpoint_path: str): | |
| """Load model from checkpoint.""" | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| # Handle different checkpoint formats | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| # Remove 'module.' prefix if present (from DDP training) | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('module.'): | |
| new_key = key[7:] # Remove 'module.' prefix | |
| else: | |
| new_key = key | |
| new_state_dict[new_key] = value | |
| self.model.load_state_dict(new_state_dict) | |
| print(f"Successfully loaded checkpoint from {checkpoint_path}") | |
| def predict(self, batch): | |
| """Run inference on a batch.""" | |
| with torch.no_grad(): | |
| # Move data to device | |
| time_series = batch['time_series'].to(self.device) | |
| normal_time_series = batch['normal_time_series'].to(self.device) | |
| masked_time_series = batch['masked_time_series'].to(self.device) | |
| attribute = batch['attribute'] | |
| batch_size, seq_len, num_features = time_series.shape | |
| # 对时间序列标准化 | |
| time_series = (time_series - time_series.mean(dim=1, keepdim=True)) / (time_series.std(dim=1, keepdim=True) + 1e-8) | |
| masked_time_series = (masked_time_series - masked_time_series.mean(dim=1, keepdim=True)) / (masked_time_series.std(dim=1, keepdim=True) + 1e-8) | |
| mask = batch['mask'].to(self.device) | |
| labels = batch['labels'].to(self.device) | |
| attention_mask = batch['attention_mask'].to(self.device) | |
| # Get embeddings | |
| local_embeddings = self.model( | |
| time_series=time_series, | |
| mask=attention_mask) | |
| # Get reconstruction | |
| reconstructed = self.model.reconstruction_head(local_embeddings) | |
| reconstructed = reconstructed.view(batch_size, seq_len, num_features) # (B, seq_len, num_features) | |
| # Get anomaly predictions | |
| anomaly_logits = self.model.anomaly_head(local_embeddings) | |
| anomaly_logits = torch.mean(anomaly_logits, dim=-2) # (B, seq_len, 2) | |
| anomaly_probs = F.softmax(anomaly_logits, dim=-1)[..., 1] # Probability of anomaly (B, seq_len) | |
| return { | |
| 'original': time_series.cpu(), | |
| 'normal': normal_time_series.cpu(), | |
| 'masked': masked_time_series.cpu(), | |
| 'reconstructed': reconstructed.cpu(), | |
| 'mask': mask.cpu(), | |
| 'anomaly_probs': anomaly_probs.cpu(), | |
| 'true_labels': labels.cpu(), | |
| 'attention_mask': attention_mask.cpu(), | |
| 'attribute': attribute | |
| } | |
| def visualize_single_sample(self, results, sample_idx=0, save_path=None): | |
| """Visualize results for a single time series sample.""" | |
| # Extract data for the specified sample | |
| original = results['original'][sample_idx].squeeze(-1).numpy() # (seq_len, num_features) / (seq_len,) | |
| normal = results['normal'][sample_idx].squeeze(-1).numpy() | |
| masked = results['masked'][sample_idx].squeeze(-1).numpy() | |
| reconstructed = results['reconstructed'][sample_idx].squeeze(-1).numpy() | |
| mask = results['mask'][sample_idx].numpy().astype(bool) | |
| anomaly_probs = results['anomaly_probs'][sample_idx].numpy() # (seq_len,) | |
| true_labels = results['true_labels'][sample_idx].numpy() # (seq_len,) | |
| attention_mask = results['attention_mask'][sample_idx].numpy().astype(bool) | |
| attribute = results['attribute'][sample_idx] | |
| # Only consider valid sequence length | |
| valid_length = attention_mask.sum() | |
| original = original[:valid_length] | |
| normal = normal[:valid_length] | |
| masked = masked[:valid_length] | |
| reconstructed = reconstructed[:valid_length] | |
| mask = mask[:valid_length] | |
| anomaly_probs = anomaly_probs[:valid_length] | |
| true_labels = true_labels[:valid_length] | |
| # Create time axis | |
| time_axis = np.arange(len(original)) | |
| assert original.ndim == normal.ndim == reconstructed.ndim == masked.ndim, "Original, normal, reconstructed, and masked time series must have the same dimensions." | |
| if original.ndim == 1: | |
| # Create subplots | |
| fig, axes = plt.subplots(3, 1, figsize=(15, 12)) | |
| # 1. Reconstruction visualization | |
| ax1 = axes[0] | |
| ax1.plot(time_axis, original, 'b-', label='Original', linewidth=2, alpha=0.8) | |
| ax1.plot(time_axis, masked, 'g--', label='Masked Input', linewidth=1.5, alpha=0.7) | |
| ax1.plot(time_axis[mask], reconstructed[mask], 'ro', | |
| label='Reconstructed', markersize=4, alpha=0.8) | |
| # Highlight masked regions | |
| mask_regions = [] | |
| in_mask = False | |
| start_idx = 0 | |
| for i, is_masked in enumerate(mask): | |
| if is_masked and not in_mask: | |
| start_idx = i | |
| in_mask = True | |
| elif not is_masked and in_mask: | |
| mask_regions.append((start_idx, i - 1)) | |
| in_mask = False | |
| if in_mask: # Handle case where mask continues to the end | |
| mask_regions.append((start_idx, len(mask) - 1)) | |
| for start, end in mask_regions: | |
| ax1.axvspan(start, end, alpha=0.2, color='red', | |
| label='Masked Region' if start == mask_regions[0][0] else "") | |
| ax1.set_title('Time Series Reconstruction', fontsize=14, fontweight='bold') | |
| ax1.set_xlabel('Time Steps') | |
| ax1.set_ylabel('Value') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| # 2. Anomaly detection visualization | |
| ax2 = axes[1] | |
| ax2.plot(time_axis, normal, 'g-', label='Normal Time Series', linewidth=1, alpha=0.6) | |
| ax2.plot(time_axis, original, 'b-', label='Anomalous Time Series', linewidth=1, alpha=0.6) | |
| # Color background based on true anomaly labels | |
| anomaly_regions = [] | |
| in_anomaly = False | |
| start_idx = 0 | |
| for i, is_anomaly in enumerate(true_labels > 0.5): | |
| if is_anomaly and not in_anomaly: | |
| start_idx = i | |
| in_anomaly = True | |
| elif not is_anomaly and in_anomaly: | |
| anomaly_regions.append((start_idx, i - 1)) | |
| in_anomaly = False | |
| if in_anomaly: | |
| anomaly_regions.append((start_idx, len(true_labels) - 1)) | |
| for start, end in anomaly_regions: | |
| ax2.axvspan(start, end, alpha=0.3, color='red', | |
| label='True Anomaly' if start == anomaly_regions[0][0] else "") | |
| # Plot predicted anomaly probabilities | |
| ax2_twin = ax2.twinx() | |
| ax2_twin.plot(time_axis, anomaly_probs, 'r-', label='Anomaly Probability', | |
| linewidth=2, alpha=0.8) | |
| ax2_twin.axhline(y=0.5, color='orange', linestyle='--', alpha=0.7, | |
| label='Threshold (0.5)') | |
| ax2_twin.set_ylabel('Anomaly Probability', color='red') | |
| ax2_twin.set_ylim(0, 1) | |
| ax2.set_title('Anomaly Detection Results', fontsize=14, fontweight='bold') | |
| ax2.set_xlabel('Time Steps') | |
| ax2.set_ylabel('Time Series Value', color='blue') | |
| # Combine legends | |
| lines1, labels1 = ax2.get_legend_handles_labels() | |
| lines2, labels2 = ax2_twin.get_legend_handles_labels() | |
| ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right') | |
| ax2.grid(True, alpha=0.3) | |
| # 3. Performance metrics visualization | |
| ax3 = axes[2] | |
| # Calculate reconstruction error for masked regions | |
| if mask.sum() > 0: | |
| recon_error = np.abs(original[mask] - reconstructed[mask]) | |
| ax3.bar(np.arange(len(recon_error)), recon_error, | |
| alpha=0.7, color='orange', label='Reconstruction Error') | |
| ax3.set_title('Reconstruction Error (Masked Regions Only)', | |
| fontsize=14, fontweight='bold') | |
| ax3.set_xlabel('Masked Time Step Index') | |
| ax3.set_ylabel('Absolute Error') | |
| ax3.legend() | |
| ax3.grid(True, alpha=0.3) | |
| else: | |
| ax3.text(0.5, 0.5, 'No masked regions in this sample', | |
| ha='center', va='center', transform=ax3.transAxes, fontsize=12) | |
| ax3.set_title('Reconstruction Error', fontsize=14, fontweight='bold') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| elif original.ndim == 2: | |
| _, num_features = original.shape | |
| fig_height = 4 * num_features + 2 | |
| fig, axes = plt.subplots(num_features, 1, figsize=(16, fig_height)) | |
| plt.subplots_adjust(top=0.85, hspace=0.2, left=0.08, right=0.92, bottom=0.08) | |
| anomaly_regions = [] | |
| in_anomaly = False | |
| start_idx = 0 | |
| for i, is_anomaly in enumerate(true_labels > 0.5): | |
| if is_anomaly and not in_anomaly: | |
| start_idx = i | |
| in_anomaly = True | |
| elif not is_anomaly and in_anomaly: | |
| anomaly_regions.append((start_idx, i - 1)) | |
| in_anomaly = False | |
| if in_anomaly: | |
| anomaly_regions.append((start_idx, len(true_labels) - 1)) | |
| for feature_idx in range(num_features): | |
| ax = axes[feature_idx] | |
| ax.plot(time_axis, original[:, feature_idx], 'b-', | |
| linewidth=1, label=f'Anomalous Time Series', alpha=0.8) | |
| ax.plot(time_axis, normal[:, feature_idx], 'g-', | |
| linewidth=1, label='Normal Time Series', alpha=0.8) | |
| y_min, y_max = ax.get_ylim() | |
| shift = y_max - y_min | |
| ax.set_ylim(y_min - shift, y_max) | |
| for start, end in anomaly_regions: | |
| if start == end: | |
| ax.axvspan(start - 0.5, start + 0.5, alpha=0.3, color='grey', | |
| label='True Anomaly Region' if start == anomaly_regions[0][ | |
| 0] and feature_idx == 0 else "") | |
| else: | |
| ax.axvspan(start, end, alpha=0.3, color='grey', | |
| label='True Anomaly Region' if start == anomaly_regions[0][ | |
| 0] and feature_idx == 0 else "") | |
| ax2 = ax.twinx() | |
| ax2.plot(time_axis, anomaly_probs, 'r-', linewidth=1, | |
| label='Anomaly Score', alpha=0.9) | |
| ax2.set_ylim(0, 1.5) | |
| ax2.set_ylabel('Anomaly Score', fontsize=12) | |
| ax.set_ylabel(f'Value', fontsize=12) | |
| if feature_idx == num_features - 1: | |
| ax.set_xlabel('Time Steps', fontsize=12) | |
| else: | |
| ax.set_xticklabels([]) | |
| ax.set_title(f'Feature {feature_idx} - Time Series & Anomaly Score', | |
| fontsize=16, pad=10) | |
| ax.grid(True, alpha=0.3) | |
| if feature_idx == 0: | |
| lines1, labels1 = ax.get_legend_handles_labels() | |
| lines2, labels2 = ax2.get_legend_handles_labels() | |
| ax.legend(lines1 + lines2, labels1 + labels2, | |
| loc='upper right', bbox_to_anchor=(0.98, 0.98), fontsize=14) | |
| anomalies = [] | |
| isendo = attribute['is_endogenous'] | |
| edges = attribute['dag'] | |
| for idx, item in enumerate(attribute['attribute_list']): | |
| for k, v in item['anomalies'].items(): | |
| anomalies.append((f"feature_{idx}_{k[2:]}", v)) | |
| anomalies_str = ', '.join([f"{k}: {v}" for k, v in anomalies]) | |
| wrap_width = 100 | |
| wrapped_anomalies = textwrap.fill(f"Anomalies: {anomalies_str}", width=wrap_width) | |
| wrapped_edges = textwrap.fill(f"Edges: {str(edges)}", width=wrap_width) | |
| title = f"Multivariate Time Series Visualization\n{isendo}_{wrapped_anomalies}\n{wrapped_edges}" | |
| fig.suptitle(title, fontsize=22, y=0.95, ha='center', va='top') | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white') | |
| plt.show() | |
| else: | |
| raise ValueError("Unsupported original data shape: {}".format(original.shape)) | |
| def test_model(self, data_path: str, filename: str, num_samples: int = 5, save_dir: str = None, | |
| max_test_data: int = 100): | |
| """Test the model on a dataset and visualize results.""" | |
| # Load test dataset | |
| full_test_dataset = ChatTSTimeRCDPretrainDataset(data_path, filename, split="test", train_ratio=0) | |
| print(f'Length of dataset: {len(full_test_dataset)}') | |
| # Limit to max_test_data samples | |
| if len(full_test_dataset) > max_test_data: | |
| indices = torch.randperm(len(full_test_dataset))[:max_test_data].tolist() | |
| test_dataset = torch.utils.data.Subset(full_test_dataset, indices) | |
| print("random") | |
| else: | |
| test_dataset = full_test_dataset | |
| # Create visualization loader for detailed visualization (one by one) | |
| vis_loader = DataLoader( | |
| test_dataset, | |
| batch_size=1, # Process one sample at a time for visualization | |
| shuffle=False, | |
| collate_fn=collate_fn, | |
| num_workers=0 | |
| ) | |
| # Visualize individual samples (one by one) | |
| num_visualize = min(num_samples, len(test_dataset)) | |
| vis_iter = iter(vis_loader) | |
| for i in range(num_visualize): | |
| try: | |
| vis_batch = next(vis_iter) | |
| # Run inference for this single sample | |
| vis_results = self.predict(vis_batch) | |
| save_path = None | |
| if save_dir: | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_path = os.path.join(save_dir, f"sample_{i + 1}_results.png") | |
| self.visualize_single_sample(vis_results, sample_idx=0, save_path=save_path) | |
| except StopIteration: | |
| break | |
| def zero_shot(self, data): | |
| """Run zero-shot inference on the provided data.""" | |
| if len(data) <= self.win_size: | |
| self.win_size = len(data) | |
| test_loader = DataLoader( | |
| dataset=TimeRCDDataset(data, window_size=self.win_size, stride=self.win_size, normalize=True), | |
| batch_size=self.batch_size, | |
| collate_fn=test_collate_fn, | |
| num_workers=0, | |
| shuffle=False,) | |
| loop = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=True) | |
| scores = [] | |
| logits = [] | |
| with torch.no_grad(): | |
| for i, batch in loop: | |
| # Move data to device | |
| time_series = batch['time_series'].to(self.device) | |
| # print("Here is the time series shape: ", time_series.shape) | |
| # print(f"Here are a sample of dataset after normalization: {time_series[:10, :]}") | |
| batch_size, seq_len, num_features = time_series.shape | |
| # 对时间序列标准化 | |
| attention_mask = batch['attention_mask'].to(self.device) | |
| # print("Here is the attention mask shape: ", attention_mask.shape) | |
| # print("Here is the attention mask: ", attention_mask) | |
| # Get embeddings | |
| local_embeddings = self.model( | |
| time_series=time_series, | |
| mask=attention_mask) | |
| # Get anomaly predictions | |
| anomaly_logits = self.model.anomaly_head(local_embeddings) | |
| anomaly_logits = torch.mean(anomaly_logits, dim=-2) # (B, seq_len, 2) | |
| anomaly_probs = F.softmax(anomaly_logits, dim=-1)[..., 1] # Probability of anomaly (B, seq_len) | |
| scores.append(anomaly_probs.cpu().numpy()) | |
| logit = anomaly_logits[..., 1] - anomaly_logits[..., 0] # Anomaly logits (B, seq_len) | |
| logits.append(logit.cpu().numpy()) | |
| return scores, logits | |
| def evaluate(self, time_series, mask): | |
| with torch.no_grad(): | |
| time_series = time_series.to(self.device) | |
| mask = mask.to(self.device) | |
| local_embeddings = self.model(time_series = time_series, mask = mask) | |
| reconstructed = self.model.reconstruction_head(local_embeddings) # (B, seq_len, num_features, 1) | |
| reconstructed = reconstructed.squeeze(-1) | |
| mask_expand = mask.unsqueeze(-1).expand(-1, -1, reconstructed.shape[-1]) | |
| anomaly_probs = ((reconstructed - time_series) ** 2)[mask_expand] | |
| return anomaly_probs, reconstructed | |
| def zero_shot_reconstruct(self, data, visualize=True, data_index=None): | |
| """Run zero-shot inference on the provided data.""" | |
| if len(data) <= self.win_size: | |
| self.win_size = len(data) | |
| test_loader = DataLoader( | |
| dataset=Dataset_UCR(data, window_size=self.win_size), | |
| batch_size=self.batch_size, | |
| # collate_fn=collate_fn, | |
| num_workers=0, | |
| shuffle=False, ) | |
| loop = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=True) | |
| scores = [] | |
| with torch.no_grad(): | |
| for i, (x, mask) in loop: | |
| # Move data to device | |
| print("Here is the batch type: ", type(x)) | |
| print("Shape: ", np.array(x).shape) | |
| time_series = torch.tensor(x, dtype=torch.float32).to(self.device) # (B, seq_len, num_features) | |
| mask_tensor = torch.tensor(mask, dtype=torch.bool).to(self.device) | |
| # print("Here is the time series shape: ", time_series.shape) | |
| # 对时间序列标准化 | |
| # attention_mask = batch['attention_mask'].to(self.device) | |
| score, reconstructed = self.evaluate(time_series, mask_tensor) | |
| scores.append(score) | |
| # Visualize the first batch if requested | |
| if visualize: | |
| self.visualize_reconstruction(original=time_series[0].cpu().numpy(), | |
| reconstructed=reconstructed.cpu().numpy(), | |
| mask=mask_tensor[0].cpu().numpy(), | |
| scores=score.cpu().numpy(), | |
| save_path=f"/home/lihaoyang/Huawei/TSB-AD/Synthetic/random_mask_anomaly_head_Time_RCD_Reconstruction_5000/plot/", | |
| index=data_index) | |
| return scores | |
| def visualize_reconstruction(self, original, reconstructed, mask, scores, index, save_path=None): | |
| """Visualize reconstruction results for a single sample.""" | |
| import matplotlib.pyplot as plt | |
| seq_len = len(original) | |
| time_axis = np.arange(seq_len) | |
| # Squeeze singleton dimensions | |
| original = original.squeeze() | |
| reconstructed = reconstructed.squeeze(0).squeeze(-1) | |
| scores = scores.squeeze() | |
| fig, axes = plt.subplots(2, 1, figsize=(15, 10)) | |
| # 1. Reconstruction plot | |
| ax1 = axes[0] | |
| ax1.plot(time_axis, original, 'b-', label='Original', linewidth=2, alpha=0.8) | |
| ax1.plot(time_axis, reconstructed, 'r--', label='Reconstructed', linewidth=2, alpha=0.8) | |
| # Highlight masked regions | |
| mask_regions = [] | |
| in_mask = False | |
| start_idx = 0 | |
| for i, is_masked in enumerate(mask): | |
| if is_masked and not in_mask: | |
| start_idx = i | |
| in_mask = True | |
| elif not is_masked and in_mask: | |
| mask_regions.append((start_idx, i - 1)) | |
| in_mask = False | |
| if in_mask: | |
| mask_regions.append((start_idx, len(mask) - 1)) | |
| for start, end in mask_regions: | |
| ax1.axvspan(start, end, alpha=0.2, color='red', | |
| label='Masked Region' if start == mask_regions[0][0] else "") | |
| ax1.set_title('Time Series Reconstruction', fontsize=14, fontweight='bold') | |
| ax1.set_xlabel('Time Steps') | |
| ax1.set_ylabel('Value') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| # 2. Reconstruction error plot | |
| ax2 = axes[1] | |
| recon_error = np.abs(original - reconstructed) | |
| ax2.plot(time_axis, recon_error, 'g-', label='Reconstruction Error', linewidth=2, alpha=0.8) | |
| # Plot scores if available (mapped to time steps) | |
| if len(scores) == mask.sum(): | |
| # Scores are only for masked points, map back to full sequence | |
| full_scores = np.zeros(seq_len) | |
| full_scores[mask] = scores | |
| ax2_twin = ax2.twinx() | |
| ax2_twin.plot(time_axis, full_scores, 'orange', label='Anomaly Scores', linewidth=1.5, alpha=0.7) | |
| ax2_twin.set_ylabel('Anomaly Score', color='orange') | |
| ax2_twin.legend(loc='upper right') | |
| ax2.set_title('Reconstruction Error', fontsize=14, fontweight='bold') | |
| ax2.set_xlabel('Time Steps') | |
| ax2.set_ylabel('Absolute Error') | |
| ax2.legend() | |
| ax2.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| if not os.path.exists(save_path): | |
| os.makedirs(save_path, exist_ok=True) | |
| save_path = os.path.join(save_path, f"reconstruction_sample_{index}_results.png") | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print("Visualization saved to: ", save_path) | |
| # plt.show() | |
| class Dataset_UCR(Dataset): | |
| def __init__(self, data, window_size: int = 1000): | |
| super().__init__() | |
| self.data = data.reshape(-1, 1) if len(data.shape) == 1 else data | |
| self.window_size = window_size | |
| self._load_data() | |
| self._process_windows() | |
| def _load_data(self): | |
| # train_data = np.load(train_path, allow_pickle=True) # (seq_len, num_features) | |
| # test_data = np.load(test_path, allow_pickle=True) # (seq_len, num_features) | |
| # test_labels = np.load(label_path, allow_pickle=True) # (seq_len, ) | |
| train_data = self.data | |
| scaler = StandardScaler() | |
| train_data = scaler.fit_transform(train_data) | |
| self.raw_test = scaler.transform(self.data) | |
| def _process_windows(self): | |
| if len(self.raw_test) <= self.window_size: | |
| self.test = np.expand_dims(self.raw_test, axis=0) | |
| # self.test_labels = np.expand_dims(self.raw_labels, axis=0) | |
| self.mask = np.expand_dims(np.ones(len(self.raw_test), dtype=bool), axis=0) | |
| else: | |
| self.raw_masks = np.ones(len(self.raw_test), dtype=bool) | |
| padding = self.window_size - (len(self.raw_test) % self.window_size) | |
| if padding < self.window_size: | |
| self.raw_test = np.pad(self.raw_test, ((0, padding), (0, 0)), mode='constant') | |
| # self.raw_labels = np.pad(self.raw_labels, (0, padding), mode='constant') | |
| self.raw_masks = np.pad(self.raw_masks, (0, padding), mode='constant') | |
| self.test = self.raw_test.reshape(-1, self.window_size, self.raw_test.shape[1]) | |
| # self.test_labels = self.raw_labels.reshape(-1, self.window_size) | |
| self.mask = self.raw_masks.reshape(-1, self.window_size) | |
| assert self.test.shape[0] == self.test_labels.shape[0] == self.mask.shape[0], "Inconsistent window sizes" | |
| def __len__(self): | |
| return len(self.test) | |
| def __getitem__(self, index): | |
| return np.float32(self.test[index]), self.mask[index] |