Source code for mtlearn.datasets

from __future__ import annotations

# %% setup
import os, glob, time
import numpy as np
import cv2
import torch
import abc
from . import morphology
from torch.utils.data import Dataset


def _split_indices(
    num_samples: int,
    test_size: float | int = 0.25,
    *,
    shuffle: bool = True,
    random_state: int | np.random.RandomState | None = 42,
) -> tuple[np.ndarray, np.ndarray]:
    """Split dataset indices without requiring scikit-learn at import time."""

    if num_samples <= 0:
        raise ValueError("Cannot split an empty dataset")

    if isinstance(test_size, float):
        if not 0.0 < test_size < 1.0:
            raise ValueError("test_size as a float must be between 0 and 1")
        num_test = int(np.ceil(num_samples * test_size))
    elif isinstance(test_size, int):
        if not 0 < test_size < num_samples:
            raise ValueError("test_size as an int must be between 1 and len(dataset) - 1")
        num_test = test_size
    else:
        raise TypeError("test_size must be a float or int")

    num_train = num_samples - num_test
    if num_train <= 0:
        raise ValueError("test_size leaves no samples for training")

    if shuffle:
        rng = (
            random_state
            if isinstance(random_state, np.random.RandomState)
            else np.random.RandomState(random_state)
        )
        permutation = rng.permutation(num_samples)
        test_idx = permutation[:num_test]
        train_idx = permutation[num_test:num_test + num_train]
    else:
        indices = np.arange(num_samples)
        train_idx = indices[:num_train]
        test_idx = indices[num_train:num_train + num_test]

    return train_idx, test_idx


