Source code for run

try:
    # It is safer to import comet before all other imports.
    import comet_ml  # noqa
except ImportError:
    print(
        "Warning: package comet_ml not found. This may break things if you use a comet callback."
    )

from enum import Enum

import os
import sys
from glob import glob
import dotenv
import hydra
from omegaconf import DictConfig
from tqdm import tqdm

from myria3d.utils import utils
from myria3d.pctl.dataset.hdf5 import create_hdf5
from myria3d.pctl.dataset.utils import get_las_paths_by_split_dict

TASK_NAME_DETECTION_STRING = "task.task_name="
DEFAULT_DIRECTORY = "trained_model_assets/"
DEFAULT_CONFIG_FILE = "proto151_V2.0_epoch_100_Myria3DV3.1.0_predict_config_V3.7.0.yaml"
DEFAULT_CHECKPOINT = "proto151_V2.0_epoch_100_Myria3DV3.1.0.ckpt"
DEFAULT_ENV = "placeholder.env"


[docs]class TASK_NAMES(Enum): FIT = "fit" TEST = "test" FINETUNE = "finetune" PREDICT = "predict" HDF5 = "create_hdf5"
DEFAULT_TASK = TASK_NAMES.FIT.value log = utils.get_logger(__name__)
[docs]@hydra.main(config_path="configs/", config_name="config.yaml") def launch_train( config: DictConfig, ): # pragma: no cover (it's just an initialyzer of a class/method tested elsewhere) """Training, evaluation, testing, or finetuning of a neural network.""" # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 from myria3d.train import train utils.extras(config) # Pretty print config using Rich library if config.get("print_config"): utils.print_config(config, resolve=False) return train(config)
[docs]@hydra.main(config_path=DEFAULT_DIRECTORY, config_name=DEFAULT_CONFIG_FILE) def launch_predict(config: DictConfig): """Infer probabilities and automate semantic segmentation decisions on unseen data.""" # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 from myria3d.predict import predict # hydra changes current directory, so we make sure the checkpoint has an absolute path if not os.path.isabs(config.predict.ckpt_path): config.predict.ckpt_path = os.path.join( os.path.dirname(__file__), config.predict.ckpt_path ) # Pretty print config using Rich library if config.get("print_config"): utils.print_config(config, resolve=False) # Iterate over the files and predict. src_las_iterable = glob(config.predict.src_las) for config.predict.src_las in tqdm(src_las_iterable): predict(config)
[docs]@hydra.main(config_path="configs/", config_name="config.yaml") def launch_hdf5(config: DictConfig): """Build an HDF5 file from a directory with las files.""" # Pretty print config using Rich library if config.get("print_config"): utils.print_config(config, resolve=False) las_paths_by_split_dict = get_las_paths_by_split_dict( config.datamodule.get("data_dir"), config.datamodule.get("split_csv_path") ) create_hdf5( las_paths_by_split_dict=las_paths_by_split_dict, hdf5_file_path=config.datamodule.get("hdf5_file_path"), epsg=config.datamodule.get("epsg"), tile_width=config.datamodule.get("tile_width"), subtile_width=config.datamodule.get("subtile_width"), pre_filter=hydra.utils.instantiate(config.datamodule.get("pre_filter")), subtile_overlap_train=config.datamodule.get("subtile_overlap_train"), points_pre_transform=hydra.utils.instantiate( config.datamodule.get("points_pre_transform") ), )
if __name__ == "__main__": task_name = "fit" for arg in sys.argv: if TASK_NAME_DETECTION_STRING in arg: _, task_name = arg.split("=") break log.info(f"Task: {task_name}") if task_name in [TASK_NAMES.FIT.value, TASK_NAMES.TEST.value, TASK_NAMES.FINETUNE.value]: # load environment variables from `.env` file if it exists # recursively searches for `.env` in all folders starting from work dir dotenv.load_dotenv(override=True) launch_train() elif task_name == TASK_NAMES.PREDICT.value: dotenv.load_dotenv(os.path.join(DEFAULT_DIRECTORY, DEFAULT_ENV)) launch_predict() elif task_name == TASK_NAMES.HDF5.value: launch_hdf5() else: choices = ", ".join(task.value for task in TASK_NAMES) raise ValueError( f"Task '{task_name}' is not known. Specify a valid task name via task.task_name. Valid choices are: {choices})" )