try:
# It is safer to import comet before all other imports.
import comet_ml # noqa
except ImportError:
print(
"Warning: package comet_ml not found. This may break things if you use a comet callback."
)
import os
import sys
from enum import Enum
from glob import glob
import dotenv
import hydra
from omegaconf import DictConfig
from tqdm import tqdm
from myria3d.pctl.dataset.hdf5 import create_hdf5
from myria3d.pctl.dataset.utils import get_las_paths_by_split_dict
from myria3d.utils import utils
TASK_NAME_DETECTION_STRING = "task.task_name="
DEFAULT_DIRECTORY = "trained_model_assets/"
DEFAULT_CONFIG_FILE = "FRACTAL-LidarHD_7cl_randlanet-inference-Myria3DV3.8.yaml"
DEFAULT_CHECKPOINT = "FRACTAL-LidarHD_7cl_randlanet.ckpt"
DEFAULT_ENV = "placeholder.env"
[docs]
class TASK_NAMES(Enum):
FIT = "fit"
TEST = "test"
FINETUNE = "finetune"
PREDICT = "predict"
HDF5 = "create_hdf5"
DEFAULT_TASK = TASK_NAMES.FIT.value
log = utils.get_logger(__name__)
[docs]
@hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.1")
def launch_train(
config: DictConfig,
): # pragma: no cover (it's just an initialyzer of a class/method tested elsewhere)
"""Training, evaluation, testing, or finetuning of a neural network."""
# Imports should be nested inside @hydra.main to optimize tab completion
# Read more here: https://github.com/facebookresearch/hydra/issues/934
from myria3d.train import train
utils.extras(config)
# Pretty print config using Rich library
if config.get("print_config"):
utils.print_config(config, resolve=False)
return train(config)
[docs]
@hydra.main(config_path=DEFAULT_DIRECTORY, config_name=DEFAULT_CONFIG_FILE, version_base="1.1")
def launch_predict(config: DictConfig):
"""Infer probabilities and automate semantic segmentation decisions on unseen data."""
# Imports should be nested inside @hydra.main to optimize tab completion
# Read more here: https://github.com/facebookresearch/hydra/issues/934
from myria3d.predict import predict
# hydra changes current directory, so we make sure the checkpoint has an absolute path
if not os.path.isabs(config.predict.ckpt_path):
config.predict.ckpt_path = os.path.join(
os.path.dirname(__file__), config.predict.ckpt_path
)
# Pretty print config using Rich library
if config.get("print_config"):
utils.print_config(config, resolve=False)
# Iterate over the files and predict.
src_las_iterable = glob(config.predict.src_las)
for config.predict.src_las in tqdm(src_las_iterable):
predict(config)
[docs]
@hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.1")
def launch_hdf5(config: DictConfig):
"""Build an HDF5 file from a directory with las files."""
# Pretty print config using Rich library
if config.get("print_config"):
utils.print_config(config, resolve=False)
las_paths_by_split_dict = get_las_paths_by_split_dict(
config.datamodule.get("data_dir"), config.datamodule.get("split_csv_path")
)
create_hdf5(
las_paths_by_split_dict=las_paths_by_split_dict,
hdf5_file_path=config.datamodule.get("hdf5_file_path"),
epsg=config.datamodule.get("epsg"),
tile_width=config.datamodule.get("tile_width"),
subtile_width=config.datamodule.get("subtile_width"),
pre_filter=hydra.utils.instantiate(config.datamodule.get("pre_filter")),
subtile_overlap_train=config.datamodule.get("subtile_overlap_train"),
points_pre_transform=hydra.utils.instantiate(
config.datamodule.get("points_pre_transform")
),
)
if __name__ == "__main__":
task_name = "fit"
for arg in sys.argv:
if TASK_NAME_DETECTION_STRING in arg:
_, task_name = arg.split("=")
break
log.info(f"Task: {task_name}")
if task_name in [TASK_NAMES.FIT.value, TASK_NAMES.TEST.value, TASK_NAMES.FINETUNE.value]:
# load environment variables from `.env` file if it exists
# recursively searches for `.env` in all folders starting from work dir
dotenv.load_dotenv(override=True)
launch_train()
elif task_name == TASK_NAMES.PREDICT.value:
dotenv.load_dotenv(os.path.join(DEFAULT_DIRECTORY, DEFAULT_ENV))
launch_predict()
elif task_name == TASK_NAMES.HDF5.value:
launch_hdf5()
else:
choices = ", ".join(task.value for task in TASK_NAMES)
raise ValueError(
f"Task '{task_name}' is not known. Specify a valid task name via task.task_name. Valid choices are: {choices})"
)