torch_bsf package

Submodules

torch_bsf.active_learning module

torch_bsf.active_learning.suggest_next_points(models: Sequence[Module], n_suggestions: int = 1, n_candidates: int = 1000, method: Literal['qbc', 'density'] = 'qbc', params: Tensor | None = None, n_params: int | None = None) Tensor[source]

Suggest points on the simplex where new data should be sampled.

Parameters:
  • models (Sequence[nn.Module]) – An ensemble of models (e.g., from k-fold cross-validation). Each model must be callable with a tensor of shape (n_candidates, n_params) and return predictions. Accepts any Sequence of Module instances, including ModuleList.

  • n_suggestions (int, default=1) – The number of points to suggest.

  • n_candidates (int, default=1000) – The number of candidate points to evaluate.

  • method (Literal["qbc", "density"], default="qbc") – The method to use: - “qbc”: Query-By-Committee. Suggests points where models disagree most. - “density”: Suggests points that are furthest from existing training points.

  • params (torch.Tensor, optional) – The existing training parameters. Required for method=”density”.

  • n_params (int, optional) – The number of simplex parameters (input dimension). When omitted, the value is inferred from models[0].n_params if that attribute exists. When provided and models[0] exposes an n_params attribute, the two values must agree; a ValueError is raised on mismatch.

Returns:

The suggested points in shape (n_suggestions, n_params).

Return type:

torch.Tensor

torch_bsf.bezier_simplex module

class torch_bsf.bezier_simplex.BezierSimplex(control_points: ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, *, _n_params: int | None = None, _degree: int | None = None, _n_values: int | None = None)[source]

Bases: LightningModule

A Bézier simplex model.

Parameters:
  • control_points – The control points of the Bézier simplex. Pass None only when reconstructing a model from a Lightning checkpoint via load_from_checkpoint() — in that case all three shape parameters (_n_params, _degree, _n_values) must be provided so that a correctly-shaped placeholder can be built before the saved state dict is loaded into it.

  • smoothness_weight – The weight of the smoothness penalty term added to the training loss. When greater than zero, adjacent control points are encouraged to have similar values. Defaults to 0.0 (no penalty).

  • _n_paramsCheckpoint-reconstruction parameter — do not set manually. The number of parameters (source dimension + 1) used to build the placeholder control points when control_points is None. Automatically saved to, and restored from, Lightning checkpoints.

  • _degreeCheckpoint-reconstruction parameter — do not set manually. The degree of the Bézier simplex used to build the placeholder when control_points is None. Automatically saved to, and restored from, Lightning checkpoints.

  • _n_valuesCheckpoint-reconstruction parameter — do not set manually. The number of values (target dimension) used to build the placeholder when control_points is None. Automatically saved to, and restored from, Lightning checkpoints.

Examples

>>> import lightning.pytorch as L
>>> from lightning.pytorch.callbacks.early_stopping import EarlyStopping
>>> from torch.utils.data import DataLoader, TensorDataset
>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters
>>> dl = DataLoader(TensorDataset(ts, xs))
>>> bs = torch_bsf.bezier_simplex.randn(
...     n_params=int(ts.shape[1]),
...     n_values=int(xs.shape[1]),
...     degree=3,
... )
>>> trainer = L.Trainer(
...     callbacks=[EarlyStopping(monitor="train_mse")],
...     enable_progress_bar=False,
... )
>>> trainer.fit(bs, dl)
>>> ts, xs = bs.meshgrid()
configure_optimizers() Optimizer[source]
property degree: int[source]

The degree of the Bézier simplex.

forward(t: Tensor) Tensor[source]

Process a forwarding step of training.

Parameters:

t – A minibatch of parameter vectors \(\mathbf t\).

Return type:

A minibatch of value vectors.

freeze_row(index: torch_bsf.bezier_simplex.Index) None[source]

Freeze a control point so its gradient is zeroed after every backward.

Parameters:

index – The index of the control point to freeze.

meshgrid(num: int = 100) tuple[Tensor, Tensor][source]

Computes a meshgrid of the Bézier simplex.

Parameters:

num – The number of grid points on each edge.

Returns:

  • ts – A parameter matrix of the mesh grid.

  • xs – A value matrix of the mesh grid.

property n_params: int[source]

The number of parameters, i.e., the source dimension + 1.

property n_values: int[source]

The number of values, i.e., the target dimension.

on_after_backward() None[source]

Zero gradients for frozen control-point rows after each backward pass.

smoothness_penalty() Tensor[source]

Computes the smoothness penalty of the Bézier simplex.

Return type:

The smoothness penalty.

test_step(batch, batch_idx) dict[str, Any][source]
training_step(batch, batch_idx) dict[str, Any][source]
validation_step(batch, batch_idx) None[source]
class torch_bsf.bezier_simplex.BezierSimplexDataModule(params: Path, values: Path, header: int = 0, batch_size: int | None = None, split_ratio: float = 1.0, normalize: Literal['max', 'std', 'quantile', 'none'] = 'none')[source]

Bases: LightningDataModule

A data module for training a Bézier simplex.

Parameters:
  • params – The path to a parameter file.

  • values – The path to a value file.

  • header – The number of header rows in the parameter file and the value file. The first header rows are skipped in reading the files.

  • batch_size – The size of each minibatch.

  • split_ratio – The ratio of train-val split. Must be greater than 0 and less than or equal to 1. If it is set to 1, then all the data are used for training and the validation step will be skipped.

  • normalize – The data normalization method. Either "max", "std", "quantile", or "none".

fit_transform(values: Tensor) Tensor[source]
inverse_transform(values: Tensor) Tensor[source]
load_data(path) Tensor[source]
load_params() Tensor[source]
load_values() Tensor[source]
setup(stage: str | None = None)[source]
test_dataloader() DataLoader[source]
train_dataloader() DataLoader[source]
val_dataloader() DataLoader[source]
torch_bsf.bezier_simplex.fit(params: Tensor, values: Tensor, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) BezierSimplex[source]

Fits a Bézier simplex.

Parameters:
  • params – The data.

  • values – The label data.

  • degree – The degree of the Bézier simplex.

  • init – The initial values of a Bézier simplex or control points.

  • smoothness_weight – The weight of smoothness penalty.

  • freeze – The indices of control points to exclude from training.

  • batch_size – The size of minibatch.

  • seed – Random seed passed to lightning.pytorch.seed_everything() to set RNG seeds for improved reproducibility. When None (default), no seed is set. For full determinism, also set Trainer(deterministic=True) and use compatible settings.

  • kwargs – All arguments for lightning.pytorch.Trainer

