Source code for mtlearn.layers.ConnectedFilterPreprocessingLayerWithCPUTreeTraversal

"""Reference CFP implementation backed by CPU tree traversal.

This module keeps a CPU-oriented baseline for connected-filter preprocessing.
It delegates the connected filtering and parameter-gradient computation to the
C++ ``ConnectedFilterPreprocessingTreeTraversal`` helper instead of using the implicit
or dense-Jacobian PyTorch formulations. The implementation is useful for
comparison with historical experiments and backend behavior.
"""

from __future__ import annotations

import math
import torch
import numpy as np
from .. import morphology
import mtlearn
from ._helpers import (
    group_name,
    to_numpy_u8,
    build_tree,
    update_ds_stats,
    normalize_with_ds_stats,
    maybe_refresh_norm_for_key,
    make_stats_payload,
    load_stats_payload,
    IndexedDatasetWrapper,
    normalize_attributes_spec,
    validate_attributes_for_tree_type,
)






class ConnectedFilterPreprocessingCPUTreeTraversalFunction(torch.autograd.Function):
    """Autograd function backed by C++ filtering and C++ tree traversal grads."""

    @staticmethod
    def forward(ctx, tree, attrs2d: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, beta_f: float = 1000.0, clamp_logits: bool = True):
        """Apply the connected filter for one tree and one attribute group.

        Args:
            tree: Morphology-tree handle from the C++ backend.
            attrs2d: Normalized attributes with shape ``(num_nodes, K)``.
            weight: Learnable group weight vector with shape ``(K,)``.
            bias: Learnable scalar bias, either ``()`` or ``(1,)``.
            beta_f: Forward sigmoid gain.
            clamp_logits: Whether to clamp ``beta_f * logits`` before sigmoid.

        Returns:
            Filtered image with shape ``(H, W)``.
        """
        assert attrs2d.dim() == 2, "attrs2d must have shape (num_nodes, K)"
        assert weight.dim() == 1, "weight must have shape (K,)"
        logits = attrs2d @ weight.view(-1) + bias.view(())   # (numNodes,)
        s = beta_f * logits
        if clamp_logits:
            s = torch.clamp(s, -12.0, 12.0)
        sigmoid = torch.sigmoid(s)  # (numNodes,)
        y_pred = mtlearn.ConnectedFilterPreprocessingTreeTraversal.filtering(tree, sigmoid)

        # Save the C++ tree handle and tensors needed by the C++ gradient routine.
        ctx.tree = tree
        ctx.beta_f = beta_f
        ctx.save_for_backward(attrs2d, sigmoid)
        return y_pred

    @staticmethod
    def backward(ctx, grad_output):
        """Return gradients for ``weight`` and ``bias`` using C++ traversal."""
        attrs2d, sigmoid = ctx.saved_tensors
        tree = ctx.tree
        beta_f = ctx.beta_f

        # C++: gradients(treePtr, attrs, sigmoid, dL/dY) -> (dW, dB).
        dW, dB = mtlearn.ConnectedFilterPreprocessingTreeTraversal.gradients(
            tree, attrs2d, sigmoid, grad_output
        )

        # Match forward args: (tree, attrs2d, weight, bias, beta_f, clamp_logits)
        return None, None, dW, dB, None, None


