TempoPFN / examples /generate_synthetic_data.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
import importlib
import logging
import os
import torch
from src.data.containers import BatchTimeSeriesContainer
from src.data.utils import sample_future_length
from src.plotting.plot_timeseries import plot_from_container
from src.synthetic_generation.anomalies.anomaly_generator_wrapper import (
AnomalyGeneratorWrapper,
)
from src.synthetic_generation.cauker.cauker_generator_wrapper import (
CauKerGeneratorWrapper,
)
from src.synthetic_generation.forecast_pfn_prior.forecast_pfn_generator_wrapper import (
ForecastPFNGeneratorWrapper,
)
from src.synthetic_generation.generator_params import (
AnomalyGeneratorParams,
CauKerGeneratorParams,
FinancialVolatilityAudioParams,
ForecastPFNGeneratorParams,
GPGeneratorParams,
KernelGeneratorParams,
MultiScaleFractalAudioParams,
NetworkTopologyAudioParams,
OrnsteinUhlenbeckProcessGeneratorParams,
SawToothGeneratorParams,
SineWaveGeneratorParams,
SpikesGeneratorParams,
StepGeneratorParams,
StochasticRhythmAudioParams,
)
from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper
from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import (
KernelGeneratorWrapper,
)
from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import (
OrnsteinUhlenbeckProcessGeneratorWrapper,
)
from src.synthetic_generation.sawtooth.sawtooth_generator_wrapper import (
SawToothGeneratorWrapper,
)
from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (
SineWaveGeneratorWrapper,
)
from src.synthetic_generation.spikes.spikes_generator_wrapper import (
SpikesGeneratorWrapper,
)
from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper
PYO_AVAILABLE = False
spec = importlib.util.find_spec("pyo")
if spec is not None:
try:
_pyo = importlib.import_module("pyo") # intentionally assigned to underscore to avoid unused-import lint
except (ImportError, OSError):
PYO_AVAILABLE = False
else:
PYO_AVAILABLE = True
if PYO_AVAILABLE:
from src.synthetic_generation.audio_generators.financial_volatility_wrapper import (
FinancialVolatilityAudioWrapper,
)
from src.synthetic_generation.audio_generators.multi_scale_fractal_wrapper import (
MultiScaleFractalAudioWrapper,
)
from src.synthetic_generation.audio_generators.network_topology_wrapper import (
NetworkTopologyAudioWrapper,
)
from src.synthetic_generation.audio_generators.stochastic_rhythm_wrapper import (
StochasticRhythmAudioWrapper,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def visualize_batch_sample(
generator,
batch_size: int = 8,
output_dir: str = "outputs/plots",
sample_idx: int | None = None,
prefix: str = "",
seed: int | None = None,
) -> None:
os.makedirs(output_dir, exist_ok=True)
name = generator.__class__.__name__
logger.info(f"[{name}] Generating batch of size {batch_size}")
batch = generator.generate_batch(batch_size=batch_size, seed=seed)
values = torch.from_numpy(batch.values)
if values.ndim == 2:
values = values.unsqueeze(-1)
future_length = sample_future_length(range="gift_eval")
history_values = values[:, :-future_length, :]
future_values = values[:, -future_length:, :]
container = BatchTimeSeriesContainer(
history_values=history_values,
future_values=future_values,
start=batch.start,
frequency=batch.frequency,
)
indices = [sample_idx] if sample_idx is not None else range(batch_size)
for i in indices:
filename = f"{prefix}_{name.lower().replace('generatorwrapper', '')}_sample_{i}.png"
output_file = os.path.join(output_dir, filename)
title = f"{prefix.capitalize()} {name.replace('GeneratorWrapper', '')} Synthetic Series (Sample {i})"
plot_from_container(container, sample_idx=i, output_file=output_file, show=False, title=title)
logger.info(f"[{name}] Saved plot to {output_file}")
def generator_factory(global_seed: int, total_length: int) -> list:
generators = [
KernelGeneratorWrapper(KernelGeneratorParams(global_seed=global_seed, length=total_length)),
GPGeneratorWrapper(GPGeneratorParams(global_seed=global_seed, length=total_length)),
ForecastPFNGeneratorWrapper(ForecastPFNGeneratorParams(global_seed=global_seed, length=total_length)),
SineWaveGeneratorWrapper(SineWaveGeneratorParams(global_seed=global_seed, length=total_length)),
SawToothGeneratorWrapper(SawToothGeneratorParams(global_seed=global_seed, length=total_length)),
StepGeneratorWrapper(StepGeneratorParams(global_seed=global_seed, length=total_length)),
AnomalyGeneratorWrapper(AnomalyGeneratorParams(global_seed=global_seed, length=total_length)),
SpikesGeneratorWrapper(SpikesGeneratorParams(global_seed=global_seed, length=total_length)),
CauKerGeneratorWrapper(CauKerGeneratorParams(global_seed=global_seed, length=total_length, num_channels=5)),
OrnsteinUhlenbeckProcessGeneratorWrapper(
OrnsteinUhlenbeckProcessGeneratorParams(global_seed=global_seed, length=total_length)
),
]
if PYO_AVAILABLE:
generators.extend(
[
StochasticRhythmAudioWrapper(StochasticRhythmAudioParams(global_seed=global_seed, length=total_length)),
FinancialVolatilityAudioWrapper(
FinancialVolatilityAudioParams(global_seed=global_seed, length=total_length)
),
MultiScaleFractalAudioWrapper(
MultiScaleFractalAudioParams(global_seed=global_seed, length=total_length)
),
NetworkTopologyAudioWrapper(NetworkTopologyAudioParams(global_seed=global_seed, length=total_length)),
]
)
else:
logger.warning("Audio generators skipped (pyo not available)")
return generators
if __name__ == "__main__":
batch_size = 2
total_length = 2048
output_dir = "outputs/plots"
global_seed = 2025
logger.info(f"Saving plots to {output_dir}")
for gen in generator_factory(global_seed, total_length):
prefix = "multivariate" if getattr(gen.params, "num_channels", 1) > 1 else ""
visualize_batch_sample(
gen,
batch_size=batch_size,
output_dir=output_dir,
prefix=prefix,
seed=global_seed,
)