Return type:

A trained Bézier simplex.

Raises:
  • TypeError – From Trainer or DataLoader.

  • MisconfigurationException – From Trainer.

Examples

>>> import torch
>>> import torch_bsf

Prepare training data

>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters

Train a model

>>> bs = torch_bsf.fit(params=ts, values=xs, degree=3)

Predict by the trained model

>>> t = [[0.2, 0.3, 0.5]]
>>> x = bs(t)
>>> print(f"{t} -> {x}")
[[0.2, 0.3, 0.5]] -> tensor([[..., ..., ...]], grad_fn=<...>)

See also

lightning.pytorch.Trainer

Argument descriptions.

torch.DataLoader

Argument descriptions.

torch_bsf.bezier_simplex.fit_kfold(params: Tensor, values: Tensor, n_folds: int = 5, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) ModuleList[source]

Fits an ensemble of Bézier simplices using k-fold cross-validation.

Splits the training data into n_folds folds via KFoldTrainer and trains one model per fold on the data from the remaining n_folds - 1 folds. The resulting ModuleList can be passed directly to torch_bsf.active_learning.suggest_next_points() to drive a Query-By-Committee active learning loop.

If len(params) < n_folds, the actual number of folds is capped at len(params) to avoid empty training subsets.

Parameters:
  • params – The parameter data on the simplex.

  • values – The label data.

  • n_folds – The number of cross-validation folds (committee size). Defaults to 5.

  • degree – The degree of the Bézier simplex.

  • init – The initial values of a Bézier simplex or control points.

  • smoothness_weight – The weight of the smoothness penalty.

  • freeze – The indices of control points to exclude from training.

  • batch_size – The size of a minibatch. Defaults to full-batch (consistent with fit()).

  • seed – Random seed passed to lightning.pytorch.seed_everything() for reproducible training. When None (default), no seed is set.

  • kwargs – All arguments for KFoldTrainer (which itself accepts all lightning.pytorch.Trainer arguments). For example, max_epochs=10, enable_progress_bar=False, logger=False, shuffle=True, or stratified=False. By default, num_sanity_val_steps=0 and limit_val_batches=0.0 are set internally to disable per-fold validation for speed; pass explicit values to override these defaults, e.g. num_sanity_val_steps=2, limit_val_batches=1.0.

Returns:

A ModuleList of min(n_folds, len(params)) trained BezierSimplex models, one per fold.

Return type:

torch.nn.ModuleList

Raises:

ValueError – If n_folds < 2, if len(params) < 2 (too few samples for any fold split), if neither / both of degree and init are provided, if batch_size is truthy but not a positive integer, or if the reserved argument num_folds is supplied via **kwargs.

Examples

>>> import torch
>>> import torch_bsf
>>> from torch_bsf.active_learning import suggest_next_points
>>> from torch_bsf.sampling import simplex_grid

Prepare training data

>>> params = simplex_grid(n_params=3, degree=3)
>>> values = params.pow(2).sum(dim=1, keepdim=True)

Build a 5-fold ensemble and suggest the 2 most uncertain points

>>> models = torch_bsf.fit_kfold(
...     params=params,
...     values=values,
...     degree=3,
...     max_epochs=1,
...     enable_progress_bar=False,
...     enable_model_summary=False,
...     logger=False,
... )
>>> suggestions = suggest_next_points(models, n_suggestions=2, method="qbc")
>>> suggestions.shape
torch.Size([2, 3])

See also

fit

Fit a single Bézier simplex.

torch_bsf.active_learning.suggest_next_points

Use the ensemble for active learning.

torch_bsf.bezier_simplex.load(path: str | Path, *, pt_weights_only: bool | None = None) BezierSimplex[source]

Loads a Bézier simplex from a file.

Parameters:
  • path – The path to a file.

  • pt_weights_only – Whether to load weights only. This parameter is only effective when loading PyTorch (.pt) files. For other formats (e.g., .json, .yml), data loading is inherently safe and this parameter is ignored. If None, it defaults to False.

Return type:

A Bézier simplex.

Raises:
  • ValueError – If the file type is unknown.

  • ValidationError – If the control points are invalid.

Examples

>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.load("tests/data/bezier_simplex.csv")
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(n_params=2, degree=2, n_values=3)
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<...>)
torch_bsf.bezier_simplex.monomial(variable: Iterable[float], degree: Iterable[int]) Tensor[source]

Computes a monomial \(\mathbf t^{\mathbf d} = t_1^{d_1} t_2^{d_2}\cdots t_M^{d^M}\).

Parameters:
  • variable – The bases \(\mathbf t\).

  • degree – The powers \(\mathbf d\).

Return type:

The monomial \(\mathbf t^{\mathbf d}\).

torch_bsf.bezier_simplex.polynom(degree: int, index: Iterable[int]) float[source]

Computes a polynomial coefficient \(\binom{D}{\mathbf d} = \frac{D!}{d_1!d_2!\cdots d_M!}\).

Parameters:
  • degree – The degree \(D\).

  • index – The index \(\mathbf d\).

Return type:

The polynomial coefficient \(\binom{D}{\mathbf d}\).

torch_bsf.bezier_simplex.rand(n_params: int, n_values: int, degree: int, smoothness_weight: float = 0.0) BezierSimplex[source]

Generates a random Bézier simplex.

The control points are initialized by random values. The values are uniformly distributed in [0, 1).

Parameters:
  • n_params – The number of parameters, i.e., the source dimension + 1.

  • n_values – The number of values, i.e., the target dimension.

  • degree – The degree of the Bézier simplex.

  • smoothness_weight – The weight of smoothness penalty.

Return type:

A random Bézier simplex.

Raises:

ValueError – If n_params or n_values or degree is negative.

Examples

>>> import torch
>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.rand(n_params=2, n_values=3, degree=2)
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(n_params=2, degree=2, n_values=3)
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<...>)
torch_bsf.bezier_simplex.randn(n_params: int, n_values: int, degree: int, smoothness_weight: float = 0.0) BezierSimplex[source]

Generates a random Bézier simplex.

The control points are initialized by random values. The values are normally distributed with mean 0 and standard deviation 1.

