Viewing file: _collections.py (28.47 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
# util/_collections.py # Copyright (C) 2005-2019 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Collection classes and helpers."""
from __future__ import absolute_import
import operator import types import weakref
from .compat import binary_types from .compat import collections_abc from .compat import itertools_filterfalse from .compat import py2k from .compat import string_types from .compat import threading
EMPTY_SET = frozenset()
class AbstractKeyedTuple(tuple): __slots__ = ()
def keys(self): """Return a list of string key names for this :class:`.KeyedTuple`.
.. seealso::
:attr:`.KeyedTuple._fields`
"""
return list(self._fields)
class KeyedTuple(AbstractKeyedTuple): """``tuple`` subclass that adds labeled names.
E.g.::
>>> k = KeyedTuple([1, 2, 3], labels=["one", "two", "three"]) >>> k.one 1 >>> k.two 2
Result rows returned by :class:`.Query` that contain multiple ORM entities and/or column expressions make use of this class to return rows.
The :class:`.KeyedTuple` exhibits similar behavior to the ``collections.namedtuple()`` construct provided in the Python standard library, however is architected very differently. Unlike ``collections.namedtuple()``, :class:`.KeyedTuple` is does not rely on creation of custom subtypes in order to represent a new series of keys, instead each :class:`.KeyedTuple` instance receives its list of keys in place. The subtype approach of ``collections.namedtuple()`` introduces significant complexity and performance overhead, which is not necessary for the :class:`.Query` object's use case.
.. seealso::
:ref:`ormtutorial_querying`
"""
def __new__(cls, vals, labels=None): t = tuple.__new__(cls, vals) if labels: t.__dict__.update(zip(labels, vals)) else: labels = [] t.__dict__["_labels"] = labels return t
@property def _fields(self): """Return a tuple of string key names for this :class:`.KeyedTuple`.
This method provides compatibility with ``collections.namedtuple()``.
.. seealso::
:meth:`.KeyedTuple.keys`
""" return tuple([l for l in self._labels if l is not None])
def __setattr__(self, key, value): raise AttributeError("Can't set attribute: %s" % key)
def _asdict(self): """Return the contents of this :class:`.KeyedTuple` as a dictionary.
This method provides compatibility with ``collections.namedtuple()``, with the exception that the dictionary returned is **not** ordered.
""" return {key: self.__dict__[key] for key in self.keys()}
class _LW(AbstractKeyedTuple): __slots__ = ()
def __new__(cls, vals): return tuple.__new__(cls, vals)
def __reduce__(self): # for pickling, degrade down to the regular # KeyedTuple, thus avoiding anonymous class pickling # difficulties return KeyedTuple, (list(self), self._real_fields)
def _asdict(self): """Return the contents of this :class:`.KeyedTuple` as a dictionary."""
d = dict(zip(self._real_fields, self)) d.pop(None, None) return d
class ImmutableContainer(object): def _immutable(self, *arg, **kw): raise TypeError("%s object is immutable" % self.__class__.__name__)
__delitem__ = __setitem__ = __setattr__ = _immutable
class immutabledict(ImmutableContainer, dict):
clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
def __new__(cls, *args): new = dict.__new__(cls) dict.__init__(new, *args) return new
def __init__(self, *args): pass
def __reduce__(self): return immutabledict, (dict(self),)
def union(self, d): if not d: return self elif not self: if isinstance(d, immutabledict): return d else: return immutabledict(d) else: d2 = immutabledict(self) dict.update(d2, d) return d2
def __repr__(self): return "immutabledict(%s)" % dict.__repr__(self)
class Properties(object): """Provide a __getattr__/__setattr__ interface over a dict."""
__slots__ = ("_data",)
def __init__(self, data): object.__setattr__(self, "_data", data)
def __len__(self): return len(self._data)
def __iter__(self): return iter(list(self._data.values()))
def __dir__(self): return dir(super(Properties, self)) + [ str(k) for k in self._data.keys() ]
def __add__(self, other): return list(self) + list(other)
def __setitem__(self, key, obj): self._data[key] = obj
def __getitem__(self, key): return self._data[key]
def __delitem__(self, key): del self._data[key]
def __setattr__(self, key, obj): self._data[key] = obj
def __getstate__(self): return {"_data": self._data}
def __setstate__(self, state): object.__setattr__(self, "_data", state["_data"])
def __getattr__(self, key): try: return self._data[key] except KeyError: raise AttributeError(key)
def __contains__(self, key): return key in self._data
def as_immutable(self): """Return an immutable proxy for this :class:`.Properties`."""
return ImmutableProperties(self._data)
def update(self, value): self._data.update(value)
def get(self, key, default=None): if key in self: return self[key] else: return default
def keys(self): return list(self._data)
def values(self): return list(self._data.values())
def items(self): return list(self._data.items())
def has_key(self, key): return key in self._data
def clear(self): self._data.clear()
class OrderedProperties(Properties): """Provide a __getattr__/__setattr__ interface with an OrderedDict as backing store."""
__slots__ = ()
def __init__(self): Properties.__init__(self, OrderedDict())
class ImmutableProperties(ImmutableContainer, Properties): """Provide immutable dict/object attribute to an underlying dictionary."""
__slots__ = ()
class OrderedDict(dict): """A dict that returns keys/values/items in the order they were added."""
__slots__ = ("_list",)
def __reduce__(self): return OrderedDict, (self.items(),)
def __init__(self, ____sequence=None, **kwargs): self._list = [] if ____sequence is None: if kwargs: self.update(**kwargs) else: self.update(____sequence, **kwargs)
def clear(self): self._list = [] dict.clear(self)
def copy(self): return self.__copy__()
def __copy__(self): return OrderedDict(self)
def sort(self, *arg, **kw): self._list.sort(*arg, **kw)
def update(self, ____sequence=None, **kwargs): if ____sequence is not None: if hasattr(____sequence, "keys"): for key in ____sequence.keys(): self.__setitem__(key, ____sequence[key]) else: for key, value in ____sequence: self[key] = value if kwargs: self.update(kwargs)
def setdefault(self, key, value): if key not in self: self.__setitem__(key, value) return value else: return self.__getitem__(key)
def __iter__(self): return iter(self._list)
def keys(self): return list(self)
def values(self): return [self[key] for key in self._list]
def items(self): return [(key, self[key]) for key in self._list]
if py2k:
def itervalues(self): return iter(self.values())
def iterkeys(self): return iter(self)
def iteritems(self): return iter(self.items())
def __setitem__(self, key, obj): if key not in self: try: self._list.append(key) except AttributeError: # work around Python pickle loads() with # dict subclass (seems to ignore __setstate__?) self._list = [key] dict.__setitem__(self, key, obj)
def __delitem__(self, key): dict.__delitem__(self, key) self._list.remove(key)
def pop(self, key, *default): present = key in self value = dict.pop(self, key, *default) if present: self._list.remove(key) return value
def popitem(self): item = dict.popitem(self) self._list.remove(item[0]) return item
class OrderedSet(set): def __init__(self, d=None): set.__init__(self) self._list = [] if d is not None: self._list = unique_list(d) set.update(self, self._list) else: self._list = []
def add(self, element): if element not in self: self._list.append(element) set.add(self, element)
def remove(self, element): set.remove(self, element) self._list.remove(element)
def insert(self, pos, element): if element not in self: self._list.insert(pos, element) set.add(self, element)
def discard(self, element): if element in self: self._list.remove(element) set.remove(self, element)
def clear(self): set.clear(self) self._list = []
def __getitem__(self, key): return self._list[key]
def __iter__(self): return iter(self._list)
def __add__(self, other): return self.union(other)
def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
def update(self, iterable): for e in iterable: if e not in self: self._list.append(e) set.add(self, e) return self
__ior__ = update
def union(self, other): result = self.__class__(self) result.update(other) return result
__or__ = union
def intersection(self, other): other = set(other) return self.__class__(a for a in self if a in other)
__and__ = intersection
def symmetric_difference(self, other): other = set(other) result = self.__class__(a for a in self if a not in other) result.update(a for a in other if a not in self) return result
__xor__ = symmetric_difference
def difference(self, other): other = set(other) return self.__class__(a for a in self if a not in other)
__sub__ = difference
def intersection_update(self, other): other = set(other) set.intersection_update(self, other) self._list = [a for a in self._list if a in other] return self
__iand__ = intersection_update
def symmetric_difference_update(self, other): set.symmetric_difference_update(self, other) self._list = [a for a in self._list if a in self] self._list += [a for a in other._list if a in self] return self
__ixor__ = symmetric_difference_update
def difference_update(self, other): set.difference_update(self, other) self._list = [a for a in self._list if a in self] return self
__isub__ = difference_update
class IdentitySet(object): """A set that considers only object id() for uniqueness.
This strategy has edge cases for builtin types- it's possible to have two 'foo' strings in one of these sets, for example. Use sparingly.
"""
_working_set = set
def __init__(self, iterable=None): self._members = dict() if iterable: for o in iterable: self.add(o)
def add(self, value): self._members[id(value)] = value
def __contains__(self, value): return id(value) in self._members
def remove(self, value): del self._members[id(value)]
def discard(self, value): try: self.remove(value) except KeyError: pass
def pop(self): try: pair = self._members.popitem() return pair[1] except KeyError: raise KeyError("pop from an empty set")
def clear(self): self._members.clear()
def __cmp__(self, other): raise TypeError("cannot compare sets using cmp()")
def __eq__(self, other): if isinstance(other, IdentitySet): return self._members == other._members else: return False
def __ne__(self, other): if isinstance(other, IdentitySet): return self._members != other._members else: return True
def issubset(self, iterable): other = type(self)(iterable)
if len(self) > len(other): return False for m in itertools_filterfalse( other._members.__contains__, iter(self._members.keys()) ): return False return True
def __le__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return self.issubset(other)
def __lt__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return len(self) < len(other) and self.issubset(other)
def issuperset(self, iterable): other = type(self)(iterable)
if len(self) < len(other): return False
for m in itertools_filterfalse( self._members.__contains__, iter(other._members.keys()) ): return False return True
def __ge__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return self.issuperset(other)
def __gt__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return len(self) > len(other) and self.issuperset(other)
def union(self, iterable): result = type(self)() # testlib.pragma exempt:__hash__ members = self._member_id_tuples() other = _iter_id(iterable) result._members.update(self._working_set(members).union(other)) return result
def __or__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return self.union(other)
def update(self, iterable): self._members = self.union(iterable)._members
def __ior__(self, other): if not isinstance(other, IdentitySet): return NotImplemented self.update(other) return self
def difference(self, iterable): result = type(self)() # testlib.pragma exempt:__hash__ members = self._member_id_tuples() other = _iter_id(iterable) result._members.update(self._working_set(members).difference(other)) return result
def __sub__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return self.difference(other)
def difference_update(self, iterable): self._members = self.difference(iterable)._members
def __isub__(self, other): if not isinstance(other, IdentitySet): return NotImplemented self.difference_update(other) return self
def intersection(self, iterable): result = type(self)() # testlib.pragma exempt:__hash__ members = self._member_id_tuples() other = _iter_id(iterable) result._members.update(self._working_set(members).intersection(other)) return result
def __and__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return self.intersection(other)
def intersection_update(self, iterable): self._members = self.intersection(iterable)._members
def __iand__(self, other): if not isinstance(other, IdentitySet): return NotImplemented self.intersection_update(other) return self
def symmetric_difference(self, iterable): result = type(self)() # testlib.pragma exempt:__hash__ members = self._member_id_tuples() other = _iter_id(iterable) result._members.update( self._working_set(members).symmetric_difference(other) ) return result
def _member_id_tuples(self): return ((id(v), v) for v in self._members.values())
def __xor__(self, other): if not isinstance(other, IdentitySet): return NotImplemented return self.symmetric_difference(other)
def symmetric_difference_update(self, iterable): self._members = self.symmetric_difference(iterable)._members
def __ixor__(self, other): if not isinstance(other, IdentitySet): return NotImplemented self.symmetric_difference(other) return self
def copy(self): return type(self)(iter(self._members.values()))
__copy__ = copy
def __len__(self): return len(self._members)
def __iter__(self): return iter(self._members.values())
def __hash__(self): raise TypeError("set objects are unhashable")
def __repr__(self): return "%s(%r)" % (type(self).__name__, list(self._members.values()))
class WeakSequence(object): def __init__(self, __elements=()): self._storage = [ weakref.ref(element, self._remove) for element in __elements ]
def append(self, item): self._storage.append(weakref.ref(item, self._remove))
def _remove(self, ref): self._storage.remove(ref)
def __len__(self): return len(self._storage)
def __iter__(self): return ( obj for obj in (ref() for ref in self._storage) if obj is not None )
def __getitem__(self, index): try: obj = self._storage[index] except KeyError: raise IndexError("Index %s out of range" % index) else: return obj()
class OrderedIdentitySet(IdentitySet): class _working_set(OrderedSet): # a testing pragma: exempt the OIDS working set from the test suite's # "never call the user's __hash__" assertions. this is a big hammer, # but it's safe here: IDS operates on (id, instance) tuples in the # working set. __sa_hash_exempt__ = True
def __init__(self, iterable=None): IdentitySet.__init__(self) self._members = OrderedDict() if iterable: for o in iterable: self.add(o)
class PopulateDict(dict): """A dict which populates missing values via a creation function.
Note the creation function takes a key, unlike collections.defaultdict.
"""
def __init__(self, creator): self.creator = creator
def __missing__(self, key): self[key] = val = self.creator(key) return val
# Define collections that are capable of storing # ColumnElement objects as hashable keys/elements. # At this point, these are mostly historical, things # used to be more complicated. column_set = set column_dict = dict ordered_column_set = OrderedSet populate_column_dict = PopulateDict
_getters = PopulateDict(operator.itemgetter)
_property_getters = PopulateDict( lambda idx: property(operator.itemgetter(idx)) )
def unique_list(seq, hashfunc=None): seen = set() seen_add = seen.add if not hashfunc: return [x for x in seq if x not in seen and not seen_add(x)] else: return [ x for x in seq if hashfunc(x) not in seen and not seen_add(hashfunc(x)) ]
class UniqueAppender(object): """Appends items to a collection ensuring uniqueness.
Additional appends() of the same object are ignored. Membership is determined by identity (``is a``) not equality (``==``). """
def __init__(self, data, via=None): self.data = data self._unique = {} if via: self._data_appender = getattr(data, via) elif hasattr(data, "append"): self._data_appender = data.append elif hasattr(data, "add"): self._data_appender = data.add
def append(self, item): id_ = id(item) if id_ not in self._unique: self._data_appender(item) self._unique[id_] = True
def __iter__(self): return iter(self.data)
def coerce_generator_arg(arg): if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): return list(arg[0]) else: return arg
def to_list(x, default=None): if x is None: return default if not isinstance(x, collections_abc.Iterable) or isinstance( x, string_types + binary_types ): return [x] elif isinstance(x, list): return x else: return list(x)
def has_intersection(set_, iterable): r"""return True if any items of set\_ are present in iterable.
Goes through special effort to ensure __hash__ is not called on items in iterable that don't support it.
""" # TODO: optimize, write in C, etc. return bool(set_.intersection([i for i in iterable if i.__hash__]))
def to_set(x): if x is None: return set() if not isinstance(x, set): return set(to_list(x)) else: return x
def to_column_set(x): if x is None: return column_set() if not isinstance(x, column_set): return column_set(to_list(x)) else: return x
def update_copy(d, _new=None, **kw): """Copy the given dict and update with the given values."""
d = d.copy() if _new: d.update(_new) d.update(**kw) return d
def flatten_iterator(x): """Given an iterator of which further sub-elements may also be iterators, flatten the sub-elements into a single iterator.
""" for elem in x: if not isinstance(elem, str) and hasattr(elem, "__iter__"): for y in flatten_iterator(elem): yield y else: yield elem
class LRUCache(dict): """Dictionary with 'squishy' removal of least recently used items.
Note that either get() or [] should be used here, but generally its not safe to do an "in" check first as the dictionary can change subsequent to that call.
"""
__slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
def __init__(self, capacity=100, threshold=0.5, size_alert=None): self.capacity = capacity self.threshold = threshold self.size_alert = size_alert self._counter = 0 self._mutex = threading.Lock()
def _inc_counter(self): self._counter += 1 return self._counter
def get(self, key, default=None): item = dict.get(self, key, default) if item is not default: item[2] = self._inc_counter() return item[1] else: return default
def __getitem__(self, key): item = dict.__getitem__(self, key) item[2] = self._inc_counter() return item[1]
def values(self): return [i[1] for i in dict.values(self)]
def setdefault(self, key, value): if key in self: return self[key] else: self[key] = value return value
def __setitem__(self, key, value): item = dict.get(self, key) if item is None: item = [key, value, self._inc_counter()] dict.__setitem__(self, key, item) else: item[1] = value self._manage_size()
@property def size_threshold(self): return self.capacity + self.capacity * self.threshold
def _manage_size(self): if not self._mutex.acquire(False): return try: size_alert = bool(self.size_alert) while len(self) > self.capacity + self.capacity * self.threshold: if size_alert: size_alert = False self.size_alert(self) by_counter = sorted( dict.values(self), key=operator.itemgetter(2), reverse=True ) for item in by_counter[self.capacity :]: try: del self[item[0]] except KeyError: # deleted elsewhere; skip continue finally: self._mutex.release()
_lw_tuples = LRUCache(100)
def lightweight_named_tuple(name, fields): hash_ = (name,) + tuple(fields) tp_cls = _lw_tuples.get(hash_) if tp_cls: return tp_cls
tp_cls = type( name, (_LW,), dict( [ (field, _property_getters[idx]) for idx, field in enumerate(fields) if field is not None ] + [("__slots__", ())] ), )
tp_cls._real_fields = fields tp_cls._fields = tuple([f for f in fields if f is not None])
_lw_tuples[hash_] = tp_cls return tp_cls
class ScopedRegistry(object): """A Registry that can store one or multiple instances of a single class on the basis of a "scope" function.
The object implements ``__call__`` as the "getter", so by calling ``myregistry()`` the contained object is returned for the current scope.
:param createfunc: a callable that returns a new object to be placed in the registry
:param scopefunc: a callable that will return a key to store/retrieve an object. """
def __init__(self, createfunc, scopefunc): """Construct a new :class:`.ScopedRegistry`.
:param createfunc: A creation function that will generate a new value for the current scope, if none is present.
:param scopefunc: A function that returns a hashable token representing the current scope (such as, current thread identifier).
""" self.createfunc = createfunc self.scopefunc = scopefunc self.registry = {}
def __call__(self): key = self.scopefunc() try: return self.registry[key] except KeyError: return self.registry.setdefault(key, self.createfunc())
def has(self): """Return True if an object is present in the current scope."""
return self.scopefunc() in self.registry
def set(self, obj): """Set the value for the current scope."""
self.registry[self.scopefunc()] = obj
def clear(self): """Clear the current scope, if any."""
try: del self.registry[self.scopefunc()] except KeyError: pass
class ThreadLocalRegistry(ScopedRegistry): """A :class:`.ScopedRegistry` that uses a ``threading.local()`` variable for storage.
"""
def __init__(self, createfunc): self.createfunc = createfunc self.registry = threading.local()
def __call__(self): try: return self.registry.value except AttributeError: val = self.registry.value = self.createfunc() return val
def has(self): return hasattr(self.registry, "value")
def set(self, obj): self.registry.value = obj
def clear(self): try: del self.registry.value except AttributeError: pass
def _iter_id(iterable): """Generator: ((id(o), o) for o in iterable)."""
for item in iterable: yield id(item), item
def has_dupes(sequence, target): """Given a sequence and search object, return True if there's more than one, False if zero or one of them.
""" # compare to .index version below, this version introduces less function # overhead and is usually the same speed. At 15000 items (way bigger than # a relationship-bound collection in memory usually is) it begins to # fall behind the other version only by microseconds. c = 0 for item in sequence: if item is target: c += 1 if c > 1: return True return False
# .index version. the two __contains__ calls as well # as .index() and isinstance() slow this down. # def has_dupes(sequence, target): # if target not in sequence: # return False # elif not isinstance(sequence, collections_abc.Sequence): # return False # # idx = sequence.index(target) # return target in sequence[idx + 1:]
|