Layers

The primary trainable preprocessing layer is ConnectedFilterPreprocessingLayer. Reference and legacy implementations are documented as compatibility surfaces, but new experiments should prefer the primary layer.

Primary CFP Layer

class mtlearn.layers.ConnectedFilterPreprocessingLayer(in_channels, filter_specs, *, device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, clamp=None, hybrid_k=3.0, hybrid_floor_a=0.05, attribute_dtype=None, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]

Bases: Module

Learnable CFP layer defined by per-output filter specifications.

Each item in filter_specs defines one output per input channel: morphology tree, scoring attributes, and reconstructed valuation. The default valuation is CFPValuation.ALTITUDE. Top-hat output is selected with CFPValuation.ALTITUDE_TOPHAT. clamp optionally bounds beta_f * logits before the sigmoid.

Tree construction and attribute computation happen outside autograd. The trainable parameters are the per-filter weight vectors and scalar biases that produce node-wise sigmoid gates from normalized attributes.

Parameters:
  • scale_mode (str)

  • eps (float)

  • beta_f (float)

  • hybrid_k (float)

  • hybrid_floor_a (float)

  • tos_infinity_seed_row (int)

  • tos_infinity_seed_col (int)

forward(x)[source]

Apply all filter specs and return (B, C * specs, H, W).

Parameters:

x (Tensor) – Input tensor shaped (B, C, H, W) or the cached-loader form (x, idx) produced by build_dataloader_cached.

Return type:

Tensor

Returns:

Tensor with one output channel per input channel and filter spec.

predict(x, beta_f=1000.0)[source]

Run inference with a caller-provided sigmoid gain.

The method temporarily switches the module to evaluation mode, runs forward under torch.no_grad(), restores beta_f, and restores the previous training/eval state.

Return type:

Tensor

Parameters:
  • x (torch.Tensor)

  • beta_f (float)

inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]

Return cached or direct attributes, valuations, and parameters per spec.

Parameters:
  • img (Tensor) – Image tensor shaped (H, W) or (C, H, W).

  • channel (int) – Channel to inspect when img has multiple channels.

  • idx (int | None) – Optional stable dataset index used to look up cached payloads.

  • build_if_missing (bool) – Build a temporary tree payload when no cache entry exists for idx.

Returns:

Dictionary keyed by filter-spec name. Each entry contains raw and normalized attributes, valuation increments, and current trainable parameters.

freeze_ds_stats()[source]

Stop updating dataset-level normalization statistics.

unfreeze_ds_stats()[source]

Resume updating dataset-level normalization statistics.

refresh_cached_normalization()[source]

Recompute normalized attributes for all cached samples.

save_stats(path)[source]

Save dataset-level normalization statistics.

The payload is a torch-safe dictionary containing a format version, scale_mode, and serialized dataset statistics. Per-sample caches are not saved.

Parameters:

path (str)

load_stats(path, refresh_cache=True)[source]

Load dataset-level normalization statistics.

Parameters:
  • path (str) – File previously written by save_stats.

  • refresh_cache (bool) – Whether to recompute normalized cached attributes immediately after loading.

get_config()[source]

Return the architecture/configuration needed to reconstruct the layer.

The returned dictionary is serializable and accepted by from_config. It describes layer structure, filter specs, valuation choices, normalization mode, sigmoid gain, clamp bounds, and hybrid normalization constants. It does not include trainable weights or dataset statistics.

Return type:

dict[str, Any]

get_weight_contract()[source]

Return the CFP contract that defines parameter names and semantics.

Checkpoints use this contract to reject incompatible layer architectures before loading state into differently shaped CFP parameters.

Return type:

dict[str, Any]

classmethod from_config(config, *, device=None)[source]

Reconstruct a layer from get_config() output.

Return type:

ConnectedFilterPreprocessingLayer

Parameters:

config (Mapping[str, Any])

get_extra_state()[source]

Embed persistent CFP state in PyTorch checkpoints.

This includes the weight contract and dataset normalization statistics. Per-sample tree/attribute caches are intentionally not persisted.

Return type:

