torch_bsf package

Submodules

torch_bsf.bezier_simplex module

class torch_bsf.bezier_simplex.BezierSimplex(control_points: ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]])[source]

Bases: LightningModule

A Bezier simplex model.

Parameters:

control_points – The control points of the Bezier simplex.

Examples

>>> import lightning.pytorch as L
>>> from lightning.pytorch.callbacks.early_stopping import EarlyStopping
>>> from torch.utils.data import DataLoader, TensorDataset
>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters
>>> dl = DataLoader(TensorDataset(ts, xs))
>>> bs = torch_bsf.bezier_simplex.randn(
...     n_params=int(ts.shape[1]),
...     n_values=int(xs.shape[1]),
...     degree=3,
... )
>>> trainer = L.Trainer(
...     callbacks=[EarlyStopping(monitor="train_mse")],
...     enable_progress_bar=False,
... )
>>> trainer.fit(bs, dl)
>>> ts, xs = bs.meshgrid()
configure_optimizers() Optimizer[source]
property degree: int[source]

The degree of the Bezier simplex.

forward(t: Tensor) Tensor[source]

Process a forwarding step of training.

Parameters:

t – A minibatch of parameter vectors \(\mathbf t\).

Return type:

A minibatch of value vectors.

meshgrid(num: int = 100) tuple[Tensor, Tensor][source]

Computes a meshgrid of the Bezier simplex.

Parameters:

num – The number of grid points on each edge.

Returns:

  • ts – A parameter matrix of the mesh grid.

  • xs – A value matrix of the mesh grid.

property n_params: int[source]

The number of parameters, i.e., the source dimension + 1.

property n_values: int[source]

The number of values, i.e., the target dimension.

test_step(batch, batch_idx) dict[str, Any][source]
training_step(batch, batch_idx) dict[str, Any][source]
validation_end(outputs) dict[str, Any][source]
validation_step(batch, batch_idx) dict[str, Any][source]
class torch_bsf.bezier_simplex.BezierSimplexDataModule(params: Path, values: Path, header: int = 0, batch_size: int | None = None, split_ratio: float = 1.0, normalize: Literal['max', 'std', 'quantile', 'none'] = 'none')[source]

Bases: LightningDataModule

A data module for training a Bezier simplex.

Parameters:
  • params – The path to a parameter file.

  • values – The path to a value file.

  • header – The number of header rows in the parameter file and the value file. The first header rows are skipped in reading the files.

  • batch_size – The size of each minibatch.

  • split_ratio – The ratio of train-val split. Must be greater than 0 and less than or equal to 1. If it is set to 1, then all the data are used for training and the validation step will be skipped.

  • normalize – The data normalization method. Either "max", "std", "quantile", or "none".

fit_transform(values: Tensor) Tensor[source]
inverse_transform(values: Tensor) Tensor[source]
load_data(path) Tensor[source]
load_params() Tensor[source]
load_values() Tensor[source]
setup(stage: str | None = None)[source]
test_dataloader() DataLoader[source]
train_dataloader() DataLoader[source]
val_dataloader() DataLoader[source]
torch_bsf.bezier_simplex.fit(params: Tensor, values: Tensor, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, fix: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, **kwargs) BezierSimplex[source]

Fits a Bezier simplex.

Parameters:
  • params – The data.

  • values – The label data.

  • degree – The degree of the Bezier simplex.

  • init – The initial values of a bezier simplex or control points.

  • fix – The indices of control points to exclude from training.

  • batch_size – The size of minibatch.

  • kwargs – All arguments for lightning.pytorch.Trainer

Return type:

A trained Bezier simplex.

Raises:
  • TypeError – From Trainer or DataLoader.

  • MisconfigurationException – From Trainer.

Examples

>>> import torch
>>> import torch_bsf

Prepare training data

>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters

Train a model

>>> bs = torch_bsf.fit(params=ts, values=xs, degree=3)

Predict by the trained model

>>> t = [[0.2, 0.3, 0.5]]
>>> x = bs(t)
>>> print(f"{t} -> {x}")
[[0.2, 0.3, 0.5]] -> tensor([[..., ..., ...]], grad_fn=<AddBackward0>)

