Source code for tensionflow.processing

import logging

import tensorflow as tf

from tensionflow import util

logger = logging.getLogger(__name__)

# tf.logging._logger.propagate = False


[docs]class Preprocessor: def __init__(self, func, flatten=False): self.func = func self.flatten = flatten
[docs] def apply(self, dataset): ds = dataset.map(self.func) if self.flatten: ds = ds.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x)) return ds
[docs]class PythonPreprocessor(Preprocessor): def __init__(self, func, output_dtypes, output_shapes, *args, **kwargs): self.output_dtypes = output_dtypes self.output_shapes = output_shapes super().__init__(func, *args, **kwargs)
[docs] def apply(self, dataset): num_inputs = max(len(dataset.output_shapes), 1) ds = dataset.map(util.wrap_tf_py_func(self.func, Tout=self.output_dtypes[:num_inputs])) ds = ds.map(lambda *args: tuple(tf.reshape(arg, self.output_shapes[i]) for i, arg in enumerate(args))) if self.flatten: ds = ds.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x)) return ds