dict[str, Any]

set_extra_state(state)[source]

Restore persistent CFP state from state_dict and validate compatibility.

Return type:

None

Parameters:

state (Any)

export_params(path)[source]

Export CFP parameters and metadata for inspection.

This is not the recommended training checkpoint API. Use mtlearn.layers.save_checkpoint for full PyTorch models.

Parameters:

path (str)

save_params(path)[source]

Compatibility alias for export_params.

Parameters:

path (str)

get_params()[source]

Return CPU clones of the per-spec weight and bias tensors.

init_identity_with_bias(p0=0.995)

Initialize filters close to identity using only positive bias.

Weights are set to zero and each bias is chosen so the initial sigmoid value is approximately p0.

Parameters:

p0 (float)

init_identity_bias_zero(p0=0.99)

Initialize filters close to identity with zero bias.

This initialization assumes hybrid-normalized attributes with a positive floor. Each weight receives the same positive value and biases are set to zero.

Parameters:

p0 (float)

build_dataloader_cached(dataloader)[source]

Wrap a DataLoader and precompute CFP caches/statistics.

The returned DataLoader yields ((x, idx), y) batches with stable dataset indices. During the prepass, this layer builds tree payloads and updates dataset-level statistics for every sample/channel/tree key. Statistics are frozen and cached normalizations are refreshed before the wrapped loader is returned.

build_dataloader_cached_fixed_stats(dataloader, *, index_offset=0)[source]

Wrap a DataLoader and precompute CFP caches without updating stats.

Use this for validation/test splits after training statistics have been built with build_dataloader_cached(...) or restored with load_stats(...). The returned DataLoader yields ((x, idx + index_offset), y) so callers can keep split cache keys disjoint.

Parameters:

index_offset (int)

Checkpoint Helpers

mtlearn.layers.collect_cfp_configs(model)[source]

Return serializable configs for every primary CFP layer in model.

Return type:

dict[str, dict[str, Any]]

Parameters:

model (torch.nn.Module)

mtlearn.layers.save_checkpoint(path, model)[source]

Save a PyTorch checkpoint with automatically collected CFP configs.

The model parameters are saved with the normal PyTorch state_dict. Primary CFP layers are discovered by module name and their constructor configs are saved separately so a caller can reconstruct the model before calling load_state_dict. The payload intentionally contains only model weights and CFP configs that define the meaning and shape of CFP-related weights.

Return type:

dict[str, Any]

Parameters:

model (torch.nn.Module)

mtlearn.layers.load_checkpoint(path, model_or_factory, *, device=None, strict=True, weights_only=True)[source]

Load a checkpoint saved by save_checkpoint.

model_or_factory can be either an already constructed module or a callable that returns a module. Factories may accept no arguments when the model constructor already hard-codes its CFP layers, or one positional cfp_configs argument when they need the saved CFP configs.

Return type:

tuple[Module, dict[str, Any]]

Parameters:
  • model_or_factory (torch.nn.Module | Callable[[...], torch.nn.Module])

  • strict (bool)

  • weights_only (bool)

Reference Implementations

class mtlearn.layers.ConnectedFilterPreprocessingLayerLegacy(in_channels, attributes_spec, tree_type='max-tree', device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, top_hat=False, clamp_logits=False, hybrid_k=3.0, hybrid_floor_a=0.05, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]

Bases: Module

Main learnable Connected Filter Preprocessing (CFP) layer.

For each attribute group g with K normalized attributes A_g in R[num_nodes, K], the layer computes sigmoid(beta_f * (A_g @ w_g + b_g)) as a node-wise filtering criterion. The criterion is applied to the tree residues and reconstructed to pixels through the implicit-Jacobian autograd function.

This legacy implementation is kept for reproducing experiments that used the former global tree/output contract. It can run tensor operations on CUDA when device="cuda" while still building morphology trees through the CPU backend.

