Source code for rastervision.core.data.raster_transformer.cast_transformer

from typing import Optional

from rastervision.core.data.raster_transformer.raster_transformer \
    import RasterTransformer
from rastervision.pipeline.utils import repr_with_args

import numpy as np


[docs]class CastTransformer(RasterTransformer): """Casts chips to the specified dtype."""
[docs] def __init__(self, to_dtype: str): """Constructor. Args: to_dtype: (str) dtype to cast the chips to. """ self.to_dtype = np.dtype(to_dtype)
def __repr__(self): return repr_with_args(self, to_dtype=str(self.to_dtype))
[docs] def transform(self, chip: np.ndarray, channel_order: Optional[list] = None) -> np.ndarray: """Cast chip to self.to_dtype. Args: chip: ndarray of shape [height, width, channels] Returns: [height, width, channels] numpy array """ return chip.astype(self.to_dtype)