Parameters:
  • n_params – The number of parameters, i.e., the source dimension + 1.

  • n_values – The number of values, i.e., the target dimension.

  • degree – The degree of the Bézier simplex.

  • smoothness_weight – The weight of smoothness penalty.

Return type:

A random Bézier simplex.

Raises:

ValueError – If n_params or n_values or degree is negative.

Examples

>>> import torch
>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.randn(n_params=2, n_values=3, degree=2)
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(n_params=2, degree=2, n_values=3)
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<...>)
torch_bsf.bezier_simplex.save(path: str | Path, data: BezierSimplex) None[source]

Saves a Bézier simplex to a file.

Parameters:
  • path – The file path to save.

  • data – The Bézier simplex to save.

Raises:

ValueError – If the file type is unknown.

Examples

>>> import torch_bsf
>>> bs = torch_bsf.bezier_simplex.randn(n_params=2, n_values=3, degree=2)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.pt", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.csv", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.tsv", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.json", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.yml", bs)
torch_bsf.bezier_simplex.validate_control_points(data: dict[str, list[float]])[source]

Validates control points.

Parameters:

data – The control points.

Raises:

ValidationError – If the control points are invalid.

Examples

>>> from torch_bsf.bezier_simplex import validate_control_points
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0],
... })
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "0, 0, 1": [0.0, 0.0, 1.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: '0, 0, 1' is not valid under any of the given schemas
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0],
...     "(0, 0)": [0.0, 0.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0)
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1, 0)": [0.0, 0.0, 1.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0, 1, 0)
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0, 1.0, 0.0]
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0]
torch_bsf.bezier_simplex.zeros(n_params: int, n_values: int, degree: int, smoothness_weight: float = 0.0) BezierSimplex[source]

Generates a Bézier simplex with control points at origin.

Parameters:
  • n_params – The number of parameters, i.e., the source dimension + 1.

  • n_values – The number of values, i.e., the target dimension.

  • degree – The degree of the Bézier simplex.

  • smoothness_weight – The weight of smoothness penalty.

Return type:

A Bézier simplex filled with zeros.

Raises:

ValueError – If n_params or n_values or degree is negative.

Examples

>>> import torch
>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.zeros(n_params=2, n_values=3, degree=2)
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(n_params=2, degree=2, n_values=3)
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<...>)

torch_bsf.control_points module

class torch_bsf.control_points.ControlPoints(data: dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None)[source]

Bases: Module

Control points of a Bézier simplex stored as a single parameter matrix.

All control points are stored in one nn.Parameter matrix of shape (n_indices, n_values). This eliminates the Python-level loop and torch.stack call that would otherwise occur on every forward pass, giving a direct O(1) access path to the parameter data.

matrix[source]

nn.Parameter of shape (n_indices, n_values) holding all control points in canonical simplex-index order.

degree[source]

The degree of the Bézier simplex.

n_params[source]

The number of parameters (source dimension + 1).

n_values[source]

The number of values (target dimension).

Examples

>>> import torch_bsf
>>> control_points = torch_bsf.control_points.ControlPoints({
...     (1, 0): [0.0, 0.1, 0.2],
...     (0, 1): [1.0, 1.1, 1.2],
... })
>>> control_points.degree
1
>>> control_points.n_params
2
>>> control_points.n_values
3
>>> control_points[(1, 0)]
tensor([0.0000, 0.1000, 0.2000], grad_fn=<SelectBackward0>)
>>> control_points[(0, 1)]
tensor([1.0000, 1.1000, 1.2000], grad_fn=<SelectBackward0>)
extra_repr() str[source]
indices() Iterator[tuple[int, ...]][source]

Iterates the index of control points of the Bézier simplex.

Return type:

The indices in canonical order.

items() Iterator[tuple[str, Tensor]][source]

Iterates (str_key, value_tensor) pairs in canonical order.

Return type:

An iterator of (key, value) pairs.

keys() Iterator[str][source]

Iterates canonical string keys in canonical order.

torch_bsf.control_points.ControlPointsData: TypeAlias = dict[str, torch.Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], torch.Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[torch.Tensor, torch.Tensor] | dict[torch.Tensor, list[float]] | dict[torch.Tensor, tuple[float, ...]][source]

The data type of control points of a Bézier simplex.

torch_bsf.control_points.Index: TypeAlias = str | typing.Sequence[int] | torch.Tensor[source]

The index type of control points of a Bézier simplex.

torch_bsf.control_points.Value: TypeAlias = typing.Sequence[float] | torch.Tensor[source]

The value type of control points of a Bézier simplex.

torch_bsf.control_points.simplex_indices(n_params: int, degree: int) Iterable[tuple[int, ...]][source]

Iterates the index of control points of a Bézier simplex.

Parameters:
  • n_params – The tuple length of each index.

  • degree – The degree of the Bézier simplex.

Return type:

The indices.

torch_bsf.control_points.to_parameterdict(data: dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]]) dict[str, Tensor][source]

Convert data to a dictionary with canonical string keys.

Parameters:

data (dict) – Data to be converted.

Returns:

Converted data.

Return type:

dict

torch_bsf.control_points.to_parameterdict_key(index: str | Sequence[int] | Tensor) str[source]

Convert an index to a canonical string key.

Parameters:

index (str or tuple[int]) – An index of control points.

Returns:

A canonical string key.

Return type:

str

torch_bsf.control_points.to_parameterdict_value(value: Sequence[float] | Tensor) Tensor[source]

Convert a value to a tensor.

Parameters:

value (list[float]) – A value.

Returns:

A tensor.

Return type:

torch.Tensor

torch_bsf.plotting module

torch_bsf.plotting.plot_bezier_simplex(model: BezierSimplex, num: int = 100, ax=None, show_control_points: bool = True, *, max_control_points: int = 500, max_pairwise_points: int = 2000, **kwargs)[source]

Plots the Bézier simplex.

