torch_bsf package
Submodules
torch_bsf.active_learning module
- torch_bsf.active_learning.suggest_next_points(models: Sequence[Module], n_suggestions: int = 1, n_candidates: int = 1000, method: Literal['qbc', 'density'] = 'qbc', params: Tensor | None = None, n_params: int | None = None) Tensor[source]
Suggest points on the simplex where new data should be sampled.
- Parameters:
models (Sequence[nn.Module]) – An ensemble of models (e.g., from k-fold cross-validation). Each model must be callable with a tensor of shape
(n_candidates, n_params)and return predictions. Accepts anySequenceofModuleinstances, includingModuleList.n_suggestions (int, default=1) – The number of points to suggest.
n_candidates (int, default=1000) – The number of candidate points to evaluate.
method (Literal["qbc", "density"], default="qbc") – The method to use: - “qbc”: Query-By-Committee. Suggests points where models disagree most. - “density”: Suggests points that are furthest from existing training points.
params (torch.Tensor, optional) – The existing training parameters. Required for method=”density”.
n_params (int, optional) – The number of simplex parameters (input dimension). When omitted, the value is inferred from
models[0].n_paramsif that attribute exists. When provided andmodels[0]exposes ann_paramsattribute, the two values must agree; aValueErroris raised on mismatch.
- Returns:
The suggested points in shape (n_suggestions, n_params).
- Return type:
torch_bsf.bezier_simplex module
- class torch_bsf.bezier_simplex.BezierSimplex(control_points: ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, *, _n_params: int | None = None, _degree: int | None = None, _n_values: int | None = None)[source]
Bases:
LightningModuleA Bézier simplex model.
- Parameters:
control_points – The control points of the Bézier simplex. Pass
Noneonly when reconstructing a model from a Lightning checkpoint viaload_from_checkpoint()— in that case all three shape parameters (_n_params,_degree,_n_values) must be provided so that a correctly-shaped placeholder can be built before the saved state dict is loaded into it.smoothness_weight – The weight of the smoothness penalty term added to the training loss. When greater than zero, adjacent control points are encouraged to have similar values. Defaults to
0.0(no penalty)._n_params – Checkpoint-reconstruction parameter — do not set manually. The number of parameters (source dimension + 1) used to build the placeholder control points when
control_pointsisNone. Automatically saved to, and restored from, Lightning checkpoints._degree – Checkpoint-reconstruction parameter — do not set manually. The degree of the Bézier simplex used to build the placeholder when
control_pointsisNone. Automatically saved to, and restored from, Lightning checkpoints._n_values – Checkpoint-reconstruction parameter — do not set manually. The number of values (target dimension) used to build the placeholder when
control_pointsisNone. Automatically saved to, and restored from, Lightning checkpoints.
Examples
>>> import lightning.pytorch as L >>> from lightning.pytorch.callbacks.early_stopping import EarlyStopping >>> from torch.utils.data import DataLoader, TensorDataset >>> ts = torch.tensor( # parameters on a simplex ... [ ... [3/3, 0/3, 0/3], ... [2/3, 1/3, 0/3], ... [2/3, 0/3, 1/3], ... [1/3, 2/3, 0/3], ... [1/3, 1/3, 1/3], ... [1/3, 0/3, 2/3], ... [0/3, 3/3, 0/3], ... [0/3, 2/3, 1/3], ... [0/3, 1/3, 2/3], ... [0/3, 0/3, 3/3], ... ] ... ) >>> xs = 1 - ts * ts # values corresponding to the parameters >>> dl = DataLoader(TensorDataset(ts, xs)) >>> bs = torch_bsf.bezier_simplex.randn( ... n_params=int(ts.shape[1]), ... n_values=int(xs.shape[1]), ... degree=3, ... ) >>> trainer = L.Trainer( ... callbacks=[EarlyStopping(monitor="train_mse")], ... enable_progress_bar=False, ... ) >>> trainer.fit(bs, dl) >>> ts, xs = bs.meshgrid()
- forward(t: Tensor) Tensor[source]
Process a forwarding step of training.
- Parameters:
t – A minibatch of parameter vectors \(\mathbf t\).
- Return type:
A minibatch of value vectors.
- freeze_row(index: torch_bsf.bezier_simplex.Index) None[source]
Freeze a control point so its gradient is zeroed after every backward.
- Parameters:
index – The index of the control point to freeze.
- meshgrid(num: int = 100) tuple[Tensor, Tensor][source]
Computes a meshgrid of the Bézier simplex.
- Parameters:
num – The number of grid points on each edge.
- Returns:
ts – A parameter matrix of the mesh grid.
xs – A value matrix of the mesh grid.
- on_after_backward() None[source]
Zero gradients for frozen control-point rows after each backward pass.
- class torch_bsf.bezier_simplex.BezierSimplexDataModule(params: Path, values: Path, header: int = 0, batch_size: int | None = None, split_ratio: float = 1.0, normalize: Literal['max', 'std', 'quantile', 'none'] = 'none')[source]
Bases:
LightningDataModuleA data module for training a Bézier simplex.
- Parameters:
params – The path to a parameter file.
values – The path to a value file.
header – The number of header rows in the parameter file and the value file. The first
headerrows are skipped in reading the files.batch_size – The size of each minibatch.
split_ratio – The ratio of train-val split. Must be greater than 0 and less than or equal to 1. If it is set to 1, then all the data are used for training and the validation step will be skipped.
normalize – The data normalization method. Either
"max","std","quantile", or"none".
- test_dataloader() DataLoader[source]
- train_dataloader() DataLoader[source]
- val_dataloader() DataLoader[source]
- torch_bsf.bezier_simplex.fit(params: Tensor, values: Tensor, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) BezierSimplex[source]
Fits a Bézier simplex.
- Parameters:
params – The data.
values – The label data.
degree – The degree of the Bézier simplex.
init – The initial values of a Bézier simplex or control points.
smoothness_weight – The weight of smoothness penalty.
freeze – The indices of control points to exclude from training.
batch_size – The size of minibatch.
seed – Random seed passed to
lightning.pytorch.seed_everything()to set RNG seeds for improved reproducibility. WhenNone(default), no seed is set. For full determinism, also setTrainer(deterministic=True)and use compatible settings.kwargs – All arguments for lightning.pytorch.Trainer
- Return type:
A trained Bézier simplex.
- Raises:
TypeError – From Trainer or DataLoader.
MisconfigurationException – From Trainer.
Examples
>>> import torch >>> import torch_bsf
Prepare training data
>>> ts = torch.tensor( # parameters on a simplex ... [ ... [3/3, 0/3, 0/3], ... [2/3, 1/3, 0/3], ... [2/3, 0/3, 1/3], ... [1/3, 2/3, 0/3], ... [1/3, 1/3, 1/3], ... [1/3, 0/3, 2/3], ... [0/3, 3/3, 0/3], ... [0/3, 2/3, 1/3], ... [0/3, 1/3, 2/3], ... [0/3, 0/3, 3/3], ... ] ... ) >>> xs = 1 - ts * ts # values corresponding to the parameters
Train a model
>>> bs = torch_bsf.fit(params=ts, values=xs, degree=3)
Predict by the trained model
>>> t = [[0.2, 0.3, 0.5]] >>> x = bs(t) >>> print(f"{t} -> {x}") [[0.2, 0.3, 0.5]] -> tensor([[..., ..., ...]], grad_fn=<...>)
See also
lightning.pytorch.TrainerArgument descriptions.
torch.DataLoaderArgument descriptions.
- torch_bsf.bezier_simplex.fit_kfold(params: Tensor, values: Tensor, n_folds: int = 5, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) ModuleList[source]
Fits an ensemble of Bézier simplices using k-fold cross-validation.
Splits the training data into
n_foldsfolds viaKFoldTrainerand trains one model per fold on the data from the remainingn_folds - 1folds. The resultingModuleListcan be passed directly totorch_bsf.active_learning.suggest_next_points()to drive a Query-By-Committee active learning loop.If
len(params) < n_folds, the actual number of folds is capped atlen(params)to avoid empty training subsets.- Parameters:
params – The parameter data on the simplex.
values – The label data.
n_folds – The number of cross-validation folds (committee size). Defaults to 5.
degree – The degree of the Bézier simplex.
init – The initial values of a Bézier simplex or control points.
smoothness_weight – The weight of the smoothness penalty.
freeze – The indices of control points to exclude from training.
batch_size – The size of a minibatch. Defaults to full-batch (consistent with
fit()).seed – Random seed passed to
lightning.pytorch.seed_everything()for reproducible training. WhenNone(default), no seed is set.kwargs – All arguments for
KFoldTrainer(which itself accepts alllightning.pytorch.Trainerarguments). For example,max_epochs=10,enable_progress_bar=False,logger=False,shuffle=True, orstratified=False. By default,num_sanity_val_steps=0andlimit_val_batches=0.0are set internally to disable per-fold validation for speed; pass explicit values to override these defaults, e.g.num_sanity_val_steps=2, limit_val_batches=1.0.
- Returns:
A
ModuleListofmin(n_folds, len(params))trainedBezierSimplexmodels, one per fold.- Return type:
- Raises:
ValueError – If
n_folds < 2, iflen(params) < 2(too few samples for any fold split), if neither / both ofdegreeandinitare provided, ifbatch_sizeis truthy but not a positive integer, or if the reserved argumentnum_foldsis supplied via**kwargs.
Examples
>>> import torch >>> import torch_bsf >>> from torch_bsf.active_learning import suggest_next_points >>> from torch_bsf.sampling import simplex_grid
Prepare training data
>>> params = simplex_grid(n_params=3, degree=3) >>> values = params.pow(2).sum(dim=1, keepdim=True)
Build a 5-fold ensemble and suggest the 2 most uncertain points
>>> models = torch_bsf.fit_kfold( ... params=params, ... values=values, ... degree=3, ... max_epochs=1, ... enable_progress_bar=False, ... enable_model_summary=False, ... logger=False, ... ) >>> suggestions = suggest_next_points(models, n_suggestions=2, method="qbc") >>> suggestions.shape torch.Size([2, 3])
See also
fitFit a single Bézier simplex.
torch_bsf.active_learning.suggest_next_pointsUse the ensemble for active learning.
- torch_bsf.bezier_simplex.load(path: str | Path, *, pt_weights_only: bool | None = None) BezierSimplex[source]
Loads a Bézier simplex from a file.
- Parameters:
path – The path to a file.
pt_weights_only – Whether to load weights only. This parameter is only effective when loading PyTorch (
.pt) files. For other formats (e.g.,.json,.yml), data loading is inherently safe and this parameter is ignored. IfNone, it defaults toFalse.
- Return type:
A Bézier simplex.
- Raises:
ValueError – If the file type is unknown.
ValidationError – If the control points are invalid.
Examples
>>> from torch_bsf import bezier_simplex >>> bs = bezier_simplex.load("tests/data/bezier_simplex.csv") >>> print(bs) BezierSimplex( (control_points): ControlPoints(n_params=2, degree=2, n_values=3) ) >>> print(bs(torch.tensor([[0.2, 0.8]]))) tensor([[..., ..., ...]], grad_fn=<...>)
- torch_bsf.bezier_simplex.monomial(variable: Iterable[float], degree: Iterable[int]) Tensor[source]
Computes a monomial \(\mathbf t^{\mathbf d} = t_1^{d_1} t_2^{d_2}\cdots t_M^{d^M}\).
- Parameters:
variable – The bases \(\mathbf t\).
degree – The powers \(\mathbf d\).
- Return type:
The monomial \(\mathbf t^{\mathbf d}\).
- torch_bsf.bezier_simplex.polynom(degree: int, index: Iterable[int]) float[source]
Computes a polynomial coefficient \(\binom{D}{\mathbf d} = \frac{D!}{d_1!d_2!\cdots d_M!}\).
- Parameters:
degree – The degree \(D\).
index – The index \(\mathbf d\).
- Return type:
The polynomial coefficient \(\binom{D}{\mathbf d}\).
- torch_bsf.bezier_simplex.rand(n_params: int, n_values: int, degree: int, smoothness_weight: float = 0.0) BezierSimplex[source]
Generates a random Bézier simplex.
The control points are initialized by random values. The values are uniformly distributed in [0, 1).
- Parameters:
n_params – The number of parameters, i.e., the source dimension + 1.
n_values – The number of values, i.e., the target dimension.
degree – The degree of the Bézier simplex.
smoothness_weight – The weight of smoothness penalty.
- Return type:
A random Bézier simplex.
- Raises:
ValueError – If
n_paramsorn_valuesordegreeis negative.
Examples
>>> import torch >>> from torch_bsf import bezier_simplex >>> bs = bezier_simplex.rand(n_params=2, n_values=3, degree=2) >>> print(bs) BezierSimplex( (control_points): ControlPoints(n_params=2, degree=2, n_values=3) ) >>> print(bs(torch.tensor([[0.2, 0.8]]))) tensor([[..., ..., ...]], grad_fn=<...>)
- torch_bsf.bezier_simplex.randn(n_params: int, n_values: int, degree: int, smoothness_weight: float = 0.0) BezierSimplex[source]
Generates a random Bézier simplex.
The control points are initialized by random values. The values are normally distributed with mean 0 and standard deviation 1.
- Parameters:
n_params – The number of parameters, i.e., the source dimension + 1.
n_values – The number of values, i.e., the target dimension.
degree – The degree of the Bézier simplex.
smoothness_weight – The weight of smoothness penalty.
- Return type:
A random Bézier simplex.
- Raises:
ValueError – If
n_paramsorn_valuesordegreeis negative.
Examples
>>> import torch >>> from torch_bsf import bezier_simplex >>> bs = bezier_simplex.randn(n_params=2, n_values=3, degree=2) >>> print(bs) BezierSimplex( (control_points): ControlPoints(n_params=2, degree=2, n_values=3) ) >>> print(bs(torch.tensor([[0.2, 0.8]]))) tensor([[..., ..., ...]], grad_fn=<...>)
- torch_bsf.bezier_simplex.save(path: str | Path, data: BezierSimplex) None[source]
Saves a Bézier simplex to a file.
- Parameters:
path – The file path to save.
data – The Bézier simplex to save.
- Raises:
ValueError – If the file type is unknown.
Examples
>>> import torch_bsf >>> bs = torch_bsf.bezier_simplex.randn(n_params=2, n_values=3, degree=2) >>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.pt", bs) >>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.csv", bs) >>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.tsv", bs) >>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.json", bs) >>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.yml", bs)
- torch_bsf.bezier_simplex.validate_control_points(data: dict[str, list[float]])[source]
Validates control points.
- Parameters:
data – The control points.
- Raises:
ValidationError – If the control points are invalid.
Examples
>>> from torch_bsf.bezier_simplex import validate_control_points >>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0, 1.0], ... })
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "0, 0, 1": [0.0, 0.0, 1.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: '0, 0, 1' is not valid under any of the given schemas
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0, 1.0], ... "(0, 0)": [0.0, 0.0, 0.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0)
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1, 0)": [0.0, 0.0, 1.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0, 1, 0)
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0, 1.0, 0.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0, 1.0, 0.0]
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0]
- torch_bsf.bezier_simplex.zeros(n_params: int, n_values: int, degree: int, smoothness_weight: float = 0.0) BezierSimplex[source]
Generates a Bézier simplex with control points at origin.
- Parameters:
n_params – The number of parameters, i.e., the source dimension + 1.
n_values – The number of values, i.e., the target dimension.
degree – The degree of the Bézier simplex.
smoothness_weight – The weight of smoothness penalty.
- Return type:
A Bézier simplex filled with zeros.
- Raises:
ValueError – If
n_paramsorn_valuesordegreeis negative.
Examples
>>> import torch >>> from torch_bsf import bezier_simplex >>> bs = bezier_simplex.zeros(n_params=2, n_values=3, degree=2) >>> print(bs) BezierSimplex( (control_points): ControlPoints(n_params=2, degree=2, n_values=3) ) >>> print(bs(torch.tensor([[0.2, 0.8]]))) tensor([[..., ..., ...]], grad_fn=<...>)
torch_bsf.control_points module
- class torch_bsf.control_points.ControlPoints(data: dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None)[source]
Bases:
ModuleControl points of a Bézier simplex stored as a single parameter matrix.
All control points are stored in one
nn.Parametermatrix of shape(n_indices, n_values). This eliminates the Python-level loop andtorch.stackcall that would otherwise occur on every forward pass, giving a direct O(1) access path to the parameter data.- matrix[source]
nn.Parameterof shape(n_indices, n_values)holding all control points in canonical simplex-index order.
Examples
>>> import torch_bsf >>> control_points = torch_bsf.control_points.ControlPoints({ ... (1, 0): [0.0, 0.1, 0.2], ... (0, 1): [1.0, 1.1, 1.2], ... })
>>> control_points.degree 1 >>> control_points.n_params 2 >>> control_points.n_values 3
>>> control_points[(1, 0)] tensor([0.0000, 0.1000, 0.2000], grad_fn=<SelectBackward0>) >>> control_points[(0, 1)] tensor([1.0000, 1.1000, 1.2000], grad_fn=<SelectBackward0>)
- indices() Iterator[tuple[int, ...]][source]
Iterates the index of control points of the Bézier simplex.
- Return type:
The indices in canonical order.
- torch_bsf.control_points.ControlPointsData: TypeAlias = dict[str, torch.Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], torch.Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[torch.Tensor, torch.Tensor] | dict[torch.Tensor, list[float]] | dict[torch.Tensor, tuple[float, ...]][source]
The data type of control points of a Bézier simplex.
- torch_bsf.control_points.Index: TypeAlias = str | typing.Sequence[int] | torch.Tensor[source]
The index type of control points of a Bézier simplex.
- torch_bsf.control_points.Value: TypeAlias = typing.Sequence[float] | torch.Tensor[source]
The value type of control points of a Bézier simplex.
- torch_bsf.control_points.simplex_indices(n_params: int, degree: int) Iterable[tuple[int, ...]][source]
Iterates the index of control points of a Bézier simplex.
- Parameters:
n_params – The tuple length of each index.
degree – The degree of the Bézier simplex.
- Return type:
The indices.
- torch_bsf.control_points.to_parameterdict(data: dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]]) dict[str, Tensor][source]
Convert data to a dictionary with canonical string keys.
torch_bsf.plotting module
- torch_bsf.plotting.plot_bezier_simplex(model: BezierSimplex, num: int = 100, ax=None, show_control_points: bool = True, *, max_control_points: int = 500, max_pairwise_points: int = 2000, **kwargs)[source]
Plots the Bézier simplex.
- Parameters:
model (BezierSimplex) – The Bézier simplex model to plot.
num (int) – The number of grid points for each edge. For
model.n_params >= 4this value is used only to decide whether to use a full meshgrid or random sampling: if the combinatorial meshgrid sizecomb(num + n_params - 1, n_params - 1)exceedsmax_pairwise_points, exactlymax_pairwise_pointsuniformly random simplex samples are drawn instead andnumno longer controls the sample count.ax (matplotlib.axes.Axes or None) – The matplotlib axes to plot on. If None, a new figure is created. Ignored when
model.n_params >= 4; for pairwise plots a new figure is created only whenmodel.n_values > 0(whenmodel.n_values == 0an empty(0, 0)ndarray is returned without creating a figure).show_control_points (bool) – Whether to show control points.
max_control_points (int) – Maximum number of control points to render in pairwise plots (
model.n_params >= 4). When the model has more control points than this limit, a random subset ofmax_control_pointsis drawn instead to avoid combinatorial slowdowns. Defaults to 500. Ignored whenshow_control_pointsisFalseormodel.n_params < 4. Must be a non-negative integer.max_pairwise_points (int) – Maximum number of sample points for the pairwise scatter plot (
model.n_params >= 4). When the combinatorial meshgrid size would exceed this limit,max_pairwise_pointsuniformly random simplex samples are drawn instead to bound memory usage and plot time. Defaults to 2000. Ignored whenmodel.n_params < 4. Must be a non-negative integer.**kwargs – Additional keyword arguments forwarded to the plot call. For
model.n_params == 2, forwarded toax.plot(curve). Formodel.n_params == 3andmodel.n_values >= 3, forwarded toax.plot_trisurf(3D surface). Formodel.n_params == 3andmodel.n_values == 2, ignored. Formodel.n_params >= 4, forwarded toax.scatter(pairwise).
- Returns:
The axes containing the plot. For
model.n_params <= 3a singleAxes(orAxes3D) is returned. Formodel.n_params >= 4a 2-Dnumpy.ndarrayofAxeswith shape(n_values, n_values)is returned (pairwise scatter plot).- Return type:
matplotlib.axes.Axes or mpl_toolkits.mplot3d.axes3d.Axes3D or numpy.ndarray
- Raises:
ImportError – If matplotlib is not installed. This dependency is required for all plotting backends used by this function.
ImportError – If SciPy is not installed and
model.n_params == 3. SciPy is required for the triangulation-based plotting used in the Bézier triangle case.ValueError – If
model.n_params < 2. This function only supports Bézier simplex models with at least two parameters.
torch_bsf.preprocessing module
- class torch_bsf.preprocessing.MinMaxScaler[source]
Bases:
objectMin-max scaler that normalizes values to the [0, 1] range.
- scales[source]
Scale factors (max - min) per feature dimension. Dimensions where max equals min are set to 1 to avoid division by zero.
- Type:
- fit(values: Tensor) None[source]
Fit the scaler to the data.
Computes the per-feature minimum and scale (max - min) from
values.- Parameters:
values (torch.Tensor) – Input tensor of shape
(n_samples, n_features).
- fit_transform(values: Tensor) Tensor[source]
Fit the scaler to the data and return the normalized tensor.
- Parameters:
values (torch.Tensor) – Input tensor of shape
(n_samples, n_features).- Returns:
Normalized tensor of the same shape as
values, with each feature scaled to the [0, 1] range.- Return type:
- inverse_transform(values: Tensor) Tensor[source]
Reverse the normalization applied by
fit_transform().- Parameters:
values (torch.Tensor) – Normalized tensor of shape
(n_samples, n_features).- Returns:
Tensor rescaled back to the original value range.
- Return type:
- class torch_bsf.preprocessing.NoneScaler[source]
Bases:
objectPass-through scaler that leaves values unchanged.
Useful as a no-op placeholder when normalization is not required, while still providing the same
fit(),fit_transform(), andinverse_transform()interface as the other scalers.- fit(values: Tensor) None[source]
No-op fit method included for API compatibility.
- Parameters:
values (torch.Tensor) – Input tensor (ignored).
- fit_transform(values: Tensor) Tensor[source]
Return
valuesunchanged.- Parameters:
values (torch.Tensor) – Input tensor of any shape.
- Returns:
The same tensor as
valueswithout any modification.- Return type:
- inverse_transform(values: Tensor) Tensor[source]
Return
valuesunchanged.- Parameters:
values (torch.Tensor) – Input tensor of any shape.
- Returns:
The same tensor as
valueswithout any modification.- Return type:
- class torch_bsf.preprocessing.QuantileScaler[source]
Bases:
objectQuantile-based scaler that normalizes values using percentile-based ranges.
Values are scaled so that the
q-th percentile maps to 0 and the(1 - q)-th percentile maps to 1, without clipping values to this range. This makes the scaler robust to outliers compared toMinMaxScaler.- q[source]
The lower quantile fraction used as the effective minimum. Defaults to
0.05(5th percentile), ignoring the bottom and top 5% of values.- Type:
- scales[source]
Per-feature scale factors (
(1-q)-quantile minusq-quantile). Dimensions where the scale is zero are set to 1 to avoid division by zero.- Type:
- fit(values: Tensor) None[source]
Fit the scaler to the data.
Computes per-feature quantile bounds from
values.- Parameters:
values (torch.Tensor) – Input tensor of shape
(n_samples, n_features).
- fit_transform(values: Tensor) Tensor[source]
Fit the scaler to the data and return the scaled tensor.
- Parameters:
values (torch.Tensor) – Input tensor of shape
(n_samples, n_features).- Returns:
Scaled tensor of the same shape as
values. Values outside the fitted quantile range may exceed [0, 1].- Return type:
- inverse_transform(values: Tensor) Tensor[source]
Reverse the scaling applied by
fit_transform().- Parameters:
values (torch.Tensor) – Scaled tensor of shape
(n_samples, n_features).- Returns:
Tensor rescaled back to the original value range.
- Return type:
- class torch_bsf.preprocessing.Scaler(*args, **kwargs)[source]
Bases:
ProtocolProtocol for data scalers with a fit/transform interface.
All concrete scaler classes (
MinMaxScaler,StdScaler,QuantileScaler,NoneScaler) satisfy this protocol via structural subtyping — no explicit inheritance is required.- fit_transform(values: Tensor) Tensor[source]
Fit the scaler to
valuesand return the transformed tensor.
- inverse_transform(values: Tensor) Tensor[source]
Reverse the transformation applied by
fit_transform().
- class torch_bsf.preprocessing.StdScaler[source]
Bases:
objectStandard-score (z-score) scaler that normalizes values to zero mean and unit variance.
- stds[source]
Per-feature standard deviations computed during
fit(). Dimensions with zero standard deviation are set to 1 to avoid division by zero.- Type:
- fit(values: Tensor) None[source]
Fit the scaler to the data.
Computes the per-feature mean and standard deviation from
values.- Parameters:
values (torch.Tensor) – Input tensor of shape
(n_samples, n_features).
- fit_transform(values: Tensor) Tensor[source]
Fit the scaler to the data and return the standardized tensor.
- Parameters:
values (torch.Tensor) – Input tensor of shape
(n_samples, n_features).- Returns:
Standardized tensor of the same shape as
values, with each feature having zero mean and unit standard deviation.- Return type:
- inverse_transform(values: Tensor) Tensor[source]
Reverse the standardization applied by
fit_transform().- Parameters:
values (torch.Tensor) – Standardized tensor of shape
(n_samples, n_features).- Returns:
Tensor rescaled back to the original value range.
- Return type:
torch_bsf.sampling module
- torch_bsf.sampling.simplex_grid(n_params: int, degree: int) Tensor[source]
Generates a uniform grid on a simplex.
- Parameters:
- Returns:
Array of grid points in shape (N, n_params), where N = binom(degree + n_params - 1, n_params - 1).
- Return type:
- torch_bsf.sampling.simplex_random(n_params: int, n_samples: int, seed: int | None = None) Tensor[source]
Generates random points on a simplex using Dirichlet distribution.
- Parameters:
n_params (int) – The number of parameters (vertices of the simplex).
n_samples (int) – The number of samples.
seed (int or None, optional) – Random seed for reproducibility. When provided, a new
numpy.random.Generatoris created with this seed so the global NumPy random state is not affected. WhenNone(default), the globalnumpy.randomstate is used.
- Returns:
Array of sample points in shape (n_samples, n_params).
- Return type:
- Raises:
ValueError – If
n_paramsis not positive orn_samplesis negative.
- torch_bsf.sampling.simplex_sobol(n_params: int, n_samples: int, seed: int | None = None) Tensor[source]
Generates quasi-random points on a simplex using Sobol sequence.
Uses a scrambled Sobol sequence projected onto the simplex via the sorted-differences mapping. Sobol sequences are low-discrepancy: they fill space more uniformly than pseudo-random draws, giving a convergence rate of roughly O((log N)^(d-1) / N) instead of the O(1/sqrt(N)) rate of Monte Carlo sampling (where d =
n_params - 1).Note
Power-of-two sample sizes are strongly recommended. Sobol sequences are constructed in base 2 and achieve their best uniformity guarantees when
n_samplesis an exact power of 2 (e.g. 64, 128, 256, …). Whenn_samplesis not a power of 2, the strongest low-discrepancy guarantees no longer apply and the coverage of the simplex can be somewhat less uniform. AUserWarningis emitted automatically when a non-power-of-two value is requested.Note
scipy is required. This function relies on
scipy.stats.qmc.Sobol. Install it withpip install scipyorpip install pytorch-bsf[sampling].- Parameters:
n_params (int) – The number of parameters (vertices of the simplex). Must be at least 2. The Sobol sequence is drawn in
n_params - 1dimensions and then mapped to the simplex.n_samples (int) – The number of samples. For best coverage, use a power of 2 (e.g. 64, 128, 256).
seed (int or None, optional) – Random seed for the scrambled Sobol sequence, passed directly to
scipy.stats.qmc.Sobolas itsseedargument. WhenNone(default), a random scramble is used on each call. Pass an integer for reproducible sequences.
- Returns:
Array of sample points in shape (n_samples, n_params). Each row is non-negative and sums to 1.
- Return type:
- Raises:
ImportError – If SciPy is not installed.
ValueError – If
n_paramsis less than 2 orn_samplesis negative.
- Warns:
UserWarning – If
n_samplesis not a power of 2. The samples are still returned, but the low-discrepancy coverage guarantee is weakened.
Examples
>>> import torch >>> from torch_bsf.sampling import simplex_sobol >>> pts = simplex_sobol(n_params=3, n_samples=128) >>> pts.shape torch.Size([128, 3]) >>> pts.sum(dim=1).allclose(torch.ones(128)) True
torch_bsf.sklearn module
- class torch_bsf.sklearn.BezierSimplexRegressor(degree: int | None = None, smoothness_weight: float = 0.0, init: BezierSimplex | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, max_epochs: int = 3, accelerator: str = 'auto', devices: int | str = 'auto', precision: str = '32-true', trainer_kwargs: dict[str, Any] | None = None)[source]
Bases:
BaseEstimator,RegressorMixinScikit-learn wrapper for Bézier Simplex Fitting.
- Parameters:
degree (int, default=3) – The degree of the Bézier simplex.
smoothness_weight (float, default=0.0) – The weight of smoothness penalty.
init (BezierSimplex | ControlPointsData | None, default=None) – Initial control points or model.
freeze (Iterable[Index] | None, default=None) – Indices of control points to freeze during training.
batch_size (int | None, default=None) – Size of minibatches.
max_epochs (int, default=1000) – Maximum number of epochs to train.
accelerator (str, default="auto") – Hardware accelerator to use (“cpu”, “gpu”, “auto”, etc.).
devices (int | str, default="auto") – Number of devices to use.
precision (str, default="32-true") – Floating point precision.
trainer_kwargs (dict | None, default=None) – Additional keyword arguments for lightning.pytorch.Trainer.
- fit(X: Any, y: Any)[source]
Fit the Bézier simplex model.
- Parameters:
X (array-like of shape (n_samples, n_params)) – Training data (parameters on a simplex).
y (array-like of shape (n_samples, n_values)) – Target values.
- Returns:
self – Fitted estimator.
- Return type:
- predict(X: Any) ndarray[source]
Predict using the Bézier simplex model.
- Parameters:
X (array-like of shape (n_samples, n_params)) – Input parameters.
- Returns:
y_pred – Predicted values.
- Return type:
ndarray of shape (n_samples, n_values)
- score(X: Any, y: Any, sample_weight: Any = None) float[source]
Return the coefficient of determination R^2 of the prediction.
- Parameters:
X (array-like of shape (n_samples, n_params)) – Test samples.
y (array-like of shape (n_samples, n_values)) – True values.
sample_weight (array-like of shape (n_samples,), default=None) – Sample weights.
- Returns:
score – R^2 of self.predict(X) wrt. y.
- Return type:
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') BezierSimplexRegressor[source]
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
torch_bsf.splitting module
Bézier simplex splitting via the de Casteljau algorithm.
This module provides functions to split a Bézier simplex along a single edge by inserting a new vertex and applying the de Casteljau algorithm to compute the control points of the two resulting sub-simplices.
The split point can be chosen explicitly or determined automatically by
optimising a SplitCriterion.
Examples
Split a Bézier curve at its midpoint:
>>> import torch
>>> from torch_bsf.bezier_simplex import rand
>>> from torch_bsf.splitting import split
>>> bs = rand(n_params=2, n_values=3, degree=3)
>>> bs_A, bs_B = split(bs, i=0, j=1, s=0.5)
>>> bs_A.n_params == bs.n_params and bs_A.degree == bs.degree
True
Use the longest-edge criterion to choose the split automatically:
>>> from torch_bsf.splitting import longest_edge_criterion, split_by_criterion
>>> bs_A, bs_B = split_by_criterion(bs, longest_edge_criterion)
- torch_bsf.splitting.SplitCriterion[source]
a callable that accepts a
BezierSimplexand returns(i, j, s)— the edge vertex indices and the split parameter.- Type:
Type of a split criterion
alias of
Callable[[BezierSimplex],tuple[int,int,float]]
- torch_bsf.splitting.longest_edge_criterion(bs: BezierSimplex, s: float = 0.5) tuple[int, int, float][source]
Select the edge with the greatest value-space length and split it at
s.The “length” of edge \((i, j)\) is the Euclidean distance between the Bézier simplex values at the two vertices of the parameter domain:
\[\ell_{ij} = \|B(e_i) - B(e_j)\|_2 = \|b_{n \cdot e_i} - b_{n \cdot e_j}\|_2\]where \(n\) is the degree and \(e_k\) is the \(k\)-th unit vector.
- Parameters:
bs – The Bézier simplex.
s – Split parameter included verbatim in the returned tuple (defaults to
0.5).
- Returns:
The edge
(i, j)with the greatest value-space length, together withs.- Return type:
(i, j, s)
Examples
>>> import torch >>> from torch_bsf.bezier_simplex import rand >>> from torch_bsf.splitting import longest_edge_criterion, split_by_criterion >>> bs = rand(n_params=3, n_values=2, degree=2) >>> i, j, s = longest_edge_criterion(bs) >>> 0 <= i < j < bs.n_params True
- torch_bsf.splitting.max_error_criterion(params: Tensor, values: Tensor, grid_size: int = 10) Callable[[BezierSimplex], tuple[int, int, float]][source]
Build a criterion that minimises the combined approximation error.
For each candidate edge
(i, j)and split parametersdrawn from a uniform grid over(0, 1), the combined mean-squared error\[E(i, j, s) = \mathrm{MSE}_A + \mathrm{MSE}_B\]is computed, where \(\mathrm{MSE}_A\) and \(\mathrm{MSE}_B\) are evaluated on the portions of the data that fall within each sub-simplex. The candidate
(i, j, s)minimisingEis returned.- Parameters:
params – Parameter vectors of shape
(N, n_params).values – Target value vectors of shape
(N, n_values).grid_size – Number of candidate
svalues in the grid search (default10). Larger values give a finer search at higher cost.
- Returns:
A callable
criterion(bs) -> (i, j, s).- Return type:
SplitCriterion
Examples
>>> import torch >>> from torch_bsf.bezier_simplex import rand >>> from torch_bsf.splitting import max_error_criterion, split_by_criterion >>> torch.manual_seed(0) <torch._C.Generator object at 0x...> >>> params = torch.tensor([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) >>> values = torch.tensor([[0.0], [0.5], [1.0]]) >>> bs = rand(n_params=2, n_values=1, degree=1) >>> criterion = max_error_criterion(params, values) >>> i, j, s = criterion(bs) >>> i == 0 and j == 1 True
- torch_bsf.splitting.reparametrize(t: Tensor, i: int, j: int, s: float, subsimplex: str) tuple[Tensor, Tensor][source]
Re-parameterise points from the original simplex to a sub-simplex.
After splitting edge
(i, j)ats, each data point on the original simplex belongs to one of the two sub-simplices. This function converts the original barycentric coordinates to the sub-simplex’s local barycentric coordinates.- Parameters:
t – Parameter vectors of shape
(N, n_params)on the original simplex (each row sums to 1).i – Edge vertex indices used in
split().j – Edge vertex indices used in
split().s – Split parameter.
subsimplex –
"A"for the sub-simplex covering \(\{t : t_j / (t_i + t_j) \le s\}\), or"B"for the complementary region.
- Returns:
(u, mask) –
u — local barycentric coordinates on the sub-simplex, shape
(N, n_params).mask — boolean tensor of shape
(N,)indicating which input points belong to the requested sub-simplex. Points where \(t_i = t_j = 0\) belong to both sub-simplices and are included in both masks.
- Return type:
Notes
Sub-simplex A transformation:
\[u_j = \frac{t_j}{s}, \quad u_i = t_i - \frac{1-s}{s}\,t_j, \quad u_k = t_k \; (k \ne i, j).\]Sub-simplex B transformation:
\[u_i = \frac{t_i}{1-s}, \quad u_j = t_j - \frac{s}{1-s}\,t_i, \quad u_k = t_k \; (k \ne i, j).\]
- torch_bsf.splitting.split(bs: BezierSimplex, i: int, j: int, s: float = 0.5) tuple[BezierSimplex, BezierSimplex][source]
Split a Bézier simplex along edge
(i, j)using the de Casteljau algorithm.A new vertex is inserted on the edge between vertex
iand vertexjof the parameter domain at the relative positions. The original simplex is thereby subdivided into two sub-simplices that together cover the entire original domain.- Parameters:
bs – The Bézier simplex to split.
i – Indices of the two vertices that define the split edge (0-indexed,
i ≠ j).j – Indices of the two vertices that define the split edge (0-indexed,
i ≠ j).s – Split parameter in the open interval
(0, 1).s = 0.5produces a midpoint split. The new vertex is located at \((1-s)\,v_i + s\,v_j\) in the parameter domain.
- Returns:
(bs_A, bs_B) –
bs_A — sub-simplex that replaces vertex \(j\) with the new vertex. It covers the sub-domain \(\{t : t_j / (t_i + t_j) \le s\}\).
bs_B — sub-simplex that replaces vertex \(i\) with the new vertex. It covers the sub-domain \(\{t : t_j / (t_i + t_j) \ge s\}\).
- Return type:
Notes
The algorithm runs \(n\) de Casteljau steps along the chosen edge direction, where \(n\) is the degree of the Bézier simplex.
At each step \(r = 1, \ldots, n\) the control-point matrix is updated as
\[\begin{split}c^{(r)}_\alpha = \begin{cases} s \cdot c^{(r-1)}_\alpha + (1-s) \cdot c^{(r-1)}_{\alpha + e_i - e_j} & \text{if } \alpha_j \ge 1 \\ c^{(r-1)}_\alpha & \text{otherwise,} \end{cases}\end{split}\]and rows with \(\alpha_j = r\) are saved as control points of bs_A. An analogous recursion with the roles of \(i\) and \(j\) swapped gives bs_B.
The two sub-simplices share the split point — both evaluate to the same value at the new vertex.
Examples
Split the identity Bézier curve at the midpoint:
>>> import torch >>> from torch_bsf.bezier_simplex import BezierSimplex >>> from torch_bsf.splitting import split >>> bs = BezierSimplex({(1, 0): [0.0], (0, 1): [1.0]}) >>> bs_A, bs_B = split(bs, i=0, j=1, s=0.5) >>> float(bs_A.control_points[(0, 1)].item()) # split-point value 0.5 >>> float(bs_B.control_points[(1, 0)].item()) # same split point from other side 0.5
- torch_bsf.splitting.split_by_criterion(bs: BezierSimplex, criterion: Callable[[BezierSimplex], tuple[int, int, float]]) tuple[BezierSimplex, BezierSimplex][source]
Split a Bézier simplex using a split criterion.
Convenience wrapper that calls
criterion(bs)to obtain the edge indices and split parameter, then delegates tosplit().- Parameters:
bs – The Bézier simplex to split.
criterion – A
SplitCriterioncallable that returns(i, j, s).
- Returns:
(bs_A, bs_B) – The two sub-Bézier-simplices; see
split()for details.- Return type:
Examples
>>> from torch_bsf.bezier_simplex import rand >>> from torch_bsf.splitting import longest_edge_criterion, split_by_criterion >>> bs = rand(n_params=2, n_values=3, degree=2) >>> bs_A, bs_B = split_by_criterion(bs, longest_edge_criterion) >>> bs_A.degree == bs.degree True
torch_bsf.validator module
- torch_bsf.validator.index_list(val: str) list[list[int]][source]
Parse
valinto a list of indices.- Parameters:
val – A string expression of a list of indices.
- Return type:
The persed indices.
- torch_bsf.validator.indices_schema(n_params: int, degree: int) dict[str, Any][source]
Generate a JSON schema for indices of the control points with given
n_paramsanddegree.- Parameters:
n_params – The number of index elements of control points.
degree – The degree of a Bézier simplex.
- Return type:
A JSON schema.
- Raises:
ValueError – If
n_paramsordegreeis negative.
See also
validate_simplex_indicesValidate an instance that has appropriate n_params and degree.
- torch_bsf.validator.int_or_str(val: str) int | str[source]
Try to parse int. Return the int value if the parse is succeeded; the original string otherwise.
- torch_bsf.validator.validate_simplex_indices(instance: object, n_params: int, degree: int) None[source]
Validate an instance that has appropriate n_params and degree.
- Parameters:
instance – An index list of a Bézier simplex.
n_params – The n_params of a Bézier simplex.
degree – The degree of a Bézier simplex.
- Raises:
ValidationError – If
instancedoes not comply with the following schema or has an inner array whose sum is not equal todegree.` { "$schema": "https://json-schema.org/draft/2020-12/schema", "type": "array", "items": { "type": "array", "items": { "type": "integer", "minimum": 0, "maximum": degree, }, "minItems": n_params, "maxItems": n_params, }, } `
Module contents
torch_bsf: PyTorch implementation of Bézier simplex fitting.
- class torch_bsf.BezierSimplex(control_points: ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, *, _n_params: int | None = None, _degree: int | None = None, _n_values: int | None = None)[source]
Bases:
LightningModuleA Bézier simplex model.
- Parameters:
control_points – The control points of the Bézier simplex. Pass
Noneonly when reconstructing a model from a Lightning checkpoint viaload_from_checkpoint()— in that case all three shape parameters (_n_params,_degree,_n_values) must be provided so that a correctly-shaped placeholder can be built before the saved state dict is loaded into it.smoothness_weight – The weight of the smoothness penalty term added to the training loss. When greater than zero, adjacent control points are encouraged to have similar values. Defaults to
0.0(no penalty)._n_params – Checkpoint-reconstruction parameter — do not set manually. The number of parameters (source dimension + 1) used to build the placeholder control points when
control_pointsisNone. Automatically saved to, and restored from, Lightning checkpoints._degree – Checkpoint-reconstruction parameter — do not set manually. The degree of the Bézier simplex used to build the placeholder when
control_pointsisNone. Automatically saved to, and restored from, Lightning checkpoints._n_values – Checkpoint-reconstruction parameter — do not set manually. The number of values (target dimension) used to build the placeholder when
control_pointsisNone. Automatically saved to, and restored from, Lightning checkpoints.
Examples
>>> import lightning.pytorch as L >>> from lightning.pytorch.callbacks.early_stopping import EarlyStopping >>> from torch.utils.data import DataLoader, TensorDataset >>> ts = torch.tensor( # parameters on a simplex ... [ ... [3/3, 0/3, 0/3], ... [2/3, 1/3, 0/3], ... [2/3, 0/3, 1/3], ... [1/3, 2/3, 0/3], ... [1/3, 1/3, 1/3], ... [1/3, 0/3, 2/3], ... [0/3, 3/3, 0/3], ... [0/3, 2/3, 1/3], ... [0/3, 1/3, 2/3], ... [0/3, 0/3, 3/3], ... ] ... ) >>> xs = 1 - ts * ts # values corresponding to the parameters >>> dl = DataLoader(TensorDataset(ts, xs)) >>> bs = torch_bsf.bezier_simplex.randn( ... n_params=int(ts.shape[1]), ... n_values=int(xs.shape[1]), ... degree=3, ... ) >>> trainer = L.Trainer( ... callbacks=[EarlyStopping(monitor="train_mse")], ... enable_progress_bar=False, ... ) >>> trainer.fit(bs, dl) >>> ts, xs = bs.meshgrid()
- forward(t: Tensor) Tensor[source]
Process a forwarding step of training.
- Parameters:
t – A minibatch of parameter vectors \(\mathbf t\).
- Return type:
A minibatch of value vectors.
- freeze_row(index: torch_bsf.bezier_simplex.Index) None[source]
Freeze a control point so its gradient is zeroed after every backward.
- Parameters:
index – The index of the control point to freeze.
- meshgrid(num: int = 100) tuple[Tensor, Tensor][source]
Computes a meshgrid of the Bézier simplex.
- Parameters:
num – The number of grid points on each edge.
- Returns:
ts – A parameter matrix of the mesh grid.
xs – A value matrix of the mesh grid.
- on_after_backward() None[source]
Zero gradients for frozen control-point rows after each backward pass.
- class torch_bsf.BezierSimplexDataModule(params: Path, values: Path, header: int = 0, batch_size: int | None = None, split_ratio: float = 1.0, normalize: Literal['max', 'std', 'quantile', 'none'] = 'none')[source]
Bases:
LightningDataModuleA data module for training a Bézier simplex.
- Parameters:
params – The path to a parameter file.
values – The path to a value file.
header – The number of header rows in the parameter file and the value file. The first
headerrows are skipped in reading the files.batch_size – The size of each minibatch.
split_ratio – The ratio of train-val split. Must be greater than 0 and less than or equal to 1. If it is set to 1, then all the data are used for training and the validation step will be skipped.
normalize – The data normalization method. Either
"max","std","quantile", or"none".
- test_dataloader() DataLoader[source]
- train_dataloader() DataLoader[source]
- val_dataloader() DataLoader[source]
- torch_bsf.fit(params: Tensor, values: Tensor, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) BezierSimplex[source]
Fits a Bézier simplex.
- Parameters:
params – The data.
values – The label data.
degree – The degree of the Bézier simplex.
init – The initial values of a Bézier simplex or control points.
smoothness_weight – The weight of smoothness penalty.
freeze – The indices of control points to exclude from training.
batch_size – The size of minibatch.
seed – Random seed passed to
lightning.pytorch.seed_everything()to set RNG seeds for improved reproducibility. WhenNone(default), no seed is set. For full determinism, also setTrainer(deterministic=True)and use compatible settings.kwargs – All arguments for lightning.pytorch.Trainer
- Return type:
A trained Bézier simplex.
- Raises:
TypeError – From Trainer or DataLoader.
MisconfigurationException – From Trainer.
Examples
>>> import torch >>> import torch_bsf
Prepare training data
>>> ts = torch.tensor( # parameters on a simplex ... [ ... [3/3, 0/3, 0/3], ... [2/3, 1/3, 0/3], ... [2/3, 0/3, 1/3], ... [1/3, 2/3, 0/3], ... [1/3, 1/3, 1/3], ... [1/3, 0/3, 2/3], ... [0/3, 3/3, 0/3], ... [0/3, 2/3, 1/3], ... [0/3, 1/3, 2/3], ... [0/3, 0/3, 3/3], ... ] ... ) >>> xs = 1 - ts * ts # values corresponding to the parameters
Train a model
>>> bs = torch_bsf.fit(params=ts, values=xs, degree=3)
Predict by the trained model
>>> t = [[0.2, 0.3, 0.5]] >>> x = bs(t) >>> print(f"{t} -> {x}") [[0.2, 0.3, 0.5]] -> tensor([[..., ..., ...]], grad_fn=<...>)
See also
lightning.pytorch.TrainerArgument descriptions.
torch.DataLoaderArgument descriptions.
- torch_bsf.fit_kfold(params: Tensor, values: Tensor, n_folds: int = 5, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, smoothness_weight: float = 0.0, freeze: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, seed: int | None = None, **kwargs) ModuleList[source]
Fits an ensemble of Bézier simplices using k-fold cross-validation.
Splits the training data into
n_foldsfolds viaKFoldTrainerand trains one model per fold on the data from the remainingn_folds - 1folds. The resultingModuleListcan be passed directly totorch_bsf.active_learning.suggest_next_points()to drive a Query-By-Committee active learning loop.If
len(params) < n_folds, the actual number of folds is capped atlen(params)to avoid empty training subsets.- Parameters:
params – The parameter data on the simplex.
values – The label data.
n_folds – The number of cross-validation folds (committee size). Defaults to 5.
degree – The degree of the Bézier simplex.
init – The initial values of a Bézier simplex or control points.
smoothness_weight – The weight of the smoothness penalty.
freeze – The indices of control points to exclude from training.
batch_size – The size of a minibatch. Defaults to full-batch (consistent with
fit()).seed – Random seed passed to
lightning.pytorch.seed_everything()for reproducible training. WhenNone(default), no seed is set.kwargs – All arguments for
KFoldTrainer(which itself accepts alllightning.pytorch.Trainerarguments). For example,max_epochs=10,enable_progress_bar=False,logger=False,shuffle=True, orstratified=False. By default,num_sanity_val_steps=0andlimit_val_batches=0.0are set internally to disable per-fold validation for speed; pass explicit values to override these defaults, e.g.num_sanity_val_steps=2, limit_val_batches=1.0.
- Returns:
A
ModuleListofmin(n_folds, len(params))trainedBezierSimplexmodels, one per fold.- Return type:
- Raises:
ValueError – If
n_folds < 2, iflen(params) < 2(too few samples for any fold split), if neither / both ofdegreeandinitare provided, ifbatch_sizeis truthy but not a positive integer, or if the reserved argumentnum_foldsis supplied via**kwargs.
Examples
>>> import torch >>> import torch_bsf >>> from torch_bsf.active_learning import suggest_next_points >>> from torch_bsf.sampling import simplex_grid
Prepare training data
>>> params = simplex_grid(n_params=3, degree=3) >>> values = params.pow(2).sum(dim=1, keepdim=True)
Build a 5-fold ensemble and suggest the 2 most uncertain points
>>> models = torch_bsf.fit_kfold( ... params=params, ... values=values, ... degree=3, ... max_epochs=1, ... enable_progress_bar=False, ... enable_model_summary=False, ... logger=False, ... ) >>> suggestions = suggest_next_points(models, n_suggestions=2, method="qbc") >>> suggestions.shape torch.Size([2, 3])
See also
fitFit a single Bézier simplex.
torch_bsf.active_learning.suggest_next_pointsUse the ensemble for active learning.
- torch_bsf.validate_control_points(data: dict[str, list[float]])[source]
Validates control points.
- Parameters:
data – The control points.
- Raises:
ValidationError – If the control points are invalid.
Examples
>>> from torch_bsf.bezier_simplex import validate_control_points >>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0, 1.0], ... })
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "0, 0, 1": [0.0, 0.0, 1.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: '0, 0, 1' is not valid under any of the given schemas
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0, 1.0], ... "(0, 0)": [0.0, 0.0, 0.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0)
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1, 0)": [0.0, 0.0, 1.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0, 1, 0)
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0, 1.0, 0.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0, 1.0, 0.0]
>>> validate_control_points({ ... "(1, 0, 0)": [1.0, 0.0, 0.0], ... "(0, 1, 0)": [0.0, 1.0, 0.0], ... "(0, 0, 1)": [0.0, 0.0], ... }) Traceback (most recent call last): ... jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0]