# --------------------------
# 1) AttributeFilterDataset
# --------------------------
[docs] class AttributeFilterDataset(torch.utils.data.Dataset, abc.ABC): def __init__( self, root, tree_type, attributes: list, thresholds: dict, top_hat: bool = False, numRows: int = None, numCols: int = None, tos_interpolation=None, tos_infinity_seed_row: int = 0, tos_infinity_seed_col: int = 0, ): super(torch.utils.data.Dataset, self).__init__() self.root = root exts = ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tif", "*.tiff") paths = [] for e in exts: paths.extend(glob.glob(os.path.join(root, e))) if not paths: raise FileNotFoundError(f"Nenhuma imagem encontrada em {root}") paths.sort() self.paths = paths self.thresholds = thresholds self.attributes = attributes self.tree_type = morphology.normalize_tree_type(tree_type) if self.tree_type == "tree-of-shapes": self.tos_interpolation = morphology.normalize_tos_interpolation(tos_interpolation) else: self.tos_interpolation = tos_interpolation self.tos_infinity_seed_row = int(tos_infinity_seed_row) self.tos_infinity_seed_col = int(tos_infinity_seed_col) self.numRows=numRows self.numCols=numCols self.top_hat = top_hat def __len__(self): return len(self.paths) def __getitem__(self, idx): path = self.paths[idx] img_u8 = cv2.imread(path, cv2.IMREAD_GRAYSCALE) if self.numRows != None and self.numCols != None: img_u8 = cv2.resize(img_u8, (self.numCols, self.numRows)) if img_u8 is None: raise RuntimeError(f"Falha ao ler imagem: {path}") # ----- Filtragem ----- tree = morphology.build_tree( img_u8, self.tree_type, tos_interpolation=self.tos_interpolation, tos_infinity_seed_row=self.tos_infinity_seed_row, tos_infinity_seed_col=self.tos_infinity_seed_col, ) attr_idx, attr_values = morphology.compute_attributes(tree, self.attributes) # (numNodes, numAttributes) criterion = np.ones(attr_values.shape[0], dtype=bool) for attr_type in self.attributes: name = attr_type.name criterion = criterion & (attr_values[:, attr_idx[name]] > self.thresholds[name]) filter = morphology.create_attribute_filter(tree) img_out_u8 = filter.filteringSubtractiveRule(criterion) # (numRows,numCols) uint8 if self.top_hat == True: if(self.tree_type == "min-tree"): img_out_u8 = img_out_u8 - img_u8 elif(self.tree_type == "max-tree"): img_out_u8 = img_u8 - img_out_u8 else: img_out_u8 = np.abs(img_out_u8 - img_u8) img_out = torch.from_numpy(img_out_u8).to(torch.float32).unsqueeze(0) # (1,numRows,numCols) img_in = torch.from_numpy(img_u8).to(torch.float32).unsqueeze(0) # (1,numRows,numCols) return img_in, img_out, os.path.basename(path)
[docs] def train_test_split(self, test_size=0.25, shuffle=True, random_state=42): """ Divide este dataset em (train_dataset, test_dataset), preservando o __getitem__ e toda a lógica atual. Retorna Subset(self, indices). Parâmetros ---------- test_size : float|int Igual ao do sklearn: fração (0,1] ou número absoluto de amostras no teste. shuffle : bool Se True, embaralha antes de dividir. random_state : int Semente para reprodutibilidade. Retorna ------- (train_subset, test_subset) : (torch.utils.data.Subset, torch.utils.data.Subset) """ train_idx, test_idx = _split_indices( len(self), test_size=test_size, shuffle=shuffle, random_state=random_state # stratify=None # adicione aqui se precisar balancear por classe ) # Convertemos para listas para evitar problemas com Subset em algumas versões train_subset = torch.utils.data.Subset(self, train_idx.tolist()) test_subset = torch.utils.data.Subset(self, test_idx.tolist()) return train_subset, test_subset
[docs] class PairedImageDataset(Dataset): """ Lê pares de imagens no formato: 01_in.jpg, 01_target.jpg, 02_in.jpg, 02_target.jpg, ... Parâmetros ---------- root_dir : str Diretório contendo as imagens. numRows, numCols : int | None Tamanho desejado (aplica apenas se ambos definidos). grayscale_in : bool Se True, INPUT é carregado em escala de cinza (1 canal). grayscale_target : bool Se True, TARGET é carregado em escala de cinza (1 canal). invert_in : bool Se True, aplica negativo (255 - img) **antes** da normalização na entrada. invert_target : bool Se True, aplica negativo (255 - img) **antes** da normalização no alvo. extensions : tuple[str, ...] Extensões suportadas. dtype : torch.dtype Tipo dos tensores de saída (padrão float32). scale_0_1 : bool Se True, normaliza para [0, 1]; caso contrário, mantém [0,255] (float). prefix_in, prefix_target, suffix_in, suffix_target : str Regras de nome para localizar pares. """ def __init__( self, root_dir: str, numRows: int | None = None, numCols: int | None = None, *, grayscale_in: bool = True, grayscale_target: bool = True, invert_in: bool = False, invert_target: bool = False, extensions: tuple = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"), dtype: torch.dtype = torch.float32, scale_in: bool = True, scale_out: bool = True, prefix_in: str = "", prefix_target: str = "", suffix_in: str = "_in", suffix_target: str = "_target", ): self.root_dir = root_dir self.numRows = numRows self.numCols = numCols self.grayscale_in = bool(grayscale_in) self.grayscale_target = bool(grayscale_target) self.invert_in = bool(invert_in) self.invert_target = bool(invert_target) self.extensions = tuple(e.lower() for e in extensions) self.dtype = dtype self.scale_in = bool(scale_in) self.scale_out = bool(scale_out) self.prefix_in = prefix_in self.prefix_target = prefix_target self.suffix_in = suffix_in self.suffix_target = suffix_target if (self.numRows is None) ^ (self.numCols is None): print("[PairedImageDataset] Aviso: numRows e numCols devem ser ambos None ou ambos definidos. " "Resize será ignorado porque apenas um foi fornecido.") self.numRows = None self.numCols = None self.pairs = self._scan_pairs() if not self.pairs: raise RuntimeError( f"Nenhum par *_in / *_target encontrado em {root_dir} " f"com extensões {self.extensions}." ) # --------- API Dataset --------- def __len__(self): return len(self.pairs) def __getitem__(self, idx: int): in_path, tgt_path = self.pairs[idx] # Leitura com modo de cor independente img_in = self._read_image(in_path, self.grayscale_in) img_tg = self._read_image(tgt_path, self.grayscale_target) # Invertemos apenas conforme os flags correspondentes if self.invert_in: img_in = self._invert(img_in) if self.invert_target: img_tg = self._invert(img_tg) # Redimensionamento (se definido) if self.numRows is not None and self.numCols is not None: img_in = cv2.resize(img_in, (self.numCols, self.numRows), interpolation=cv2.INTER_AREA) img_tg = cv2.resize(img_tg, (self.numCols, self.numRows), interpolation=cv2.INTER_NEAREST) # Conversão para tensor (mantendo flags separados) tin = self._to_tensor(img_in, self.grayscale_in) ttg = self._to_tensor(img_tg, self.grayscale_target) # Normalização if self.scale_in: tin = tin / 255.0 if self.scale_out: ttg = ttg / 255.0 return tin, ttg, os.path.basename(in_path) # --------- Helpers --------- def _scan_pairs(self): pairs = [] suffix_in_len = len(self.suffix_in) for ext in self.extensions: pattern = os.path.join(self.root_dir, f"{self.prefix_in}*{self.suffix_in}{ext}") for fpath in glob.glob(pattern): base = os.path.basename(fpath) if not base.startswith(self.prefix_in) or not base.endswith(self.suffix_in + ext): continue stem = base[len(self.prefix_in):-suffix_in_len - len(ext)] target_path = None for ext_t in self.extensions: candidate = f"{self.prefix_target}{stem}{self.suffix_target}{ext_t}" cand = os.path.join(self.root_dir, candidate) if os.path.exists(cand): target_path = cand break if target_path is not None: pairs.append((fpath, target_path)) pairs.sort(key=lambda p: p[0]) return pairs def _read_image(self, path: str, grayscale_flag: bool) -> np.ndarray: """Lê com cv2 e retorna ndarray (H,W) se grayscale_flag, ou (H,W,3) RGB se color.""" if grayscale_flag: img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) if img is None: raise RuntimeError(f"Falha ao ler (grayscale) {path}") return img else: img = cv2.imread(path, cv2.IMREAD_COLOR) if img is None: raise RuntimeError(f"Falha ao ler (color) {path}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img @staticmethod def _invert(img: np.ndarray) -> np.ndarray: if img.dtype != np.uint8: img_u8 = np.clip(img, 0, 255).astype(np.uint8) return 255 - img_u8 return 255 - img def _to_tensor(self, img: np.ndarray, grayscale_flag: bool) -> torch.Tensor: """Converte ndarray para tensor (C,H,W): (1,H,W) se grayscale, senão (3,H,W).""" if grayscale_flag: tensor = torch.from_numpy(img).unsqueeze(0) # (1,H,W) else: tensor = torch.from_numpy(img).permute(2, 0, 1) # (3,H,W) tensor = tensor.to(self.dtype) return tensor
[docs] def train_test_split(self, test_size=0.25, shuffle=True, random_state=42): train_idx, test_idx = _split_indices( len(self), test_size=test_size, shuffle=shuffle, random_state=random_state ) return ( torch.utils.data.Subset(self, train_idx.tolist()), torch.utils.data.Subset(self, test_idx.tolist()), )