See also

lightning.pytorch.Trainer

Argument descriptions.

torch.DataLoader

Argument descriptions.

torch_bsf.bezier_simplex.load(path: str | Path, *, pt_weights_only: bool | None = None) BezierSimplex[source]

Loads a Bezier simplex from a file.

Parameters:
  • path – The path to a file.

  • pt_weights_only – Whether to load weights only. This parameter is only effective when loading PyTorch (.pt) files. For other formats (e.g., .json, .yml), data loading is inherently safe and this parameter is ignored. If None, it defaults to False.

Return type:

A Bezier simplex.

Raises:
  • ValueError – If the file type is unknown.

  • ValidationError – If the control points are invalid.

Notes

Setting pt_weights_only=True will fail if the model contains classes not allowed by PyTorch’s WeightsUnpickler (like lightning’s AttributeDict), even if they are in the safe globals list.

Examples

>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.load("tests/data/bezier_simplex.csv")
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(
      ((0, 2)): Parameter containing: [torch.FloatTensor of size 3]
      ((1, 1)): Parameter containing: [torch.FloatTensor of size 3]
      ((2, 0)): Parameter containing: [torch.FloatTensor of size 3]
  )
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<AddBackward0>)
torch_bsf.bezier_simplex.monomial(variable: Iterable[float], degree: Iterable[int]) Tensor[source]

Computes a monomial \(\mathbf t^{\mathbf d} = t_1^{d_1} t_2^{d_2}\cdots t_M^{d^M}\).

Parameters:
  • variable – The bases \(\mathbf t\).

  • degree – The powers \(\mathbf d\).

Return type:

The monomial \(\mathbf t^{\mathbf d}\).

torch_bsf.bezier_simplex.polynom(degree: int, index: Iterable[int]) float[source]

Computes a polynomial coefficient \(\binom{D}{\mathbf d} = \frac{D!}{d_1!d_2!\cdots d_M!}\).

Parameters:
  • degree – The degree \(D\).

  • index – The index \(\mathbf d\).

Return type:

The polynomial coefficient \(\binom{D}{\mathbf d}\).

torch_bsf.bezier_simplex.rand(n_params: int, n_values: int, degree: int) BezierSimplex[source]

Generates a random Bezier simplex.

The control points are initialized by random values. The values are uniformly distributed in [0, 1).

Parameters:
  • n_params – The number of parameters, i.e., the source dimension + 1.

  • n_values – The number of values, i.e., the target dimension.

  • degree – The degree of the Bezier simplex.

Return type:

A random Bezier simplex.

Raises:

ValueError – If n_params or n_values or degree is negative.

Examples

>>> import torch
>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.rand(n_params=2, n_values=3, degree=2)
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(
      ((0, 2)): Parameter containing: [torch.FloatTensor of size 3]
      ((1, 1)): Parameter containing: [torch.FloatTensor of size 3]
      ((2, 0)): Parameter containing: [torch.FloatTensor of size 3]
  )
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<AddBackward0>)
torch_bsf.bezier_simplex.randn(n_params: int, n_values: int, degree: int) BezierSimplex[source]

Generates a random Bezier simplex.

The control points are initialized by random values. The values are normally distributed with mean 0 and standard deviation 1.

Parameters:
  • n_params – The number of parameters, i.e., the source dimension + 1.

  • n_values – The number of values, i.e., the target dimension.

  • degree – The degree of the Bezier simplex.

Return type:

A random Bezier simplex.

Raises:

ValueError – If n_params or n_values or degree is negative.

Examples

>>> import torch
>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.randn(n_params=2, n_values=3, degree=2)
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(
      ((0, 2)): Parameter containing: [torch.FloatTensor of size 3]
      ((1, 1)): Parameter containing: [torch.FloatTensor of size 3]
      ((2, 0)): Parameter containing: [torch.FloatTensor of size 3]
  )
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<AddBackward0>)
torch_bsf.bezier_simplex.save(path: str | Path, data: BezierSimplex) None[source]

Saves a Bezier simplex to a file.