Parameters:
  • model (BezierSimplex) – The Bézier simplex model to plot.

  • num (int) – The number of grid points for each edge. For model.n_params >= 4 this value is used only to decide whether to use a full meshgrid or random sampling: if the combinatorial meshgrid size comb(num + n_params - 1, n_params - 1) exceeds max_pairwise_points, exactly max_pairwise_points uniformly random simplex samples are drawn instead and num no longer controls the sample count.

  • ax (matplotlib.axes.Axes or None) – The matplotlib axes to plot on. If None, a new figure is created. Ignored when model.n_params >= 4; for pairwise plots a new figure is created only when model.n_values > 0 (when model.n_values == 0 an empty (0, 0) ndarray is returned without creating a figure).

  • show_control_points (bool) – Whether to show control points.

  • max_control_points (int) – Maximum number of control points to render in pairwise plots (model.n_params >= 4). When the model has more control points than this limit, a random subset of max_control_points is drawn instead to avoid combinatorial slowdowns. Defaults to 500. Ignored when show_control_points is False or model.n_params < 4. Must be a non-negative integer.

  • max_pairwise_points (int) – Maximum number of sample points for the pairwise scatter plot (model.n_params >= 4). When the combinatorial meshgrid size would exceed this limit, max_pairwise_points uniformly random simplex samples are drawn instead to bound memory usage and plot time. Defaults to 2000. Ignored when model.n_params < 4. Must be a non-negative integer.

  • **kwargs – Additional keyword arguments forwarded to the plot call. For model.n_params == 2, forwarded to ax.plot (curve). For model.n_params == 3 and model.n_values >= 3, forwarded to ax.plot_trisurf (3D surface). For model.n_params == 3 and model.n_values == 2, ignored. For model.n_params >= 4, forwarded to ax.scatter (pairwise).

Returns:

The axes containing the plot. For model.n_params <= 3 a single Axes (or Axes3D) is returned. For model.n_params >= 4 a 2-D numpy.ndarray of Axes with shape (n_values, n_values) is returned (pairwise scatter plot).

Return type:

matplotlib.axes.Axes or mpl_toolkits.mplot3d.axes3d.Axes3D or numpy.ndarray

Raises:
  • ImportError – If matplotlib is not installed. This dependency is required for all plotting backends used by this function.

  • ImportError – If SciPy is not installed and model.n_params == 3. SciPy is required for the triangulation-based plotting used in the Bézier triangle case.

  • ValueError – If model.n_params < 2. This function only supports Bézier simplex models with at least two parameters.

torch_bsf.preprocessing module

class torch_bsf.preprocessing.MinMaxScaler[source]

Bases: object

Min-max scaler that normalizes values to the [0, 1] range.

mins[source]

Minimum values per feature dimension.

Type:

torch.Tensor

scales[source]

Scale factors (max - min) per feature dimension. Dimensions where max equals min are set to 1 to avoid division by zero.

Type:

torch.Tensor

fit(values: Tensor) None[source]

Fit the scaler to the data.

Computes the per-feature minimum and scale (max - min) from values.

Parameters:

values (torch.Tensor) – Input tensor of shape (n_samples, n_features).

fit_transform(values: Tensor) Tensor[source]

Fit the scaler to the data and return the normalized tensor.

Parameters:

values (torch.Tensor) – Input tensor of shape (n_samples, n_features).

Returns:

Normalized tensor of the same shape as values, with each feature scaled to the [0, 1] range.

Return type:

torch.Tensor

inverse_transform(values: Tensor) Tensor[source]

Reverse the normalization applied by fit_transform().

Parameters:

values (torch.Tensor) – Normalized tensor of shape (n_samples, n_features).

Returns:

Tensor rescaled back to the original value range.

Return type:

torch.Tensor

mins: Tensor[source]
scales: Tensor[source]
class torch_bsf.preprocessing.NoneScaler[source]

Bases: object

Pass-through scaler that leaves values unchanged.

Useful as a no-op placeholder when normalization is not required, while still providing the same fit(), fit_transform(), and inverse_transform() interface as the other scalers.

fit(values: Tensor) None[source]

No-op fit method included for API compatibility.

Parameters:

values (torch.Tensor) – Input tensor (ignored).

fit_transform(values: Tensor) Tensor[source]

Return values unchanged.

Parameters:

values (torch.Tensor) – Input tensor of any shape.

Returns:

The same tensor as values without any modification.

Return type:

torch.Tensor

inverse_transform(values: Tensor) Tensor[source]

Return values unchanged.

Parameters:

values (torch.Tensor) – Input tensor of any shape.

Returns:

The same tensor as values without any modification.

Return type:

torch.Tensor

class torch_bsf.preprocessing.QuantileScaler[source]

Bases: object

Quantile-based scaler that normalizes values using percentile-based ranges.

Values are scaled so that the q-th percentile maps to 0 and the (1 - q)-th percentile maps to 1, without clipping values to this range. This makes the scaler robust to outliers compared to MinMaxScaler.

q[source]

The lower quantile fraction used as the effective minimum. Defaults to 0.05 (5th percentile), ignoring the bottom and top 5% of values.

Type:

float

mins[source]

Per-feature values at the q-th quantile, computed during fit().

Type:

torch.Tensor

scales[source]

Per-feature scale factors ((1-q)-quantile minus q-quantile). Dimensions where the scale is zero are set to 1 to avoid division by zero.

Type:

torch.Tensor

fit(values: Tensor) None[source]

Fit the scaler to the data.

Computes per-feature quantile bounds from values.

Parameters:

values (torch.Tensor) – Input tensor of shape (n_samples, n_features).

fit_transform(values: Tensor) Tensor[source]

Fit the scaler to the data and return the scaled tensor.

Parameters:

values (torch.Tensor) – Input tensor of shape (n_samples, n_features).

Returns:

Scaled tensor of the same shape as values. Values outside the fitted quantile range may exceed [0, 1].

Return type:

torch.Tensor

inverse_transform(values: Tensor) Tensor[source]

Reverse the scaling applied by fit_transform().

Parameters:

values (torch.Tensor) – Scaled tensor of shape (n_samples, n_features).

Returns:

Tensor rescaled back to the original value range.

Return type:

torch.Tensor

mins: Tensor[source]
q: float = 0.05[source]
scales: Tensor[source]
class torch_bsf.preprocessing.Scaler(*args, **kwargs)[source]

Bases: Protocol

Protocol for data scalers with a fit/transform interface.

All concrete scaler classes (MinMaxScaler, StdScaler, QuantileScaler, NoneScaler) satisfy this protocol via structural subtyping — no explicit inheritance is required.

fit(values: Tensor) None[source]

Fit the scaler to values.

fit_transform(values: Tensor) Tensor[source]

Fit the scaler to values and return the transformed tensor.

inverse_transform(values: Tensor) Tensor[source]

Reverse the transformation applied by fit_transform().

class torch_bsf.preprocessing.StdScaler[source]

Bases: object

