Source code for rastervision.core.data.raster_transformer.stats_transformer
from typing import TYPE_CHECKING, Sequence
import numpy as np
from rastervision.core.data.raster_transformer import RasterTransformer
from rastervision.core.raster_stats import RasterStats
from rastervision.pipeline.utils import repr_with_args
if TYPE_CHECKING:
from typing import Self
from rastervision.core.data import RasterSource
[docs]class StatsTransformer(RasterTransformer):
"""Transforms non-uint8 to uint8 values using channel statistics.
This works as follows:
- Convert pixel values to z-scores using channel means and standard
deviations.
- Clip z-scores to the specified number of standard deviations (default 3)
on each side.
- Scale values to 0-255 and cast to uint8.
This transformation is not applied to NODATA pixels (assumed to be pixels
with all values equal to zero).
"""
[docs] def __init__(self,
means: Sequence[float],
stds: Sequence[float],
max_stds: float = 3.):
"""Construct a new StatsTransformer.
Args:
means: Channel means.
means: Channel standard deviations.
max_stds: Number of standard deviations to clip the distribution to
on both sides. Defaults to 3.
"""
# shape = (1, 1, num_channels)
self.means = np.array(means, dtype=float)
self.stds = np.array(stds, dtype=float)
self.max_stds = max_stds
[docs] def transform(self, chip: np.ndarray) -> np.ndarray:
"""Clip values to +-max_stds std devs and convert to uint8 (0-255).
Args:
chip: Array of shape (..., H, W, C).
Returns:
Array of shape (..., H, W, C)
"""
if chip.dtype == np.uint8:
return chip
means = self.means
stds = self.stds
max_stds = self.max_stds
# Don't transform NODATA zero values.
nodata_mask = chip == 0
# Subtract mean and divide by std to get zscores.
chip = chip.astype(float)
chip -= means
chip /= stds
# Make zscores that fall between -max_stds and max_stds span 0 to 255.
# range: (-max_stds, max_stds)
chip = np.clip(chip, -max_stds, max_stds, out=chip)
# range: [0, 2 * max_stds]
chip += max_stds
# range: [0, 1]
chip /= (2 * max_stds)
# range: [0, 255]
chip *= 255
chip = chip.astype(np.uint8)
chip[nodata_mask] = 0
return chip
[docs] @classmethod
def from_raster_sources(cls,
raster_sources: list['RasterSource'],
sample_prob: float | None = 0.1,
max_stds: float = 3.,
chip_sz: int = 300) -> 'Self':
"""Build with stats from the given raster sources.
Args:
raster_sources (list[RasterSource]): List of raster sources to
compute stats from.
sample_prob (float): Fraction of each raster to sample
for computing stats. For details see docs for
RasterStats.compute(). Defaults to 0.1.
max_stds (float): Number of standard deviations to clip
the distribution to on both sides. Defaults to 3.
Returns:
StatsTransformer: A StatsTransformer.
"""
stats = RasterStats()
stats.compute(
raster_sources=raster_sources,
sample_prob=sample_prob,
chip_sz=chip_sz)
stats_transformer = StatsTransformer.from_raster_stats(
stats, max_stds=max_stds)
return stats_transformer
[docs] @classmethod
def from_stats_json(cls,
uri: str,
channel_order: list[int] | None = None,
**kwargs) -> 'Self':
"""Build with stats from a JSON file.
The file is expected to be in the same format as written by
:meth:`.RasterStats.save`.
Args:
uri: URI of the JSON file.
channel_order: Channel order to apply to the means and stds in the
file.
**kwargs: Extra args for :meth:`.__init__`.
Returns:
A StatsTransformer.
"""
stats = RasterStats.load(uri)
stats_transformer = StatsTransformer.from_raster_stats(
stats, channel_order=channel_order, **kwargs)
return stats_transformer
[docs] @classmethod
def from_raster_stats(cls,
stats: RasterStats,
channel_order: list[int] | None = None,
**kwargs) -> 'Self':
"""Build with stats from a :class:`.RasterStats` instance.
The file is expected to be in the same format as written by
:meth:`.RasterStats.save`.
Args:
stats: A :class:`.RasterStats` instance with non-None stats.
channel_order: Channel order to apply to the means and stds in the
:class:`.RasterStats`.
**kwargs: Extra args for :meth:`.__init__`.
Returns:
A StatsTransformer.
"""
means, stds = stats.means, stats.stds
if channel_order is not None:
means = means[channel_order]
stds = stds[channel_order]
stats_transformer = StatsTransformer(means, stds, **kwargs)
return stats_transformer
@property
def stats(self) -> RasterStats:
"""Current statistics as a :class:`.RasterStats` instance."""
return RasterStats(self.means, self.stds)
def __repr__(self) -> str:
return repr_with_args(
self, means=self.means, stds=self.stds, max_stds=self.max_stds)