Connected Filter Preprocessing¶
ConnectedFilterPreprocessingLayer turns morphology-tree attribute filtering
into a trainable PyTorch module. It builds a tree for each sample/channel,
computes node attributes outside autograd, learns node-wise sigmoid gates, and
reconstructs one output image per input channel and filter spec.
Minimal Layer¶
Each filter spec defines a tree type and scoring attributes. The public examples below use the default reconstructed signal.
import torch
from mtlearn import morphology
from mtlearn.layers import ConnectedFilterPreprocessingLayer
layer = ConnectedFilterPreprocessingLayer(
in_channels=1,
filter_specs=[
{
"name": "area_opening",
"tree_type": morphology.TreeType.MAX_TREE,
"attributes": [
morphology.AttributeType.AREA,
morphology.AttributeType.GRAY_HEIGHT,
],
},
],
scale_mode="minmax01",
)
x = torch.rand(4, 1, 32, 32)
y = layer(x)
assert y.shape == (4, 1, 32, 32)
With N specs and C input channels, output channels are ordered by input
channel first, then spec index. The shape is (B, C * N, H, W).
Filter Specs¶
A spec has these user-facing fields:
Field |
Required |
Meaning |
|---|---|---|
|
No |
Stable key for weights, biases, exported params, and checkpoints. |
|
Yes |
|
|
Yes |
One scalar attribute, one group, or a list/tuple of scalar attributes. |
|
No |
Per-spec tree-of-shapes interpolation override. |
|
No |
Per-spec tree-of-shapes infinity seed row. |
|
No |
Per-spec tree-of-shapes infinity seed column. |
Multiple specs can share the same tree. mtlearn caches tree metadata per tree key and only computes distinct trees once per sample/channel/cache key.
filter_specs = [
{
"name": "bright_shape",
"tree_type": "max-tree",
"attributes": morphology.AttributeGroup.SHAPE,
},
{
"name": "dark_topology",
"tree_type": "min-tree",
"attributes": morphology.AttributeGroup.TREE_TOPOLOGY,
},
]
Normalization and Caching¶
The default scale_mode is "hybrid". It uses dataset-level z-score
statistics, clips values to [-hybrid_k, hybrid_k], and rescales them into a
positive interval controlled by hybrid_floor_a.
For "hybrid", call build_dataloader_cached or load_stats before normal
forward passes.
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=16, shuffle=False)
cached_loader = layer.build_dataloader_cached(loader)
for (x, idx), target in cached_loader:
y = layer((x, idx))
For quick experiments that should not require a stats prepass, use
scale_mode="minmax01", "zscore_tree", or "none".
debug_layer = ConnectedFilterPreprocessingLayer(
in_channels=1,
filter_specs=filter_specs,
scale_mode="minmax01",
)
Initialization¶
Two helpers initialize filters close to identity.
layer.init_identity_with_bias(p0=0.995)
# Alternative for hybrid-normalized attributes with positive floor.
layer.init_identity_bias_zero(p0=0.99)
Use an identity-like initialization when CFP is placed before a pretrained or sensitive downstream network. Use random initialization when the preprocessing block is meant to discover strong filtering behavior from scratch.
Inference¶
predict temporarily switches to evaluation mode, runs without gradients, and
uses a caller-provided sigmoid gain. A large beta_f makes gates closer to
hard decisions.
with torch.no_grad():
y_soft = layer(x)
y_hard = layer.predict(x, beta_f=1000.0)
Inspect One Sample¶
Use inspect_training_sample to debug attributes, normalized attributes,
tree payloads, and current trainable parameters.
report = layer.inspect_training_sample(x[0], channel=0, idx=0)
for name, spec_report in report["specs"].items():
print(name)
print(spec_report["attributes"])
print(spec_report["weight"])
print(spec_report["bias"])
Save Stats and Params¶
Dataset statistics are separate from ordinary model weights.
layer.save_stats("cfp-stats.pt")
layer.load_stats("cfp-stats.pt")
weights, biases = layer.get_params()
layer.export_params("cfp-params.pt")
For full model checkpoints, use the helpers documented in PyTorch Integration.