Standard-score (z-score) scaler that normalizes values to zero mean and unit variance.

means[source]

Per-feature means computed during fit().

Type:

torch.Tensor

stds[source]

Per-feature standard deviations computed during fit(). Dimensions with zero standard deviation are set to 1 to avoid division by zero.

Type:

torch.Tensor

fit(values: Tensor) None[source]

Fit the scaler to the data.

Computes the per-feature mean and standard deviation from values.

Parameters:

values (torch.Tensor) – Input tensor of shape (n_samples, n_features).

fit_transform(values: Tensor) Tensor[source]

Fit the scaler to the data and return the standardized tensor.

Parameters:

values (torch.Tensor) – Input tensor of shape (n_samples, n_features).

Returns:

Standardized tensor of the same shape as values, with each feature having zero mean and unit standard deviation.

Return type:

torch.Tensor

inverse_transform(values: Tensor) Tensor[source]

Reverse the standardization applied by fit_transform().

Parameters:

values (torch.Tensor) – Standardized tensor of shape (n_samples, n_features).

Returns:

Tensor rescaled back to the original value range.

Return type:

torch.Tensor

means: Tensor[source]
stds: Tensor[source]

torch_bsf.sampling module

torch_bsf.sampling.simplex_grid(n_params: int, degree: int) Tensor[source]

Generates a uniform grid on a simplex.

Parameters:
  • n_params (int) – The number of parameters (vertices of the simplex).

  • degree (int) – The degree of the grid.

Returns:

Array of grid points in shape (N, n_params), where N = binom(degree + n_params - 1, n_params - 1).

Return type:

torch.Tensor

torch_bsf.sampling.simplex_random(n_params: int, n_samples: int, seed: int | None = None) Tensor[source]

Generates random points on a simplex using Dirichlet distribution.

Parameters:
  • n_params (int) – The number of parameters (vertices of the simplex).

  • n_samples (int) – The number of samples.

  • seed (int or None, optional) – Random seed for reproducibility. When provided, a new numpy.random.Generator is created with this seed so the global NumPy random state is not affected. When None (default), the global numpy.random state is used.

Returns:

Array of sample points in shape (n_samples, n_params).

Return type:

torch.Tensor

Raises:

ValueError – If n_params is not positive or n_samples is negative.

torch_bsf.sampling.simplex_sobol(n_params: int, n_samples: int, seed: int | None = None) Tensor[source]

Generates quasi-random points on a simplex using Sobol sequence.

Uses a scrambled Sobol sequence projected onto the simplex via the sorted-differences mapping. Sobol sequences are low-discrepancy: they fill space more uniformly than pseudo-random draws, giving a convergence rate of roughly O((log N)^(d-1) / N) instead of the O(1/sqrt(N)) rate of Monte Carlo sampling (where d = n_params - 1).

Note

Power-of-two sample sizes are strongly recommended. Sobol sequences are constructed in base 2 and achieve their best uniformity guarantees when n_samples is an exact power of 2 (e.g. 64, 128, 256, …). When n_samples is not a power of 2, the strongest low-discrepancy guarantees no longer apply and the coverage of the simplex can be somewhat less uniform. A UserWarning is emitted automatically when a non-power-of-two value is requested.

Note

scipy is required. This function relies on scipy.stats.qmc.Sobol. Install it with pip install scipy or pip install pytorch-bsf[sampling].

Parameters:
  • n_params (int) – The number of parameters (vertices of the simplex). Must be at least 2. The Sobol sequence is drawn in n_params - 1 dimensions and then mapped to the simplex.

  • n_samples (int) – The number of samples. For best coverage, use a power of 2 (e.g. 64, 128, 256).

  • seed (int or None, optional) – Random seed for the scrambled Sobol sequence, passed directly to scipy.stats.qmc.Sobol as its seed argument. When None (default), a random scramble is used on each call. Pass an integer for reproducible sequences.

Returns:

Array of sample points in shape (n_samples, n_params). Each row is non-negative and sums to 1.

Return type:

torch.Tensor

Raises:
  • ImportError – If SciPy is not installed.

  • ValueError – If n_params is less than 2 or n_samples is negative.

Warns:

UserWarning – If n_samples is not a power of 2. The samples are still returned, but the low-discrepancy coverage guarantee is weakened.

Examples

>>> import torch
>>> from torch_bsf.sampling import simplex_sobol
>>> pts = simplex_sobol(n_params=3, n_samples=128)
>>> pts.shape
torch.Size([128, 3])
>>> pts.sum(dim=1).allclose(torch.ones(128))
True

torch_bsf.sklearn module

class torch_bsf.sklearn.BezierSimplexRegressor(degree: int = 3, smoothness_weight: float = 0.0, init: BezierSimplex | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, max_epochs: int = 1000, accelerator: str = 'auto', devices: int | str = 'auto', precision: str = '32-true', trainer_kwargs: dict[str, Any] | None = None)[source]

Bases: BaseEstimator, RegressorMixin

Scikit-learn wrapper for Bézier Simplex Fitting.

Parameters:
  • degree (int, default=3) – The degree of the Bézier simplex.

  • smoothness_weight (float, default=0.0) – The weight of smoothness penalty.

  • init (BezierSimplex | ControlPointsData | None, default=None) – Initial control points or model.

  • freeze (Iterable[Index] | None, default=None) – Indices of control points to freeze during training.

  • batch_size (int | None, default=None) – Size of minibatches.

  • max_epochs (int, default=1000) – Maximum number of epochs to train.

  • accelerator (str, default="auto") – Hardware accelerator to use (“cpu”, “gpu”, “auto”, etc.).

  • devices (int | str, default="auto") – Number of devices to use.

  • precision (str, default="32-true") – Floating point precision.

  • trainer_kwargs (dict | None, default=None) – Additional keyword arguments for lightning.pytorch.Trainer.

fit(X: Any, y: Any)[source]

Fit the Bézier simplex model.

Parameters:
  • X (array-like of shape (n_samples, n_params)) – Training data (parameters on a simplex).

  • y (array-like of shape (n_samples, n_values)) – Target values.

Returns:

self – Fitted estimator.

Return type:

object

predict(X: Any) ndarray[source]

Predict using the Bézier simplex model.

Parameters:

X (array-like of shape (n_samples, n_params)) – Input parameters.

Returns:

y_pred – Predicted values.

Return type:

ndarray of shape (n_samples, n_values)

