Source code for run

try:
    # It is safer to import comet before all other imports.
    import comet_ml  # noqa
except ImportError:
    print(
        "Warning: package comet_ml not found. This may break things if you use a comet callback."
    )

import os
import sys
from enum import Enum
from glob import glob

import dotenv
import hydra
from omegaconf import DictConfig
from tqdm import tqdm

from myria3d.pctl.dataset.hdf5 import create_hdf5
from myria3d.pctl.dataset.utils import get_las_paths_by_split_dict
from myria3d.utils import utils

TASK_NAME_DETECTION_STRING = "task.task_name="
DEFAULT_DIRECTORY = "trained_model_assets/"
DEFAULT_CONFIG_FILE = "FRACTAL-LidarHD_7cl_randlanet-inference-Myria3DV3.8.yaml"
DEFAULT_CHECKPOINT = "FRACTAL-LidarHD_7cl_randlanet.ckpt"
DEFAULT_ENV = "placeholder.env"


[docs] class TASK_NAMES(Enum): FIT = "fit" TEST = "test" FINETUNE = "finetune" PREDICT = "predict" HDF5 = "create_hdf5"
DEFAULT_TASK = TASK_NAMES.FIT.value log = utils.get_logger(__name__)
[docs] @hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.1") def launch_train( config: DictConfig, ): # pragma: no cover (it's just an initialyzer of a class/method tested elsewhere) """Training, evaluation, testing, or finetuning of a neural network.""" # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 from myria3d.train import train utils.extras(config) # Pretty print config using Rich library if config.get("print_config"): utils.print_config(config, resolve=False) return train(config)
[docs] @hydra.main(config_path=DEFAULT_DIRECTORY, config_name=DEFAULT_CONFIG_FILE, version_base="1.1") def launch_predict(config: DictConfig): """Infer probabilities and automate semantic segmentation decisions on unseen data.""" # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 from myria3d.predict import predict # hydra changes current directory, so we make sure the checkpoint has an absolute path if not os.path.isabs(config.predict.ckpt_path): config.predict.ckpt_path = os.path.join( os.path.dirname(__file__), config.predict.ckpt_path ) # Pretty print config using Rich library if config.get("print_config"): utils.print_config(config, resolve=False) # Iterate over the files and predict. src_las_iterable = glob(config.predict.src_las) for config.predict.src_las in tqdm(src_las_iterable): predict(config)
[docs] @hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.1") def launch_hdf5(config: DictConfig): """Build an HDF5 file from a directory with las files.""" # Pretty print config using Rich library if config.get("print_config"): utils.print_config(config, resolve=False) las_paths_by_split_dict = get_las_paths_by_split_dict( config.datamodule.get("data_dir"), config.datamodule.get("split_csv_path") ) create_hdf5( las_paths_by_split_dict=las_paths_by_split_dict, hdf5_file_path=config.datamodule.get("hdf5_file_path"), epsg=config.datamodule.get("epsg"), tile_width=config.datamodule.get("tile_width"), subtile_width=config.datamodule.get("subtile_width"), pre_filter=hydra.utils.instantiate(config.datamodule.get("pre_filter")), subtile_overlap_train=config.datamodule.get("subtile_overlap_train"), points_pre_transform=hydra.utils.instantiate( config.datamodule.get("points_pre_transform") ), )
if __name__ == "__main__": task_name = "fit" for arg in sys.argv: if TASK_NAME_DETECTION_STRING in arg: _, task_name = arg.split("=") break log.info(f"Task: {task_name}") if task_name in [TASK_NAMES.FIT.value, TASK_NAMES.TEST.value, TASK_NAMES.FINETUNE.value]: # load environment variables from `.env` file if it exists # recursively searches for `.env` in all folders starting from work dir dotenv.load_dotenv(override=True) launch_train() elif task_name == TASK_NAMES.PREDICT.value: dotenv.load_dotenv(os.path.join(DEFAULT_DIRECTORY, DEFAULT_ENV)) launch_predict() elif task_name == TASK_NAMES.HDF5.value: launch_hdf5() else: choices = ", ".join(task.value for task in TASK_NAMES) raise ValueError( f"Task '{task_name}' is not known. Specify a valid task name via task.task_name. Valid choices are: {choices})" )