Source code for rastervision.core.data.raster_transformer.nan_transformer

import numpy as np

from rastervision.core.data.raster_transformer.raster_transformer \
    import RasterTransformer


[docs]class NanTransformer(RasterTransformer): """Removes NaN values from float raster."""
[docs] def __init__(self, to_value: float = 0.0): """Constructor. Args: to_value: NaN values are replaced with this. """ self.to_value = to_value
[docs] def transform(self, chip): """Removes NaN values. Args: chip: Array of shape (..., H, W, C). Returns: Array of shape (..., H, W, C) """ nan_mask = np.isnan(chip) chip[nan_mask] = self.to_value return chip