score(X: Any, y: Any, sample_weight: Any = None) float[source]

Return the coefficient of determination R^2 of the prediction.

Parameters:
  • X (array-like of shape (n_samples, n_params)) – Test samples.

  • y (array-like of shape (n_samples, n_values)) – True values.

  • sample_weight (array-like of shape (n_samples,), default=None) – Sample weights.

Returns:

score – R^2 of self.predict(X) wrt. y.

Return type:

float

set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') BezierSimplexRegressor[source]

Configure whether metadata should be requested to be passed to the score method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:

sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for sample_weight parameter in score.

Returns:

self – The updated object.

Return type:

object

torch_bsf.validator module

torch_bsf.validator.index_list(val: str) list[list[int]][source]

Parse val into a list of indices.

Parameters:

val – A string expression of a list of indices.

Return type:

The persed indices.

torch_bsf.validator.indices_schema(n_params: int, degree: int) dict[str, Any][source]

Generate a JSON schema for indices of the control points with given n_params and degree.

Parameters:
  • n_params – The number of index elements of control points.

  • degree – The degree of a Bézier simplex.

Return type:

A JSON schema.

Raises:

ValueError – If n_params or degree is negative.

See also

validate_simplex_indices

Validate an instance that has appropriate n_params and degree.

torch_bsf.validator.int_or_str(val: str) int | str[source]

Try to parse int. Return the int value if the parse is succeeded; the original string otherwise.

Parameters:

val – The value to try to convert into int.

Returns:

The converted integer or the original value.

Return type:

Union[int, str]

torch_bsf.validator.validate_simplex_indices(instance: object, n_params: int, degree: int) None[source]

Validate an instance that has appropriate n_params and degree.

Parameters:
  • instance – An index list of a Bézier simplex.

  • n_params – The n_params of a Bézier simplex.

  • degree – The degree of a Bézier simplex.

Raises:

ValidationError – If instance does not comply with the following schema or has an inner array whose sum is not equal to degree. `     {         "$schema": "https://json-schema.org/draft/2020-12/schema",         "type": "array",         "items": {             "type": "array",             "items": {                 "type": "integer",                 "minimum": 0,                 "maximum": degree,             },             "minItems": n_params,             "maxItems": n_params,         },     }     `

Module contents

torch_bsf: PyTorch implementation of Bézier simplex fitting.

class torch_bsf.BezierSimplex(control_points: ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, *, _n_params: int | None = None, _degree: int | None = None, _n_values: int | None = None)[source]

Bases: LightningModule

A Bézier simplex model.

Parameters:
  • control_points – The control points of the Bézier simplex. Pass None only when reconstructing a model from a Lightning checkpoint via load_from_checkpoint() — in that case all three shape parameters (_n_params, _degree, _n_values) must be provided so that a correctly-shaped placeholder can be built before the saved state dict is loaded into it.

  • smoothness_weight – The weight of the smoothness penalty term added to the training loss. When greater than zero, adjacent control points are encouraged to have similar values. Defaults to 0.0 (no penalty).

  • _n_paramsCheckpoint-reconstruction parameter — do not set manually. The number of parameters (source dimension + 1) used to build the placeholder control points when control_points is None. Automatically saved to, and restored from, Lightning checkpoints.

  • _degreeCheckpoint-reconstruction parameter — do not set manually. The degree of the Bézier simplex used to build the placeholder when control_points is None. Automatically saved to, and restored from, Lightning checkpoints.

  • _n_valuesCheckpoint-reconstruction parameter — do not set manually. The number of values (target dimension) used to build the placeholder when control_points is None. Automatically saved to, and restored from, Lightning checkpoints.

Examples

>>> import lightning.pytorch as L
>>> from lightning.pytorch.callbacks.early_stopping import EarlyStopping
>>> from torch.utils.data import DataLoader, TensorDataset
>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters
>>> dl = DataLoader(TensorDataset(ts, xs))
>>> bs = torch_bsf.bezier_simplex.randn(
...     n_params=int(ts.shape[1]),
...     n_values=int(xs.shape[1]),
...     degree=3,
... )
>>> trainer = L.Trainer(
...     callbacks=[EarlyStopping(monitor="train_mse")],
...     enable_progress_bar=False,
... )
>>> trainer.fit(bs, dl)
>>> ts, xs = bs.meshgrid()
configure_optimizers() Optimizer[source]
property degree: int[source]

The degree of the Bézier simplex.

forward(t: Tensor) Tensor[source]

Process a forwarding step of training.

Parameters:

t – A minibatch of parameter vectors \(\mathbf t\).

Return type:

A minibatch of value vectors.

freeze_row(index: torch_bsf.bezier_simplex.Index) None[source]

Freeze a control point so its gradient is zeroed after every backward.

Parameters:

index – The index of the control point to freeze.

meshgrid(num: int = 100) tuple[Tensor, Tensor][source]

Computes a meshgrid of the Bézier simplex.

Parameters:

num – The number of grid points on each edge.

Returns:

  • ts – A parameter matrix of the mesh grid.

  • xs – A value matrix of the mesh grid.

property n_params: int[source]

The number of parameters, i.e., the source dimension + 1.

property n_values: int[source]

The number of values, i.e., the target dimension.

on_after_backward() None[source]

Zero gradients for frozen control-point rows after each backward pass.

smoothness_penalty() Tensor[source]

Computes the smoothness penalty of the Bézier simplex.

Return type:

The smoothness penalty.

test_step(batch, batch_idx) dict[str, Any][source]
training_step(batch, batch_idx) dict[str, Any][source]
validation_step(batch, batch_idx) None[source]
class torch_bsf.BezierSimplexDataModule(params: Path, values: Path, header: int = 0, batch_size: int | None = None, split_ratio: float = 1.0, normalize: Literal['max', 'std', 'quantile', 'none'] = 'none')[source]

Bases: LightningDataModule

A data module for training a Bézier simplex.

Parameters:
  • params – The path to a parameter file.

  • values – The path to a value file.

  • header – The number of header rows in the parameter file and the value file. The first header rows are skipped in reading the files.

  • batch_size – The size of each minibatch.

  • split_ratio – The ratio of train-val split. Must be greater than 0 and less than or equal to 1. If it is set to 1, then all the data are used for training and the validation step will be skipped.

  • normalize – The data normalization method. Either "max", "std", "quantile", or "none".