Parameters:
  • in_channels – Number of input channels.

  • attributes_spec – Attribute groups. Each item is one group and must contain at least one morphology attribute enum.

  • tree_type"max-tree", "min-tree", "tree-of-shapes", or the legacy "tos" alias.

  • device – Torch device used for parameters, cached tensors, and outputs.

  • scale_mode (str) – "minmax01", "zscore_tree", "hybrid", or "none".

  • eps (float) – Numerical floor for normalization denominators.

  • beta_f (float) – Forward sigmoid gain.

  • top_hat (bool) – If true, output the tree-type-specific top-hat residual.

  • clamp_logits (bool) – If true, clamp beta_f * logits to [-12, 12].

  • hybrid_k (float) – Number of standard deviations used for hybrid clipping.

  • hybrid_floor_a (float) – Lower bound used when remapping hybrid-normalized attributes to [a, 1].

  • tos_interpolation – Tree-of-shapes interpolation policy. Accepts "self-dual", "min4c-max8c", "min8c-max4c", or the corresponding morphology.ToSInterpolation enum.

  • tos_infinity_seed_row (int) – Infinity seed used by the tree-of-shapes backend.

  • tos_infinity_seed_col (int) – Infinity seed used by the tree-of-shapes backend.

freeze_ds_stats()[source]

Stop collecting dataset statistics for future samples.

unfreeze_ds_stats()[source]

Resume collecting dataset statistics for future samples.

save_stats(path)[source]

Save normalization statistics and scale mode for reproducibility.

Parameters:

path (str)

load_stats(path, refresh_cache=True, *, trusted_legacy_format=False)[source]

Load normalization statistics and optionally refresh cached attrs.

Parameters:
  • path (str)

  • refresh_cache (bool)

  • trusted_legacy_format (bool)

inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]

Return cached or on-the-fly inspection data for one sample.

When idx is provided, the method inspects cached data under f"{idx}_{channel}". Without an index, the tree and attributes are computed on the fly and are not persisted.

Parameters:
  • img (torch.Tensor)

  • channel (int)

  • idx (int | None)

  • build_if_missing (bool)

forward(x)[source]

Apply CFP to a batch and return (B, C * groups, H, W) output.

The input can be a tensor, (x, idx), or [x, idx] from a DataLoader. Indexed inputs use persistent caches keyed by sample index and channel; plain tensor inputs build trees and attributes on demand.

Return type:

Tensor

Parameters:

x (torch.Tensor)

predict(x, beta_f=1000.0)[source]

Run inference with a caller-provided forward sigmoid gain.

Return type:

Tensor

Parameters:
  • x (torch.Tensor)

  • beta_f (float)

save_params(path)[source]

Save all group weights and biases.

Parameters:

path (str)

get_params()[source]

Return CPU clones of group weights and biases.

refresh_cached_normalization()[source]

Recompute normalized attributes for every cached sample.

init_identity_with_bias(p0=0.995)

Initialize near identity by using only a positive bias.

For each group, weights are set to zero and bias is set to logit(p0) / beta_f. This keeps the initial filtering probability near p0 while leaving trainable parameters free to move.

Parameters:

p0 (float)

init_identity_bias_zero(p0=0.99)

Initialize near identity with zero bias under hybrid normalization.

This assumes normalized attributes live in [a, 1] where a = hybrid_floor_a. Each group receives constant weights c = logit(p0) / (beta_f * K * a) and zero bias, so the lower bound on the group logits keeps the initial probability near p0.

Parameters:

p0 (float)

build_dataloader_cached(dataloader)[source]

Precompute tree metadata and attributes for a DataLoader.

The returned DataLoader wraps the original dataset and emits ((x, idx), y) batches, where idx is the original dataset index. The layer uses those indexes to reuse cached preprocessing during training.

class mtlearn.layers.ConnectedFilterPreprocessingLayerWithExplicitJacobian(in_channels, attributes_spec, tree_type='max-tree', device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, top_hat=False, clamp_logits=False, hybrid_k=3.0, hybrid_floor_a=0.05, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]

Bases: Module

Reference CFP layer with an explicit dense Jacobian.

For each attribute group g with K normalized attributes A_g in R[num_nodes, K], the layer computes sigmoid(beta_f * (A_g @ w_g + b_g)) and reconstructs the filtered image with jacobian.T @ (residues * sigmoid).

