structs.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import itertools
  2. from .compat import collections_abc
  3. class DirectedGraph(object):
  4. """A graph structure with directed edges."""
  5. def __init__(self):
  6. self._vertices = set()
  7. self._forwards = {} # <key> -> Set[<key>]
  8. self._backwards = {} # <key> -> Set[<key>]
  9. def __iter__(self):
  10. return iter(self._vertices)
  11. def __len__(self):
  12. return len(self._vertices)
  13. def __contains__(self, key):
  14. return key in self._vertices
  15. def copy(self):
  16. """Return a shallow copy of this graph."""
  17. other = DirectedGraph()
  18. other._vertices = set(self._vertices)
  19. other._forwards = {k: set(v) for k, v in self._forwards.items()}
  20. other._backwards = {k: set(v) for k, v in self._backwards.items()}
  21. return other
  22. def add(self, key):
  23. """Add a new vertex to the graph."""
  24. if key in self._vertices:
  25. raise ValueError("vertex exists")
  26. self._vertices.add(key)
  27. self._forwards[key] = set()
  28. self._backwards[key] = set()
  29. def remove(self, key):
  30. """Remove a vertex from the graph, disconnecting all edges from/to it."""
  31. self._vertices.remove(key)
  32. for f in self._forwards.pop(key):
  33. self._backwards[f].remove(key)
  34. for t in self._backwards.pop(key):
  35. self._forwards[t].remove(key)
  36. def connected(self, f, t):
  37. return f in self._backwards[t] and t in self._forwards[f]
  38. def connect(self, f, t):
  39. """Connect two existing vertices.
  40. Nothing happens if the vertices are already connected.
  41. """
  42. if t not in self._vertices:
  43. raise KeyError(t)
  44. self._forwards[f].add(t)
  45. self._backwards[t].add(f)
  46. def iter_edges(self):
  47. for f, children in self._forwards.items():
  48. for t in children:
  49. yield f, t
  50. def iter_children(self, key):
  51. return iter(self._forwards[key])
  52. def iter_parents(self, key):
  53. return iter(self._backwards[key])
  54. class IteratorMapping(collections_abc.Mapping):
  55. def __init__(self, mapping, accessor, appends=None):
  56. self._mapping = mapping
  57. self._accessor = accessor
  58. self._appends = appends or {}
  59. def __repr__(self):
  60. return "IteratorMapping({!r}, {!r}, {!r})".format(
  61. self._mapping,
  62. self._accessor,
  63. self._appends,
  64. )
  65. def __bool__(self):
  66. return bool(self._mapping or self._appends)
  67. __nonzero__ = __bool__ # XXX: Python 2.
  68. def __contains__(self, key):
  69. return key in self._mapping or key in self._appends
  70. def __getitem__(self, k):
  71. try:
  72. v = self._mapping[k]
  73. except KeyError:
  74. return iter(self._appends[k])
  75. return itertools.chain(self._accessor(v), self._appends.get(k, ()))
  76. def __iter__(self):
  77. more = (k for k in self._appends if k not in self._mapping)
  78. return itertools.chain(self._mapping, more)
  79. def __len__(self):
  80. more = sum(1 for k in self._appends if k not in self._mapping)
  81. return len(self._mapping) + more
  82. class _FactoryIterableView(object):
  83. """Wrap an iterator factory returned by `find_matches()`.
  84. Calling `iter()` on this class would invoke the underlying iterator
  85. factory, making it a "collection with ordering" that can be iterated
  86. through multiple times, but lacks random access methods presented in
  87. built-in Python sequence types.
  88. """
  89. def __init__(self, factory):
  90. self._factory = factory
  91. self._iterable = None
  92. def __repr__(self):
  93. return "{}({})".format(type(self).__name__, list(self))
  94. def __bool__(self):
  95. try:
  96. next(iter(self))
  97. except StopIteration:
  98. return False
  99. return True
  100. __nonzero__ = __bool__ # XXX: Python 2.
  101. def __iter__(self):
  102. iterable = (
  103. self._factory() if self._iterable is None else self._iterable
  104. )
  105. self._iterable, current = itertools.tee(iterable)
  106. return current
  107. class _SequenceIterableView(object):
  108. """Wrap an iterable returned by find_matches().
  109. This is essentially just a proxy to the underlying sequence that provides
  110. the same interface as `_FactoryIterableView`.
  111. """
  112. def __init__(self, sequence):
  113. self._sequence = sequence
  114. def __repr__(self):
  115. return "{}({})".format(type(self).__name__, self._sequence)
  116. def __bool__(self):
  117. return bool(self._sequence)
  118. __nonzero__ = __bool__ # XXX: Python 2.
  119. def __iter__(self):
  120. return iter(self._sequence)
  121. def build_iter_view(matches):
  122. """Build an iterable view from the value returned by `find_matches()`."""
  123. if callable(matches):
  124. return _FactoryIterableView(matches)
  125. if not isinstance(matches, collections_abc.Sequence):
  126. matches = list(matches)
  127. return _SequenceIterableView(matches)