fit_transform(values: Tensor) Tensor[source]
inverse_transform(values: Tensor) Tensor[source]
load_data(path) Tensor[source]
load_params() Tensor[source]
load_values() Tensor[source]
setup(stage: str | None = None)[source]
test_dataloader() DataLoader[source]
train_dataloader() DataLoader[source]
val_dataloader() DataLoader[source]
torch_bsf.fit(params: Tensor, values: Tensor, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) BezierSimplex[source]

Fits a Bézier simplex.

Parameters:
  • params – The data.

  • values – The label data.

  • degree – The degree of the Bézier simplex.

  • init – The initial values of a Bézier simplex or control points.

  • smoothness_weight – The weight of smoothness penalty.

  • freeze – The indices of control points to exclude from training.

  • batch_size – The size of minibatch.

  • seed – Random seed passed to lightning.pytorch.seed_everything() to set RNG seeds for improved reproducibility. When None (default), no seed is set. For full determinism, also set Trainer(deterministic=True) and use compatible settings.

  • kwargs – All arguments for lightning.pytorch.Trainer

Return type:

A trained Bézier simplex.

Raises:
  • TypeError – From Trainer or DataLoader.

  • MisconfigurationException – From Trainer.

Examples

>>> import torch
>>> import torch_bsf

Prepare training data

>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters

Train a model

>>> bs = torch_bsf.fit(params=ts, values=xs, degree=3)

Predict by the trained model

>>> t = [[0.2, 0.3, 0.5]]
>>> x = bs(t)
>>> print(f"{t} -> {x}")
[[0.2, 0.3, 0.5]] -> tensor([[..., ..., ...]], grad_fn=<...>)

See also

lightning.pytorch.Trainer

Argument descriptions.

torch.DataLoader

Argument descriptions.

torch_bsf.fit_kfold(params: Tensor, values: Tensor, n_folds: int = 5, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) ModuleList[source]

Fits an ensemble of Bézier simplices using k-fold cross-validation.

Splits the training data into n_folds folds via KFoldTrainer and trains one model per fold on the data from the remaining n_folds - 1 folds. The resulting ModuleList can be passed directly to torch_bsf.active_learning.suggest_next_points() to drive a Query-By-Committee active learning loop.

If len(params) < n_folds, the actual number of folds is capped at len(params) to avoid empty training subsets.

Parameters:
  • params – The parameter data on the simplex.

  • values – The label data.

  • n_folds – The number of cross-validation folds (committee size). Defaults to 5.

  • degree – The degree of the Bézier simplex.

  • init – The initial values of a Bézier simplex or control points.

  • smoothness_weight – The weight of the smoothness penalty.

  • freeze – The indices of control points to exclude from training.

  • batch_size – The size of a minibatch. Defaults to full-batch (consistent with fit()).

  • seed – Random seed passed to lightning.pytorch.seed_everything() for reproducible training. When None (default), no seed is set.

  • kwargs – All arguments for KFoldTrainer (which itself accepts all lightning.pytorch.Trainer arguments). For example, max_epochs=10, enable_progress_bar=False, logger=False, shuffle=True, or stratified=False. By default, num_sanity_val_steps=0 and limit_val_batches=0.0 are set internally to disable per-fold validation for speed; pass explicit values to override these defaults, e.g. num_sanity_val_steps=2, limit_val_batches=1.0.

Returns:

A ModuleList of min(n_folds, len(params)) trained BezierSimplex models, one per fold.

Return type:

torch.nn.ModuleList

Raises:

ValueError – If n_folds < 2, if len(params) < 2 (too few samples for any fold split), if neither / both of degree and init are provided, if batch_size is truthy but not a positive integer, or if the reserved argument num_folds is supplied via **kwargs.

Examples

>>> import torch
>>> import torch_bsf
>>> from torch_bsf.active_learning import suggest_next_points
>>> from torch_bsf.sampling import simplex_grid

Prepare training data

>>> params = simplex_grid(n_params=3, degree=3)
>>> values = params.pow(2).sum(dim=1, keepdim=True)

Build a 5-fold ensemble and suggest the 2 most uncertain points

>>> models = torch_bsf.fit_kfold(
...     params=params,
...     values=values,
...     degree=3,
...     max_epochs=1,
...     enable_progress_bar=False,
...     enable_model_summary=False,
...     logger=False,
... )
>>> suggestions = suggest_next_points(models, n_suggestions=2, method="qbc")
>>> suggestions.shape
torch.Size([2, 3])

See also

fit

Fit a single Bézier simplex.

torch_bsf.active_learning.suggest_next_points

Use the ensemble for active learning.

torch_bsf.longest_edge_criterion(bs: BezierSimplex, s: float = 0.5) tuple[int, int, float][source]

Select the edge with the greatest value-space length and split it at s.

The “length” of edge \((i, j)\) is the Euclidean distance between the Bézier simplex values at the two vertices of the parameter domain:

\[\ell_{ij} = \|B(e_i) - B(e_j)\|_2 = \|b_{n \cdot e_i} - b_{n \cdot e_j}\|_2\]

where \(n\) is the degree and \(e_k\) is the \(k\)-th unit vector.

Parameters:
  • bs – The Bézier simplex.

  • s – Split parameter included verbatim in the returned tuple (defaults to 0.5).

Returns:

The edge (i, j) with the greatest value-space length, together with s.

Return type:

(i, j, s)

Examples

>>> import torch
>>> from torch_bsf.bezier_simplex import rand
>>> from torch_bsf.splitting import longest_edge_criterion, split_by_criterion
>>> bs = rand(n_params=3, n_values=2, degree=2)
>>> i, j, s = longest_edge_criterion(bs)
>>> 0 <= i < j < bs.n_params
True
torch_bsf.max_error_criterion(params: Tensor, values: Tensor, grid_size: int = 10) Callable[[BezierSimplex], tuple[int, int, float]][source]

Build a criterion that minimises the combined approximation error.

For each candidate edge (i, j) and split parameter s drawn from a uniform grid over (0, 1), the combined mean-squared error

\[E(i, j, s) = \mathrm{MSE}_A + \mathrm{MSE}_B\]

is computed, where \(\mathrm{MSE}_A\) and \(\mathrm{MSE}_B\) are evaluated on the portions of the data that fall within each sub-simplex. The candidate (i, j, s) minimising E is returned.

Parameters:
  • params – Parameter vectors of shape (N, n_params).

  • values – Target value vectors of shape (N, n_values).

  • grid_size – Number of candidate s values in the grid search (default 10). Larger values give a finer search at higher cost.

