semsynth.torch_compat

Compatibility helpers for optional PyTorch features.

Functions

ensure_npsum_compat()

Patch numpy.sum so it accepts generators (removed in NumPy ≥ 2).

ensure_torch_rmsnorm()

Ensure torch.nn exposes RMSNorm on versions that predate it.

ensure_trapz_compat()

Ensure numpy and scipy expose trapz (renamed to trapezoid in newer versions).

semsynth.torch_compat.ensure_npsum_compat() None

Patch numpy.sum so it accepts generators (removed in NumPy ≥ 2).

Synthcity’s ctgan plugin calls np.sum(some_generator) which raises a TypeError under NumPy 2. Wrapping np.sum to materialise generators via list() before delegation preserves the original behaviour.

semsynth.torch_compat.ensure_torch_rmsnorm() None

Ensure torch.nn exposes RMSNorm on versions that predate it.

semsynth.torch_compat.ensure_trapz_compat() None

Ensure numpy and scipy expose trapz (renamed to trapezoid in newer versions).

Some optional dependencies (e.g. xgbse used by synthcity) import numpy.trapz and scipy.integrate.trapz which were removed in NumPy ≥ 2 / SciPy ≥ 1.14.