Source code for rastervision.core.data.raster_transformer.stats_transformer

from typing import TYPE_CHECKING, List, Optional, Sequence

import numpy as np

from rastervision.core.data.raster_transformer import RasterTransformer
from rastervision.core.raster_stats import RasterStats
from rastervision.pipeline.utils import repr_with_args

if TYPE_CHECKING:
    from rastervision.core.data import RasterSource


[docs]class StatsTransformer(RasterTransformer): """Transforms non-uint8 to uint8 values using channel statistics. This works as follows: - Convert pixel values to z-scores using channel means and standard deviations. - Clip z-scores to the specified number of standard deviations (default 3) on each side. - Scale values to 0-255 and cast to uint8. This transformation is not applied to NODATA pixels (assumed to be pixels with all values equal to zero). """
[docs] def __init__(self, means: Sequence[float], stds: Sequence[float], max_stds: float = 3.): """Construct a new StatsTransformer. Args: means (np.ndarray): Channel means. means (np.ndarray): Channel standard deviations. max_stds (float): Number of standard deviations to clip the distribution to on both sides. Defaults to 3. """ # shape = (1, 1, num_channels) self.means = np.array(means, dtype=float) self.stds = np.array(stds, dtype=float) self.max_stds = max_stds
[docs] def transform(self, chip: np.ndarray, channel_order: Optional[Sequence[int]] = None) -> np.ndarray: """Transform a chip. Transforms non-uint8 to uint8 values using raster_stats. Args: chip: ndarray of shape [height, width, channels] This is assumed to already have the channel_order applied to it if channel_order is set. In other words, channels should be equal to len(channel_order). channel_order: list of indices of channels that were extracted from the raw imagery. Returns: [height, width, channels] uint8 numpy array """ if chip.dtype == np.uint8: return chip means = self.means stds = self.stds max_stds = self.max_stds if channel_order is not None: means = means[channel_order] stds = stds[channel_order] # Don't transform NODATA zero values. nodata_mask = chip == 0 # Subtract mean and divide by std to get zscores. chip = chip.astype(float) chip -= means chip /= stds # Make zscores that fall between -max_stds and max_stds span 0 to 255. # range: (-max_stds, max_stds) chip = np.clip(chip, -max_stds, max_stds, out=chip) # range: [0, 2 * max_stds] chip += max_stds # range: [0, 1] chip /= (2 * max_stds) # range: [0, 255] chip *= 255 chip = chip.astype(np.uint8) chip[nodata_mask] = 0 return chip
[docs] @classmethod def from_raster_sources(cls, raster_sources: List['RasterSource'], sample_prob: Optional[float] = 0.1, max_stds: float = 3., chip_sz: int = 300) -> 'StatsTransformer': """Build with stats from the given raster sources. Args: raster_sources (List[RasterSource]): List of raster sources to compute stats from. sample_prob (float, optional): Fraction of each raster to sample for computing stats. For details see docs for RasterStats.compute(). Defaults to 0.1. max_stds (float, optional): Number of standard deviations to clip the distribution to on both sides. Defaults to 3. Returns: StatsTransformer: A StatsTransformer. """ stats = RasterStats() stats.compute( raster_sources=raster_sources, sample_prob=sample_prob, chip_sz=chip_sz) stats_transformer = StatsTransformer.from_raster_stats( stats, max_stds=max_stds) return stats_transformer
[docs] @classmethod def from_stats_json(cls, uri: str, **kwargs) -> 'StatsTransformer': """Build with stats from a JSON file. The file is expected to be in the same format as written by :meth:`.RasterStats.save`. Args: uri (str): URI of the JSON file. **kwargs: Extra args for :meth:`.__init__`. Returns: StatsTransformer: A StatsTransformer. """ stats = RasterStats.load(uri) stats_transformer = StatsTransformer.from_raster_stats(stats, **kwargs) return stats_transformer
[docs] @classmethod def from_raster_stats(cls, stats: RasterStats, **kwargs) -> 'StatsTransformer': """Build with stats from a :class:`.RasterStats` instance. The file is expected to be in the same format as written by :meth:`.RasterStats.save`. Args: stats (RasterStats): A :class:`.RasterStats` instance with non-None stats. **kwargs: Extra args for :meth:`.__init__`. Returns: StatsTransformer: A StatsTransformer. """ stats_transformer = StatsTransformer(stats.means, stats.stds, **kwargs) return stats_transformer
@property def stats(self) -> RasterStats: """Current statistics as a :class:`.RasterStats` instance.""" return RasterStats(self.means, self.stds) def __repr__(self) -> str: return repr_with_args( self, means=self.means, stds=self.stds, max_stds=self.max_stds)