Returns:

A callable criterion(bs) -> (i, j, s).

Return type:

SplitCriterion

Examples

>>> import torch
>>> from torch_bsf.bezier_simplex import rand
>>> from torch_bsf.splitting import max_error_criterion, split_by_criterion
>>> torch.manual_seed(0)
<torch._C.Generator object at 0x...>
>>> params = torch.tensor([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
>>> values = torch.tensor([[0.0], [0.5], [1.0]])
>>> bs = rand(n_params=2, n_values=1, degree=1)
>>> criterion = max_error_criterion(params, values)
>>> i, j, s = criterion(bs)
>>> i == 0 and j == 1
True
torch_bsf.reparametrize(t: Tensor, i: int, j: int, s: float, subsimplex: str) tuple[Tensor, Tensor][source]

Re-parameterise points from the original simplex to a sub-simplex.

After splitting edge (i, j) at s, each data point on the original simplex belongs to one of the two sub-simplices. This function converts the original barycentric coordinates to the sub-simplex’s local barycentric coordinates.

Parameters:
  • t – Parameter vectors of shape (N, n_params) on the original simplex (each row sums to 1).

  • i – Edge vertex indices used in split().

  • j – Edge vertex indices used in split().

  • s – Split parameter.

  • subsimplex"A" for the sub-simplex covering \(\{t : t_j / (t_i + t_j) \le s\}\), or "B" for the complementary region.

Returns:

(u, mask)

  • u — local barycentric coordinates on the sub-simplex, shape (N, n_params).

  • mask — boolean tensor of shape (N,) indicating which input points belong to the requested sub-simplex. Points where \(t_i = t_j = 0\) belong to both sub-simplices and are included in both masks.

Return type:

tuple[torch.Tensor, torch.Tensor]

Notes

Sub-simplex A transformation:

\[u_j = \frac{t_j}{s}, \quad u_i = t_i - \frac{1-s}{s}\,t_j, \quad u_k = t_k \; (k \ne i, j).\]

Sub-simplex B transformation:

\[u_i = \frac{t_i}{1-s}, \quad u_j = t_j - \frac{s}{1-s}\,t_i, \quad u_k = t_k \; (k \ne i, j).\]
torch_bsf.split(bs: BezierSimplex, i: int, j: int, s: float = 0.5) tuple[BezierSimplex, BezierSimplex][source]

Split a Bézier simplex along edge (i, j) using the de Casteljau algorithm.

A new vertex is inserted on the edge between vertex i and vertex j of the parameter domain at the relative position s. The original simplex is thereby subdivided into two sub-simplices that together cover the entire original domain.

Parameters:
  • bs – The Bézier simplex to split.

  • i – Indices of the two vertices that define the split edge (0-indexed, i j).

  • j – Indices of the two vertices that define the split edge (0-indexed, i j).

  • s – Split parameter in the open interval (0, 1). s = 0.5 produces a midpoint split. The new vertex is located at \((1-s)\,v_i + s\,v_j\) in the parameter domain.

Returns:

(bs_A, bs_B)

  • bs_A — sub-simplex that replaces vertex \(j\) with the new vertex. It covers the sub-domain \(\{t : t_j / (t_i + t_j) \le s\}\).

  • bs_B — sub-simplex that replaces vertex \(i\) with the new vertex. It covers the sub-domain \(\{t : t_j / (t_i + t_j) \ge s\}\).

Return type:

tuple[BezierSimplex, BezierSimplex]

Notes

The algorithm runs \(n\) de Casteljau steps along the chosen edge direction, where \(n\) is the degree of the Bézier simplex.

At each step \(r = 1, \ldots, n\) the control-point matrix is updated as

\[\begin{split}c^{(r)}_\alpha = \begin{cases} s \cdot c^{(r-1)}_\alpha + (1-s) \cdot c^{(r-1)}_{\alpha + e_i - e_j} & \text{if } \alpha_j \ge 1 \\ c^{(r-1)}_\alpha & \text{otherwise,} \end{cases}\end{split}\]

and rows with \(\alpha_j = r\) are saved as control points of bs_A. An analogous recursion with the roles of \(i\) and \(j\) swapped gives bs_B.

The two sub-simplices share the split point — both evaluate to the same value at the new vertex.

Examples

Split the identity Bézier curve at the midpoint:

>>> import torch
>>> from torch_bsf.bezier_simplex import BezierSimplex
>>> from torch_bsf.splitting import split
>>> bs = BezierSimplex({(1, 0): [0.0], (0, 1): [1.0]})
>>> bs_A, bs_B = split(bs, i=0, j=1, s=0.5)
>>> float(bs_A.control_points[(0, 1)].item())  # split-point value
0.5
>>> float(bs_B.control_points[(1, 0)].item())  # same split point from other side
0.5
torch_bsf.split_by_criterion(bs: BezierSimplex, criterion: Callable[[BezierSimplex], tuple[int, int, float]]) tuple[BezierSimplex, BezierSimplex][source]

Split a Bézier simplex using a split criterion.

Convenience wrapper that calls criterion(bs) to obtain the edge indices and split parameter, then delegates to split().

Parameters:
  • bs – The Bézier simplex to split.

  • criterion – A SplitCriterion callable that returns (i, j, s).

Returns:

(bs_A, bs_B) – The two sub-Bézier-simplices; see split() for details.

Return type:

tuple[BezierSimplex, BezierSimplex]

Examples

>>> from torch_bsf.bezier_simplex import rand
>>> from torch_bsf.splitting import longest_edge_criterion, split_by_criterion
>>> bs = rand(n_params=2, n_values=3, degree=2)
>>> bs_A, bs_B = split_by_criterion(bs, longest_edge_criterion)
>>> bs_A.degree == bs.degree
True
torch_bsf.validate_control_points(data: dict[str, list[float]])[source]

Validates control points.

Parameters:

data – The control points.

Raises:

ValidationError – If the control points are invalid.

Examples

>>> from torch_bsf.bezier_simplex import validate_control_points
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0],
... })
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "0, 0, 1": [0.0, 0.0, 1.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: '0, 0, 1' is not valid under any of the given schemas
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0],
...     "(0, 0)": [0.0, 0.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0)
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1, 0)": [0.0, 0.0, 1.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0, 1, 0)
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0, 1.0, 0.0]
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0]