[docs] class ConnectedFilterPreprocessingLayerWithCPUTreeTraversal(torch.nn.Module): """Reference CFP layer that uses C++ CPU tree traversal. For each attribute group ``g`` with ``K`` normalized attributes ``A_g in R[num_nodes, K]``, the layer computes ``sigmoid(beta_f * (A_g @ w_g + b_g))`` and sends that criterion to the C++ connected-filter backend. The backward pass asks the backend to traverse the tree and compute ``dW`` and ``dB``. This implementation is CPU-bound and mainly serves as a reference for the primary implicit-Jacobian layer. Args: in_channels: Number of input channels. attributes_spec: Attribute groups. Each group must contain at least one morphology attribute enum. tree_type: ``"max-tree"``, ``"min-tree"``, ``"tree-of-shapes"``, or the legacy ``"tos"`` alias. device: Torch device used for parameters, cached tensors, and outputs. scale_mode: ``"minmax01"``, ``"zscore_tree"``, ``"hybrid"``, or ``"none"``. eps: Numerical floor for normalization denominators. beta_f: Forward sigmoid gain. top_hat: If true, output the tree-type-specific top-hat residual. clamp_logits: If true, clamp ``beta_f * logits`` to ``[-12, 12]``. hybrid_k: Number of standard deviations used for hybrid clipping. hybrid_floor_a: Lower bound used when remapping hybrid-normalized attributes to ``[a, 1]``. tos_interpolation: Tree-of-shapes interpolation policy. Accepts ``"self-dual"``, ``"min4c-max8c"``, ``"min8c-max4c"``, or the corresponding ``morphology.ToSInterpolation`` enum. tos_infinity_seed_row, tos_infinity_seed_col: Infinity seed used by the tree-of-shapes backend. """ def __init__(self, in_channels, attributes_spec, tree_type="max-tree", device="cpu", scale_mode: str = "hybrid", eps: float = 1e-6, beta_f: float = 1.0, top_hat: bool = False, clamp_logits: bool = True, hybrid_k: float = 3.0, hybrid_floor_a: float = 0.05, tos_interpolation=None, tos_infinity_seed_row: int = 0, tos_infinity_seed_col: int = 0, ): """Initialize CPU-tree-traversal CFP caches and parameters. The constructor keeps the same public configuration as the primary CFP layer, but cache entries store backend tree handles because both filtering and gradient computation are delegated to C++ tree traversal. """ super().__init__() # Hybrid normalization configuration. self.hybrid_k = float(hybrid_k) self.hybrid_floor_a = float(hybrid_floor_a) self.in_channels = int(in_channels) self.tree_type = morphology.normalize_tree_type(tree_type) self.device = torch.device(device) self.scale_mode = str(scale_mode) self.eps = float(eps) self.beta_f = float(beta_f) self.top_hat = bool(top_hat) self.clamp_logits = bool(clamp_logits) if self.tree_type == "tree-of-shapes": self.tos_interpolation = morphology.normalize_tos_interpolation(tos_interpolation) else: self.tos_interpolation = tos_interpolation self.tos_infinity_seed_row = int(tos_infinity_seed_row) self.tos_infinity_seed_col = int(tos_infinity_seed_col) # Attribute groups and the flat set of scalar attribute types used by them. self.group_defs, self._all_attr_types = normalize_attributes_spec(attributes_spec, self.tree_type) validate_attributes_for_tree_type(self._all_attr_types, self.tree_type) self.num_groups = len(self.group_defs) self.out_channels = self.in_channels * self.num_groups # Tree, attribute, and normalization cache state. self._trees = {} self._base_attrs = {} self._norm_attrs = {} self._stats_epoch = 0 self._norm_epoch_by_key = {} self._ds_stats = {} self._stats_frozen = False # Learnable parameters: one weight vector and one bias per group. self._weights = torch.nn.ParameterDict() self._biases = torch.nn.ParameterDict() for g, group in enumerate(self.group_defs): k = len(group) gname = "+".join([t.name for t in group]) w = torch.empty(k, dtype=torch.float32, device=self.device) b = torch.empty(1, dtype=torch.float32, device=self.device) # Xavier-like initialization for a one-dimensional parameter vector. fan_in, fan_out = k, 1 gain = 1.0 std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) a = math.sqrt(3.0) * std torch.nn.init.uniform_(w, -a, a) torch.nn.init.constant_(b, 0.0) self._weights[gname] = torch.nn.Parameter(w, requires_grad=True) self._biases[gname] = torch.nn.Parameter(b, requires_grad=True) # ---------- helpers ---------- def _group_name(self, group): """Return the stable parameter/cache name for an attribute group.""" return group_name(group) def _to_numpy_u8(self, img2d_t: torch.Tensor) -> np.ndarray: """Convert one image channel to the backend's ``np.uint8`` format.""" return to_numpy_u8(img2d_t) def _build_tree(self, img_np: np.ndarray): """Build the configured morphology tree for one ``np.uint8`` image.""" return build_tree( img_np, self.tree_type, tos_interpolation=self.tos_interpolation, tos_infinity_seed_row=self.tos_infinity_seed_row, tos_infinity_seed_col=self.tos_infinity_seed_col, ) # ---------- normalization with hybrid support ---------- def _update_ds_stats(self, attr_type, a_raw_1d: torch.Tensor): """Update dataset statistics for one raw attribute vector.""" if getattr(self, "_stats_frozen", False): return smode = self.scale_mode if smode == "hybrid": smode = "zscore_tree" changed = update_ds_stats(self._ds_stats, smode, attr_type, a_raw_1d) if changed: self._stats_epoch += 1 def _normalize_with_ds_stats(self, attr_type, a_raw_1d: torch.Tensor) -> torch.Tensor: """Normalize a raw attribute vector according to ``scale_mode``. Hybrid mode applies z-score normalization, clipping, and remapping to ``[hybrid_floor_a, 1]``. """ if self.scale_mode != "hybrid": return normalize_with_ds_stats(self._ds_stats, self.scale_mode, self.eps, attr_type, a_raw_1d) # Hybrid mode. st = self._ds_stats.get(attr_type, None) if st is None: # Before stats exist, preserve the raw values so early calls do not fail. return a_raw_1d # Dataset-level z-score statistics. count = st["count"].to(torch.float32) mean = (st["sum"] / torch.clamp(count, min=1.0)) if count.item() > 0 else torch.tensor(0.0, device=a_raw_1d.device) var = (st["sumsq"] / torch.clamp(count, min=1.0) - mean * mean) if count.item() > 0 else torch.tensor(0.0, device=a_raw_1d.device) std = torch.sqrt(torch.clamp(var, min=self.eps)) # 1) z-score x = (a_raw_1d - mean) / std # 2) clip em [-k, +k] k = torch.tensor(self.hybrid_k, dtype=x.dtype, device=x.device) x = torch.clamp(x, -k, k) # 3) rescale to [a, 1]. a = torch.tensor(self.hybrid_floor_a, dtype=x.dtype, device=x.device) x01 = a + (1.0 - a) * ((x + k) / (2.0 * k)) return x01 def _maybe_refresh_norm_for_key(self, key: str): """Refresh normalized cached attributes if dataset stats changed.""" # Nothing to refresh until raw attributes are cached. if key not in self._base_attrs: return # Cache is already current for this stats epoch. if self._norm_epoch_by_key.get(key, -1) == self._stats_epoch: return if self.scale_mode == "hybrid": # Reapply hybrid normalization attribute by attribute. per_attr_raw = self._base_attrs[key] # dict[attr_type] -> (numNodes,1) per_attr_norm = {} for attr_type, a_raw_2d in per_attr_raw.items(): a_raw_1d = a_raw_2d.view(-1) # (numNodes,) a_norm = self._normalize_with_ds_stats(attr_type, a_raw_1d) per_attr_norm[attr_type] = a_norm self._norm_attrs[key] = per_attr_norm self._norm_epoch_by_key[key] = self._stats_epoch else: # Non-hybrid modes are shared across CFP implementations. maybe_refresh_norm_for_key( key, self._base_attrs, self._norm_attrs, self._all_attr_types, self._ds_stats, self.scale_mode, self.eps, self._norm_epoch_by_key, self._stats_epoch )
[docs] def freeze_ds_stats(self): """Stop collecting dataset statistics for future samples.""" self._stats_frozen = True
[docs] def unfreeze_ds_stats(self): """Resume collecting dataset statistics for future samples.""" self._stats_frozen = False
[docs] def save_stats(self, path: str): """Save normalization statistics and scale mode for reproducibility.""" payload = make_stats_payload(self._ds_stats, self.scale_mode) torch.save(payload, path) print(f"[ConnectedLinearUnit] stats saved to {path}")
[docs] def load_stats(self, path: str, refresh_cache: bool = True, *, trusted_legacy_format: bool = False): """Load normalization statistics and optionally refresh cached attrs.""" payload = load_stats_payload(path, self.device, trusted_legacy_format=trusted_legacy_format) self._ds_stats = payload.get("ds_stats", {}) # Invalidate normalized values derived from older statistics. self._stats_epoch += 1 if refresh_cache: self.refresh_cached_normalization()
# ---------- tree and attribute construction ---------- def _ensure_tree_and_attr(self, key: str, img_t: torch.Tensor): """Ensure tree and raw/normalized attributes exist in cache.""" if key in self._trees: return img_np = self._to_numpy_u8(img_t.detach()) tree = self._build_tree(img_np) self._trees[key] = tree per_attr_raw, per_attr_norm = {}, {} for attr_type in self._all_attr_types: attr_np = morphology.compute_attributes(tree, [attr_type])[1] a_raw_1d = torch.as_tensor(attr_np, device=self.device).squeeze(1) self._update_ds_stats(attr_type, a_raw_1d) a_norm = self._normalize_with_ds_stats(attr_type, a_raw_1d) per_attr_raw[attr_type] = a_raw_1d.unsqueeze(1) per_attr_norm[attr_type] = a_norm self._base_attrs[key] = per_attr_raw self._norm_attrs[key] = per_attr_norm self._norm_epoch_by_key[key] = self._stats_epoch # ---------- inspection ----------
[docs] def inspect_training_sample(self, img, channel: int = 0, idx: int | None = None, build_if_missing: bool = True): """Return an inspection package for one image. The result contains the tree, raw and normalized attributes grouped by attribute group name, and references to the current group weights and biases. Args: img: Tensor ``(H, W)`` or ``(C, H, W)``, or ``(img, idx)`` tuple. channel: Channel index used when ``img`` has more than one channel. idx: Optional sample index, also accepted through ``(img, idx)``. build_if_missing: Build and cache preprocessing if it is missing. Raises: KeyError: If ``build_if_missing`` is false and the key is absent. Weights and biases are returned as parameter references; clone them if immutable snapshots are needed. """ # Detect whether img was passed as (img, idx). if isinstance(img, tuple) and len(img) == 2 and isinstance(img[0], torch.Tensor) and ( isinstance(img[1], int) or (isinstance(img[1], torch.Tensor) and img[1].numel() == 1) ): img_tensor = img[0] idx_val = int(img[1]) if not isinstance(img[1], torch.Tensor) else int(img[1].item()) else: img_tensor = img idx_val = idx # Normalize image layout to (C, H, W). if img_tensor.dim() == 2: imgCHW = img_tensor.unsqueeze(0) # (1,H,W) elif img_tensor.dim() == 3: imgCHW = img_tensor # (C,H,W) else: raise ValueError(f"img must be (H, W) or (C, H, W); got {tuple(img_tensor.shape)}") C, H, W = imgCHW.shape if C != self.in_channels: # A single-channel input is valid even if the layer has more channels. if C != 1: raise AssertionError(f"in_channels={self.in_channels}, input C={C}") c = channel if C > 1 else 0 t_u8 = imgCHW[c] # Use persistent cache when an index is available. if idx_val is not None: key = f"{idx_val}_{c}" if build_if_missing: print(f"[inspect_training_sample] Using cache key '{key}'.") self._ensure_tree_and_attr(key, t_u8) else: print(f"[inspect_training_sample] Using cache key '{key}' without build_if_missing.") tree = self._trees[key] # Refresh normalized attributes if dataset statistics changed. self._maybe_refresh_norm_for_key(key) base_attrs_by_group = {} norm_attrs_by_group = {} weights_by_group = {} bias_by_group = {} for group in self.group_defs: gname = self._group_name(group) # Stack one column per attribute in the group. cols_raw = [self._base_attrs[key][attr_type].view(-1, 1) for attr_type in group] cols_norm = [self._norm_attrs[key][attr_type].view(-1, 1) for attr_type in group] A_raw = torch.cat(cols_raw, dim=1) A_norm = torch.cat(cols_norm, dim=1) base_attrs_by_group[gname] = A_raw norm_attrs_by_group[gname] = A_norm weights_by_group[gname] = self._weights[gname] bias_by_group[gname] = self._biases[gname] else: print("[inspect_training_sample] Running without cache; computing tree and attributes directly.") img_np = self._to_numpy_u8(t_u8.detach()) tree = self._build_tree(img_np) per_attr_raw, per_attr_norm = {}, {} for attr_type in self._all_attr_types: attr_np = morphology.compute_attributes(tree, [attr_type])[1] a_raw_1d = torch.as_tensor(attr_np, device=self.device).squeeze(1) a_norm = self._normalize_with_ds_stats(attr_type, a_raw_1d) per_attr_raw[attr_type] = a_raw_1d.unsqueeze(1) per_attr_norm[attr_type] = a_norm base_attrs_by_group = {} norm_attrs_by_group = {} weights_by_group = {} bias_by_group = {} for group in self.group_defs: gname = self._group_name(group) cols_raw = [per_attr_raw[attr_type].view(-1, 1) for attr_type in group] cols_norm = [per_attr_norm[attr_type].view(-1, 1) for attr_type in group] A_raw = torch.cat(cols_raw, dim=1) A_norm = torch.cat(cols_norm, dim=1) base_attrs_by_group[gname] = A_raw norm_attrs_by_group[gname] = A_norm weights_by_group[gname] = self._weights[gname] bias_by_group[gname] = self._biases[gname] return { "tree": tree, "base_attrs_by_group": base_attrs_by_group, "norm_attrs_by_group": norm_attrs_by_group, "weights_by_group": weights_by_group, "bias_by_group": bias_by_group, }
# ---------- forward ----------
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply CFP to a batch and return ``(B, C * groups, H, W)`` output.""" # Support (x, idx) tuples and [x, idx] lists emitted by DataLoader. if isinstance(x, tuple) and len(x) == 2: x, idx = x use_cache = True elif isinstance(x, list) and len(x) == 2 and isinstance(x[1], torch.Tensor) and x[1].dim() == 1: x, idx = x[0], x[1] use_cache = True else: # Handle a list of tensors without explicit indexes. if isinstance(x, list): x = torch.stack(x, dim=0) idx = torch.arange(x.size(0), device=x.device) use_cache = False #print(f"[ConnectedFilterPreprocessingLayerWithCPUTreeTraversal] forward: use_cache={use_cache}") assert x.dim() == 4, f"expected (B, C, H, W), got {tuple(x.shape)}" B, C, H, W = x.shape assert C == self.in_channels, f"in_channels={self.in_channels}, input C={C}" out = torch.empty((B, self.out_channels, H, W), dtype=torch.float32, device=self.device) for b in range(B): for c in range(C): if use_cache: key = f"{int(idx[b])}_{c}" self._ensure_tree_and_attr(key, x[b, c]) tree = self._trees[key] self._maybe_refresh_norm_for_key(key) for g, group in enumerate(self.group_defs): gname = self._group_name(group) # Build A_norm by stacking one normalized column per group attribute. cols = [self._norm_attrs[key][attr_type].view(-1, 1) for attr_type in group] A_norm = torch.cat(cols, dim=1) # (numNodes, K) y_ch = ConnectedFilterPreprocessingCPUTreeTraversalFunction.apply( tree, A_norm, self._weights[gname], self._biases[gname], self.beta_f, self.clamp_logits ) if self.top_hat: x_bc = x[b, c].to(dtype=torch.float32, device=self.device) tt = self.tree_type if tt == "max-tree": y_out = x_bc - y_ch elif tt == "min-tree": y_out = y_ch - x_bc else: y_out = torch.abs(y_ch - x_bc) else: y_out = y_ch out[b, c * self.num_groups + g].copy_(y_out, non_blocking=True) else: # No persistent key was provided; build tree and attributes directly. #print(f"[ConnectedFilterPreprocessingLayerWithCPUTreeTraversal] forward: computing tree/attrs directly for sample {b}, channel {c}") img_np = self._to_numpy_u8(x[b, c].detach()) tree = self._build_tree(img_np) per_attr_norm = {} for attr_type in self._all_attr_types: attr_np = morphology.compute_attributes(tree, [attr_type])[1] a_raw_1d = torch.as_tensor(attr_np, device=self.device).squeeze(1) a_norm = self._normalize_with_ds_stats(attr_type, a_raw_1d) per_attr_norm[attr_type] = a_norm for g, group in enumerate(self.group_defs): gname = self._group_name(group) cols = [per_attr_norm[attr_type].view(-1, 1) for attr_type in group] A_norm = torch.cat(cols, dim=1) y_ch = ConnectedFilterPreprocessingCPUTreeTraversalFunction.apply( tree, A_norm, self._weights[gname], self._biases[gname], self.beta_f, self.clamp_logits ) if self.top_hat: x_bc = x[b, c].to(dtype=torch.float32, device=self.device) tt = self.tree_type if tt == "max-tree": y_out = x_bc - y_ch elif tt == "min-tree": y_out = y_ch - x_bc else: y_out = torch.abs(y_ch - x_bc) else: y_out = y_ch out[b, c * self.num_groups + g].copy_(y_out, non_blocking=True) return out
# ---------- prediction / inference ----------
[docs] def predict(self, x: torch.Tensor, beta_f: float = 1.0) -> torch.Tensor: """Run inference with a caller-provided forward sigmoid gain.""" was_training = self.training self.eval() with torch.no_grad(): # Same input handling as forward. if isinstance(x, tuple) and len(x) == 2: x, idx = x use_cache = True elif isinstance(x, list) and len(x) == 2 and isinstance(x[1], torch.Tensor) and x[1].dim() == 1: x, idx = x[0], x[1] use_cache = True else: if isinstance(x, list): x = torch.stack(x, dim=0) idx = torch.arange(x.size(0), device=x.device) use_cache = False #print(f"[ConnectedFilterPreprocessingLayerWithCPUTreeTraversal] predict: use_cache={use_cache}") B, C, H, W = x.shape out = torch.empty((B, self.out_channels, H, W), dtype=torch.float32, device=self.device) for b in range(B): for c in range(C): if use_cache: key = f"{int(idx[b])}_{c}" self._ensure_tree_and_attr(key, x[b, c]) tree = self._trees[key] self._maybe_refresh_norm_for_key(key) for g, group in enumerate(self.group_defs): gname = self._group_name(group) cols = [self._norm_attrs[key][attr_type].view(-1, 1) for attr_type in group] A_norm = torch.cat(cols, dim=1) # (numNodes, K) y_ch = ConnectedFilterPreprocessingCPUTreeTraversalFunction.apply( tree, A_norm, self._weights[gname], self._biases[gname], beta_f, self.clamp_logits # caller-provided beta_f ) if self.top_hat: x_bc = x[b, c].to(dtype=torch.float32, device=self.device) tt = self.tree_type if tt == "max-tree": y_out = x_bc - y_ch elif tt == "min-tree": y_out = y_ch - x_bc else: y_out = torch.abs(y_ch - x_bc) else: y_out = y_ch out[b, c * self.num_groups + g].copy_(y_out, non_blocking=True) else: #print(f"[ConnectedFilterPreprocessingLayerWithCPUTreeTraversal] predict: computing tree/attrs directly for sample {b}, channel {c}") img_np = self._to_numpy_u8(x[b, c].detach()) tree = self._build_tree(img_np) per_attr_norm = {} for attr_type in self._all_attr_types: attr_np = morphology.compute_attributes(tree, [attr_type])[1] a_raw_1d = torch.as_tensor(attr_np, device=self.device).squeeze(1) a_norm = self._normalize_with_ds_stats(attr_type, a_raw_1d) per_attr_norm[attr_type] = a_norm for g, group in enumerate(self.group_defs): gname = self._group_name(group) cols = [per_attr_norm[attr_type].view(-1, 1) for attr_type in group] A_norm = torch.cat(cols, dim=1) y_ch = ConnectedFilterPreprocessingCPUTreeTraversalFunction.apply( tree, A_norm, self._weights[gname], self._biases[gname], beta_f, self.clamp_logits ) if self.top_hat: x_bc = x[b, c].to(dtype=torch.float32, device=self.device) tt = self.tree_type if tt == "max-tree": y_out = x_bc - y_ch elif tt == "min-tree": y_out = y_ch - x_bc else: y_out = torch.abs(y_ch - x_bc) else: y_out = y_ch out[b, c * self.num_groups + g].copy_(y_out, non_blocking=True) if was_training: self.train() else: self.eval() return out
# ---------- save / load ----------
[docs] def save_params(self, path: str): """Save all group weights and biases.""" payload = { "weights": { name: p.detach().cpu() for name, p in self._weights.items() }, "biases": { name: p.detach().cpu() for name, p in self._biases.items() }, "scale_mode": self.scale_mode, } torch.save(payload, path)
[docs] def get_params(self): """Return CPU clones of group weights and biases.""" return ( { name: p.detach().cpu().clone() for name, p in self._weights.items() }, { name: p.detach().cpu().clone() for name, p in self._biases.items() }, )
# ---------- cached-normalization utilities ----------
[docs] def refresh_cached_normalization(self): """Recompute normalized attributes for every cached sample.""" for key, per_attr_raw in self._base_attrs.items(): per_attr_norm = {} for attr_type, a_raw_2d in per_attr_raw.items(): a_raw_1d = a_raw_2d.view(-1) a_norm = self._normalize_with_ds_stats(attr_type, a_raw_1d) per_attr_norm[attr_type] = a_norm self._norm_attrs[key] = per_attr_norm self._norm_epoch_by_key[key] = self._stats_epoch
# ---------- initialization helpers ---------- @staticmethod def _logit(p: float) -> float: """Return a numerically clipped logit for probability ``p``.""" p = max(min(float(p), 1.0 - 1e-6), 1e-6) return math.log(p / (1.0 - p)) @torch.no_grad() def init_identity_with_bias(self, p0: float = 0.995): """Initialize near identity by using only a positive bias. For each group, weights are set to zero and bias is set to ``logit(p0) / beta_f``. """ L = self._logit(p0) / float(self.beta_f) for group in self.group_defs: gname = self._group_name(group) self._weights[gname].zero_() self._biases[gname].fill_(L) @torch.no_grad() def init_identity_bias_zero(self, p0: float = 0.99): """Initialize near identity with zero bias under hybrid normalization. This assumes normalized attributes live in ``[a, 1]`` where ``a = hybrid_floor_a``. Each group receives constant weights ``c = logit(p0) / (beta_f * K * a)`` and zero bias. """ if self.scale_mode != "hybrid": print("[init_identity_bias_zero] Warning: this initializer assumes scale_mode == 'hybrid'.") a = max(min(self.hybrid_floor_a, 1.0), 1e-6) L = self._logit(p0) / float(self.beta_f) for group in self.group_defs: gname = self._group_name(group) K = len(group) c = L / (K * a) self._weights[gname].fill_(c) self._biases[gname].zero_()
[docs] def build_dataloader_cached(self, dataloader): """Precompute trees and attributes for a DataLoader. The returned DataLoader wraps the original dataset and emits ``((x, idx), y)`` batches. This method builds per-sample trees and attributes, updates dataset normalization statistics, refreshes cached normalized attributes, and freezes statistics when preprocessing ends. """ from torch.utils.data import Dataset, DataLoader dataset_wrapped = IndexedDatasetWrapper(dataloader.dataset) new_loader = DataLoader( dataset_wrapped, batch_size=dataloader.batch_size, shuffle=False, num_workers=dataloader.num_workers, pin_memory=dataloader.pin_memory, drop_last=False, collate_fn=dataloader.collate_fn, persistent_workers=getattr(dataloader, "persistent_workers", False), ) print(f"[ConnectedFilterPreprocessingLayerWithCPUTreeTraversal] Preprocessing dataset using mode '{self.scale_mode}'...") self._stats_frozen = False total_batches = len(new_loader) with torch.no_grad(): for batch_i, ((x, idx), y) in enumerate(new_loader): B, C, H, W = x.shape for b in range(B): for c in range(C): key = f"{int(idx[b])}_{c}" self._ensure_tree_and_attr(key, x[b, c]) if (batch_i + 1) % 10 == 0 or batch_i == total_batches - 1: print(f" [{batch_i+1}/{total_batches}] batches processed.") self.freeze_ds_stats() self.refresh_cached_normalization() print(f"[ConnectedFilterPreprocessingLayerWithCPUTreeTraversal] Full and normalized cache with '{self.scale_mode}'.") return new_loader
CFPLayerWithCPUTreeTraversal = ConnectedFilterPreprocessingLayerWithCPUTreeTraversal __all__ = [ 'ConnectedFilterPreprocessingLayerWithCPUTreeTraversal', 'CFPLayerWithCPUTreeTraversal', 'ConnectedFilterPreprocessingCPUTreeTraversalFunction', ]