Use this implementation for debugging and mathematical comparison with the implicit layer, not for memory-sensitive training on larger images.

Parameters:
  • in_channels – Number of input channels.

  • attributes_spec – Attribute groups. Each group must contain at least one morphology attribute enum.

  • tree_type"max-tree", "min-tree", "tree-of-shapes", or the legacy "tos" alias.

  • device – Torch device used for parameters, cached tensors, and outputs.

  • scale_mode (str) – "minmax01", "zscore_tree", "hybrid", or "none".

  • eps (float) – Numerical floor for normalization denominators.

  • beta_f (float) – Forward sigmoid gain.

  • top_hat (bool) – If true, output the tree-type-specific top-hat residual.

  • clamp_logits (bool) – If true, clamp beta_f * logits to [-12, 12].

  • hybrid_k (float) – Number of standard deviations used for hybrid clipping.

  • hybrid_floor_a (float) – Lower bound used when remapping hybrid-normalized attributes to [a, 1].

  • tos_interpolation – Tree-of-shapes interpolation policy. Accepts "self-dual", "min4c-max8c", "min8c-max4c", or the corresponding morphology.ToSInterpolation enum.

  • tos_infinity_seed_row (int) – Infinity seed used by the tree-of-shapes backend.

  • tos_infinity_seed_col (int) – Infinity seed used by the tree-of-shapes backend.

freeze_ds_stats()[source]

Stop collecting dataset statistics for future samples.

unfreeze_ds_stats()[source]

Resume collecting dataset statistics for future samples.

save_stats(path)[source]

Save normalization statistics and scale mode for reproducibility.

Parameters:

path (str)

load_stats(path, refresh_cache=True, *, trusted_legacy_format=False)[source]

Load normalization statistics and optionally refresh cached attrs.

Parameters:
  • path (str)

  • refresh_cache (bool)

  • trusted_legacy_format (bool)

inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]

Return cached or on-the-fly inspection data for one sample.

When idx is provided, the method inspects cached data under f"{idx}_{channel}". Without an index, the tree, dense Jacobian, and attributes are computed on the fly and are not persisted.

Parameters:
  • img (torch.Tensor)

  • channel (int)

  • idx (int | None)

  • build_if_missing (bool)

forward(x)[source]

Apply CFP to a batch and return (B, C * groups, H, W) output.

Return type:

Tensor

Parameters:

x (torch.Tensor)

predict(x, beta_f=1000.0)[source]

Run inference with a caller-provided forward sigmoid gain.

Return type:

Tensor

Parameters:
  • x (torch.Tensor)

  • beta_f (float)

save_params(path)[source]

Save all group weights and biases.

Parameters:

path (str)

get_params()[source]

Return CPU clones of group weights and biases.

refresh_cached_normalization()[source]

Recompute normalized attributes for every cached sample.

init_identity_with_bias(p0=0.995)

Initialize near identity by using only a positive bias.

For each group, weights are set to zero and bias is set to logit(p0) / beta_f.

Parameters:

p0 (float)

init_identity_bias_zero(p0=0.99)

Initialize near identity with zero bias under hybrid normalization.

This assumes normalized attributes live in [a, 1] where a = hybrid_floor_a. Each group receives constant weights c = logit(p0) / (beta_f * K * a) and zero bias.

Parameters:

p0 (float)

build_dataloader_cached(dataloader)[source]

Precompute trees, dense Jacobians, and attributes for a DataLoader.

The returned DataLoader wraps the original dataset and emits ((x, idx), y) batches, where idx is the original dataset index.

class mtlearn.layers.ConnectedFilterPreprocessingLayerWithCPUTreeTraversal(in_channels, attributes_spec, tree_type='max-tree', device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, top_hat=False, clamp_logits=True, hybrid_k=3.0, hybrid_floor_a=0.05, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]

Bases: Module

Reference CFP layer that uses C++ CPU tree traversal.

