context.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. import subprocess
  3. import contextlib
  4. import functools
  5. import tempfile
  6. import shutil
  7. import operator
  8. @contextlib.contextmanager
  9. def pushd(dir):
  10. orig = os.getcwd()
  11. os.chdir(dir)
  12. try:
  13. yield dir
  14. finally:
  15. os.chdir(orig)
  16. @contextlib.contextmanager
  17. def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
  18. """
  19. Get a tarball, extract it, change to that directory, yield, then
  20. clean up.
  21. `runner` is the function to invoke commands.
  22. `pushd` is a context manager for changing the directory.
  23. """
  24. if target_dir is None:
  25. target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
  26. if runner is None:
  27. runner = functools.partial(subprocess.check_call, shell=True)
  28. # In the tar command, use --strip-components=1 to strip the first path and
  29. # then
  30. # use -C to cause the files to be extracted to {target_dir}. This ensures
  31. # that we always know where the files were extracted.
  32. runner('mkdir {target_dir}'.format(**vars()))
  33. try:
  34. getter = 'wget {url} -O -'
  35. extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
  36. cmd = ' | '.join((getter, extract))
  37. runner(cmd.format(compression=infer_compression(url), **vars()))
  38. with pushd(target_dir):
  39. yield target_dir
  40. finally:
  41. runner('rm -Rf {target_dir}'.format(**vars()))
  42. def infer_compression(url):
  43. """
  44. Given a URL or filename, infer the compression code for tar.
  45. """
  46. # cheat and just assume it's the last two characters
  47. compression_indicator = url[-2:]
  48. mapping = dict(gz='z', bz='j', xz='J')
  49. # Assume 'z' (gzip) if no match
  50. return mapping.get(compression_indicator, 'z')
  51. @contextlib.contextmanager
  52. def temp_dir(remover=shutil.rmtree):
  53. """
  54. Create a temporary directory context. Pass a custom remover
  55. to override the removal behavior.
  56. """
  57. temp_dir = tempfile.mkdtemp()
  58. try:
  59. yield temp_dir
  60. finally:
  61. remover(temp_dir)
  62. @contextlib.contextmanager
  63. def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
  64. """
  65. Check out the repo indicated by url.
  66. If dest_ctx is supplied, it should be a context manager
  67. to yield the target directory for the check out.
  68. """
  69. exe = 'git' if 'git' in url else 'hg'
  70. with dest_ctx() as repo_dir:
  71. cmd = [exe, 'clone', url, repo_dir]
  72. if branch:
  73. cmd.extend(['--branch', branch])
  74. devnull = open(os.path.devnull, 'w')
  75. stdout = devnull if quiet else None
  76. subprocess.check_call(cmd, stdout=stdout)
  77. yield repo_dir
  78. @contextlib.contextmanager
  79. def null():
  80. yield
  81. class ExceptionTrap:
  82. """
  83. A context manager that will catch certain exceptions and provide an
  84. indication they occurred.
  85. >>> with ExceptionTrap() as trap:
  86. ... raise Exception()
  87. >>> bool(trap)
  88. True
  89. >>> with ExceptionTrap() as trap:
  90. ... pass
  91. >>> bool(trap)
  92. False
  93. >>> with ExceptionTrap(ValueError) as trap:
  94. ... raise ValueError("1 + 1 is not 3")
  95. >>> bool(trap)
  96. True
  97. >>> with ExceptionTrap(ValueError) as trap:
  98. ... raise Exception()
  99. Traceback (most recent call last):
  100. ...
  101. Exception
  102. >>> bool(trap)
  103. False
  104. """
  105. exc_info = None, None, None
  106. def __init__(self, exceptions=(Exception,)):
  107. self.exceptions = exceptions
  108. def __enter__(self):
  109. return self
  110. @property
  111. def type(self):
  112. return self.exc_info[0]
  113. @property
  114. def value(self):
  115. return self.exc_info[1]
  116. @property
  117. def tb(self):
  118. return self.exc_info[2]
  119. def __exit__(self, *exc_info):
  120. type = exc_info[0]
  121. matches = type and issubclass(type, self.exceptions)
  122. if matches:
  123. self.exc_info = exc_info
  124. return matches
  125. def __bool__(self):
  126. return bool(self.type)
  127. def raises(self, func, *, _test=bool):
  128. """
  129. Wrap func and replace the result with the truth
  130. value of the trap (True if an exception occurred).
  131. First, give the decorator an alias to support Python 3.8
  132. Syntax.
  133. >>> raises = ExceptionTrap(ValueError).raises
  134. Now decorate a function that always fails.
  135. >>> @raises
  136. ... def fail():
  137. ... raise ValueError('failed')
  138. >>> fail()
  139. True
  140. """
  141. @functools.wraps(func)
  142. def wrapper(*args, **kwargs):
  143. with ExceptionTrap(self.exceptions) as trap:
  144. func(*args, **kwargs)
  145. return _test(trap)
  146. return wrapper
  147. def passes(self, func):
  148. """
  149. Wrap func and replace the result with the truth
  150. value of the trap (True if no exception).
  151. First, give the decorator an alias to support Python 3.8
  152. Syntax.
  153. >>> passes = ExceptionTrap(ValueError).passes
  154. Now decorate a function that always fails.
  155. >>> @passes
  156. ... def fail():
  157. ... raise ValueError('failed')
  158. >>> fail()
  159. False
  160. """
  161. return self.raises(func, _test=operator.not_)
  162. class suppress(contextlib.suppress, contextlib.ContextDecorator):
  163. """
  164. A version of contextlib.suppress with decorator support.
  165. >>> @suppress(KeyError)
  166. ... def key_error():
  167. ... {}['']
  168. >>> key_error()
  169. """