nested.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
  2. def nested_compare(t, u):
  3. """
  4. Return whether nested structure of t1 and t2 matches.
  5. """
  6. if isinstance(t, (list, tuple)):
  7. if not isinstance(u, type(t)):
  8. return False
  9. if len(t) != len(u):
  10. return False
  11. for a, b in zip(t, u):
  12. if not nested_compare(a, b):
  13. return False
  14. return True
  15. if isinstance(t, dict):
  16. if not isinstance(u, dict):
  17. return False
  18. if set(t.keys()) != set(u.keys()):
  19. return False
  20. for k in t:
  21. if not nested_compare(t[k], u[k]):
  22. return False
  23. return True
  24. else:
  25. return True
  26. def nested_flatten(t):
  27. """
  28. Turn nested list/tuple/dict into a flat iterator.
  29. """
  30. if isinstance(t, (list, tuple)):
  31. for x in t:
  32. yield from nested_flatten(x)
  33. elif isinstance(t, dict):
  34. for k, v in sorted(t.items()):
  35. yield from nested_flatten(v)
  36. else:
  37. yield t
  38. def nested_pack(flat, structure):
  39. """
  40. Restore nested structure from flattened state
  41. :param flat: result of nested_flatten
  42. :param structure: used as example when recovering structure
  43. :returns: nested structure like :structure: filled with elements of :flat:
  44. """
  45. return _nested_pack(iter(flat), structure)
  46. def _nested_pack(flat_iter, structure):
  47. if is_namedtuple(structure):
  48. return type(structure)(*[
  49. _nested_pack(flat_iter, x)
  50. for x in structure]
  51. )
  52. elif isinstance(structure, (list, tuple)):
  53. return type(structure)(
  54. _nested_pack(flat_iter, x)
  55. for x in structure
  56. )
  57. elif isinstance(structure, dict):
  58. return {
  59. k: _nested_pack(flat_iter, v)
  60. for k, v in sorted(structure.items())
  61. }
  62. else:
  63. return next(flat_iter)
  64. def is_namedtuple(x):
  65. """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 ."""
  66. t = type(x)
  67. b = t.__bases__
  68. if len(b) != 1 or b[0] != tuple: return False
  69. f = getattr(t, '_fields', None)
  70. if not isinstance(f, tuple): return False
  71. return all(type(n) == str for n in f)
  72. def nested_map(fn, *t):
  73. # Check arguments.
  74. if not t:
  75. raise ValueError('Expected 2+ arguments, got 1')
  76. for i in range(1, len(t)):
  77. if not nested_compare(t[0], t[i]):
  78. msg = 'Nested structure of %r and %r differs'
  79. raise ValueError(msg % (t[0], t[i]))
  80. # Map.
  81. flat = map(nested_flatten, t)
  82. return nested_pack(map(fn, *flat), t[0])