Layers
The primary trainable preprocessing layer is
ConnectedFilterPreprocessingLayer. Reference and legacy implementations are
documented as compatibility surfaces, but new experiments should prefer the
primary layer.
Primary CFP Layer
-
class mtlearn.layers.ConnectedFilterPreprocessingLayer(in_channels, filter_specs, *, device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, clamp=None, hybrid_k=3.0, hybrid_floor_a=0.05, attribute_dtype=None, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]
Bases: Module
Learnable CFP layer defined by per-output filter specifications.
Each item in filter_specs defines one output per input channel:
morphology tree, scoring attributes, and reconstructed valuation. The
default valuation is CFPValuation.ALTITUDE. Top-hat output is selected
with CFPValuation.ALTITUDE_TOPHAT. clamp optionally bounds
beta_f * logits before the sigmoid.
Tree construction and attribute computation happen outside autograd. The
trainable parameters are the per-filter weight vectors and scalar biases
that produce node-wise sigmoid gates from normalized attributes.
- Parameters:
-
-
forward(x)[source]
Apply all filter specs and return (B, C * specs, H, W).
- Parameters:
x (Tensor) – Input tensor shaped (B, C, H, W) or the cached-loader form
(x, idx) produced by build_dataloader_cached.
- Return type:
Tensor
- Returns:
Tensor with one output channel per input channel and filter spec.
-
predict(x, beta_f=1000.0)[source]
Run inference with a caller-provided sigmoid gain.
The method temporarily switches the module to evaluation mode, runs
forward under torch.no_grad(), restores beta_f, and restores
the previous training/eval state.
- Return type:
Tensor
- Parameters:
x (torch.Tensor)
beta_f (float)
-
inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]
Return cached or direct attributes, valuations, and parameters per spec.
- Parameters:
img (Tensor) – Image tensor shaped (H, W) or (C, H, W).
channel (int) – Channel to inspect when img has multiple channels.
idx (int | None) – Optional stable dataset index used to look up cached payloads.
build_if_missing (bool) – Build a temporary tree payload when no cache entry
exists for idx.
- Returns:
Dictionary keyed by filter-spec name. Each entry contains raw and
normalized attributes, valuation increments, and current trainable
parameters.
-
freeze_ds_stats()[source]
Stop updating dataset-level normalization statistics.
-
unfreeze_ds_stats()[source]
Resume updating dataset-level normalization statistics.
-
refresh_cached_normalization()[source]
Recompute normalized attributes for all cached samples.
-
save_stats(path)[source]
Save dataset-level normalization statistics.
The payload is a torch-safe dictionary containing a format version,
scale_mode, and serialized dataset statistics. Per-sample caches are
not saved.
- Parameters:
path (str)
-
load_stats(path, refresh_cache=True)[source]
Load dataset-level normalization statistics.
- Parameters:
-
-
get_config()[source]
Return the architecture/configuration needed to reconstruct the layer.
The returned dictionary is serializable and accepted by
from_config. It describes layer structure, filter specs, valuation
choices, normalization mode, sigmoid gain, clamp bounds, and hybrid
normalization constants. It does not include trainable weights or
dataset statistics.
- Return type:
dict[str, Any]
-
get_weight_contract()[source]
Return the CFP contract that defines parameter names and semantics.
Checkpoints use this contract to reject incompatible layer
architectures before loading state into differently shaped CFP
parameters.
- Return type:
dict[str, Any]
-
classmethod from_config(config, *, device=None)[source]
Reconstruct a layer from get_config() output.
- Return type:
ConnectedFilterPreprocessingLayer
- Parameters:
config (Mapping[str, Any])
Embed persistent CFP state in PyTorch checkpoints.
This includes the weight contract and dataset normalization statistics.
Per-sample tree/attribute caches are intentionally not persisted.
- Return type:
dict[str, Any]
Restore persistent CFP state from state_dict and validate compatibility.
- Return type:
None
- Parameters:
state (Any)
-
export_params(path)[source]
Export CFP parameters and metadata for inspection.
This is not the recommended training checkpoint API. Use
mtlearn.layers.save_checkpoint for full PyTorch models.
- Parameters:
path (str)
-
save_params(path)[source]
Compatibility alias for export_params.
- Parameters:
path (str)
-
get_params()[source]
Return CPU clones of the per-spec weight and bias tensors.
-
init_identity_with_bias(p0=0.995)
Initialize filters close to identity using only positive bias.
Weights are set to zero and each bias is chosen so the initial sigmoid
value is approximately p0.
- Parameters:
p0 (float)
-
init_identity_bias_zero(p0=0.99)
Initialize filters close to identity with zero bias.
This initialization assumes hybrid-normalized attributes with a
positive floor. Each weight receives the same positive value and biases
are set to zero.
- Parameters:
p0 (float)
-
build_dataloader_cached(dataloader)[source]
Wrap a DataLoader and precompute CFP caches/statistics.
The returned DataLoader yields ((x, idx), y) batches with stable
dataset indices. During the prepass, this layer builds tree payloads and
updates dataset-level statistics for every sample/channel/tree key.
Statistics are frozen and cached normalizations are refreshed before
the wrapped loader is returned.
-
build_dataloader_cached_fixed_stats(dataloader, *, index_offset=0)[source]
Wrap a DataLoader and precompute CFP caches without updating stats.
Use this for validation/test splits after training statistics have been
built with build_dataloader_cached(...) or restored with
load_stats(...). The returned DataLoader yields
((x, idx + index_offset), y) so callers can keep split cache keys
disjoint.
- Parameters:
index_offset (int)
Checkpoint Helpers
-
mtlearn.layers.collect_cfp_configs(model)[source]
Return serializable configs for every primary CFP layer in model.
- Return type:
dict[str, dict[str, Any]]
- Parameters:
model (torch.nn.Module)
-
mtlearn.layers.save_checkpoint(path, model)[source]
Save a PyTorch checkpoint with automatically collected CFP configs.
The model parameters are saved with the normal PyTorch state_dict.
Primary CFP layers are discovered by module name and their constructor
configs are saved separately so a caller can reconstruct the model before
calling load_state_dict. The payload intentionally contains only model
weights and CFP configs that define the meaning and shape of CFP-related
weights.
- Return type:
dict[str, Any]
- Parameters:
model (torch.nn.Module)
-
mtlearn.layers.load_checkpoint(path, model_or_factory, *, device=None, strict=True, weights_only=True)[source]
Load a checkpoint saved by save_checkpoint.
model_or_factory can be either an already constructed module or a
callable that returns a module. Factories may accept no arguments when the
model constructor already hard-codes its CFP layers, or one positional
cfp_configs argument when they need the saved CFP configs.
- Return type:
tuple[Module, dict[str, Any]]
- Parameters:
-
Reference Implementations
-
class mtlearn.layers.ConnectedFilterPreprocessingLayerLegacy(in_channels, attributes_spec, tree_type='max-tree', device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, top_hat=False, clamp_logits=False, hybrid_k=3.0, hybrid_floor_a=0.05, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]
Bases: Module
Main learnable Connected Filter Preprocessing (CFP) layer.
For each attribute group g with K normalized attributes
A_g in R[num_nodes, K], the layer computes
sigmoid(beta_f * (A_g @ w_g + b_g)) as a node-wise filtering criterion.
The criterion is applied to the tree residues and reconstructed to pixels
through the implicit-Jacobian autograd function.
This legacy implementation is kept for reproducing experiments that used
the former global tree/output contract. It can run tensor operations on
CUDA when device="cuda" while still building morphology trees through
the CPU backend.
- Parameters:
in_channels – Number of input channels.
attributes_spec – Attribute groups. Each item is one group and must
contain at least one morphology attribute enum.
tree_type – "max-tree", "min-tree", "tree-of-shapes", or
the legacy "tos" alias.
device – Torch device used for parameters, cached tensors, and outputs.
scale_mode (str) – "minmax01", "zscore_tree", "hybrid", or
"none".
eps (float) – Numerical floor for normalization denominators.
beta_f (float) – Forward sigmoid gain.
top_hat (bool) – If true, output the tree-type-specific top-hat residual.
clamp_logits (bool) – If true, clamp beta_f * logits to [-12, 12].
hybrid_k (float) – Number of standard deviations used for hybrid clipping.
hybrid_floor_a (float) – Lower bound used when remapping hybrid-normalized
attributes to [a, 1].
tos_interpolation – Tree-of-shapes interpolation policy. Accepts
"self-dual", "min4c-max8c", "min8c-max4c", or the
corresponding morphology.ToSInterpolation enum.
tos_infinity_seed_row (int) – Infinity seed used by
the tree-of-shapes backend.
tos_infinity_seed_col (int) – Infinity seed used by
the tree-of-shapes backend.
-
freeze_ds_stats()[source]
Stop collecting dataset statistics for future samples.
-
unfreeze_ds_stats()[source]
Resume collecting dataset statistics for future samples.
-
save_stats(path)[source]
Save normalization statistics and scale mode for reproducibility.
- Parameters:
path (str)
-
load_stats(path, refresh_cache=True, *, trusted_legacy_format=False)[source]
Load normalization statistics and optionally refresh cached attrs.
- Parameters:
-
-
inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]
Return cached or on-the-fly inspection data for one sample.
When idx is provided, the method inspects cached data under
f"{idx}_{channel}". Without an index, the tree and attributes are
computed on the fly and are not persisted.
- Parameters:
img (torch.Tensor)
channel (int)
idx (int | None)
build_if_missing (bool)
-
forward(x)[source]
Apply CFP to a batch and return (B, C * groups, H, W) output.
The input can be a tensor, (x, idx), or [x, idx] from a
DataLoader. Indexed inputs use persistent caches keyed by sample index
and channel; plain tensor inputs build trees and attributes on demand.
- Return type:
Tensor
- Parameters:
x (torch.Tensor)
-
predict(x, beta_f=1000.0)[source]
Run inference with a caller-provided forward sigmoid gain.
- Return type:
Tensor
- Parameters:
x (torch.Tensor)
beta_f (float)
-
save_params(path)[source]
Save all group weights and biases.
- Parameters:
path (str)
-
get_params()[source]
Return CPU clones of group weights and biases.
-
refresh_cached_normalization()[source]
Recompute normalized attributes for every cached sample.
-
init_identity_with_bias(p0=0.995)
Initialize near identity by using only a positive bias.
For each group, weights are set to zero and bias is set to
logit(p0) / beta_f. This keeps the initial filtering probability
near p0 while leaving trainable parameters free to move.
- Parameters:
p0 (float)
-
init_identity_bias_zero(p0=0.99)
Initialize near identity with zero bias under hybrid normalization.
This assumes normalized attributes live in [a, 1] where
a = hybrid_floor_a. Each group receives constant weights
c = logit(p0) / (beta_f * K * a) and zero bias, so the lower bound
on the group logits keeps the initial probability near p0.
- Parameters:
p0 (float)
-
build_dataloader_cached(dataloader)[source]
Precompute tree metadata and attributes for a DataLoader.
The returned DataLoader wraps the original dataset and emits
((x, idx), y) batches, where idx is the original dataset index.
The layer uses those indexes to reuse cached preprocessing during
training.
-
class mtlearn.layers.ConnectedFilterPreprocessingLayerWithExplicitJacobian(in_channels, attributes_spec, tree_type='max-tree', device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, top_hat=False, clamp_logits=False, hybrid_k=3.0, hybrid_floor_a=0.05, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]
Bases: Module
Reference CFP layer with an explicit dense Jacobian.
For each attribute group g with K normalized attributes
A_g in R[num_nodes, K], the layer computes
sigmoid(beta_f * (A_g @ w_g + b_g)) and reconstructs the filtered image
with jacobian.T @ (residues * sigmoid).
Use this implementation for debugging and mathematical comparison with the
implicit layer, not for memory-sensitive training on larger images.
- Parameters:
in_channels – Number of input channels.
attributes_spec – Attribute groups. Each group must contain at least one
morphology attribute enum.
tree_type – "max-tree", "min-tree", "tree-of-shapes", or
the legacy "tos" alias.
device – Torch device used for parameters, cached tensors, and outputs.
scale_mode (str) – "minmax01", "zscore_tree", "hybrid", or
"none".
eps (float) – Numerical floor for normalization denominators.
beta_f (float) – Forward sigmoid gain.
top_hat (bool) – If true, output the tree-type-specific top-hat residual.
clamp_logits (bool) – If true, clamp beta_f * logits to [-12, 12].
hybrid_k (float) – Number of standard deviations used for hybrid clipping.
hybrid_floor_a (float) – Lower bound used when remapping hybrid-normalized
attributes to [a, 1].
tos_interpolation – Tree-of-shapes interpolation policy. Accepts
"self-dual", "min4c-max8c", "min8c-max4c", or the
corresponding morphology.ToSInterpolation enum.
tos_infinity_seed_row (int) – Infinity seed used by
the tree-of-shapes backend.
tos_infinity_seed_col (int) – Infinity seed used by
the tree-of-shapes backend.
-
freeze_ds_stats()[source]
Stop collecting dataset statistics for future samples.
-
unfreeze_ds_stats()[source]
Resume collecting dataset statistics for future samples.
-
save_stats(path)[source]
Save normalization statistics and scale mode for reproducibility.
- Parameters:
path (str)
-
load_stats(path, refresh_cache=True, *, trusted_legacy_format=False)[source]
Load normalization statistics and optionally refresh cached attrs.
- Parameters:
-
-
inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]
Return cached or on-the-fly inspection data for one sample.
When idx is provided, the method inspects cached data under
f"{idx}_{channel}". Without an index, the tree, dense Jacobian, and
attributes are computed on the fly and are not persisted.
- Parameters:
img (torch.Tensor)
channel (int)
idx (int | None)
build_if_missing (bool)
-
forward(x)[source]
Apply CFP to a batch and return (B, C * groups, H, W) output.
- Return type:
Tensor
- Parameters:
x (torch.Tensor)
-
predict(x, beta_f=1000.0)[source]
Run inference with a caller-provided forward sigmoid gain.
- Return type:
Tensor
- Parameters:
x (torch.Tensor)
beta_f (float)
-
save_params(path)[source]
Save all group weights and biases.
- Parameters:
path (str)
-
get_params()[source]
Return CPU clones of group weights and biases.
-
refresh_cached_normalization()[source]
Recompute normalized attributes for every cached sample.
-
init_identity_with_bias(p0=0.995)
Initialize near identity by using only a positive bias.
For each group, weights are set to zero and bias is set to
logit(p0) / beta_f.
- Parameters:
p0 (float)
-
init_identity_bias_zero(p0=0.99)
Initialize near identity with zero bias under hybrid normalization.
This assumes normalized attributes live in [a, 1] where
a = hybrid_floor_a. Each group receives constant weights
c = logit(p0) / (beta_f * K * a) and zero bias.
- Parameters:
p0 (float)
-
build_dataloader_cached(dataloader)[source]
Precompute trees, dense Jacobians, and attributes for a DataLoader.
The returned DataLoader wraps the original dataset and emits
((x, idx), y) batches, where idx is the original dataset index.
-
class mtlearn.layers.ConnectedFilterPreprocessingLayerWithCPUTreeTraversal(in_channels, attributes_spec, tree_type='max-tree', device='cpu', scale_mode='hybrid', eps=1e-06, beta_f=1.0, top_hat=False, clamp_logits=True, hybrid_k=3.0, hybrid_floor_a=0.05, tos_interpolation=None, tos_infinity_seed_row=0, tos_infinity_seed_col=0)[source]
Bases: Module
Reference CFP layer that uses C++ CPU tree traversal.
For each attribute group g with K normalized attributes
A_g in R[num_nodes, K], the layer computes
sigmoid(beta_f * (A_g @ w_g + b_g)) and sends that criterion to the C++
connected-filter backend. The backward pass asks the backend to traverse
the tree and compute dW and dB.
This implementation is CPU-bound and mainly serves as a reference for the
primary implicit-Jacobian layer.
- Parameters:
in_channels – Number of input channels.
attributes_spec – Attribute groups. Each group must contain at least one
morphology attribute enum.
tree_type – "max-tree", "min-tree", "tree-of-shapes", or
the legacy "tos" alias.
device – Torch device used for parameters, cached tensors, and outputs.
scale_mode (str) – "minmax01", "zscore_tree", "hybrid", or
"none".
eps (float) – Numerical floor for normalization denominators.
beta_f (float) – Forward sigmoid gain.
top_hat (bool) – If true, output the tree-type-specific top-hat residual.
clamp_logits (bool) – If true, clamp beta_f * logits to [-12, 12].
hybrid_k (float) – Number of standard deviations used for hybrid clipping.
hybrid_floor_a (float) – Lower bound used when remapping hybrid-normalized
attributes to [a, 1].
tos_interpolation – Tree-of-shapes interpolation policy. Accepts
"self-dual", "min4c-max8c", "min8c-max4c", or the
corresponding morphology.ToSInterpolation enum.
tos_infinity_seed_row (int) – Infinity seed used by
the tree-of-shapes backend.
tos_infinity_seed_col (int) – Infinity seed used by
the tree-of-shapes backend.
-
freeze_ds_stats()[source]
Stop collecting dataset statistics for future samples.
-
unfreeze_ds_stats()[source]
Resume collecting dataset statistics for future samples.
-
save_stats(path)[source]
Save normalization statistics and scale mode for reproducibility.
- Parameters:
path (str)
-
load_stats(path, refresh_cache=True, *, trusted_legacy_format=False)[source]
Load normalization statistics and optionally refresh cached attrs.
- Parameters:
-
-
inspect_training_sample(img, channel=0, idx=None, build_if_missing=True)[source]
Return an inspection package for one image.
The result contains the tree, raw and normalized attributes grouped by
attribute group name, and references to the current group weights and
biases.
- Parameters:
img – Tensor (H, W) or (C, H, W), or (img, idx) tuple.
channel (int) – Channel index used when img has more than one channel.
idx (int | None) – Optional sample index, also accepted through (img, idx).
build_if_missing (bool) – Build and cache preprocessing if it is missing.
- Raises:
KeyError – If build_if_missing is false and the key is absent.
Weights and biases are returned as parameter references; clone them if
immutable snapshots are needed.
-
forward(x)[source]
Apply CFP to a batch and return (B, C * groups, H, W) output.
- Return type:
Tensor
- Parameters:
x (torch.Tensor)
-
predict(x, beta_f=1.0)[source]
Run inference with a caller-provided forward sigmoid gain.
- Return type:
Tensor
- Parameters:
x (torch.Tensor)
beta_f (float)
-
save_params(path)[source]
Save all group weights and biases.
- Parameters:
path (str)
-
get_params()[source]
Return CPU clones of group weights and biases.
-
refresh_cached_normalization()[source]
Recompute normalized attributes for every cached sample.
-
init_identity_with_bias(p0=0.995)
Initialize near identity by using only a positive bias.
For each group, weights are set to zero and bias is set to
logit(p0) / beta_f.
- Parameters:
p0 (float)
-
init_identity_bias_zero(p0=0.99)
Initialize near identity with zero bias under hybrid normalization.
This assumes normalized attributes live in [a, 1] where
a = hybrid_floor_a. Each group receives constant weights
c = logit(p0) / (beta_f * K * a) and zero bias.
- Parameters:
p0 (float)
-
build_dataloader_cached(dataloader)[source]
Precompute trees and attributes for a DataLoader.
The returned DataLoader wraps the original dataset and emits
((x, idx), y) batches. This method builds per-sample trees and
attributes, updates dataset normalization statistics, refreshes cached
normalized attributes, and freezes statistics when preprocessing ends.