Source code for rastervision.core.raster_stats

from typing import (TYPE_CHECKING, Iterable, Iterator, Optional, Sequence,
                    Tuple, Union)

import numpy as np
from import tqdm

from rastervision.pipeline.utils import repr_with_args
from rastervision.pipeline.file_system import file_to_json, json_to_file
from import ensure_json_serializable

    from import Box
    from import RasterSource

[docs]class RasterStats: """Band-wise means and standard deviations."""
[docs] def __init__(self, means: Optional[np.ndarray] = None, stds: Optional[np.ndarray] = None, counts: Optional[np.ndarray] = None): """Constructor. Args: means (Optional[np.ndarray]): Band means. Defaults to None. stds (Optional[np.ndarray]): Band standard deviations. Defaults to None. counts (Optional[np.ndarray]): Band pixel counts (used to compute the specified means and stds). Defaults to None. """ self.means = means self.stds = stds self.counts = counts
[docs] @classmethod def load(cls, stats_uri: str) -> 'RasterStats': """Load stats from file.""" stats_json = file_to_json(stats_uri) assert 'means' in stats_json and 'stds' in stats_json stats = RasterStats( means=stats_json['means'], stds=stats_json['stds'], counts=stats_json.get('counts')) return stats
[docs] def compute(self, raster_sources: Sequence['RasterSource'], sample_prob: Optional[float] = None, chip_sz: int = 300, stride: Optional[int] = None, nodata_value: Optional[float] = 0) -> None: """Compute the mean and stds over all the raster_sources. This ignores NODATA values if nodata_value is not None. If sample_prob is set, then a subset of each scene is used to compute stats which speeds up the computation. Roughly speaking, if sample_prob=0.5, then half the pixels in the scene will be used. More precisely, the number of chips is equal to sample_prob * (width * height / 300^2), or 1, whichever is greater. Each chip is uniformly sampled from the scene with replacement. Otherwise, it uses a sliding window over the entire scene to compute stats. Args: raster_sources Sequence['RasterSource']: List of RasterSources. sample_prob (Optional[float]): Pixel sampling probability. See notes above. Defaults to None. nodata_value (Optional[float]): NODATA value. If set, these pixels will be ignored when computing stats. """ if sample_prob is None: if stride is None: stride = chip_sz chip_stream = sliding_chip_stream( raster_sources, chip_sz, stride, nodata_value=nodata_value) else: chip_stream = random_chip_stream( raster_sources, chip_sz, sample_prob, nodata_value=nodata_value) means, vars, counts = self.compute_from_chips( chip_stream, running_mean=self.means, running_var=self.vars, running_count=self.counts) if means is None or vars is None: raise ValueError('No valid chips found in raster sources to ' 'compute stats from. This may be because all ' 'sampled chips were entirely composed of NODATA ' 'pixels.') self.means = means self.stds = np.sqrt(vars) self.counts = counts
[docs] def compute_from_chips( self, chips: Iterable[np.ndarray], running_mean: Optional[np.ndarray] = None, running_var: Optional[np.ndarray] = None, running_count: Optional[np.ndarray] = None) -> Union[Tuple[ None, None, None], Tuple[np.ndarray, np.ndarray, np.ndarray]]: """Compute running mean and var from chips in stream.""" with tqdm(chips, desc='Analyzing chips') as bar: for chip in bar: num_channels = chip.shape[-1] # (..., H, W, C) --> (... * H * W, C) pixels = chip.reshape(-1, num_channels) stats = self.compute_from_pixels(pixels, running_mean, running_var, running_count) running_mean, running_var, running_count = stats return running_mean, running_var, running_count
[docs] def compute_from_pixels(self, pixels: np.ndarray, running_mean: Optional[np.ndarray] = None, running_var: Optional[np.ndarray] = None, running_count: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Update running mean and var from pixel values.""" running_stats = [running_mean, running_var, running_count] has_running_stats = any(s is not None for s in running_stats) has_all_running_stats = all(s is not None for s in running_stats) if has_running_stats and not has_all_running_stats: raise ValueError('Provide either none or all running stats.') channel_means = np.nanmean(pixels, axis=0) channel_vars = np.nanvar(pixels, axis=0) channel_counts = np.sum(~np.isnan(pixels), axis=0) if not has_running_stats: return channel_means, channel_vars, channel_counts running_var = parallel_variance(channel_means, channel_counts, channel_vars, running_mean, running_count, running_var) running_mean = parallel_mean(channel_means, channel_counts, running_mean, running_count) running_count += channel_counts return running_mean, running_var, running_count
[docs] def to_dict(self) -> dict: stats_dict = dict(means=self.means, stds=self.stds, counts=self.counts) return stats_dict
[docs] def save(self, stats_uri: str) -> None: """Save stats to file.""" assert self.means is not None and self.stds is not None stats_dict = self.to_dict() stats_dict = ensure_json_serializable(stats_dict) json_to_file(stats_dict, stats_uri)
@property def vars(self) -> Optional[np.ndarray]: """Channel variances, if self.stds is set.""" if self.stds is None: return None return self.stds**2 def __repr__(self) -> str: return repr_with_args(self, **self.to_dict())
[docs]def parallel_variance(mean_a, count_a, var_a, mean_b, count_b, var_b): """Compute the variance based on stats from two partitions of the data. See "Parallel Algorithm" in Args: mean_a: the mean of partition a count_a: the number of elements in partition a var_a: the variance of partition a mean_b: the mean of partition b count_b: the number of elements in partition b var_b: the variance of partition b Return: the variance of the two partitions if they were combined """ delta = mean_b - mean_a m_a = var_a * (count_a - 1) m_b = var_b * (count_b - 1) M2 = m_a + m_b + delta**2 * count_a * count_b / (count_a + count_b) var = M2 / (count_a + count_b - 1) return var
[docs]def parallel_mean(mean_a, count_a, mean_b, count_b): """Compute the mean based on stats from two partitions of the data. See "Parallel Algorithm" in Args: mean_a: the mean of partition a count_a: the number of elements in partition a mean_b: the mean of partition b count_b: the number of elements in partition b Return: the mean of the two partitions if they were combined """ mean = (count_a * mean_a + count_b * mean_b) / (count_a + count_b) return mean
[docs]def sliding_chip_stream( raster_sources: Iterable['RasterSource'], chip_sz: int, stride: int, nodata_value: Optional[float] = 0) -> Iterator[np.ndarray]: """Get stream of chips using a sliding window.""" for raster_source in raster_sources: windows = raster_source.extent.get_windows(chip_sz, stride) for window in windows: chip = get_chip(raster_source, window, nodata_value=nodata_value) if chip is None: continue yield chip
[docs]def random_chip_stream( raster_sources: Iterable['RasterSource'], chip_sz: int, sample_prob: float, nodata_value: Optional[float] = 0) -> Iterator[np.ndarray]: """Get random stream of chips.""" for raster_source in raster_sources: extent = raster_source.extent num_chips_to_sample = get_num_chips_to_sample(extent, chip_sz, sample_prob) if num_chips_to_sample == 0: windows = [extent] else: windows = [ extent.make_random_square(chip_sz) for _ in range(num_chips_to_sample) ] for window in windows: chip = get_chip(raster_source, window, nodata_value=nodata_value) if chip is None: continue yield chip
[docs]def get_chip(raster_source: 'RasterSource', window: 'Box', nodata_value: Optional[float] = 0) -> Optional[np.ndarray]: """Return chip or None if all values are NODATA.""" chip = raster_source.get_raw_chip(window).astype(float) if nodata_value is None: return chip chip[chip == nodata_value] = np.nan all_nan_pixels = np.all(np.isnan(chip)) if all_nan_pixels: return None return chip
[docs]def get_num_chips_to_sample(extent: 'Box', chip_sz: int, sample_prob: float) -> int: num_pixels_total = extent.area num_pixels_per_chip = chip_sz**2 if num_pixels_per_chip > num_pixels_total: return 0 num_chips_total = (num_pixels_total / num_pixels_per_chip) num_chips_to_sample = round(sample_prob * num_chips_total) return max(1, num_chips_to_sample)