Parameters:
  • path – The file path to save.

  • data – The Bezier simplex to save.

Raises:

ValueError – If the file type is unknown.

Examples

>>> import torch_bsf
>>> bs = torch_bsf.bezier_simplex.randn(n_params=2, n_values=3, degree=2)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.pt", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.csv", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.tsv", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.json", bs)
>>> torch_bsf.bezier_simplex.save("tests/data/bezier_simplex.yml", bs)
torch_bsf.bezier_simplex.validate_control_points(data: dict[str, list[float]])[source]

Validates control points.

Parameters:

data – The control points.

Raises:

ValidationError – If the control points are invalid.

Examples

>>> from torch_bsf.bezier_simplex import validate_control_points
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0],
... })
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "0, 0, 1": [0.0, 0.0, 1.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: '0, 0, 1' is not valid under any of the given schemas
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0],
...     "(0, 0)": [0.0, 0.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0)
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1, 0)": [0.0, 0.0, 1.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: (0, 0, 1, 0)
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0, 1.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0, 1.0, 0.0]
>>> validate_control_points({
...     "(1, 0, 0)": [1.0, 0.0, 0.0],
...     "(0, 1, 0)": [0.0, 1.0, 0.0],
...     "(0, 0, 1)": [0.0, 0.0],
... })
Traceback (most recent call last):
    ...
jsonschema.exceptions.ValidationError: Dimension mismatch: [0.0, 0.0]
torch_bsf.bezier_simplex.zeros(n_params: int, n_values: int, degree: int) BezierSimplex[source]

Generates a Bezier simplex with control points at origin.

Parameters:
  • n_params – The number of parameters, i.e., the source dimension + 1.

  • n_values – The number of values, i.e., the target dimension.

  • degree – The degree of the Bezier simplex.

Return type:

A Bezier simplex filled with zeros.

Raises:

ValueError – If n_params or n_values or degree is negative.

Examples

>>> import torch
>>> from torch_bsf import bezier_simplex
>>> bs = bezier_simplex.zeros(n_params=2, n_values=3, degree=2)
>>> print(bs)
BezierSimplex(
  (control_points): ControlPoints(
      ((0, 2)): Parameter containing: [torch.FloatTensor of size 3]
      ((1, 1)): Parameter containing: [torch.FloatTensor of size 3]
      ((2, 0)): Parameter containing: [torch.FloatTensor of size 3]
  )
)
>>> print(bs(torch.tensor([[0.2, 0.8]])))
tensor([[..., ..., ...]], grad_fn=<AddBackward0>)

torch_bsf.control_points module

class torch_bsf.control_points.ControlPoints(data: dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None)[source]

Bases: ParameterDict

Control points of a Bezier simplex.

degree[source]

The degree of the Bezier simplex.

n_params[source]

The number of parameters.

n_values[source]

The number of values.

Examples

>>> import torch_bsf
>>> control_points = torch_bsf.control_points.ControlPoints({
...     (1, 0): [0.0, 0.1, 0.2],
...     (0, 1): [1.0, 1.1, 1.2],
... })
>>> control_points.degree
1
>>> control_points.n_params
2
>>> control_points.n_values
3
>>> control_points[(1, 0)]
Parameter containing:
tensor([0.0000, 0.1000, 0.2000], requires_grad=True)
>>> control_points[(0, 1)]
Parameter containing:
tensor([1.0000, 1.1000, 1.2000], requires_grad=True)
>>> control_points[(1, 0)].requires_grad = False
indices() Iterable[tuple[int, ...]][source]

Iterates the index of control points of the Bezier simplex.

Return type:

The indices.

torch_bsf.control_points.ControlPointsData: TypeAlias = dict[str, torch.Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], torch.Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[torch.Tensor, torch.Tensor] | dict[torch.Tensor, list[float]] | dict[torch.Tensor, tuple[float, ...]][source]

The data type of control points of a Bezier simplex.

torch_bsf.control_points.Index: TypeAlias = str | typing.Sequence[int] | torch.Tensor[source]

The index type of control points of a Bezier simplex.

