Source code for tensionflow.util

import collections
import logging
import tensorflow as tf
import numpy as np
from bidict import bidict

logger = logging.getLogger(__name__)


[docs]def identity_func(*args): return args
[docs]def default_of_type(dtype=int): if dtype == tf.string: return 'default_value' try: return dtype() except TypeError: pass try: return dtype.as_numpy_dtype() except (AttributeError, TypeError): pass return 0
[docs]def indexify(labels, label_dict=None): if label_dict is None: distinct = set() for y in labels: if isinstance(y, str): distinct.add(y) else: for label in y: distinct.add(label) distinct = sorted(distinct) label_dict = bidict((label, index) for index, label in enumerate(distinct)) logger.info('Converting labels to index range: [%s-%s]', min(label_dict.values()), max(label_dict.values())) y = [map_if_collection(label_dict.get, label) for label in labels] return y, label_dict
[docs]def one_hotify(labels, label_dict=None): if label_dict is None: labels, label_dict = indexify(labels, label_dict) y = np.stack([one_hot(label, len(label_dict.inv)) for label in labels]) return y, label_dict
[docs]def one_hot(index_labels, num_labels): y = np.zeros([num_labels], dtype=np.dtype('uint8')) y[np.asarray(index_labels)] = 1 return y
[docs]def one_hot_to_some_hot(Y): return Y / Y.sum(axis=1, keepdims=True)
[docs]def map_if_collection(func, obj): if not isinstance(obj, (str, bytes)) and isinstance(obj, collections.Iterable): return tuple(map(func, obj)) else: return func(obj)
def _dtype_feature(ndarray): """match appropriate tf.train.Feature class with dtype of ndarray. """ if not isinstance(ndarray, np.ndarray): logger.info('Converting %s to ndarray', ndarray) ndarray = np.array([ndarray]) dtype = ndarray.dtype if np.issubdtype(dtype, np.float): return tf.train.Feature(float_list=tf.train.FloatList(value=ndarray)) elif np.issubdtype(dtype, np.integer): return tf.train.Feature(int64_list=tf.train.Int64List(value=ndarray)) else: return tf.train.Feature(bytes_list=tf.train.BytesList(value=ndarray))
[docs]def wrap_tf_py_func(py_func, Tout): def f(*args): return tf.py_func(py_func, inp=args, Tout=Tout) return f
[docs]def get_all_subclasses(cls): all_subclasses = {} for subclass in cls.__subclasses__(): all_subclasses[subclass.__name__] = subclass all_subclasses.update(get_all_subclasses(subclass)) return all_subclasses