For each attribute group g with K normalized attributes A_g in R[num_nodes, K], the layer computes sigmoid(beta_f * (A_g @ w_g + b_g)) and sends that criterion to the C++ connected-filter backend. The backward pass asks the backend to traverse the tree and compute dW and dB.

This implementation is CPU-bound and mainly serves as a reference for the primary implicit-Jacobian layer.

Parameters:
  • in_channels – Number of input channels.

  • attributes_spec – Attribute groups. Each group must contain at least one morphology attribute enum.

  • tree_type"max-tree", "min-tree", "tree-of-shapes", or the legacy "tos" alias.

  • device – Torch device used for parameters, cached tensors, and outputs.

  • scale_mode (str) – "minmax01", "zscore_tree", "hybrid", or "none".

  • eps (float) – Numerical floor for normalization denominators.

  • beta_f (float) – Forward sigmoid gain.

  • top_hat (bool) – If true, output the tree-type-specific top-hat residual.

  • clamp_logits (bool) – If true, clamp beta_f * logits to [-12, 12].

  • hybrid_k (float) – Number of standard deviations used for hybrid clipping.

  • hybrid_floor_a (float) – Lower bound used when remapping hybrid-normalized attributes to [a, 1].

  • tos_interpolation – Tree-of-shapes interpolation policy. Accepts "self-dual", "min4c-max8c", "min8c-max4c", or the corresponding morphology.ToSInterpolation enum.

  • tos_infinity_seed_row (int) – Infinity seed used by the tree-of-shapes backend.

  • tos_infinity_seed_col (int) – Infinity seed used by the tree-of-shapes backend.

freeze_ds_stats()[source]

Stop collecting dataset statistics for future samples.

unfreeze_ds_stats()[source]

Resume collecting dataset statistics for future samples.

save_stats(path)[source]

Save normalization statistics and scale mode for reproducibility.

Parameters:

path (str)

load_stats(path, refresh_cache=True, *, trusted_legacy_format=False)[source]

Load normalization statistics and optionally refresh cached attrs.

Parameters:
  • path (str)

  • refresh_cache (bool)

  • trusted_legacy_format (bool)

inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]

Return an inspection package for one image.

The result contains the tree, raw and normalized attributes grouped by attribute group name, and references to the current group weights and biases.

Parameters:
  • img – Tensor (H, W) or (C, H, W), or (img, idx) tuple.

  • channel (int) – Channel index used when img has more than one channel.

  • idx (int | None) – Optional sample index, also accepted through (img, idx).

  • build_if_missing (bool) – Build and cache preprocessing if it is missing.

Raises:

KeyError – If build_if_missing is false and the key is absent.

Weights and biases are returned as parameter references; clone them if immutable snapshots are needed.

forward(x)[source]

Apply CFP to a batch and return (B, C * groups, H, W) output.

Return type:

Tensor

Parameters:

x (torch.Tensor)

predict(x, beta_f=1.0)[source]

Run inference with a caller-provided forward sigmoid gain.

Return type:

Tensor

Parameters:
  • x (torch.Tensor)

  • beta_f (float)

save_params(path)[source]

Save all group weights and biases.

Parameters:

path (str)

get_params()[source]

Return CPU clones of group weights and biases.

refresh_cached_normalization()[source]

Recompute normalized attributes for every cached sample.

init_identity_with_bias(p0=0.995)

Initialize near identity by using only a positive bias.

For each group, weights are set to zero and bias is set to logit(p0) / beta_f.

Parameters:

p0 (float)

init_identity_bias_zero(p0=0.99)

Initialize near identity with zero bias under hybrid normalization.

This assumes normalized attributes live in [a, 1] where a = hybrid_floor_a. Each group receives constant weights c = logit(p0) / (beta_f * K * a) and zero bias.

Parameters:

p0 (float)

build_dataloader_cached(dataloader)[source]

Precompute trees and attributes for a DataLoader.

The returned DataLoader wraps the original dataset and emits ((x, idx), y) batches. This method builds per-sample trees and attributes, updates dataset normalization statistics, refreshes cached normalized attributes, and freezes statistics when preprocessing ends.