torch_bsf.control_points.Value: TypeAlias = typing.Sequence[float] | torch.Tensor[source]

The value type of control points of a Bezier simplex.

torch_bsf.control_points.simplex_indices(n_params: int, degree: int) Iterable[tuple[int, ...]][source]

Iterates the index of control points of a Bezier simplex.

Parameters:
  • n_params – The tuple length of each index.

  • degree – The degree of the Bezier simplex.

Return type:

The indices.

torch_bsf.control_points.to_parameterdict(data: dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]]) dict[str, Tensor][source]

Convert data to a dictionary of parameters.

Parameters:

data (dict) – Data to be converted.

Returns:

Converted data.

Return type:

dict

torch_bsf.control_points.to_parameterdict_key(index: str | Sequence[int] | Tensor) str[source]

Convert an index to a key of a ParameterDict.

Parameters:

index (str or tuple[int]) – An index of a ParameterDict.

Returns:

A key of a ParameterDict.

Return type:

str

torch_bsf.control_points.to_parameterdict_value(value: Sequence[float] | Tensor) Tensor[source]

Convert a value to a value of a ParameterDict.

Parameters:

value (list[float]) – A value of a ParameterDict.

Returns:

A value of a ParameterDict.

Return type:

torch.Tensor

torch_bsf.validator module

torch_bsf.validator.index_list(val: str) list[list[int]][source]

Parse val into a list of indices.

Parameters:

val – A string expression of a list of indices.

Return type:

The persed indices.

torch_bsf.validator.indices_schema(n_params: int, degree: int) dict[str, Any][source]

Generate a JSON schema for indices of the control points with given n_params and degree.

Parameters:
  • n_params – The number of index elements of control points.

  • degree – The degree of a Bezier surface.

Return type:

A JSON schema.

Raises:

ValueError – If n_params or degree is negative.

See also

validate_simplex_indices

Validate an instance that has appropriate n_params and degree.

torch_bsf.validator.int_or_str(val: str) int | str[source]

Try to parse int. Return the int value if the parse is succeeded; the original string otherwise.

Parameters:

val – The value to try to convert into int.

Returns:

The converted integer or the original value.

Return type:

Union[int, str]

torch_bsf.validator.validate_simplex_indices(instance: object, n_params: int, degree: int) None[source]

Validate an instance that has appropriate n_params and degree.

Parameters:
  • instance – An index list of a Bezier simplex.

  • n_params – The n_params of a Bezier simplex.

  • degree – The degree of a Bezier simplex.

Raises:

ValidationError – If instance does not comply with the following schema or has an inner array whose sum is not equal to degree. `     {         "$schema": "https://json-schema.org/draft/2020-12/schema",         "type": "array",         "items": {             "type": "array",             "items": {                 "type": "integer",                 "minimum": 0,                 "maximum": degree,             },             "minItems": n_params,             "maxItems": n_params,         },     }     `

Module contents

torch_bsf: PyTorch implementation of Bezier simplex fitting.

class torch_bsf.BezierSimplex(control_points: ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]])[source]

Bases: LightningModule

A Bezier simplex model.

Parameters:

control_points – The control points of the Bezier simplex.

Examples

>>> import lightning.pytorch as L
>>> from lightning.pytorch.callbacks.early_stopping import EarlyStopping
>>> from torch.utils.data import DataLoader, TensorDataset
>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters
>>> dl = DataLoader(TensorDataset(ts, xs))
>>> bs = torch_bsf.bezier_simplex.randn(
...     n_params=int(ts.shape[1]),
...     n_values=int(xs.shape[1]),
...     degree=3,
... )
>>> trainer = L.Trainer(
...     callbacks=[EarlyStopping(monitor="train_mse")],
...     enable_progress_bar=False,
... )
>>> trainer.fit(bs, dl)
>>> ts, xs = bs.meshgrid()
configure_optimizers() Optimizer[source]
property degree: int[source]

The degree of the Bezier simplex.

forward(t: Tensor) Tensor[source]

Process a forwarding step of training.

Parameters:

t – A minibatch of parameter vectors \(\mathbf t\).

Return type:

A minibatch of value vectors.

