Source code for myria3d.pctl.transforms.compose

from typing import Callable, List

from torch_geometric.transforms import BaseTransform


[docs]class CustomCompose(BaseTransform): """ Composes several transforms together. Edited to bypass downstream transforms if None is returned by a transform. Args: transforms (List[Callable]): List of transforms to compose. """ def __init__(self, transforms: List[Callable]): self.transforms = transforms def __call__(self, data): for transform in self.transforms: if isinstance(data, (list, tuple)): data = [transform(d) for d in data] data = [d for d in data if d is not None and d.num_nodes != 0] if len(data) == 0: return None else: data = transform(data) if data is None or data.num_nodes == 0: return None return data