meshgrid(num: int = 100) tuple[Tensor, Tensor][source]

Computes a meshgrid of the Bezier simplex.

Parameters:

num – The number of grid points on each edge.

Returns:

  • ts – A parameter matrix of the mesh grid.

  • xs – A value matrix of the mesh grid.

property n_params: int[source]

The number of parameters, i.e., the source dimension + 1.

property n_values: int[source]

The number of values, i.e., the target dimension.

test_step(batch, batch_idx) dict[str, Any][source]
training_step(batch, batch_idx) dict[str, Any][source]
validation_end(outputs) dict[str, Any][source]
validation_step(batch, batch_idx) dict[str, Any][source]
class torch_bsf.BezierSimplexDataModule(params: Path, values: Path, header: int = 0, batch_size: int | None = None, split_ratio: float = 1.0, normalize: Literal['max', 'std', 'quantile', 'none'] = 'none')[source]

Bases: LightningDataModule

A data module for training a Bezier simplex.

Parameters:
  • params – The path to a parameter file.

  • values – The path to a value file.

  • header – The number of header rows in the parameter file and the value file. The first header rows are skipped in reading the files.

  • batch_size – The size of each minibatch.

  • split_ratio – The ratio of train-val split. Must be greater than 0 and less than or equal to 1. If it is set to 1, then all the data are used for training and the validation step will be skipped.

  • normalize – The data normalization method. Either "max", "std", "quantile", or "none".

fit_transform(values: Tensor) Tensor[source]
inverse_transform(values: Tensor) Tensor[source]
load_data(path) Tensor[source]
load_params() Tensor[source]
load_values() Tensor[source]
setup(stage: str | None = None)[source]
test_dataloader() DataLoader[source]
train_dataloader() DataLoader[source]
val_dataloader() DataLoader[source]
torch_bsf.fit(params: Tensor, values: Tensor, degree: int | None = None, init: BezierSimplex | ControlPoints | dict[str, Tensor] | dict[str, list[float]] | dict[str, tuple[float, ...]] | dict[tuple[int, ...], Tensor] | dict[tuple[int, ...], list[float]] | dict[tuple[int, ...], tuple[float, ...]] | dict[Tensor, Tensor] | dict[Tensor, list[float]] | dict[Tensor, tuple[float, ...]] | None = None, fix: Iterable[str | Sequence[int] | Tensor] | None = None, batch_size: int | None = None, **kwargs) BezierSimplex[source]

Fits a Bezier simplex.

Parameters:
  • params – The data.

  • values – The label data.

  • degree – The degree of the Bezier simplex.

  • init – The initial values of a bezier simplex or control points.

  • fix – The indices of control points to exclude from training.

  • batch_size – The size of minibatch.

  • kwargs – All arguments for lightning.pytorch.Trainer

Return type:

A trained Bezier simplex.

Raises:
  • TypeError – From Trainer or DataLoader.

  • MisconfigurationException – From Trainer.

Examples

>>> import torch
>>> import torch_bsf

Prepare training data

>>> ts = torch.tensor(  # parameters on a simplex
...     [
...         [3/3, 0/3, 0/3],
...         [2/3, 1/3, 0/3],
...         [2/3, 0/3, 1/3],
...         [1/3, 2/3, 0/3],
...         [1/3, 1/3, 1/3],
...         [1/3, 0/3, 2/3],
...         [0/3, 3/3, 0/3],
...         [0/3, 2/3, 1/3],
...         [0/3, 1/3, 2/3],
...         [0/3, 0/3, 3/3],
...     ]
... )
>>> xs = 1 - ts * ts  # values corresponding to the parameters

Train a model

>>> bs = torch_bsf.fit(params=ts, values=xs, degree=3)

Predict by the trained model

>>> t = [[0.2, 0.3, 0.5]]
>>> x = bs(t)
>>> print(f"{t} -> {x}")
[[0.2, 0.3, 0.5]] -> tensor([[..., ..., ...]], grad_fn=<AddBackward0>)

See also

lightning.pytorch.Trainer

Argument descriptions.

torch.DataLoader

Argument descriptions.