扫码登录,获取cookies

This commit is contained in:
2026-03-09 16:10:29 +08:00
parent 754e720ba7
commit 8229208165
7775 changed files with 1150053 additions and 208 deletions

View File

@@ -0,0 +1,56 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Hypothesis is a library for writing unit tests which are parametrized by
some source of data.
It verifies your code against a wide range of input and minimizes any
failing examples it finds.
"""
from hypothesis._settings import HealthCheck, Phase, Verbosity, settings
from hypothesis.control import (
assume,
currently_in_test_context,
event,
note,
reject,
target,
)
from hypothesis.core import example, find, given, reproduce_failure, seed
from hypothesis.entry_points import run
from hypothesis.internal.entropy import register_random
from hypothesis.utils.conventions import infer
from hypothesis.version import __version__, __version_info__
__all__ = [
"HealthCheck",
"Phase",
"Verbosity",
"assume",
"currently_in_test_context",
"event",
"example",
"find",
"given",
"infer",
"note",
"register_random",
"reject",
"reproduce_failure",
"seed",
"settings",
"target",
"__version__",
"__version_info__",
]
run()
del run

View File

@@ -0,0 +1,738 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""The settings module configures runtime options for Hypothesis.
Either an explicit settings object can be used or the default object on
this module can be modified.
"""
import contextlib
import datetime
import inspect
import os
import warnings
from enum import Enum, EnumMeta, IntEnum, unique
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Collection,
Dict,
List,
Optional,
TypeVar,
Union,
)
import attr
from hypothesis.errors import (
HypothesisDeprecationWarning,
InvalidArgument,
InvalidState,
)
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.internal.validation import check_type, try_convert
from hypothesis.utils.conventions import not_set
from hypothesis.utils.dynamicvariables import DynamicVariable
if TYPE_CHECKING:
from hypothesis.database import ExampleDatabase
__all__ = ["settings"]
all_settings: Dict[str, "Setting"] = {}
T = TypeVar("T")
class settingsProperty:
def __init__(self, name, show_default):
self.name = name
self.show_default = show_default
def __get__(self, obj, type=None):
if obj is None:
return self
else:
try:
result = obj.__dict__[self.name]
# This is a gross hack, but it preserves the old behaviour that
# you can change the storage directory and it will be reflected
# in the default database.
if self.name == "database" and result is not_set:
from hypothesis.database import ExampleDatabase
result = ExampleDatabase(not_set)
return result
except KeyError:
raise AttributeError(self.name) from None
def __set__(self, obj, value):
obj.__dict__[self.name] = value
def __delete__(self, obj):
raise AttributeError(f"Cannot delete attribute {self.name}")
@property
def __doc__(self):
description = all_settings[self.name].description
default = (
repr(getattr(settings.default, self.name))
if self.show_default
else "(dynamically calculated)"
)
return f"{description}\n\ndefault value: ``{default}``"
default_variable = DynamicVariable(None)
class settingsMeta(type):
def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def default(cls):
v = default_variable.value
if v is not None:
return v
if hasattr(settings, "_current_profile"):
settings.load_profile(settings._current_profile)
assert default_variable.value is not None
return default_variable.value
def _assign_default_internal(cls, value):
default_variable.value = value
def __setattr__(cls, name, value):
if name == "default":
raise AttributeError(
"Cannot assign to the property settings.default - "
"consider using settings.load_profile instead."
)
elif not (isinstance(value, settingsProperty) or name.startswith("_")):
raise AttributeError(
f"Cannot assign hypothesis.settings.{name}={value!r} - the settings "
"class is immutable. You can change the global default "
"settings with settings.load_profile, or use @settings(...) "
"to decorate your test instead."
)
return super().__setattr__(name, value)
class settings(metaclass=settingsMeta):
"""A settings object configures options including verbosity, runtime controls,
persistence, determinism, and more.
Default values are picked up from the settings.default object and
changes made there will be picked up in newly created settings.
"""
__definitions_are_locked = False
_profiles: ClassVar[Dict[str, "settings"]] = {}
__module__ = "hypothesis"
def __getattr__(self, name):
if name in all_settings:
return all_settings[name].default
else:
raise AttributeError(f"settings has no attribute {name}")
def __init__(
self,
parent: Optional["settings"] = None,
*,
# This looks pretty strange, but there's good reason: we want Mypy to detect
# bad calls downstream, but not to freak out about the `= not_set` part even
# though it's not semantically valid to pass that as an argument value.
# The intended use is "like **kwargs, but more tractable for tooling".
max_examples: int = not_set, # type: ignore
derandomize: bool = not_set, # type: ignore
database: Optional["ExampleDatabase"] = not_set, # type: ignore
verbosity: "Verbosity" = not_set, # type: ignore
phases: Collection["Phase"] = not_set, # type: ignore
stateful_step_count: int = not_set, # type: ignore
report_multiple_bugs: bool = not_set, # type: ignore
suppress_health_check: Collection["HealthCheck"] = not_set, # type: ignore
deadline: Union[int, float, datetime.timedelta, None] = not_set, # type: ignore
print_blob: bool = not_set, # type: ignore
) -> None:
if parent is not None:
check_type(settings, parent, "parent")
if derandomize not in (not_set, False):
if database not in (not_set, None): # type: ignore
raise InvalidArgument(
"derandomize=True implies database=None, so passing "
f"{database=} too is invalid."
)
database = None
defaults = parent or settings.default
if defaults is not None:
for setting in all_settings.values():
value = locals()[setting.name]
if value is not_set:
object.__setattr__(
self, setting.name, getattr(defaults, setting.name)
)
else:
object.__setattr__(self, setting.name, setting.validator(value))
def __call__(self, test: T) -> T:
"""Make the settings object (self) an attribute of the test.
The settings are later discovered by looking them up on the test itself.
"""
# Aliasing as Any avoids mypy errors (attr-defined) when accessing and
# setting custom attributes on the decorated function or class.
_test: Any = test
# Using the alias here avoids a mypy error (return-value) later when
# ``test`` is returned, because this check results in type refinement.
if not callable(_test):
raise InvalidArgument(
"settings objects can be called as a decorator with @given, "
f"but decorated {test=} is not callable."
)
if inspect.isclass(test):
from hypothesis.stateful import RuleBasedStateMachine
if issubclass(_test, RuleBasedStateMachine):
attr_name = "_hypothesis_internal_settings_applied"
if getattr(test, attr_name, False):
raise InvalidArgument(
"Applying the @settings decorator twice would "
"overwrite the first version; merge their arguments "
"instead."
)
setattr(test, attr_name, True)
_test.TestCase.settings = self
return test # type: ignore
else:
raise InvalidArgument(
"@settings(...) can only be used as a decorator on "
"functions, or on subclasses of RuleBasedStateMachine."
)
if hasattr(_test, "_hypothesis_internal_settings_applied"):
# Can't use _hypothesis_internal_use_settings as an indicator that
# @settings was applied, because @given also assigns that attribute.
descr = get_pretty_function_description(test)
raise InvalidArgument(
f"{descr} has already been decorated with a settings object.\n"
f" Previous: {_test._hypothesis_internal_use_settings!r}\n"
f" This: {self!r}"
)
_test._hypothesis_internal_use_settings = self
_test._hypothesis_internal_settings_applied = True
return test
@classmethod
def _define_setting(
cls,
name,
description,
*,
default,
options=None,
validator=None,
show_default=True,
):
"""Add a new setting.
- name is the name of the property that will be used to access the
setting. This must be a valid python identifier.
- description will appear in the property's docstring
- default is the default value. This may be a zero argument
function in which case it is evaluated and its result is stored
the first time it is accessed on any given settings object.
"""
if settings.__definitions_are_locked:
raise InvalidState(
"settings have been locked and may no longer be defined."
)
if options is not None:
options = tuple(options)
assert default in options
def validator(value):
if value not in options:
msg = f"Invalid {name}, {value!r}. Valid options: {options!r}"
raise InvalidArgument(msg)
return value
else:
assert validator is not None
all_settings[name] = Setting(
name=name,
description=description.strip(),
default=default,
validator=validator,
)
setattr(settings, name, settingsProperty(name, show_default))
@classmethod
def lock_further_definitions(cls):
settings.__definitions_are_locked = True
def __setattr__(self, name, value):
raise AttributeError("settings objects are immutable")
def __repr__(self):
bits = sorted(f"{name}={getattr(self, name)!r}" for name in all_settings)
return "settings({})".format(", ".join(bits))
def show_changed(self):
bits = []
for name, setting in all_settings.items():
value = getattr(self, name)
if value != setting.default:
bits.append(f"{name}={value!r}")
return ", ".join(sorted(bits, key=len))
@staticmethod
def register_profile(
name: str,
parent: Optional["settings"] = None,
**kwargs: Any,
) -> None:
"""Registers a collection of values to be used as a settings profile.
Settings profiles can be loaded by name - for example, you might
create a 'fast' profile which runs fewer examples, keep the 'default'
profile, and create a 'ci' profile that increases the number of
examples and uses a different database to store failures.
The arguments to this method are exactly as for
:class:`~hypothesis.settings`: optional ``parent`` settings, and
keyword arguments for each setting that will be set differently to
parent (or settings.default, if parent is None).
"""
check_type(str, name, "name")
settings._profiles[name] = settings(parent=parent, **kwargs)
@staticmethod
def get_profile(name: str) -> "settings":
"""Return the profile with the given name."""
check_type(str, name, "name")
try:
return settings._profiles[name]
except KeyError:
raise InvalidArgument(f"Profile {name!r} is not registered") from None
@staticmethod
def load_profile(name: str) -> None:
"""Loads in the settings defined in the profile provided.
If the profile does not exist, InvalidArgument will be raised.
Any setting not defined in the profile will be the library
defined default for that setting.
"""
check_type(str, name, "name")
settings._current_profile = name
settings._assign_default_internal(settings.get_profile(name))
@contextlib.contextmanager
def local_settings(s):
with default_variable.with_value(s):
yield s
@attr.s()
class Setting:
name = attr.ib()
description = attr.ib()
default = attr.ib()
validator = attr.ib()
def _max_examples_validator(x):
check_type(int, x, name="max_examples")
if x < 1:
raise InvalidArgument(
f"max_examples={x!r} should be at least one. You can disable "
"example generation with the `phases` setting instead."
)
return x
settings._define_setting(
"max_examples",
default=100,
validator=_max_examples_validator,
description="""
Once this many satisfying examples have been considered without finding any
counter-example, Hypothesis will stop looking.
Note that we might call your test function fewer times if we find a bug early
or can tell that we've exhausted the search space; or more if we discard some
examples due to use of .filter(), assume(), or a few other things that can
prevent the test case from completing successfully.
The default value is chosen to suit a workflow where the test will be part of
a suite that is regularly executed locally or on a CI server, balancing total
running time against the chance of missing a bug.
If you are writing one-off tests, running tens of thousands of examples is
quite reasonable as Hypothesis may miss uncommon bugs with default settings.
For very complex code, we have observed Hypothesis finding novel bugs after
*several million* examples while testing :pypi:`SymPy`.
If you are running more than 100k examples for a test, consider using our
:ref:`integration for coverage-guided fuzzing <fuzz_one_input>` - it really
shines when given minutes or hours to run.
""",
)
settings._define_setting(
"derandomize",
default=False,
options=(True, False),
description="""
If True, seed Hypothesis' random number generator using a hash of the test
function, so that every run will test the same set of examples until you
update Hypothesis, Python, or the test function.
This allows you to `check for regressions and look for bugs
<https://blog.nelhage.com/post/two-kinds-of-testing/>`__ using
:ref:`separate settings profiles <settings_profiles>` - for example running
quick deterministic tests on every commit, and a longer non-deterministic
nightly testing run.
""",
)
def _validate_database(db):
from hypothesis.database import ExampleDatabase
if db is None or isinstance(db, ExampleDatabase):
return db
raise InvalidArgument(
"Arguments to the database setting must be None or an instance of "
f"ExampleDatabase. Try passing database=ExampleDatabase({db!r}), or "
"construct and use one of the specific subclasses in "
"hypothesis.database"
)
settings._define_setting(
"database",
default=not_set,
show_default=False,
description="""
An instance of :class:`~hypothesis.database.ExampleDatabase` that will be
used to save examples to and load previous examples from. May be ``None``
in which case no storage will be used.
See the :doc:`example database documentation <database>` for a list of built-in
example database implementations, and how to define custom implementations.
""",
validator=_validate_database,
)
@unique
class Phase(IntEnum):
explicit = 0 #: controls whether explicit examples are run.
reuse = 1 #: controls whether previous examples will be reused.
generate = 2 #: controls whether new examples will be generated.
target = 3 #: controls whether examples will be mutated for targeting.
shrink = 4 #: controls whether examples will be shrunk.
explain = 5 #: controls whether Hypothesis attempts to explain test failures.
def __repr__(self):
return f"Phase.{self.name}"
class HealthCheckMeta(EnumMeta):
def __iter__(self):
deprecated = (HealthCheck.return_value, HealthCheck.not_a_test_method)
return iter(x for x in super().__iter__() if x not in deprecated)
@unique
class HealthCheck(Enum, metaclass=HealthCheckMeta):
"""Arguments for :attr:`~hypothesis.settings.suppress_health_check`.
Each member of this enum is a type of health check to suppress.
"""
def __repr__(self):
return f"{self.__class__.__name__}.{self.name}"
@classmethod
def all(cls) -> List["HealthCheck"]:
# Skipping of deprecated attributes is handled in HealthCheckMeta.__iter__
note_deprecation(
"`Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.",
since="2023-04-16",
has_codemod=True,
stacklevel=1,
)
return list(HealthCheck)
data_too_large = 1
"""Checks if too many examples are aborted for being too large.
This is measured by the number of random choices that Hypothesis makes
in order to generate something, not the size of the generated object.
For example, choosing a 100MB object from a predefined list would take
only a few bits, while generating 10KB of JSON from scratch might trigger
this health check.
"""
filter_too_much = 2
"""Check for when the test is filtering out too many examples, either
through use of :func:`~hypothesis.assume()` or :ref:`filter() <filtering>`,
or occasionally for Hypothesis internal reasons."""
too_slow = 3
"""Check for when your data generation is extremely slow and likely to hurt
testing."""
return_value = 5
"""Deprecated; we always error if a test returns a non-None value."""
large_base_example = 7
"""Checks if the natural example to shrink towards is very large."""
not_a_test_method = 8
"""Deprecated; we always error if :func:`@given <hypothesis.given>` is applied
to a method defined by :class:`python:unittest.TestCase` (i.e. not a test)."""
function_scoped_fixture = 9
"""Checks if :func:`@given <hypothesis.given>` has been applied to a test
with a pytest function-scoped fixture. Function-scoped fixtures run once
for the whole function, not once per example, and this is usually not what
you want.
Because of this limitation, tests that need to set up or reset
state for every example need to do so manually within the test itself,
typically using an appropriate context manager.
Suppress this health check only in the rare case that you are using a
function-scoped fixture that does not need to be reset between individual
examples, but for some reason you cannot use a wider fixture scope
(e.g. session scope, module scope, class scope).
This check requires the :ref:`Hypothesis pytest plugin<pytest-plugin>`,
which is enabled by default when running Hypothesis inside pytest."""
differing_executors = 10
"""Checks if :func:`@given <hypothesis.given>` has been applied to a test
which is executed by different :ref:`executors<custom-function-execution>`.
If your test function is defined as a method on a class, that class will be
your executor, and subclasses executing an inherited test is a common way
for things to go wrong.
The correct fix is often to bring the executor instance under the control
of hypothesis by explicit parametrization over, or sampling from,
subclasses, or to refactor so that :func:`@given <hypothesis.given>` is
specified on leaf subclasses."""
@unique
class Verbosity(IntEnum):
quiet = 0
normal = 1
verbose = 2
debug = 3
def __repr__(self):
return f"Verbosity.{self.name}"
settings._define_setting(
"verbosity",
options=tuple(Verbosity),
default=Verbosity.normal,
description="Control the verbosity level of Hypothesis messages",
)
def _validate_phases(phases):
phases = tuple(phases)
for a in phases:
if not isinstance(a, Phase):
raise InvalidArgument(f"{a!r} is not a valid phase")
return tuple(p for p in list(Phase) if p in phases)
settings._define_setting(
"phases",
default=tuple(Phase),
description=(
"Control which phases should be run. "
"See :ref:`the full documentation for more details <phases>`"
),
validator=_validate_phases,
)
def _validate_stateful_step_count(x):
check_type(int, x, name="stateful_step_count")
if x < 1:
raise InvalidArgument(f"stateful_step_count={x!r} must be at least one.")
return x
settings._define_setting(
name="stateful_step_count",
default=50,
validator=_validate_stateful_step_count,
description="""
Number of steps to run a stateful program for before giving up on it breaking.
""",
)
settings._define_setting(
name="report_multiple_bugs",
default=True,
options=(True, False),
description="""
Because Hypothesis runs the test many times, it can sometimes find multiple
bugs in a single run. Reporting all of them at once is usually very useful,
but replacing the exceptions can occasionally clash with debuggers.
If disabled, only the exception with the smallest minimal example is raised.
""",
)
def validate_health_check_suppressions(suppressions):
suppressions = try_convert(list, suppressions, "suppress_health_check")
for s in suppressions:
if not isinstance(s, HealthCheck):
raise InvalidArgument(
f"Non-HealthCheck value {s!r} of type {type(s).__name__} "
"is invalid in suppress_health_check."
)
if s in (HealthCheck.return_value, HealthCheck.not_a_test_method):
note_deprecation(
f"The {s.name} health check is deprecated, because this is always an error.",
since="2023-03-15",
has_codemod=False,
stacklevel=2,
)
return suppressions
settings._define_setting(
"suppress_health_check",
default=(),
description="""A list of :class:`~hypothesis.HealthCheck` items to disable.""",
validator=validate_health_check_suppressions,
)
class duration(datetime.timedelta):
"""A timedelta specifically measured in milliseconds."""
def __repr__(self):
ms = self.total_seconds() * 1000
return f"timedelta(milliseconds={int(ms) if ms == int(ms) else ms!r})"
def _validate_deadline(x):
if x is None:
return x
invalid_deadline_error = InvalidArgument(
f"deadline={x!r} (type {type(x).__name__}) must be a timedelta object, "
"an integer or float number of milliseconds, or None to disable the "
"per-test-case deadline."
)
if isinstance(x, (int, float)):
if isinstance(x, bool):
raise invalid_deadline_error
try:
x = duration(milliseconds=x)
except OverflowError:
raise InvalidArgument(
f"deadline={x!r} is invalid, because it is too large to represent "
"as a timedelta. Use deadline=None to disable deadlines."
) from None
if isinstance(x, datetime.timedelta):
if x <= datetime.timedelta(0):
raise InvalidArgument(
f"deadline={x!r} is invalid, because it is impossible to meet a "
"deadline <= 0. Use deadline=None to disable deadlines."
)
return duration(seconds=x.total_seconds())
raise invalid_deadline_error
settings._define_setting(
"deadline",
default=duration(milliseconds=200),
validator=_validate_deadline,
description="""
If set, a duration (as timedelta, or integer or float number of milliseconds)
that each individual example (i.e. each time your test
function is called, not the whole decorated test) within a test is not
allowed to exceed. Tests which take longer than that may be converted into
errors (but will not necessarily be if close to the deadline, to allow some
variability in test run time).
Set this to ``None`` to disable this behaviour entirely.
""",
)
def is_in_ci() -> bool:
# GitHub Actions, Travis CI and AppVeyor have "CI"
# Azure Pipelines has "TF_BUILD"
return "CI" in os.environ or "TF_BUILD" in os.environ
settings._define_setting(
"print_blob",
default=is_in_ci(),
show_default=False,
options=(True, False),
description="""
If set to ``True``, Hypothesis will print code for failing examples that can be used with
:func:`@reproduce_failure <hypothesis.reproduce_failure>` to reproduce the failing example.
The default is ``True`` if the ``CI`` or ``TF_BUILD`` env vars are set, ``False`` otherwise.
""",
)
settings.lock_further_definitions()
def note_deprecation(
message: str, *, since: str, has_codemod: bool, stacklevel: int = 0
) -> None:
if since != "RELEASEDAY":
date = datetime.date.fromisoformat(since)
assert datetime.date(2021, 1, 1) <= date
if has_codemod:
message += (
"\n The `hypothesis codemod` command-line tool can automatically "
"refactor your code to fix this warning."
)
warnings.warn(HypothesisDeprecationWarning(message), stacklevel=2 + stacklevel)
settings.register_profile("default", settings())
settings.load_profile("default")
assert settings.default is not None
# Check that the kwonly args to settings.__init__ is the same as the set of
# defined settings - in case we've added or remove something from one but
# not the other.
assert set(all_settings) == {
p.name
for p in inspect.signature(settings.__init__).parameters.values()
if p.kind == inspect.Parameter.KEYWORD_ONLY
}

View File

@@ -0,0 +1,31 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import os
from pathlib import Path
__hypothesis_home_directory_default = Path.cwd() / ".hypothesis"
__hypothesis_home_directory = None
def set_hypothesis_home_dir(directory):
global __hypothesis_home_directory
__hypothesis_home_directory = None if directory is None else Path(directory)
def storage_directory(*names):
global __hypothesis_home_directory
if not __hypothesis_home_directory:
if where := os.getenv("HYPOTHESIS_STORAGE_DIRECTORY"):
__hypothesis_home_directory = Path(where)
if not __hypothesis_home_directory:
__hypothesis_home_directory = __hypothesis_home_directory_default
return __hypothesis_home_directory.joinpath(*names)

View File

@@ -0,0 +1,262 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import math
from collections import defaultdict
from typing import NoReturn, Union
from weakref import WeakKeyDictionary
from hypothesis import Verbosity, settings
from hypothesis._settings import note_deprecation
from hypothesis.errors import InvalidArgument, UnsatisfiedAssumption
from hypothesis.internal.compat import BaseExceptionGroup
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.internal.validation import check_type
from hypothesis.reporting import report, verbose_report
from hypothesis.utils.dynamicvariables import DynamicVariable
from hypothesis.vendor.pretty import IDKey
def reject() -> NoReturn:
if _current_build_context.value is None:
note_deprecation(
"Using `reject` outside a property-based test is deprecated",
since="2023-09-25",
has_codemod=False,
)
raise UnsatisfiedAssumption
def assume(condition: object) -> bool:
"""Calling ``assume`` is like an :ref:`assert <python:assert>` that marks
the example as bad, rather than failing the test.
This allows you to specify properties that you *assume* will be
true, and let Hypothesis try to avoid similar examples in future.
"""
if _current_build_context.value is None:
note_deprecation(
"Using `assume` outside a property-based test is deprecated",
since="2023-09-25",
has_codemod=False,
)
if not condition:
raise UnsatisfiedAssumption
return True
_current_build_context = DynamicVariable(None)
def currently_in_test_context() -> bool:
"""Return ``True`` if the calling code is currently running inside an
:func:`@given <hypothesis.given>` or :doc:`stateful <stateful>` test,
``False`` otherwise.
This is useful for third-party integrations and assertion helpers which
may be called from traditional or property-based tests, but can only use
:func:`~hypothesis.assume` or :func:`~hypothesis.target` in the latter case.
"""
return _current_build_context.value is not None
def current_build_context() -> "BuildContext":
context = _current_build_context.value
if context is None:
raise InvalidArgument("No build context registered")
return context
class BuildContext:
def __init__(self, data, *, is_final=False, close_on_capture=True):
assert isinstance(data, ConjectureData)
self.data = data
self.tasks = []
self.is_final = is_final
self.close_on_capture = close_on_capture
self.close_on_del = False
# Use defaultdict(list) here to handle the possibility of having multiple
# functions registered for the same object (due to caching, small ints, etc).
# The printer will discard duplicates which return different representations.
self.known_object_printers = defaultdict(list)
def record_call(self, obj, func, args, kwargs, arg_slices=None):
name = get_pretty_function_description(func)
self.known_object_printers[IDKey(obj)].append(
lambda obj, p, cycle: (
p.text("<...>")
if cycle
else p.repr_call(name, args, kwargs, arg_slices=arg_slices)
)
)
def prep_args_kwargs_from_strategies(self, arg_strategies, kwarg_strategies):
arg_labels = {}
all_s = [(None, s) for s in arg_strategies] + list(kwarg_strategies.items())
args = []
kwargs = {}
for i, (k, s) in enumerate(all_s):
start_idx = self.data.index
obj = self.data.draw(s)
end_idx = self.data.index
assert k is not None
kwargs[k] = obj
# This high up the stack, we can't see or really do much with the conjecture
# Example objects - not least because they're only materialized after the
# test case is completed. Instead, we'll stash the (start_idx, end_idx)
# pair on our data object for the ConjectureRunner engine to deal with, and
# pass a dict of such out so that the pretty-printer knows where to place
# the which-parts-matter comments later.
if start_idx != end_idx:
arg_labels[k or i] = (start_idx, end_idx)
self.data.arg_slices.add((start_idx, end_idx))
return args, kwargs, arg_labels
def __enter__(self):
self.assign_variable = _current_build_context.with_value(self)
self.assign_variable.__enter__()
return self
def __exit__(self, exc_type, exc_value, tb):
self.assign_variable.__exit__(exc_type, exc_value, tb)
errors = []
for task in self.tasks:
try:
task()
except BaseException as err:
errors.append(err)
if errors:
if len(errors) == 1:
raise errors[0] from exc_value
raise BaseExceptionGroup("Cleanup failed", errors) from exc_value
def cleanup(teardown):
"""Register a function to be called when the current test has finished
executing. Any exceptions thrown in teardown will be printed but not
rethrown.
Inside a test this isn't very interesting, because you can just use
a finally block, but note that you can use this inside map, flatmap,
etc. in order to e.g. insist that a value is closed at the end.
"""
context = _current_build_context.value
if context is None:
raise InvalidArgument("Cannot register cleanup outside of build context")
context.tasks.append(teardown)
def should_note():
context = _current_build_context.value
if context is None:
raise InvalidArgument("Cannot make notes outside of a test")
return context.is_final or settings.default.verbosity >= Verbosity.verbose
def note(value: str) -> None:
"""Report this value for the minimal failing example."""
if should_note():
report(value)
def event(value: str, payload: Union[str, int, float] = "") -> None:
"""Record an event that occurred during this test. Statistics on the number of test
runs with each event will be reported at the end if you run Hypothesis in
statistics reporting mode.
Event values should be strings or convertible to them. If an optional
payload is given, it will be included in the string for :ref:`statistics`.
"""
context = _current_build_context.value
if context is None:
raise InvalidArgument("Cannot make record events outside of a test")
payload = _event_to_string(payload, (str, int, float))
context.data.events[_event_to_string(value)] = payload
_events_to_strings: WeakKeyDictionary = WeakKeyDictionary()
def _event_to_string(event, allowed_types=str):
if isinstance(event, allowed_types):
return event
try:
return _events_to_strings[event]
except (KeyError, TypeError):
pass
result = str(event)
try:
_events_to_strings[event] = result
except TypeError:
pass
return result
def target(observation: Union[int, float], *, label: str = "") -> Union[int, float]:
"""Calling this function with an ``int`` or ``float`` observation gives it feedback
with which to guide our search for inputs that will cause an error, in
addition to all the usual heuristics. Observations must always be finite.
Hypothesis will try to maximize the observed value over several examples;
almost any metric will work so long as it makes sense to increase it.
For example, ``-abs(error)`` is a metric that increases as ``error``
approaches zero.
Example metrics:
- Number of elements in a collection, or tasks in a queue
- Mean or maximum runtime of a task (or both, if you use ``label``)
- Compression ratio for data (perhaps per-algorithm or per-level)
- Number of steps taken by a state machine
The optional ``label`` argument can be used to distinguish between
and therefore separately optimise distinct observations, such as the
mean and standard deviation of a dataset. It is an error to call
``target()`` with any label more than once per test case.
.. note::
**The more examples you run, the better this technique works.**
As a rule of thumb, the targeting effect is noticeable above
:obj:`max_examples=1000 <hypothesis.settings.max_examples>`,
and immediately obvious by around ten thousand examples
*per label* used by your test.
:ref:`statistics` include the best score seen for each label,
which can help avoid `the threshold problem
<https://hypothesis.works/articles/threshold-problem/>`__ when the minimal
example shrinks right down to the threshold of failure (:issue:`2180`).
"""
check_type((int, float), observation, "observation")
if not math.isfinite(observation):
raise InvalidArgument(f"{observation=} must be a finite float.")
check_type(str, label, "label")
context = _current_build_context.value
if context is None:
raise InvalidArgument(
"Calling target() outside of a test is invalid. "
"Consider guarding this call with `if currently_in_test_context(): ...`"
)
verbose_report(f"Saw target({observation!r}, {label=})")
if label in context.data.target_observations:
raise InvalidArgument(
f"Calling target({observation!r}, {label=}) would overwrite "
f"target({context.data.target_observations[label]!r}, {label=})"
)
else:
context.data.target_observations[label] = observation
return observation

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,663 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import abc
import binascii
import json
import os
import sys
import warnings
from datetime import datetime, timedelta, timezone
from functools import lru_cache
from hashlib import sha384
from os import getenv
from pathlib import Path, PurePath
from typing import Dict, Iterable, Optional, Set
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
from zipfile import BadZipFile, ZipFile
from hypothesis.configuration import storage_directory
from hypothesis.errors import HypothesisException, HypothesisWarning
from hypothesis.utils.conventions import not_set
__all__ = [
"DirectoryBasedExampleDatabase",
"ExampleDatabase",
"InMemoryExampleDatabase",
"MultiplexedDatabase",
"ReadOnlyDatabase",
"GitHubArtifactDatabase",
]
def _usable_dir(path: Path) -> bool:
"""
Returns True iff the desired path can be used as database path because
either the directory exists and can be used, or its root directory can
be used and we can make the directory as needed.
"""
while not path.exists():
# Loop terminates because the root dir ('/' on unix) always exists.
path = path.parent
return path.is_dir() and os.access(path, os.R_OK | os.W_OK | os.X_OK)
def _db_for_path(path=None):
if path is not_set:
if os.getenv("HYPOTHESIS_DATABASE_FILE") is not None: # pragma: no cover
raise HypothesisException(
"The $HYPOTHESIS_DATABASE_FILE environment variable no longer has any "
"effect. Configure your database location via a settings profile instead.\n"
"https://hypothesis.readthedocs.io/en/latest/settings.html#settings-profiles"
)
path = storage_directory("examples")
if not _usable_dir(path): # pragma: no cover
warnings.warn(
"The database setting is not configured, and the default "
"location is unusable - falling back to an in-memory "
f"database for this session. {path=}",
HypothesisWarning,
stacklevel=3,
)
return InMemoryExampleDatabase()
if path in (None, ":memory:"):
return InMemoryExampleDatabase()
return DirectoryBasedExampleDatabase(str(path))
class _EDMeta(abc.ABCMeta):
def __call__(self, *args, **kwargs):
if self is ExampleDatabase:
return _db_for_path(*args, **kwargs)
return super().__call__(*args, **kwargs)
# This __call__ method is picked up by Sphinx as the signature of all ExampleDatabase
# subclasses, which is accurate, reasonable, and unhelpful. Fortunately Sphinx
# maintains a list of metaclass-call-methods to ignore, and while they would prefer
# not to maintain it upstream (https://github.com/sphinx-doc/sphinx/pull/8262) we
# can insert ourselves here.
#
# This code only runs if Sphinx has already been imported; and it would live in our
# docs/conf.py except that we would also like it to work for anyone documenting
# downstream ExampleDatabase subclasses too.
if "sphinx" in sys.modules:
try:
from sphinx.ext.autodoc import _METACLASS_CALL_BLACKLIST
_METACLASS_CALL_BLACKLIST.append("hypothesis.database._EDMeta.__call__")
except Exception:
pass
class ExampleDatabase(metaclass=_EDMeta):
"""An abstract base class for storing examples in Hypothesis' internal format.
An ExampleDatabase maps each ``bytes`` key to many distinct ``bytes``
values, like a ``Mapping[bytes, AbstractSet[bytes]]``.
"""
@abc.abstractmethod
def save(self, key: bytes, value: bytes) -> None:
"""Save ``value`` under ``key``.
If this value is already present for this key, silently do nothing.
"""
raise NotImplementedError(f"{type(self).__name__}.save")
@abc.abstractmethod
def fetch(self, key: bytes) -> Iterable[bytes]:
"""Return an iterable over all values matching this key."""
raise NotImplementedError(f"{type(self).__name__}.fetch")
@abc.abstractmethod
def delete(self, key: bytes, value: bytes) -> None:
"""Remove this value from this key.
If this value is not present, silently do nothing.
"""
raise NotImplementedError(f"{type(self).__name__}.delete")
def move(self, src: bytes, dest: bytes, value: bytes) -> None:
"""Move ``value`` from key ``src`` to key ``dest``. Equivalent to
``delete(src, value)`` followed by ``save(src, value)``, but may
have a more efficient implementation.
Note that ``value`` will be inserted at ``dest`` regardless of whether
it is currently present at ``src``.
"""
if src == dest:
self.save(src, value)
return
self.delete(src, value)
self.save(dest, value)
class InMemoryExampleDatabase(ExampleDatabase):
"""A non-persistent example database, implemented in terms of a dict of sets.
This can be useful if you call a test function several times in a single
session, or for testing other database implementations, but because it
does not persist between runs we do not recommend it for general use.
"""
def __init__(self):
self.data = {}
def __repr__(self) -> str:
return f"InMemoryExampleDatabase({self.data!r})"
def fetch(self, key: bytes) -> Iterable[bytes]:
yield from self.data.get(key, ())
def save(self, key: bytes, value: bytes) -> None:
self.data.setdefault(key, set()).add(bytes(value))
def delete(self, key: bytes, value: bytes) -> None:
self.data.get(key, set()).discard(bytes(value))
def _hash(key):
return sha384(key).hexdigest()[:16]
class DirectoryBasedExampleDatabase(ExampleDatabase):
"""Use a directory to store Hypothesis examples as files.
Each test corresponds to a directory, and each example to a file within that
directory. While the contents are fairly opaque, a
``DirectoryBasedExampleDatabase`` can be shared by checking the directory
into version control, for example with the following ``.gitignore``::
# Ignore files cached by Hypothesis...
.hypothesis/*
# except for the examples directory
!.hypothesis/examples/
Note however that this only makes sense if you also pin to an exact version of
Hypothesis, and we would usually recommend implementing a shared database with
a network datastore - see :class:`~hypothesis.database.ExampleDatabase`, and
the :class:`~hypothesis.database.MultiplexedDatabase` helper.
"""
def __init__(self, path: os.PathLike) -> None:
self.path = Path(path)
self.keypaths: Dict[bytes, Path] = {}
def __repr__(self) -> str:
return f"DirectoryBasedExampleDatabase({self.path!r})"
def _key_path(self, key: bytes) -> Path:
try:
return self.keypaths[key]
except KeyError:
pass
self.keypaths[key] = self.path / _hash(key)
return self.keypaths[key]
def _value_path(self, key, value):
return self._key_path(key) / _hash(value)
def fetch(self, key: bytes) -> Iterable[bytes]:
kp = self._key_path(key)
if not kp.is_dir():
return
for path in os.listdir(kp):
try:
yield (kp / path).read_bytes()
except OSError:
pass
def save(self, key: bytes, value: bytes) -> None:
# Note: we attempt to create the dir in question now. We
# already checked for permissions, but there can still be other issues,
# e.g. the disk is full
self._key_path(key).mkdir(exist_ok=True, parents=True)
path = self._value_path(key, value)
if not path.exists():
suffix = binascii.hexlify(os.urandom(16)).decode("ascii")
tmpname = path.with_suffix(f"{path.suffix}.{suffix}")
tmpname.write_bytes(value)
try:
tmpname.rename(path)
except OSError: # pragma: no cover
tmpname.unlink()
assert not tmpname.exists()
def move(self, src: bytes, dest: bytes, value: bytes) -> None:
if src == dest:
self.save(src, value)
return
try:
os.renames(
self._value_path(src, value),
self._value_path(dest, value),
)
except OSError:
self.delete(src, value)
self.save(dest, value)
def delete(self, key: bytes, value: bytes) -> None:
try:
self._value_path(key, value).unlink()
except OSError:
pass
class ReadOnlyDatabase(ExampleDatabase):
"""A wrapper to make the given database read-only.
The implementation passes through ``fetch``, and turns ``save``, ``delete``, and
``move`` into silent no-ops.
Note that this disables Hypothesis' automatic discarding of stale examples.
It is designed to allow local machines to access a shared database (e.g. from CI
servers), without propagating changes back from a local or in-development branch.
"""
def __init__(self, db: ExampleDatabase) -> None:
assert isinstance(db, ExampleDatabase)
self._wrapped = db
def __repr__(self) -> str:
return f"ReadOnlyDatabase({self._wrapped!r})"
def fetch(self, key: bytes) -> Iterable[bytes]:
yield from self._wrapped.fetch(key)
def save(self, key: bytes, value: bytes) -> None:
pass
def delete(self, key: bytes, value: bytes) -> None:
pass
class MultiplexedDatabase(ExampleDatabase):
"""A wrapper around multiple databases.
Each ``save``, ``fetch``, ``move``, or ``delete`` operation will be run against
all of the wrapped databases. ``fetch`` does not yield duplicate values, even
if the same value is present in two or more of the wrapped databases.
This combines well with a :class:`ReadOnlyDatabase`, as follows:
.. code-block:: python
local = DirectoryBasedExampleDatabase("/tmp/hypothesis/examples/")
shared = CustomNetworkDatabase()
settings.register_profile("ci", database=shared)
settings.register_profile(
"dev", database=MultiplexedDatabase(local, ReadOnlyDatabase(shared))
)
settings.load_profile("ci" if os.environ.get("CI") else "dev")
So your CI system or fuzzing runs can populate a central shared database;
while local runs on development machines can reproduce any failures from CI
but will only cache their own failures locally and cannot remove examples
from the shared database.
"""
def __init__(self, *dbs: ExampleDatabase) -> None:
assert all(isinstance(db, ExampleDatabase) for db in dbs)
self._wrapped = dbs
def __repr__(self) -> str:
return "MultiplexedDatabase({})".format(", ".join(map(repr, self._wrapped)))
def fetch(self, key: bytes) -> Iterable[bytes]:
seen = set()
for db in self._wrapped:
for value in db.fetch(key):
if value not in seen:
yield value
seen.add(value)
def save(self, key: bytes, value: bytes) -> None:
for db in self._wrapped:
db.save(key, value)
def delete(self, key: bytes, value: bytes) -> None:
for db in self._wrapped:
db.delete(key, value)
def move(self, src: bytes, dest: bytes, value: bytes) -> None:
for db in self._wrapped:
db.move(src, dest, value)
class GitHubArtifactDatabase(ExampleDatabase):
"""
A file-based database loaded from a `GitHub Actions <https://docs.github.com/en/actions>`_ artifact.
You can use this for sharing example databases between CI runs and developers, allowing
the latter to get read-only access to the former. This is particularly useful for
continuous fuzzing (i.e. with `HypoFuzz <https://hypofuzz.com/>`_),
where the CI system can help find new failing examples through fuzzing,
and developers can reproduce them locally without any manual effort.
.. note::
You must provide ``GITHUB_TOKEN`` as an environment variable. In CI, Github Actions provides
this automatically, but it needs to be set manually for local usage. In a developer machine,
this would usually be a `Personal Access Token <https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token>`_.
If the repository is private, it's necessary for the token to have `repo` scope
in the case of a classic token, or `actions:read` in the case of a fine-grained token.
In most cases, this will be used
through the :class:`~hypothesis.database.MultiplexedDatabase`,
by combining a local directory-based database with this one. For example:
.. code-block:: python
local = DirectoryBasedExampleDatabase(".hypothesis/examples")
shared = ReadOnlyDatabase(GitHubArtifactDatabase("user", "repo"))
settings.register_profile("ci", database=local)
settings.register_profile("dev", database=MultiplexedDatabase(local, shared))
# We don't want to use the shared database in CI, only to populate its local one.
# which the workflow should then upload as an artifact.
settings.load_profile("ci" if os.environ.get("CI") else "dev")
.. note::
Because this database is read-only, you always need to wrap it with the
:class:`ReadOnlyDatabase`.
A setup like this can be paired with a GitHub Actions workflow including
something like the following:
.. code-block:: yaml
- name: Download example database
uses: dawidd6/action-download-artifact@v2.24.3
with:
name: hypothesis-example-db
path: .hypothesis/examples
if_no_artifact_found: warn
workflow_conclusion: completed
- name: Run tests
run: pytest
- name: Upload example database
uses: actions/upload-artifact@v3
if: always()
with:
name: hypothesis-example-db
path: .hypothesis/examples
In this workflow, we use `dawidd6/action-download-artifact <https://github.com/dawidd6/action-download-artifact>`_
to download the latest artifact given that the official `actions/download-artifact <https://github.com/actions/download-artifact>`_
does not support downloading artifacts from previous workflow runs.
The database automatically implements a simple file-based cache with a default expiration period
of 1 day. You can adjust this through the `cache_timeout` property.
For mono-repo support, you can provide a unique `artifact_name` (e.g. `hypofuzz-example-db-frontend`).
"""
def __init__(
self,
owner: str,
repo: str,
artifact_name: str = "hypothesis-example-db",
cache_timeout: timedelta = timedelta(days=1),
path: Optional[os.PathLike] = None,
):
self.owner = owner
self.repo = repo
self.artifact_name = artifact_name
self.cache_timeout = cache_timeout
# Get the GitHub token from the environment
# It's unnecessary to use a token if the repo is public
self.token: Optional[str] = getenv("GITHUB_TOKEN")
if path is None:
self.path: Path = Path(
storage_directory(f"github-artifacts/{self.artifact_name}/")
)
else:
self.path = Path(path)
# We don't want to initialize the cache until we need to
self._initialized: bool = False
self._disabled: bool = False
# This is the path to the artifact in usage
# .hypothesis/github-artifacts/<artifact-name>/<modified_isoformat>.zip
self._artifact: Optional[Path] = None
# This caches the artifact structure
self._access_cache: Optional[Dict[PurePath, Set[PurePath]]] = None
# Message to display if user doesn't wrap around ReadOnlyDatabase
self._read_only_message = (
"This database is read-only. "
"Please wrap this class with ReadOnlyDatabase"
"i.e. ReadOnlyDatabase(GitHubArtifactDatabase(...))."
)
def __repr__(self) -> str:
return (
f"GitHubArtifactDatabase(owner={self.owner!r}, "
f"repo={self.repo!r}, artifact_name={self.artifact_name!r})"
)
def _prepare_for_io(self) -> None:
assert self._artifact is not None, "Artifact not loaded."
if self._initialized: # pragma: no cover
return
# Test that the artifact is valid
try:
with ZipFile(self._artifact) as f:
if f.testzip(): # pragma: no cover
raise BadZipFile
# Turns out that testzip() doesn't work quite well
# doing the cache initialization here instead
# will give us more coverage of the artifact.
# Cache the files inside each keypath
self._access_cache = {}
with ZipFile(self._artifact) as zf:
namelist = zf.namelist()
# Iterate over files in the artifact
for filename in namelist:
fileinfo = zf.getinfo(filename)
if fileinfo.is_dir():
self._access_cache[PurePath(filename)] = set()
else:
# Get the keypath from the filename
keypath = PurePath(filename).parent
# Add the file to the keypath
self._access_cache[keypath].add(PurePath(filename))
except BadZipFile:
warnings.warn(
"The downloaded artifact from GitHub is invalid. "
"This could be because the artifact was corrupted, "
"or because the artifact was not created by Hypothesis. ",
HypothesisWarning,
stacklevel=3,
)
self._disabled = True
self._initialized = True
def _initialize_db(self) -> None:
# Create the cache directory if it doesn't exist
self.path.mkdir(exist_ok=True, parents=True)
# Get all artifacts
cached_artifacts = sorted(
self.path.glob("*.zip"),
key=lambda a: datetime.fromisoformat(a.stem.replace("_", ":")),
)
# Remove all but the latest artifact
for artifact in cached_artifacts[:-1]:
artifact.unlink()
try:
found_artifact = cached_artifacts[-1]
except IndexError:
found_artifact = None
# Check if the latest artifact is a cache hit
if found_artifact is not None and (
datetime.now(timezone.utc)
- datetime.fromisoformat(found_artifact.stem.replace("_", ":"))
< self.cache_timeout
):
self._artifact = found_artifact
else:
# Download the latest artifact from GitHub
new_artifact = self._fetch_artifact()
if new_artifact:
if found_artifact is not None:
found_artifact.unlink()
self._artifact = new_artifact
elif found_artifact is not None:
warnings.warn(
"Using an expired artifact as a fallback for the database: "
f"{found_artifact}",
HypothesisWarning,
stacklevel=2,
)
self._artifact = found_artifact
else:
warnings.warn(
"Couldn't acquire a new or existing artifact. Disabling database.",
HypothesisWarning,
stacklevel=2,
)
self._disabled = True
return
self._prepare_for_io()
def _get_bytes(self, url: str) -> Optional[bytes]: # pragma: no cover
request = Request(
url,
headers={
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28 ",
"Authorization": f"Bearer {self.token}",
},
)
warning_message = None
response_bytes: Optional[bytes] = None
try:
with urlopen(request) as response:
response_bytes = response.read()
except HTTPError as e:
if e.code == 401:
warning_message = (
"Authorization failed when trying to download artifact from GitHub. "
"Check that you have a valid GITHUB_TOKEN set in your environment."
)
else:
warning_message = (
"Could not get the latest artifact from GitHub. "
"This could be because because the repository "
"or artifact does not exist. "
)
except URLError:
warning_message = "Could not connect to GitHub to get the latest artifact. "
except TimeoutError:
warning_message = (
"Could not connect to GitHub to get the latest artifact "
"(connection timed out)."
)
if warning_message is not None:
warnings.warn(warning_message, HypothesisWarning, stacklevel=4)
return None
return response_bytes
def _fetch_artifact(self) -> Optional[Path]: # pragma: no cover
# Get the list of artifacts from GitHub
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/artifacts"
response_bytes = self._get_bytes(url)
if response_bytes is None:
return None
artifacts = json.loads(response_bytes)["artifacts"]
artifacts = [a for a in artifacts if a["name"] == self.artifact_name]
if not artifacts:
return None
# Get the latest artifact from the list
artifact = max(artifacts, key=lambda a: a["created_at"])
url = artifact["archive_download_url"]
# Download the artifact
artifact_bytes = self._get_bytes(url)
if artifact_bytes is None:
return None
# Save the artifact to the cache
# We replace ":" with "_" to ensure the filenames are compatible
# with Windows filesystems
timestamp = datetime.now(timezone.utc).isoformat().replace(":", "_")
artifact_path = self.path / f"{timestamp}.zip"
try:
artifact_path.write_bytes(artifact_bytes)
except OSError:
warnings.warn(
"Could not save the latest artifact from GitHub. ",
HypothesisWarning,
stacklevel=3,
)
return None
return artifact_path
@staticmethod
@lru_cache
def _key_path(key: bytes) -> PurePath:
return PurePath(_hash(key) + "/")
def fetch(self, key: bytes) -> Iterable[bytes]:
if self._disabled:
return
if not self._initialized:
self._initialize_db()
if self._disabled:
return
assert self._artifact is not None
assert self._access_cache is not None
kp = self._key_path(key)
with ZipFile(self._artifact) as zf:
# Get the all files in the the kp from the cache
filenames = self._access_cache.get(kp, ())
for filename in filenames:
with zf.open(filename.as_posix()) as f:
yield f.read()
# Read-only interface
def save(self, key: bytes, value: bytes) -> None:
raise RuntimeError(self._read_only_message)
def move(self, src: bytes, dest: bytes, value: bytes) -> None:
raise RuntimeError(self._read_only_message)
def delete(self, key: bytes, value: bytes) -> None:
raise RuntimeError(self._read_only_message)

View File

@@ -0,0 +1,37 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Run all functions registered for the "hypothesis" entry point.
This can be used with `st.register_type_strategy` to register strategies for your
custom types, running the relevant code when *hypothesis* is imported instead of
your package.
"""
import importlib.metadata
import os
def get_entry_points():
try:
eps = importlib.metadata.entry_points(group="hypothesis")
except TypeError: # pragma: no cover
# Load-time selection requires Python >= 3.10. See also
# https://importlib-metadata.readthedocs.io/en/latest/using.html
eps = importlib.metadata.entry_points().get("hypothesis", [])
yield from eps
def run():
if not os.environ.get("HYPOTHESIS_NO_PLUGINS"):
for entry in get_entry_points(): # pragma: no cover
hook = entry.load()
if callable(hook):
hook()

View File

@@ -0,0 +1,183 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
class HypothesisException(Exception):
"""Generic parent class for exceptions thrown by Hypothesis."""
class _Trimmable(HypothesisException):
"""Hypothesis can trim these tracebacks even if they're raised internally."""
class UnsatisfiedAssumption(HypothesisException):
"""An internal error raised by assume.
If you're seeing this error something has gone wrong.
"""
class NoSuchExample(HypothesisException):
"""The condition we have been asked to satisfy appears to be always false.
This does not guarantee that no example exists, only that we were
unable to find one.
"""
def __init__(self, condition_string, extra=""):
super().__init__(f"No examples found of condition {condition_string}{extra}")
class Unsatisfiable(_Trimmable):
"""We ran out of time or examples before we could find enough examples
which satisfy the assumptions of this hypothesis.
This could be because the function is too slow. If so, try upping
the timeout. It could also be because the function is using assume
in a way that is too hard to satisfy. If so, try writing a custom
strategy or using a better starting point (e.g if you are requiring
a list has unique values you could instead filter out all duplicate
values from the list)
"""
class Flaky(_Trimmable):
"""This function appears to fail non-deterministically: We have seen it
fail when passed this example at least once, but a subsequent invocation
did not fail.
Common causes for this problem are:
1. The function depends on external state. e.g. it uses an external
random number generator. Try to make a version that passes all the
relevant state in from Hypothesis.
2. The function is suffering from too much recursion and its failure
depends sensitively on where it's been called from.
3. The function is timing sensitive and can fail or pass depending on
how long it takes. Try breaking it up into smaller functions which
don't do that and testing those instead.
"""
class InvalidArgument(_Trimmable, TypeError):
"""Used to indicate that the arguments to a Hypothesis function were in
some manner incorrect."""
class ResolutionFailed(InvalidArgument):
"""Hypothesis had to resolve a type to a strategy, but this failed.
Type inference is best-effort, so this only happens when an
annotation exists but could not be resolved for a required argument
to the target of ``builds()``, or where the user passed ``...``.
"""
class InvalidState(HypothesisException):
"""The system is not in a state where you were allowed to do that."""
class InvalidDefinition(_Trimmable, TypeError):
"""Used to indicate that a class definition was not well put together and
has something wrong with it."""
class HypothesisWarning(HypothesisException, Warning):
"""A generic warning issued by Hypothesis."""
class FailedHealthCheck(_Trimmable):
"""Raised when a test fails a healthcheck."""
class NonInteractiveExampleWarning(HypothesisWarning):
"""SearchStrategy.example() is designed for interactive use,
but should never be used in the body of a test.
"""
class HypothesisDeprecationWarning(HypothesisWarning, FutureWarning):
"""A deprecation warning issued by Hypothesis.
Actually inherits from FutureWarning, because DeprecationWarning is
hidden by the default warnings filter.
You can configure the Python :mod:`python:warnings` to handle these
warnings differently to others, either turning them into errors or
suppressing them entirely. Obviously we would prefer the former!
"""
class Frozen(HypothesisException):
"""Raised when a mutation method has been called on a ConjectureData object
after freeze() has been called."""
def __getattr__(name):
if name == "MultipleFailures":
from hypothesis._settings import note_deprecation
from hypothesis.internal.compat import BaseExceptionGroup
note_deprecation(
"MultipleFailures is deprecated; use the builtin `BaseExceptionGroup` type "
"instead, or `exceptiongroup.BaseExceptionGroup` before Python 3.11",
since="2022-08-02",
has_codemod=False, # This would be a great PR though!
stacklevel=1,
)
return BaseExceptionGroup
raise AttributeError(f"Module 'hypothesis.errors' has no attribute {name}")
class DeadlineExceeded(_Trimmable):
"""Raised when an individual test body has taken too long to run."""
def __init__(self, runtime, deadline):
super().__init__(
"Test took %.2fms, which exceeds the deadline of %.2fms"
% (runtime.total_seconds() * 1000, deadline.total_seconds() * 1000)
)
self.runtime = runtime
self.deadline = deadline
def __reduce__(self):
return (type(self), (self.runtime, self.deadline))
class StopTest(BaseException):
"""Raised when a test should stop running and return control to
the Hypothesis engine, which should then continue normally.
"""
def __init__(self, testcounter):
super().__init__(repr(testcounter))
self.testcounter = testcounter
class DidNotReproduce(HypothesisException):
pass
class Found(Exception):
"""Signal that the example matches condition. Internal use only."""
hypothesis_internal_never_escalate = True
class RewindRecursive(Exception):
"""Signal that the type inference should be rewound due to recursive types. Internal use only."""
def __init__(self, target):
self.target = target
class SmallSearchSpaceWarning(HypothesisWarning):
"""Indicates that an inferred strategy does not span the search space
in a meaningful way, for example by only creating default instances."""

View File

@@ -0,0 +1,9 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

View File

@@ -0,0 +1,685 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import re
from typing import NamedTuple, Optional, Tuple, Union
from hypothesis import assume, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.utils import _calc_p_continue
from hypothesis.internal.coverage import check_function
from hypothesis.internal.validation import check_type, check_valid_interval
from hypothesis.strategies._internal.utils import defines_strategy
from hypothesis.utils.conventions import UniqueIdentifier, not_set
__all__ = [
"NDIM_MAX",
"Shape",
"BroadcastableShapes",
"BasicIndex",
"check_argument",
"order_check",
"check_valid_dims",
"array_shapes",
"valid_tuple_axes",
"broadcastable_shapes",
"mutually_broadcastable_shapes",
"MutuallyBroadcastableShapesStrategy",
"BasicIndexStrategy",
]
Shape = Tuple[int, ...]
# We silence flake8 here because it disagrees with mypy about `ellipsis` (`type(...)`)
BasicIndex = Tuple[Union[int, slice, None, "ellipsis"], ...] # noqa: F821
class BroadcastableShapes(NamedTuple):
input_shapes: Tuple[Shape, ...]
result_shape: Shape
@check_function
def check_argument(condition, fail_message, *f_args, **f_kwargs):
if not condition:
raise InvalidArgument(fail_message.format(*f_args, **f_kwargs))
@check_function
def order_check(name, floor, min_, max_):
if floor > min_:
raise InvalidArgument(f"min_{name} must be at least {floor} but was {min_}")
if min_ > max_:
raise InvalidArgument(f"min_{name}={min_} is larger than max_{name}={max_}")
# 32 is a dimension limit specific to NumPy, and does not necessarily apply to
# other array/tensor libraries. Historically these strategies were built for the
# NumPy extra, so it's nice to keep these limits, and it's seemingly unlikely
# someone would want to generate >32 dim arrays anyway.
# See https://github.com/HypothesisWorks/hypothesis/pull/3067.
NDIM_MAX = 32
@check_function
def check_valid_dims(dims, name):
if dims > NDIM_MAX:
raise InvalidArgument(
f"{name}={dims}, but Hypothesis does not support arrays with "
f"more than {NDIM_MAX} dimensions"
)
@defines_strategy()
def array_shapes(
*,
min_dims: int = 1,
max_dims: Optional[int] = None,
min_side: int = 1,
max_side: Optional[int] = None,
) -> st.SearchStrategy[Shape]:
"""Return a strategy for array shapes (tuples of int >= 1).
* ``min_dims`` is the smallest length that the generated shape can possess.
* ``max_dims`` is the largest length that the generated shape can possess,
defaulting to ``min_dims + 2``.
* ``min_side`` is the smallest size that a dimension can possess.
* ``max_side`` is the largest size that a dimension can possess,
defaulting to ``min_side + 5``.
"""
check_type(int, min_dims, "min_dims")
check_type(int, min_side, "min_side")
check_valid_dims(min_dims, "min_dims")
if max_dims is None:
max_dims = min(min_dims + 2, NDIM_MAX)
check_type(int, max_dims, "max_dims")
check_valid_dims(max_dims, "max_dims")
if max_side is None:
max_side = min_side + 5
check_type(int, max_side, "max_side")
order_check("dims", 0, min_dims, max_dims)
order_check("side", 0, min_side, max_side)
return st.lists(
st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims
).map(tuple)
@defines_strategy()
def valid_tuple_axes(
ndim: int,
*,
min_size: int = 0,
max_size: Optional[int] = None,
) -> st.SearchStrategy[Tuple[int, ...]]:
"""All tuples will have a length >= ``min_size`` and <= ``max_size``. The default
value for ``max_size`` is ``ndim``.
Examples from this strategy shrink towards an empty tuple, which render most
sequential functions as no-ops.
The following are some examples drawn from this strategy.
.. code-block:: pycon
>>> [valid_tuple_axes(3).example() for i in range(4)]
[(-3, 1), (0, 1, -1), (0, 2), (0, -2, 2)]
``valid_tuple_axes`` can be joined with other strategies to generate
any type of valid axis object, i.e. integers, tuples, and ``None``:
.. code-block:: python
any_axis_strategy = none() | integers(-ndim, ndim - 1) | valid_tuple_axes(ndim)
"""
check_type(int, ndim, "ndim")
check_type(int, min_size, "min_size")
if max_size is None:
max_size = ndim
check_type(int, max_size, "max_size")
order_check("size", 0, min_size, max_size)
check_valid_interval(max_size, ndim, "max_size", "ndim")
axes = st.integers(0, max(0, 2 * ndim - 1)).map(
lambda x: x if x < ndim else x - 2 * ndim
)
return st.lists(
axes, min_size=min_size, max_size=max_size, unique_by=lambda x: x % ndim
).map(tuple)
@defines_strategy()
def broadcastable_shapes(
shape: Shape,
*,
min_dims: int = 0,
max_dims: Optional[int] = None,
min_side: int = 1,
max_side: Optional[int] = None,
) -> st.SearchStrategy[Shape]:
"""Return a strategy for shapes that are broadcast-compatible with the
provided shape.
Examples from this strategy shrink towards a shape with length ``min_dims``.
The size of an aligned dimension shrinks towards size ``1``. The size of an
unaligned dimension shrink towards ``min_side``.
* ``shape`` is a tuple of integers.
* ``min_dims`` is the smallest length that the generated shape can possess.
* ``max_dims`` is the largest length that the generated shape can possess,
defaulting to ``max(len(shape), min_dims) + 2``.
* ``min_side`` is the smallest size that an unaligned dimension can possess.
* ``max_side`` is the largest size that an unaligned dimension can possess,
defaulting to 2 plus the size of the largest aligned dimension.
The following are some examples drawn from this strategy.
.. code-block:: pycon
>>> [broadcastable_shapes(shape=(2, 3)).example() for i in range(5)]
[(1, 3), (), (2, 3), (2, 1), (4, 1, 3), (3, )]
"""
check_type(tuple, shape, "shape")
check_type(int, min_side, "min_side")
check_type(int, min_dims, "min_dims")
check_valid_dims(min_dims, "min_dims")
strict_check = max_side is None or max_dims is None
if max_dims is None:
max_dims = min(max(len(shape), min_dims) + 2, NDIM_MAX)
check_type(int, max_dims, "max_dims")
check_valid_dims(max_dims, "max_dims")
if max_side is None:
max_side = max(shape[-max_dims:] + (min_side,)) + 2
check_type(int, max_side, "max_side")
order_check("dims", 0, min_dims, max_dims)
order_check("side", 0, min_side, max_side)
if strict_check:
dims = max_dims
bound_name = "max_dims"
else:
dims = min_dims
bound_name = "min_dims"
# check for unsatisfiable min_side
if not all(min_side <= s for s in shape[::-1][:dims] if s != 1):
raise InvalidArgument(
f"Given shape={shape}, there are no broadcast-compatible "
f"shapes that satisfy: {bound_name}={dims} and min_side={min_side}"
)
# check for unsatisfiable [min_side, max_side]
if not (
min_side <= 1 <= max_side or all(s <= max_side for s in shape[::-1][:dims])
):
raise InvalidArgument(
f"Given base_shape={shape}, there are no broadcast-compatible "
f"shapes that satisfy all of {bound_name}={dims}, "
f"min_side={min_side}, and max_side={max_side}"
)
if not strict_check:
# reduce max_dims to exclude unsatisfiable dimensions
for n, s in zip(range(max_dims), shape[::-1]):
if s < min_side and s != 1:
max_dims = n
break
elif not (min_side <= 1 <= max_side or s <= max_side):
max_dims = n
break
return MutuallyBroadcastableShapesStrategy(
num_shapes=1,
base_shape=shape,
min_dims=min_dims,
max_dims=max_dims,
min_side=min_side,
max_side=max_side,
).map(lambda x: x.input_shapes[0])
# See https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html
# Implementation based on numpy.lib.function_base._parse_gufunc_signature
# with minor upgrades to handle numeric and optional dimensions. Examples:
#
# add (),()->() binary ufunc
# sum1d (i)->() reduction
# inner1d (i),(i)->() vector-vector multiplication
# matmat (m,n),(n,p)->(m,p) matrix multiplication
# vecmat (n),(n,p)->(p) vector-matrix multiplication
# matvec (m,n),(n)->(m) matrix-vector multiplication
# matmul (m?,n),(n,p?)->(m?,p?) combination of the four above
# cross1d (3),(3)->(3) cross product with frozen dimensions
#
# Note that while no examples of such usage are given, Numpy does allow
# generalised ufuncs that have *multiple output arrays*. This is not
# currently supported by Hypothesis - please contact us if you would use it!
#
# We are unsure if gufuncs allow frozen dimensions to be optional, but it's
# easy enough to support here - and so we will unless we learn otherwise.
_DIMENSION = r"\w+\??" # Note that \w permits digits too!
_SHAPE = rf"\((?:{_DIMENSION}(?:,{_DIMENSION}){{0,31}})?\)"
_ARGUMENT_LIST = f"{_SHAPE}(?:,{_SHAPE})*"
_SIGNATURE = rf"^{_ARGUMENT_LIST}->{_SHAPE}$"
_SIGNATURE_MULTIPLE_OUTPUT = rf"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
class _GUfuncSig(NamedTuple):
input_shapes: Tuple[Shape, ...]
result_shape: Shape
def _hypothesis_parse_gufunc_signature(signature):
# Disable all_checks to better match the Numpy version, for testing
if not re.match(_SIGNATURE, signature):
if re.match(_SIGNATURE_MULTIPLE_OUTPUT, signature):
raise InvalidArgument(
"Hypothesis does not yet support generalised ufunc signatures "
"with multiple output arrays - mostly because we don't know of "
"anyone who uses them! Please get in touch with us to fix that."
f"\n ({signature=})"
)
if re.match(
(
# Taken from np.lib.function_base._SIGNATURE
r"^\((?:\w+(?:,\w+)*)?\)(?:,\((?:\w+(?:,\w+)*)?\))*->"
r"\((?:\w+(?:,\w+)*)?\)(?:,\((?:\w+(?:,\w+)*)?\))*$"
),
signature,
):
raise InvalidArgument(
f"{signature=} matches Numpy's regex for gufunc signatures, "
f"but contains shapes with more than {NDIM_MAX} dimensions and is thus invalid."
)
raise InvalidArgument(f"{signature!r} is not a valid gufunc signature")
input_shapes, output_shapes = (
tuple(tuple(re.findall(_DIMENSION, a)) for a in re.findall(_SHAPE, arg_list))
for arg_list in signature.split("->")
)
assert len(output_shapes) == 1
result_shape = output_shapes[0]
# Check that there are no names in output shape that do not appear in inputs.
# (kept out of parser function for easier generation of test values)
# We also disallow frozen optional dimensions - this is ambiguous as there is
# no way to share an un-named dimension between shapes. Maybe just padding?
# Anyway, we disallow it pending clarification from upstream.
for shape in (*input_shapes, result_shape):
for name in shape:
try:
int(name.strip("?"))
if "?" in name:
raise InvalidArgument(
f"Got dimension {name!r}, but handling of frozen optional dimensions "
"is ambiguous. If you known how this should work, please "
"contact us to get this fixed and documented ({signature=})."
)
except ValueError:
names_in = {n.strip("?") for shp in input_shapes for n in shp}
names_out = {n.strip("?") for n in result_shape}
if name.strip("?") in (names_out - names_in):
raise InvalidArgument(
"The {name!r} dimension only appears in the output shape, and is "
"not frozen, so the size is not determined ({signature=})."
) from None
return _GUfuncSig(input_shapes=input_shapes, result_shape=result_shape)
@defines_strategy()
def mutually_broadcastable_shapes(
*,
num_shapes: Union[UniqueIdentifier, int] = not_set,
signature: Union[UniqueIdentifier, str] = not_set,
base_shape: Shape = (),
min_dims: int = 0,
max_dims: Optional[int] = None,
min_side: int = 1,
max_side: Optional[int] = None,
) -> st.SearchStrategy[BroadcastableShapes]:
"""Return a strategy for a specified number of shapes N that are
mutually-broadcastable with one another and with the provided base shape.
* ``num_shapes`` is the number of mutually broadcast-compatible shapes to generate.
* ``base_shape`` is the shape against which all generated shapes can broadcast.
The default shape is empty, which corresponds to a scalar and thus does
not constrain broadcasting at all.
* ``min_dims`` is the smallest length that the generated shape can possess.
* ``max_dims`` is the largest length that the generated shape can possess,
defaulting to ``max(len(shape), min_dims) + 2``.
* ``min_side`` is the smallest size that an unaligned dimension can possess.
* ``max_side`` is the largest size that an unaligned dimension can possess,
defaulting to 2 plus the size of the largest aligned dimension.
The strategy will generate a :obj:`python:typing.NamedTuple` containing:
* ``input_shapes`` as a tuple of the N generated shapes.
* ``result_shape`` as the resulting shape produced by broadcasting the N shapes
with the base shape.
The following are some examples drawn from this strategy.
.. code-block:: pycon
>>> # Draw three shapes where each shape is broadcast-compatible with (2, 3)
... strat = mutually_broadcastable_shapes(num_shapes=3, base_shape=(2, 3))
>>> for _ in range(5):
... print(strat.example())
BroadcastableShapes(input_shapes=((4, 1, 3), (4, 2, 3), ()), result_shape=(4, 2, 3))
BroadcastableShapes(input_shapes=((3,), (1, 3), (2, 3)), result_shape=(2, 3))
BroadcastableShapes(input_shapes=((), (), ()), result_shape=())
BroadcastableShapes(input_shapes=((3,), (), (3,)), result_shape=(3,))
BroadcastableShapes(input_shapes=((1, 2, 3), (3,), ()), result_shape=(1, 2, 3))
"""
arg_msg = "Pass either the `num_shapes` or the `signature` argument, but not both."
if num_shapes is not not_set:
check_argument(signature is not_set, arg_msg)
check_type(int, num_shapes, "num_shapes")
assert isinstance(num_shapes, int) # for mypy
parsed_signature = None
sig_dims = 0
else:
check_argument(signature is not not_set, arg_msg)
if signature is None:
raise InvalidArgument(
"Expected a string, but got invalid signature=None. "
"(maybe .signature attribute of an element-wise ufunc?)"
)
check_type(str, signature, "signature")
parsed_signature = _hypothesis_parse_gufunc_signature(signature)
all_shapes = (*parsed_signature.input_shapes, parsed_signature.result_shape)
sig_dims = min(len(s) for s in all_shapes)
num_shapes = len(parsed_signature.input_shapes)
if num_shapes < 1:
raise InvalidArgument(f"num_shapes={num_shapes} must be at least 1")
check_type(tuple, base_shape, "base_shape")
check_type(int, min_side, "min_side")
check_type(int, min_dims, "min_dims")
check_valid_dims(min_dims, "min_dims")
strict_check = max_dims is not None
if max_dims is None:
max_dims = min(max(len(base_shape), min_dims) + 2, NDIM_MAX - sig_dims)
check_type(int, max_dims, "max_dims")
check_valid_dims(max_dims, "max_dims")
if max_side is None:
max_side = max(base_shape[-max_dims:] + (min_side,)) + 2
check_type(int, max_side, "max_side")
order_check("dims", 0, min_dims, max_dims)
order_check("side", 0, min_side, max_side)
if signature is not None and max_dims > NDIM_MAX - sig_dims:
raise InvalidArgument(
f"max_dims={signature!r} would exceed the {NDIM_MAX}-dimension"
"limit Hypothesis imposes on array shapes, "
f"given signature={parsed_signature!r}"
)
if strict_check:
dims = max_dims
bound_name = "max_dims"
else:
dims = min_dims
bound_name = "min_dims"
# check for unsatisfiable min_side
if not all(min_side <= s for s in base_shape[::-1][:dims] if s != 1):
raise InvalidArgument(
f"Given base_shape={base_shape}, there are no broadcast-compatible "
f"shapes that satisfy: {bound_name}={dims} and min_side={min_side}"
)
# check for unsatisfiable [min_side, max_side]
if not (
min_side <= 1 <= max_side or all(s <= max_side for s in base_shape[::-1][:dims])
):
raise InvalidArgument(
f"Given base_shape={base_shape}, there are no broadcast-compatible "
f"shapes that satisfy all of {bound_name}={dims}, "
f"min_side={min_side}, and max_side={max_side}"
)
if not strict_check:
# reduce max_dims to exclude unsatisfiable dimensions
for n, s in zip(range(max_dims), base_shape[::-1]):
if s < min_side and s != 1:
max_dims = n
break
elif not (min_side <= 1 <= max_side or s <= max_side):
max_dims = n
break
return MutuallyBroadcastableShapesStrategy(
num_shapes=num_shapes,
signature=parsed_signature,
base_shape=base_shape,
min_dims=min_dims,
max_dims=max_dims,
min_side=min_side,
max_side=max_side,
)
class MutuallyBroadcastableShapesStrategy(st.SearchStrategy):
def __init__(
self,
num_shapes,
signature=None,
base_shape=(),
min_dims=0,
max_dims=None,
min_side=1,
max_side=None,
):
super().__init__()
self.base_shape = base_shape
self.side_strat = st.integers(min_side, max_side)
self.num_shapes = num_shapes
self.signature = signature
self.min_dims = min_dims
self.max_dims = max_dims
self.min_side = min_side
self.max_side = max_side
self.size_one_allowed = self.min_side <= 1 <= self.max_side
def do_draw(self, data):
# We don't usually have a gufunc signature; do the common case first & fast.
if self.signature is None:
return self._draw_loop_dimensions(data)
# When we *do*, draw the core dims, then draw loop dims, and finally combine.
core_in, core_res = self._draw_core_dimensions(data)
# If some core shape has omitted optional dimensions, it's an error to add
# loop dimensions to it. We never omit core dims if min_dims >= 1.
# This ensures that we respect Numpy's gufunc broadcasting semantics and user
# constraints without needing to check whether the loop dims will be
# interpreted as an invalid substitute for the omitted core dims.
# We may implement this check later!
use = [None not in shp for shp in core_in]
loop_in, loop_res = self._draw_loop_dimensions(data, use=use)
def add_shape(loop, core):
return tuple(x for x in (loop + core)[-NDIM_MAX:] if x is not None)
return BroadcastableShapes(
input_shapes=tuple(add_shape(l_in, c) for l_in, c in zip(loop_in, core_in)),
result_shape=add_shape(loop_res, core_res),
)
def _draw_core_dimensions(self, data):
# Draw gufunc core dimensions, with None standing for optional dimensions
# that will not be present in the final shape. We track omitted dims so
# that we can do an accurate per-shape length cap.
dims = {}
shapes = []
for shape in (*self.signature.input_shapes, self.signature.result_shape):
shapes.append([])
for name in shape:
if name.isdigit():
shapes[-1].append(int(name))
continue
if name not in dims:
dim = name.strip("?")
dims[dim] = data.draw(self.side_strat)
if self.min_dims == 0 and not data.draw_boolean(7 / 8):
dims[dim + "?"] = None
else:
dims[dim + "?"] = dims[dim]
shapes[-1].append(dims[name])
return tuple(tuple(s) for s in shapes[:-1]), tuple(shapes[-1])
def _draw_loop_dimensions(self, data, use=None):
# All shapes are handled in column-major order; i.e. they are reversed
base_shape = self.base_shape[::-1]
result_shape = list(base_shape)
shapes = [[] for _ in range(self.num_shapes)]
if use is None:
use = [True for _ in range(self.num_shapes)]
else:
assert len(use) == self.num_shapes
assert all(isinstance(x, bool) for x in use)
_gap = self.max_dims - self.min_dims
p_keep_extending_shape = _calc_p_continue(desired_avg=_gap / 2, max_size=_gap)
for dim_count in range(1, self.max_dims + 1):
dim = dim_count - 1
# We begin by drawing a valid dimension-size for the given
# dimension. This restricts the variability across the shapes
# at this dimension such that they can only choose between
# this size and a singleton dimension.
if len(base_shape) < dim_count or base_shape[dim] == 1:
# dim is unrestricted by the base-shape: shrink to min_side
dim_side = data.draw(self.side_strat)
elif base_shape[dim] <= self.max_side:
# dim is aligned with non-singleton base-dim
dim_side = base_shape[dim]
else:
# only a singleton is valid in alignment with the base-dim
dim_side = 1
allowed_sides = sorted([1, dim_side]) # shrink to 0 when available
for shape_id, shape in enumerate(shapes):
# Populating this dimension-size for each shape, either
# the drawn size is used or, if permitted, a singleton
# dimension.
if dim <= len(result_shape) and self.size_one_allowed:
# aligned: shrink towards size 1
side = data.draw(st.sampled_from(allowed_sides))
else:
side = dim_side
# Use a trick where where a biased coin is queried to see
# if the given shape-tuple will continue to be grown. All
# of the relevant draws will still be made for the given
# shape-tuple even if it is no longer being added to.
# This helps to ensure more stable shrinking behavior.
if self.min_dims < dim_count:
use[shape_id] &= data.draw_boolean(p_keep_extending_shape)
if use[shape_id]:
shape.append(side)
if len(result_shape) < len(shape):
result_shape.append(shape[-1])
elif shape[-1] != 1 and result_shape[dim] == 1:
result_shape[dim] = shape[-1]
if not any(use):
break
result_shape = result_shape[: max(map(len, [self.base_shape, *shapes]))]
assert len(shapes) == self.num_shapes
assert all(self.min_dims <= len(s) <= self.max_dims for s in shapes)
assert all(self.min_side <= s <= self.max_side for side in shapes for s in side)
return BroadcastableShapes(
input_shapes=tuple(tuple(reversed(shape)) for shape in shapes),
result_shape=tuple(reversed(result_shape)),
)
class BasicIndexStrategy(st.SearchStrategy):
def __init__(
self,
shape,
min_dims,
max_dims,
allow_ellipsis,
allow_newaxis,
allow_fewer_indices_than_dims,
):
self.shape = shape
self.min_dims = min_dims
self.max_dims = max_dims
self.allow_ellipsis = allow_ellipsis
self.allow_newaxis = allow_newaxis
# allow_fewer_indices_than_dims=False will disable generating indices
# that don't cover all axes, i.e. indices that will flat index arrays.
# This is necessary for the Array API as such indices are not supported.
self.allow_fewer_indices_than_dims = allow_fewer_indices_than_dims
def do_draw(self, data):
# General plan: determine the actual selection up front with a straightforward
# approach that shrinks well, then complicate it by inserting other things.
result = []
for dim_size in self.shape:
if dim_size == 0:
result.append(slice(None))
continue
strategy = st.integers(-dim_size, dim_size - 1) | st.slices(dim_size)
result.append(data.draw(strategy))
# Insert some number of new size-one dimensions if allowed
result_dims = sum(isinstance(idx, slice) for idx in result)
while (
self.allow_newaxis
and result_dims < self.max_dims
and (result_dims < self.min_dims or data.draw(st.booleans()))
):
i = data.draw(st.integers(0, len(result)))
result.insert(i, None) # Note that `np.newaxis is None`
result_dims += 1
# Check that we'll have the right number of dimensions; reject if not.
# It's easy to do this by construction if you don't care about shrinking,
# which is really important for array shapes. So we filter instead.
assume(self.min_dims <= result_dims <= self.max_dims)
# This is a quick-and-dirty way to insert ..., xor shorten the indexer,
# but it means we don't have to do any structural analysis.
if self.allow_ellipsis and data.draw(st.booleans()):
# Choose an index; then replace all adjacent whole-dimension slices.
i = j = data.draw(st.integers(0, len(result)))
while i > 0 and result[i - 1] == slice(None):
i -= 1
while j < len(result) and result[j] == slice(None):
j += 1
result[i:j] = [Ellipsis]
elif self.allow_fewer_indices_than_dims: # pragma: no cover
while result[-1:] == [slice(None, None)] and data.draw(st.integers(0, 7)):
result.pop()
if len(result) == 1 and data.draw(st.booleans()):
# Sometimes generate bare element equivalent to a length-one tuple
return result[0]
return tuple(result)

View File

@@ -0,0 +1,225 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
Write patches which add @example() decorators for discovered test cases.
Requires `hypothesis[codemods,ghostwriter]` installed, i.e. black and libcst.
This module is used by Hypothesis' builtin pytest plugin for failing examples
discovered during testing, and by HypoFuzz for _covering_ examples discovered
during fuzzing.
"""
import difflib
import hashlib
import inspect
import re
import sys
from ast import literal_eval
from contextlib import suppress
from datetime import date, datetime, timedelta, timezone
from pathlib import Path
import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from hypothesis.configuration import storage_directory
from hypothesis.version import __version__
try:
import black
except ImportError:
black = None # type: ignore
HEADER = f"""\
From HEAD Mon Sep 17 00:00:00 2001
From: Hypothesis {__version__} <no-reply@hypothesis.works>
Date: {{when:%a, %d %b %Y %H:%M:%S}}
Subject: [PATCH] {{msg}}
---
"""
FAIL_MSG = "discovered failure"
_space_only_re = re.compile("^ +$", re.MULTILINE)
_leading_space_re = re.compile("(^[ ]*)(?:[^ \n])", re.MULTILINE)
def dedent(text):
# Simplified textwrap.dedent, for valid Python source code only
text = _space_only_re.sub("", text)
prefix = min(_leading_space_re.findall(text), key=len)
return re.sub(r"(?m)^" + prefix, "", text), prefix
def indent(text: str, prefix: str) -> str:
return "".join(prefix + line for line in text.splitlines(keepends=True))
class AddExamplesCodemod(VisitorBasedCodemodCommand):
DESCRIPTION = "Add explicit examples to failing tests."
def __init__(self, context, fn_examples, strip_via=(), dec="example", width=88):
"""Add @example() decorator(s) for failing test(s).
`code` is the source code of the module where the test functions are defined.
`fn_examples` is a dict of function name to list-of-failing-examples.
"""
assert fn_examples, "This codemod does nothing without fn_examples."
super().__init__(context)
self.decorator_func = cst.parse_expression(dec)
self.line_length = width
value_in_strip_via = m.MatchIfTrue(lambda x: literal_eval(x.value) in strip_via)
self.strip_matching = m.Call(
m.Attribute(m.Call(), m.Name("via")),
[m.Arg(m.SimpleString() & value_in_strip_via)],
)
# Codemod the failing examples to Call nodes usable as decorators
self.fn_examples = {
k: tuple(self.__call_node_to_example_dec(ex, via) for ex, via in nodes)
for k, nodes in fn_examples.items()
}
def __call_node_to_example_dec(self, node, via):
# If we have black installed, remove trailing comma, _unless_ there's a comment
node = node.with_changes(
func=self.decorator_func,
args=[
a.with_changes(
comma=a.comma
if m.findall(a.comma, m.Comment())
else cst.MaybeSentinel.DEFAULT
)
for a in node.args
]
if black
else node.args,
)
# Note: calling a method on a decorator requires PEP-614, i.e. Python 3.9+,
# but plumbing two cases through doesn't seem worth the trouble :-/
via = cst.Call(
func=cst.Attribute(node, cst.Name("via")),
args=[cst.Arg(cst.SimpleString(repr(via)))],
)
if black: # pragma: no branch
pretty = black.format_str(
cst.Module([]).code_for_node(via),
mode=black.FileMode(line_length=self.line_length),
)
via = cst.parse_expression(pretty.strip())
return cst.Decorator(via)
def leave_FunctionDef(self, _, updated_node):
return updated_node.with_changes(
# TODO: improve logic for where in the list to insert this decorator
decorators=tuple(
d
for d in updated_node.decorators
# `findall()` to see through the identity function workaround on py38
if not m.findall(d, self.strip_matching)
)
+ self.fn_examples.get(updated_node.name.value, ())
)
def get_patch_for(func, failing_examples, *, strip_via=()):
# Skip this if we're unable to find the location or source of this function.
try:
module = sys.modules[func.__module__]
fname = Path(module.__file__).relative_to(Path.cwd())
before = inspect.getsource(func)
except Exception:
return None
# The printed examples might include object reprs which are invalid syntax,
# so we parse here and skip over those. If _none_ are valid, there's no patch.
call_nodes = []
for ex, via in set(failing_examples):
with suppress(Exception):
node = cst.parse_expression(ex)
assert isinstance(node, cst.Call), node
# Check for st.data(), which doesn't support explicit examples
data = m.Arg(m.Call(m.Name("data"), args=[m.Arg(m.Ellipsis())]))
if m.matches(node, m.Call(args=[m.ZeroOrMore(), data, m.ZeroOrMore()])):
return None
call_nodes.append((node, via))
if not call_nodes:
return None
if (
module.__dict__.get("hypothesis") is sys.modules["hypothesis"]
and "given" not in module.__dict__ # more reliably present than `example`
):
decorator_func = "hypothesis.example"
else:
decorator_func = "example"
# Do the codemod and return a triple containing location and replacement info.
dedented, prefix = dedent(before)
try:
node = cst.parse_module(dedented)
except Exception: # pragma: no cover
# inspect.getsource() sometimes returns a decorator alone, which is invalid
return None
after = AddExamplesCodemod(
CodemodContext(),
fn_examples={func.__name__: call_nodes},
strip_via=strip_via,
dec=decorator_func,
width=88 - len(prefix), # to match Black's default formatting
).transform_module(node)
return (str(fname), before, indent(after.code, prefix=prefix))
def make_patch(triples, *, msg="Hypothesis: add explicit examples", when=None):
"""Create a patch for (fname, before, after) triples."""
assert triples, "attempted to create empty patch"
when = when or datetime.now(tz=timezone.utc)
by_fname = {}
for fname, before, after in triples:
by_fname.setdefault(Path(fname), []).append((before, after))
diffs = [HEADER.format(msg=msg, when=when)]
for fname, changes in sorted(by_fname.items()):
source_before = source_after = fname.read_text(encoding="utf-8")
for before, after in changes:
source_after = source_after.replace(before.rstrip(), after.rstrip(), 1)
ud = difflib.unified_diff(
source_before.splitlines(keepends=True),
source_after.splitlines(keepends=True),
fromfile=str(fname),
tofile=str(fname),
)
diffs.append("".join(ud))
return "".join(diffs)
def save_patch(patch: str, *, slug: str = "") -> Path: # pragma: no cover
assert re.fullmatch(r"|[a-z]+-", slug), f"malformed {slug=}"
now = date.today().isoformat()
cleaned = re.sub(r"^Date: .+?$", "", patch, count=1, flags=re.MULTILINE)
hash8 = hashlib.sha1(cleaned.encode()).hexdigest()[:8]
fname = Path(storage_directory("patches", f"{now}--{slug}{hash8}.patch"))
fname.parent.mkdir(parents=True, exist_ok=True)
fname.write_text(patch, encoding="utf-8")
return fname.relative_to(Path.cwd())
def gc_patches(slug: str = "") -> None: # pragma: no cover
cutoff = date.today() - timedelta(days=7)
for fname in Path(storage_directory("patches")).glob(
f"????-??-??--{slug}????????.patch"
):
if date.fromisoformat(fname.stem.split("--")[0]) < cutoff:
fname.unlink()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,345 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
.. _hypothesis-cli:
----------------
hypothesis[cli]
----------------
::
$ hypothesis --help
Usage: hypothesis [OPTIONS] COMMAND [ARGS]...
Options:
--version Show the version and exit.
-h, --help Show this message and exit.
Commands:
codemod `hypothesis codemod` refactors deprecated or inefficient code.
fuzz [hypofuzz] runs tests with an adaptive coverage-guided fuzzer.
write `hypothesis write` writes property-based tests for you!
This module requires the :pypi:`click` package, and provides Hypothesis' command-line
interface, for e.g. :doc:`'ghostwriting' tests <ghostwriter>` via the terminal.
It's also where `HypoFuzz <https://hypofuzz.com/>`__ adds the :command:`hypothesis fuzz`
command (`learn more about that here <https://hypofuzz.com/docs/quickstart.html>`__).
"""
import builtins
import importlib
import inspect
import sys
import types
from difflib import get_close_matches
from functools import partial
from multiprocessing import Pool
from pathlib import Path
try:
import pytest
except ImportError:
pytest = None # type: ignore
MESSAGE = """
The Hypothesis command-line interface requires the `{}` package,
which you do not have installed. Run:
python -m pip install --upgrade 'hypothesis[cli]'
and try again.
"""
try:
import click
except ImportError:
def main():
"""If `click` is not installed, tell the user to install it then exit."""
sys.stderr.write(MESSAGE.format("click"))
sys.exit(1)
else:
# Ensure that Python scripts in the current working directory are importable,
# on the principle that Ghostwriter should 'just work' for novice users. Note
# that we append rather than prepend to the module search path, so this will
# never shadow the stdlib or installed packages.
sys.path.append(".")
@click.group(context_settings={"help_option_names": ("-h", "--help")})
@click.version_option()
def main():
pass
def obj_name(s: str) -> object:
"""This "type" imports whatever object is named by a dotted string."""
s = s.strip()
if "/" in s or "\\" in s:
raise click.UsageError(
"Remember that the ghostwriter should be passed the name of a module, not a path."
) from None
try:
return importlib.import_module(s)
except ImportError:
pass
classname = None
if "." not in s:
modulename, module, funcname = "builtins", builtins, s
else:
modulename, funcname = s.rsplit(".", 1)
try:
module = importlib.import_module(modulename)
except ImportError as err:
try:
modulename, classname = modulename.rsplit(".", 1)
module = importlib.import_module(modulename)
except (ImportError, ValueError):
if s.endswith(".py"):
raise click.UsageError(
"Remember that the ghostwriter should be passed the name of a module, not a file."
) from None
raise click.UsageError(
f"Failed to import the {modulename} module for introspection. "
"Check spelling and your Python import path, or use the Python API?"
) from err
def describe_close_matches(
module_or_class: types.ModuleType, objname: str
) -> str:
public_names = [
name for name in vars(module_or_class) if not name.startswith("_")
]
matches = get_close_matches(objname, public_names)
if matches:
return f" Closest matches: {matches!r}"
else:
return ""
if classname is None:
try:
return getattr(module, funcname)
except AttributeError as err:
if funcname == "py":
# Likely attempted to pass a local file (Eg., "myscript.py") instead of a module name
raise click.UsageError(
"Remember that the ghostwriter should be passed the name of a module, not a file."
f"\n\tTry: hypothesis write {s[:-3]}"
) from None
raise click.UsageError(
f"Found the {modulename!r} module, but it doesn't have a "
f"{funcname!r} attribute."
+ describe_close_matches(module, funcname)
) from err
else:
try:
func_class = getattr(module, classname)
except AttributeError as err:
raise click.UsageError(
f"Found the {modulename!r} module, but it doesn't have a "
f"{classname!r} class." + describe_close_matches(module, classname)
) from err
try:
return getattr(func_class, funcname)
except AttributeError as err:
if inspect.isclass(func_class):
func_class_is = "class"
else:
func_class_is = "attribute"
raise click.UsageError(
f"Found the {modulename!r} module and {classname!r} {func_class_is}, "
f"but it doesn't have a {funcname!r} attribute."
+ describe_close_matches(func_class, funcname)
) from err
def _refactor(func, fname):
try:
oldcode = Path(fname).read_text(encoding="utf-8")
except (OSError, UnicodeError) as err:
# Permissions or encoding issue, or file deleted, etc.
return f"skipping {fname!r} due to {err}"
if "hypothesis" not in oldcode:
return # This is a fast way to avoid running slow no-op codemods
try:
newcode = func(oldcode)
except Exception as err:
from libcst import ParserSyntaxError
if isinstance(err, ParserSyntaxError):
from hypothesis.extra._patching import indent
msg = indent(str(err).replace("\n\n", "\n"), " ").strip()
return f"skipping {fname!r} due to {msg}"
raise
if newcode != oldcode:
Path(fname).write_text(newcode, encoding="utf-8")
@main.command() # type: ignore # Click adds the .command attribute
@click.argument("path", type=str, required=True, nargs=-1)
def codemod(path):
"""`hypothesis codemod` refactors deprecated or inefficient code.
It adapts `python -m libcst.tool`, removing many features and config options
which are rarely relevant for this purpose. If you need more control, we
encourage you to use the libcst CLI directly; if not this one is easier.
PATH is the file(s) or directories of files to format in place, or
"-" to read from stdin and write to stdout.
"""
try:
from libcst.codemod import gather_files
from hypothesis.extra import codemods
except ImportError:
sys.stderr.write(
"You are missing required dependencies for this option. Run:\n\n"
" python -m pip install --upgrade hypothesis[codemods]\n\n"
"and try again."
)
sys.exit(1)
# Special case for stdin/stdout usage
if "-" in path:
if len(path) > 1:
raise Exception(
"Cannot specify multiple paths when reading from stdin!"
)
print("Codemodding from stdin", file=sys.stderr)
print(codemods.refactor(sys.stdin.read()))
return 0
# Find all the files to refactor, and then codemod them
files = gather_files(path)
errors = set()
if len(files) <= 1:
errors.add(_refactor(codemods.refactor, *files))
else:
with Pool() as pool:
for msg in pool.imap_unordered(
partial(_refactor, codemods.refactor), files
):
errors.add(msg)
errors.discard(None)
for msg in errors:
print(msg, file=sys.stderr)
return 1 if errors else 0
@main.command() # type: ignore # Click adds the .command attribute
@click.argument("func", type=obj_name, required=True, nargs=-1)
@click.option(
"--roundtrip",
"writer",
flag_value="roundtrip",
help="start by testing write/read or encode/decode!",
)
@click.option(
"--equivalent",
"writer",
flag_value="equivalent",
help="very useful when optimising or refactoring code",
)
@click.option(
"--errors-equivalent",
"writer",
flag_value="errors-equivalent",
help="--equivalent, but also allows consistent errors",
)
@click.option(
"--idempotent",
"writer",
flag_value="idempotent",
help="check that f(x) == f(f(x))",
)
@click.option(
"--binary-op",
"writer",
flag_value="binary_operation",
help="associativity, commutativity, identity element",
)
# Note: we deliberately omit a --ufunc flag, because the magic()
# detection of ufuncs is both precise and complete.
@click.option(
"--style",
type=click.Choice(["pytest", "unittest"]),
default="pytest" if pytest else "unittest",
help="pytest-style function, or unittest-style method?",
)
@click.option(
"-e",
"--except",
"except_",
type=obj_name,
multiple=True,
help="dotted name of exception(s) to ignore",
)
@click.option(
"--annotate/--no-annotate",
default=None,
help="force ghostwritten tests to be type-annotated (or not). "
"By default, match the code to test.",
)
def write(func, writer, except_, style, annotate): # \b disables autowrap
"""`hypothesis write` writes property-based tests for you!
Type annotations are helpful but not required for our advanced introspection
and templating logic. Try running the examples below to see how it works:
\b
hypothesis write gzip
hypothesis write numpy.matmul
hypothesis write pandas.from_dummies
hypothesis write re.compile --except re.error
hypothesis write --equivalent ast.literal_eval eval
hypothesis write --roundtrip json.dumps json.loads
hypothesis write --style=unittest --idempotent sorted
hypothesis write --binary-op operator.add
"""
# NOTE: if you want to call this function from Python, look instead at the
# ``hypothesis.extra.ghostwriter`` module. Click-decorated functions have
# a different calling convention, and raise SystemExit instead of returning.
kwargs = {"except_": except_ or (), "style": style, "annotate": annotate}
if writer is None:
writer = "magic"
elif writer == "idempotent" and len(func) > 1:
raise click.UsageError("Test functions for idempotence one at a time.")
elif writer == "roundtrip" and len(func) == 1:
writer = "idempotent"
elif "equivalent" in writer and len(func) == 1:
writer = "fuzz"
if writer == "errors-equivalent":
writer = "equivalent"
kwargs["allow_same_errors"] = True
try:
from hypothesis.extra import ghostwriter
except ImportError:
sys.stderr.write(MESSAGE.format("black"))
sys.exit(1)
code = getattr(ghostwriter, writer)(*func, **kwargs)
try:
from rich.console import Console
from rich.syntax import Syntax
from hypothesis.utils.terminal import guess_background_color
except ImportError:
print(code)
else:
try:
theme = "default" if guess_background_color() == "light" else "monokai"
code = Syntax(code, "python", background_color="default", theme=theme)
Console().print(code, soft_wrap=True)
except Exception:
print("# Error while syntax-highlighting code", file=sys.stderr)
print(code)

View File

@@ -0,0 +1,284 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
.. _codemods:
--------------------
hypothesis[codemods]
--------------------
This module provides codemods based on the :pypi:`LibCST` library, which can
both detect *and automatically fix* issues with code that uses Hypothesis,
including upgrading from deprecated features to our recommended style.
You can run the codemods via our CLI::
$ hypothesis codemod --help
Usage: hypothesis codemod [OPTIONS] PATH...
`hypothesis codemod` refactors deprecated or inefficient code.
It adapts `python -m libcst.tool`, removing many features and config
options which are rarely relevant for this purpose. If you need more
control, we encourage you to use the libcst CLI directly; if not this one
is easier.
PATH is the file(s) or directories of files to format in place, or "-" to
read from stdin and write to stdout.
Options:
-h, --help Show this message and exit.
Alternatively you can use ``python -m libcst.tool``, which offers more control
at the cost of additional configuration (adding ``'hypothesis.extra'`` to the
``modules`` list in ``.libcst.codemod.yaml``) and `some issues on Windows
<https://github.com/Instagram/LibCST/issues/435>`__.
.. autofunction:: refactor
"""
import functools
import importlib
from inspect import Parameter, signature
from typing import ClassVar, List
import libcst as cst
import libcst.matchers as m
from libcst.codemod import VisitorBasedCodemodCommand
def refactor(code: str) -> str:
"""Update a source code string from deprecated to modern Hypothesis APIs.
This may not fix *all* the deprecation warnings in your code, but we're
confident that it will be easier than doing it all by hand.
We recommend using the CLI, but if you want a Python function here it is.
"""
context = cst.codemod.CodemodContext()
mod = cst.parse_module(code)
transforms: List[VisitorBasedCodemodCommand] = [
HypothesisFixPositionalKeywonlyArgs(context),
HypothesisFixComplexMinMagnitude(context),
HypothesisFixHealthcheckAll(context),
HypothesisFixCharactersArguments(context),
]
for transform in transforms:
mod = transform.transform_module(mod)
return mod.code
def match_qualname(name):
# We use the metadata to get qualname instead of matching directly on function
# name, because this handles some scope and "from x import y as z" issues.
return m.MatchMetadataIfTrue(
cst.metadata.QualifiedNameProvider,
# If there are multiple possible qualnames, e.g. due to conditional imports,
# be conservative. Better to leave the user to fix a few things by hand than
# to break their code while attempting to refactor it!
lambda qualnames: all(n.name == name for n in qualnames),
)
class HypothesisFixComplexMinMagnitude(VisitorBasedCodemodCommand):
"""Fix a deprecated min_magnitude=None argument for complex numbers::
st.complex_numbers(min_magnitude=None) -> st.complex_numbers(min_magnitude=0)
Note that this should be run *after* ``HypothesisFixPositionalKeywonlyArgs``,
in order to handle ``st.complex_numbers(None)``.
"""
DESCRIPTION = "Fix a deprecated min_magnitude=None argument for complex numbers."
METADATA_DEPENDENCIES = (cst.metadata.QualifiedNameProvider,)
@m.call_if_inside(
m.Call(metadata=match_qualname("hypothesis.strategies.complex_numbers"))
)
def leave_Arg(self, original_node, updated_node):
if m.matches(
updated_node, m.Arg(keyword=m.Name("min_magnitude"), value=m.Name("None"))
):
return updated_node.with_changes(value=cst.Integer("0"))
return updated_node
@functools.lru_cache
def get_fn(import_path):
mod, fn = import_path.rsplit(".", 1)
return getattr(importlib.import_module(mod), fn)
class HypothesisFixPositionalKeywonlyArgs(VisitorBasedCodemodCommand):
"""Fix positional arguments for newly keyword-only parameters, e.g.::
st.fractions(0, 1, 9) -> st.fractions(0, 1, max_denominator=9)
Applies to a majority of our public API, since keyword-only parameters are
great but we couldn't use them until after we dropped support for Python 2.
"""
DESCRIPTION = "Fix positional arguments for newly keyword-only parameters."
METADATA_DEPENDENCIES = (cst.metadata.QualifiedNameProvider,)
kwonly_functions = (
"hypothesis.target",
"hypothesis.find",
"hypothesis.extra.lark.from_lark",
"hypothesis.extra.numpy.arrays",
"hypothesis.extra.numpy.array_shapes",
"hypothesis.extra.numpy.unsigned_integer_dtypes",
"hypothesis.extra.numpy.integer_dtypes",
"hypothesis.extra.numpy.floating_dtypes",
"hypothesis.extra.numpy.complex_number_dtypes",
"hypothesis.extra.numpy.datetime64_dtypes",
"hypothesis.extra.numpy.timedelta64_dtypes",
"hypothesis.extra.numpy.byte_string_dtypes",
"hypothesis.extra.numpy.unicode_string_dtypes",
"hypothesis.extra.numpy.array_dtypes",
"hypothesis.extra.numpy.nested_dtypes",
"hypothesis.extra.numpy.valid_tuple_axes",
"hypothesis.extra.numpy.broadcastable_shapes",
"hypothesis.extra.pandas.indexes",
"hypothesis.extra.pandas.series",
"hypothesis.extra.pandas.columns",
"hypothesis.extra.pandas.data_frames",
"hypothesis.provisional.domains",
"hypothesis.stateful.run_state_machine_as_test",
"hypothesis.stateful.rule",
"hypothesis.stateful.initialize",
"hypothesis.strategies.floats",
"hypothesis.strategies.lists",
"hypothesis.strategies.sets",
"hypothesis.strategies.frozensets",
"hypothesis.strategies.iterables",
"hypothesis.strategies.dictionaries",
"hypothesis.strategies.characters",
"hypothesis.strategies.text",
"hypothesis.strategies.from_regex",
"hypothesis.strategies.binary",
"hypothesis.strategies.fractions",
"hypothesis.strategies.decimals",
"hypothesis.strategies.recursive",
"hypothesis.strategies.complex_numbers",
"hypothesis.strategies.shared",
"hypothesis.strategies.uuids",
"hypothesis.strategies.runner",
"hypothesis.strategies.functions",
"hypothesis.strategies.datetimes",
"hypothesis.strategies.times",
)
def leave_Call(self, original_node, updated_node):
"""Convert positional to keyword arguments."""
metadata = self.get_metadata(cst.metadata.QualifiedNameProvider, original_node)
qualnames = {qn.name for qn in metadata}
# If this isn't one of our known functions, or it has no posargs, stop there.
if (
len(qualnames) != 1
or not qualnames.intersection(self.kwonly_functions)
or not m.matches(
updated_node,
m.Call(
func=m.DoesNotMatch(m.Call()),
args=[m.Arg(keyword=None), m.ZeroOrMore()],
),
)
):
return updated_node
# Get the actual function object so that we can inspect the signature.
# This does e.g. incur a dependency on Numpy to fix Numpy-dependent code,
# but having a single source of truth about the signatures is worth it.
try:
params = signature(get_fn(*qualnames)).parameters.values()
except ModuleNotFoundError:
return updated_node
# st.floats() has a new allow_subnormal kwonly argument not at the end,
# so we do a bit more of a dance here.
if qualnames == {"hypothesis.strategies.floats"}:
params = [p for p in params if p.name != "allow_subnormal"]
if len(updated_node.args) > len(params):
return updated_node
# Create new arg nodes with the newly required keywords
assign_nospace = cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
)
newargs = [
arg
if arg.keyword or arg.star or p.kind is not Parameter.KEYWORD_ONLY
else arg.with_changes(keyword=cst.Name(p.name), equal=assign_nospace)
for p, arg in zip(params, updated_node.args)
]
return updated_node.with_changes(args=newargs)
class HypothesisFixHealthcheckAll(VisitorBasedCodemodCommand):
"""Replace Healthcheck.all() with list(Healthcheck)"""
DESCRIPTION = "Replace Healthcheck.all() with list(Healthcheck)"
@m.leave(m.Call(func=m.Attribute(m.Name("Healthcheck"), m.Name("all")), args=[]))
def replace_healthcheck(self, original_node, updated_node):
return updated_node.with_changes(
func=cst.Name("list"),
args=[cst.Arg(value=cst.Name("Healthcheck"))],
)
class HypothesisFixCharactersArguments(VisitorBasedCodemodCommand):
"""Fix deprecated white/blacklist arguments to characters::
st.characters(whitelist_categories=...) -> st.characters(categories=...)
st.characters(blacklist_categories=...) -> st.characters(exclude_categories=...)
st.characters(whitelist_characters=...) -> st.characters(include_characters=...)
st.characters(blacklist_characters=...) -> st.characters(exclude_characters=...)
Additionally, we drop `exclude_categories=` if `categories=` is present,
because this argument is always redundant (or an error).
"""
DESCRIPTION = "Fix deprecated white/blacklist arguments to characters."
METADATA_DEPENDENCIES = (cst.metadata.QualifiedNameProvider,)
_replacements: ClassVar = {
"whitelist_categories": "categories",
"blacklist_categories": "exclude_categories",
"whitelist_characters": "include_characters",
"blacklist_characters": "exclude_characters",
}
@m.leave(
m.Call(
metadata=match_qualname("hypothesis.strategies.characters"),
args=[
m.ZeroOrMore(),
m.Arg(keyword=m.OneOf(*map(m.Name, _replacements))),
m.ZeroOrMore(),
],
),
)
def fn(self, original_node, updated_node):
# Update to the new names
newargs = []
for arg in updated_node.args:
kw = self._replacements.get(arg.keyword.value, arg.keyword.value)
newargs.append(arg.with_changes(keyword=cst.Name(kw)))
# Drop redundant exclude_categories, which is now an error
if any(m.matches(arg, m.Arg(keyword=m.Name("categories"))) for arg in newargs):
ex = m.Arg(keyword=m.Name("exclude_categories"))
newargs = [a for a in newargs if m.matches(a, ~ex)]
return updated_node.with_changes(args=newargs)

View File

@@ -0,0 +1,64 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
--------------------
hypothesis[dateutil]
--------------------
This module provides :pypi:`dateutil <python-dateutil>` timezones.
You can use this strategy to make :func:`~hypothesis.strategies.datetimes`
and :func:`~hypothesis.strategies.times` produce timezone-aware values.
"""
import datetime as dt
from dateutil import tz, zoneinfo # type: ignore
from hypothesis import strategies as st
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
__all__ = ["timezones"]
def __zone_sort_key(zone):
"""Sort by absolute UTC offset at reference date,
positive first, with ties broken by name.
"""
assert zone is not None
offset = zone.utcoffset(dt.datetime(2000, 1, 1))
offset = 999 if offset is None else offset
return (abs(offset), -offset, str(zone))
@cacheable
@defines_strategy()
def timezones() -> st.SearchStrategy[dt.tzinfo]:
"""Any timezone from :pypi:`dateutil <python-dateutil>`.
This strategy minimises to UTC, or the timezone with the smallest offset
from UTC as of 2000-01-01, and is designed for use with
:py:func:`~hypothesis.strategies.datetimes`.
Note that the timezones generated by the strategy may vary depending on the
configuration of your machine. See the dateutil documentation for more
information.
"""
all_timezones = sorted(
(tz.gettz(t) for t in zoneinfo.get_zonefile_instance().zones),
key=__zone_sort_key,
)
all_timezones.insert(0, tz.UTC)
# We discard Nones in the list comprehension because Mypy knows that
# tz.gettz may return None. However this should never happen for known
# zone names, so we assert that it's impossible first.
assert None not in all_timezones
return st.sampled_from([z for z in all_timezones if z is not None])

View File

@@ -0,0 +1,30 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from hypothesis.extra.django._fields import from_field, register_field_strategy
from hypothesis.extra.django._impl import (
LiveServerTestCase,
StaticLiveServerTestCase,
TestCase,
TransactionTestCase,
from_form,
from_model,
)
__all__ = [
"LiveServerTestCase",
"StaticLiveServerTestCase",
"TestCase",
"TransactionTestCase",
"from_field",
"from_model",
"register_field_strategy",
"from_form",
]

View File

@@ -0,0 +1,343 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import re
import string
from datetime import timedelta
from decimal import Decimal
from functools import lru_cache
from typing import Any, Callable, Dict, Type, TypeVar, Union
import django
from django import forms as df
from django.contrib.auth.forms import UsernameField
from django.core.validators import (
validate_ipv4_address,
validate_ipv6_address,
validate_ipv46_address,
)
from django.db import models as dm
from hypothesis import strategies as st
from hypothesis.errors import InvalidArgument, ResolutionFailed
from hypothesis.internal.validation import check_type
from hypothesis.provisional import urls
from hypothesis.strategies import emails
AnyField = Union[dm.Field, df.Field]
F = TypeVar("F", bound=AnyField)
def numeric_bounds_from_validators(
field, min_value=float("-inf"), max_value=float("inf")
):
for v in field.validators:
if isinstance(v, django.core.validators.MinValueValidator):
min_value = max(min_value, v.limit_value)
elif isinstance(v, django.core.validators.MaxValueValidator):
max_value = min(max_value, v.limit_value)
return min_value, max_value
def integers_for_field(min_value, max_value):
def inner(field):
return st.integers(*numeric_bounds_from_validators(field, min_value, max_value))
return inner
@lru_cache
def timezones():
# From Django 4.0, the default is to use zoneinfo instead of pytz.
assert getattr(django.conf.settings, "USE_TZ", False)
if getattr(django.conf.settings, "USE_DEPRECATED_PYTZ", True):
from hypothesis.extra.pytz import timezones
else:
from hypothesis.strategies import timezones
return timezones()
# Mapping of field types, to strategy objects or functions of (type) -> strategy
_FieldLookUpType = Dict[
Type[AnyField],
Union[st.SearchStrategy, Callable[[Any], st.SearchStrategy]],
]
_global_field_lookup: _FieldLookUpType = {
dm.SmallIntegerField: integers_for_field(-32768, 32767),
dm.IntegerField: integers_for_field(-2147483648, 2147483647),
dm.BigIntegerField: integers_for_field(-9223372036854775808, 9223372036854775807),
dm.PositiveIntegerField: integers_for_field(0, 2147483647),
dm.PositiveSmallIntegerField: integers_for_field(0, 32767),
dm.BooleanField: st.booleans(),
dm.DateField: st.dates(),
dm.EmailField: emails(),
dm.FloatField: st.floats(),
dm.NullBooleanField: st.one_of(st.none(), st.booleans()),
dm.URLField: urls(),
dm.UUIDField: st.uuids(),
df.DateField: st.dates(),
df.DurationField: st.timedeltas(),
df.EmailField: emails(),
df.FloatField: lambda field: st.floats(
*numeric_bounds_from_validators(field), allow_nan=False, allow_infinity=False
),
df.IntegerField: integers_for_field(-2147483648, 2147483647),
df.NullBooleanField: st.one_of(st.none(), st.booleans()),
df.URLField: urls(),
df.UUIDField: st.uuids(),
}
_ipv6_strings = st.one_of(
st.ip_addresses(v=6).map(str),
st.ip_addresses(v=6).map(lambda addr: addr.exploded),
)
def register_for(field_type):
def inner(func):
_global_field_lookup[field_type] = func
return func
return inner
@register_for(dm.DateTimeField)
@register_for(df.DateTimeField)
def _for_datetime(field):
if getattr(django.conf.settings, "USE_TZ", False):
return st.datetimes(timezones=timezones())
return st.datetimes()
def using_sqlite():
try:
return (
getattr(django.conf.settings, "DATABASES", {})
.get("default", {})
.get("ENGINE", "")
.endswith(".sqlite3")
)
except django.core.exceptions.ImproperlyConfigured:
return None
@register_for(dm.TimeField)
def _for_model_time(field):
# SQLITE supports TZ-aware datetimes, but not TZ-aware times.
if getattr(django.conf.settings, "USE_TZ", False) and not using_sqlite():
return st.times(timezones=timezones())
return st.times()
@register_for(df.TimeField)
def _for_form_time(field):
if getattr(django.conf.settings, "USE_TZ", False):
return st.times(timezones=timezones())
return st.times()
@register_for(dm.DurationField)
def _for_duration(field):
# SQLite stores timedeltas as six bytes of microseconds
if using_sqlite():
delta = timedelta(microseconds=2**47 - 1)
return st.timedeltas(-delta, delta)
return st.timedeltas()
@register_for(dm.SlugField)
@register_for(df.SlugField)
def _for_slug(field):
min_size = 1
if getattr(field, "blank", False) or not getattr(field, "required", True):
min_size = 0
return st.text(
alphabet=string.ascii_letters + string.digits,
min_size=min_size,
max_size=field.max_length,
)
@register_for(dm.GenericIPAddressField)
def _for_model_ip(field):
return {
"ipv4": st.ip_addresses(v=4).map(str),
"ipv6": _ipv6_strings,
"both": st.ip_addresses(v=4).map(str) | _ipv6_strings,
}[field.protocol.lower()]
@register_for(df.GenericIPAddressField)
def _for_form_ip(field):
# the IP address form fields have no direct indication of which type
# of address they want, so direct comparison with the validator
# function has to be used instead. Sorry for the potato logic here
if validate_ipv46_address in field.default_validators:
return st.ip_addresses(v=4).map(str) | _ipv6_strings
if validate_ipv4_address in field.default_validators:
return st.ip_addresses(v=4).map(str)
if validate_ipv6_address in field.default_validators:
return _ipv6_strings
raise ResolutionFailed(f"No IP version validator on {field=}")
@register_for(dm.DecimalField)
@register_for(df.DecimalField)
def _for_decimal(field):
min_value, max_value = numeric_bounds_from_validators(field)
bound = Decimal(10**field.max_digits - 1) / (10**field.decimal_places)
return st.decimals(
min_value=max(min_value, -bound),
max_value=min(max_value, bound),
places=field.decimal_places,
)
def length_bounds_from_validators(field):
min_size = 1
max_size = field.max_length
for v in field.validators:
if isinstance(v, django.core.validators.MinLengthValidator):
min_size = max(min_size, v.limit_value)
elif isinstance(v, django.core.validators.MaxLengthValidator):
max_size = min(max_size or v.limit_value, v.limit_value)
return min_size, max_size
@register_for(dm.BinaryField)
def _for_binary(field):
min_size, max_size = length_bounds_from_validators(field)
if getattr(field, "blank", False) or not getattr(field, "required", True):
return st.just(b"") | st.binary(min_size=min_size, max_size=max_size)
return st.binary(min_size=min_size, max_size=max_size)
@register_for(dm.CharField)
@register_for(dm.TextField)
@register_for(df.CharField)
@register_for(df.RegexField)
@register_for(UsernameField)
def _for_text(field):
# We can infer a vastly more precise strategy by considering the
# validators as well as the field type. This is a minimal proof of
# concept, but we intend to leverage the idea much more heavily soon.
# See https://github.com/HypothesisWorks/hypothesis-python/issues/1116
regexes = [
re.compile(v.regex, v.flags) if isinstance(v.regex, str) else v.regex
for v in field.validators
if isinstance(v, django.core.validators.RegexValidator) and not v.inverse_match
]
if regexes:
# This strategy generates according to one of the regexes, and
# filters using the others. It can therefore learn to generate
# from the most restrictive and filter with permissive patterns.
# Not maximally efficient, but it makes pathological cases rarer.
# If you want a challenge: extend https://qntm.org/greenery to
# compute intersections of the full Python regex language.
return st.one_of(*(st.from_regex(r) for r in regexes))
# If there are no (usable) regexes, we use a standard text strategy.
min_size, max_size = length_bounds_from_validators(field)
strategy = st.text(
alphabet=st.characters(exclude_characters="\x00", exclude_categories=("Cs",)),
min_size=min_size,
max_size=max_size,
).filter(lambda s: min_size <= len(s.strip()))
if getattr(field, "blank", False) or not getattr(field, "required", True):
return st.just("") | strategy
return strategy
@register_for(df.BooleanField)
def _for_form_boolean(field):
if field.required:
return st.just(True)
return st.booleans()
def register_field_strategy(
field_type: Type[AnyField], strategy: st.SearchStrategy
) -> None:
"""Add an entry to the global field-to-strategy lookup used by
:func:`~hypothesis.extra.django.from_field`.
``field_type`` must be a subtype of :class:`django.db.models.Field` or
:class:`django.forms.Field`, which must not already be registered.
``strategy`` must be a :class:`~hypothesis.strategies.SearchStrategy`.
"""
if not issubclass(field_type, (dm.Field, df.Field)):
raise InvalidArgument(f"{field_type=} must be a subtype of Field")
check_type(st.SearchStrategy, strategy, "strategy")
if field_type in _global_field_lookup:
raise InvalidArgument(
f"{field_type=} already has a registered "
f"strategy ({_global_field_lookup[field_type]!r})"
)
if issubclass(field_type, dm.AutoField):
raise InvalidArgument("Cannot register a strategy for an AutoField")
_global_field_lookup[field_type] = strategy
def from_field(field: F) -> st.SearchStrategy[Union[F, None]]:
"""Return a strategy for values that fit the given field.
This function is used by :func:`~hypothesis.extra.django.from_form` and
:func:`~hypothesis.extra.django.from_model` for any fields that require
a value, or for which you passed ``...`` (:obj:`python:Ellipsis`) to infer
a strategy from an annotation.
It's pretty similar to the core :func:`~hypothesis.strategies.from_type`
function, with a subtle but important difference: ``from_field`` takes a
Field *instance*, rather than a Field *subtype*, so that it has access to
instance attributes such as string length and validators.
"""
check_type((dm.Field, df.Field), field, "field")
if getattr(field, "choices", False):
choices: list = []
for value, name_or_optgroup in field.choices:
if isinstance(name_or_optgroup, (list, tuple)):
choices.extend(key for key, _ in name_or_optgroup)
else:
choices.append(value)
# form fields automatically include an empty choice, strip it out
if "" in choices:
choices.remove("")
min_size = 1
if isinstance(field, (dm.CharField, dm.TextField)) and field.blank:
choices.insert(0, "")
elif isinstance(field, (df.Field)) and not field.required:
choices.insert(0, "")
min_size = 0
strategy = st.sampled_from(choices)
if isinstance(field, (df.MultipleChoiceField, df.TypedMultipleChoiceField)):
strategy = st.lists(st.sampled_from(choices), min_size=min_size)
else:
if type(field) not in _global_field_lookup:
if getattr(field, "null", False):
return st.none()
raise ResolutionFailed(f"Could not infer a strategy for {field!r}")
strategy = _global_field_lookup[type(field)] # type: ignore
if not isinstance(strategy, st.SearchStrategy):
strategy = strategy(field)
assert isinstance(strategy, st.SearchStrategy)
if field.validators:
def validate(value):
try:
field.run_validators(value)
return True
except django.core.exceptions.ValidationError:
return False
strategy = strategy.filter(validate)
if getattr(field, "null", False):
return st.none() | strategy
return strategy

View File

@@ -0,0 +1,217 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import sys
import unittest
from functools import partial
from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union
from django import forms as df, test as dt
from django.contrib.staticfiles import testing as dst
from django.core.exceptions import ValidationError
from django.db import IntegrityError, models as dm
from hypothesis import reject, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra.django._fields import from_field
from hypothesis.strategies._internal.utils import defines_strategy
if sys.version_info >= (3, 10):
from types import EllipsisType as EllipsisType
elif TYPE_CHECKING:
from builtins import ellipsis as EllipsisType
else:
EllipsisType = type(Ellipsis)
ModelT = TypeVar("ModelT", bound=dm.Model)
class HypothesisTestCase:
def setup_example(self):
self._pre_setup()
def teardown_example(self, example):
self._post_teardown()
def __call__(self, result=None):
testMethod = getattr(self, self._testMethodName)
if getattr(testMethod, "is_hypothesis_test", False):
return unittest.TestCase.__call__(self, result)
else:
return dt.SimpleTestCase.__call__(self, result)
class TestCase(HypothesisTestCase, dt.TestCase):
pass
class TransactionTestCase(HypothesisTestCase, dt.TransactionTestCase):
pass
class LiveServerTestCase(HypothesisTestCase, dt.LiveServerTestCase):
pass
class StaticLiveServerTestCase(HypothesisTestCase, dst.StaticLiveServerTestCase):
pass
@defines_strategy()
def from_model(
model: Type[ModelT], /, **field_strategies: Union[st.SearchStrategy, EllipsisType]
) -> st.SearchStrategy[ModelT]:
"""Return a strategy for examples of ``model``.
.. warning::
Hypothesis creates saved models. This will run inside your testing
transaction when using the test runner, but if you use the dev console
this will leave debris in your database.
``model`` must be an subclass of :class:`~django:django.db.models.Model`.
Strategies for fields may be passed as keyword arguments, for example
``is_staff=st.just(False)``. In order to support models with fields named
"model", this is a positional-only parameter.
Hypothesis can often infer a strategy based the field type and validators,
and will attempt to do so for any required fields. No strategy will be
inferred for an :class:`~django:django.db.models.AutoField`, nullable field,
foreign key, or field for which a keyword
argument is passed to ``from_model()``. For example,
a Shop type with a foreign key to Company could be generated with::
shop_strategy = from_model(Shop, company=from_model(Company))
Like for :func:`~hypothesis.strategies.builds`, you can pass
``...`` (:obj:`python:Ellipsis`) as a keyword argument to infer a strategy for
a field which has a default value instead of using the default.
"""
if not issubclass(model, dm.Model):
raise InvalidArgument(f"{model=} must be a subtype of Model")
fields_by_name = {f.name: f for f in model._meta.concrete_fields}
for name, value in sorted(field_strategies.items()):
if value is ...:
field_strategies[name] = from_field(fields_by_name[name])
for name, field in sorted(fields_by_name.items()):
if (
name not in field_strategies
and not field.auto_created
and field.default is dm.fields.NOT_PROVIDED
):
field_strategies[name] = from_field(field)
for field in field_strategies:
if model._meta.get_field(field).primary_key:
# The primary key is generated as part of the strategy. We
# want to find any existing row with this primary key and
# overwrite its contents.
kwargs = {field: field_strategies.pop(field)}
kwargs["defaults"] = st.fixed_dictionaries(field_strategies) # type: ignore
return _models_impl(st.builds(model.objects.update_or_create, **kwargs))
# The primary key is not generated as part of the strategy, so we
# just match against any row that has the same value for all
# fields.
return _models_impl(st.builds(model.objects.get_or_create, **field_strategies))
@st.composite
def _models_impl(draw, strat):
"""Handle the nasty part of drawing a value for models()"""
try:
return draw(strat)[0]
except IntegrityError:
reject()
@defines_strategy()
def from_form(
form: Type[df.Form],
form_kwargs: Optional[dict] = None,
**field_strategies: Union[st.SearchStrategy, EllipsisType],
) -> st.SearchStrategy[df.Form]:
"""Return a strategy for examples of ``form``.
``form`` must be an subclass of :class:`~django:django.forms.Form`.
Strategies for fields may be passed as keyword arguments, for example
``is_staff=st.just(False)``.
Hypothesis can often infer a strategy based the field type and validators,
and will attempt to do so for any required fields. No strategy will be
inferred for a disabled field or field for which a keyword argument
is passed to ``from_form()``.
This function uses the fields of an unbound ``form`` instance to determine
field strategies, any keyword arguments needed to instantiate the unbound
``form`` instance can be passed into ``from_form()`` as a dict with the
keyword ``form_kwargs``. E.g.::
shop_strategy = from_form(Shop, form_kwargs={"company_id": 5})
Like for :func:`~hypothesis.strategies.builds`, you can pass
``...`` (:obj:`python:Ellipsis`) as a keyword argument to infer a strategy for
a field which has a default value instead of using the default.
"""
# currently unsupported:
# ComboField
# FilePathField
# FileField
# ImageField
form_kwargs = form_kwargs or {}
if not issubclass(form, df.BaseForm):
raise InvalidArgument(f"{form=} must be a subtype of Form")
# Forms are a little bit different from models. Model classes have
# all their fields defined, whereas forms may have different fields
# per-instance. So, we ought to instantiate the form and get the
# fields from the instance, thus we need to accept the kwargs for
# instantiation as well as the explicitly defined strategies
unbound_form = form(**form_kwargs)
fields_by_name = {}
for name, field in unbound_form.fields.items():
if isinstance(field, df.MultiValueField):
# PS: So this is a little strange, but MultiValueFields must
# have their form data encoded in a particular way for the
# values to actually be picked up by the widget instances'
# ``value_from_datadict``.
# E.g. if a MultiValueField named 'mv_field' has 3
# sub-fields then the ``value_from_datadict`` will look for
# 'mv_field_0', 'mv_field_1', and 'mv_field_2'. Here I'm
# decomposing the individual sub-fields into the names that
# the form validation process expects
for i, _field in enumerate(field.fields):
fields_by_name[f"{name}_{i}"] = _field
else:
fields_by_name[name] = field
for name, value in sorted(field_strategies.items()):
if value is ...:
field_strategies[name] = from_field(fields_by_name[name])
for name, field in sorted(fields_by_name.items()):
if name not in field_strategies and not field.disabled:
field_strategies[name] = from_field(field)
return _forms_impl(
st.builds(
partial(form, **form_kwargs),
data=st.fixed_dictionaries(field_strategies), # type: ignore
)
)
@st.composite
def _forms_impl(draw, strat):
"""Handle the nasty part of drawing a value for from_form()"""
try:
return draw(strat)
except ValidationError:
reject()

View File

@@ -0,0 +1,53 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
-----------------------
hypothesis[dpcontracts]
-----------------------
This module provides tools for working with the :pypi:`dpcontracts` library,
because `combining contracts and property-based testing works really well
<https://hillelwayne.com/talks/beyond-unit-tests/>`_.
It requires ``dpcontracts >= 0.4``.
"""
from dpcontracts import PreconditionError
from hypothesis import reject
from hypothesis.errors import InvalidArgument
from hypothesis.internal.reflection import proxies
def fulfill(contract_func):
"""Decorate ``contract_func`` to reject calls which violate preconditions,
and retry them with different arguments.
This is a convenience function for testing internal code that uses
:pypi:`dpcontracts`, to automatically filter out arguments that would be
rejected by the public interface before triggering a contract error.
This can be used as ``builds(fulfill(func), ...)`` or in the body of the
test e.g. ``assert fulfill(func)(*args)``.
"""
if not hasattr(contract_func, "__contract_wrapped_func__"):
raise InvalidArgument(
f"{contract_func.__name__} has no dpcontracts preconditions"
)
@proxies(contract_func)
def inner(*args, **kwargs):
try:
return contract_func(*args, **kwargs)
except PreconditionError:
reject()
return inner

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
----------------
hypothesis[lark]
----------------
This extra can be used to generate strings matching any context-free grammar,
using the `Lark parser library <https://github.com/lark-parser/lark>`_.
It currently only supports Lark's native EBNF syntax, but we plan to extend
this to support other common syntaxes such as ANTLR and :rfc:`5234` ABNF.
Lark already `supports loading grammars
<https://lark-parser.readthedocs.io/en/latest/nearley.html>`_
from `nearley.js <https://nearley.js.org/>`_, so you may not have to write
your own at all.
"""
from inspect import signature
from typing import Dict, Optional
import lark
from lark.grammar import NonTerminal, Terminal
from hypothesis import strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.validation import check_type
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
__all__ = ["from_lark"]
def get_terminal_names(terminals, rules, ignore_names):
"""Get names of all terminals in the grammar.
The arguments are the results of calling ``Lark.grammar.compile()``,
so you would think that the ``terminals`` and ``ignore_names`` would
have it all... but they omit terminals created with ``@declare``,
which appear only in the expansion(s) of nonterminals.
"""
names = {t.name for t in terminals} | set(ignore_names)
for rule in rules:
names |= {t.name for t in rule.expansion if isinstance(t, Terminal)}
return names
class LarkStrategy(st.SearchStrategy):
"""Low-level strategy implementation wrapping a Lark grammar.
See ``from_lark`` for details.
"""
def __init__(self, grammar, start, explicit):
assert isinstance(grammar, lark.lark.Lark)
if start is None:
start = grammar.options.start
if not isinstance(start, list):
start = [start]
self.grammar = grammar
# This is a total hack, but working around the changes is a nicer user
# experience than breaking for anyone who doesn't instantly update their
# installation of Lark alongside Hypothesis.
compile_args = signature(grammar.grammar.compile).parameters
if "terminals_to_keep" in compile_args:
terminals, rules, ignore_names = grammar.grammar.compile(start, ())
elif "start" in compile_args: # pragma: no cover
# Support lark <= 0.10.0, without the terminals_to_keep argument.
terminals, rules, ignore_names = grammar.grammar.compile(start)
else: # pragma: no cover
# This branch is to support lark <= 0.7.1, without the start argument.
terminals, rules, ignore_names = grammar.grammar.compile()
self.names_to_symbols = {}
for r in rules:
t = r.origin
self.names_to_symbols[t.name] = t
for t in terminals:
self.names_to_symbols[t.name] = Terminal(t.name)
self.start = st.sampled_from([self.names_to_symbols[s] for s in start])
self.ignored_symbols = tuple(self.names_to_symbols[n] for n in ignore_names)
self.terminal_strategies = {
t.name: st.from_regex(t.pattern.to_regexp(), fullmatch=True)
for t in terminals
}
unknown_explicit = set(explicit) - get_terminal_names(
terminals, rules, ignore_names
)
if unknown_explicit:
raise InvalidArgument(
"The following arguments were passed as explicit_strategies, "
"but there is no such terminal production in this grammar: "
+ repr(sorted(unknown_explicit))
)
self.terminal_strategies.update(explicit)
nonterminals = {}
for rule in rules:
nonterminals.setdefault(rule.origin.name, []).append(tuple(rule.expansion))
for v in nonterminals.values():
v.sort(key=len)
self.nonterminal_strategies = {
k: st.sampled_from(v) for k, v in nonterminals.items()
}
self.__rule_labels = {}
def do_draw(self, data):
state = []
start = data.draw(self.start)
self.draw_symbol(data, start, state)
return "".join(state)
def rule_label(self, name):
try:
return self.__rule_labels[name]
except KeyError:
return self.__rule_labels.setdefault(
name, calc_label_from_name(f"LARK:{name}")
)
def draw_symbol(self, data, symbol, draw_state):
if isinstance(symbol, Terminal):
try:
strategy = self.terminal_strategies[symbol.name]
except KeyError:
raise InvalidArgument(
"Undefined terminal %r. Generation does not currently support "
"use of %%declare unless you pass `explicit`, a dict of "
'names-to-strategies, such as `{%r: st.just("")}`'
% (symbol.name, symbol.name)
) from None
draw_state.append(data.draw(strategy))
else:
assert isinstance(symbol, NonTerminal)
data.start_example(self.rule_label(symbol.name))
expansion = data.draw(self.nonterminal_strategies[symbol.name])
for e in expansion:
self.draw_symbol(data, e, draw_state)
self.gen_ignore(data, draw_state)
data.stop_example()
def gen_ignore(self, data, draw_state):
if self.ignored_symbols and data.draw_boolean(1 / 4):
emit = data.draw(st.sampled_from(self.ignored_symbols))
self.draw_symbol(data, emit, draw_state)
def calc_has_reusable_values(self, recur):
return True
def check_explicit(name):
def inner(value):
check_type(str, value, "value drawn from " + name)
return value
return inner
@cacheable
@defines_strategy(force_reusable_values=True)
def from_lark(
grammar: lark.lark.Lark,
*,
start: Optional[str] = None,
explicit: Optional[Dict[str, st.SearchStrategy[str]]] = None,
) -> st.SearchStrategy[str]:
"""A strategy for strings accepted by the given context-free grammar.
``grammar`` must be a ``Lark`` object, which wraps an EBNF specification.
The Lark EBNF grammar reference can be found
`here <https://lark-parser.readthedocs.io/en/latest/grammar.html>`_.
``from_lark`` will automatically generate strings matching the
nonterminal ``start`` symbol in the grammar, which was supplied as an
argument to the Lark class. To generate strings matching a different
symbol, including terminals, you can override this by passing the
``start`` argument to ``from_lark``. Note that Lark may remove unreachable
productions when the grammar is compiled, so you should probably pass the
same value for ``start`` to both.
Currently ``from_lark`` does not support grammars that need custom lexing.
Any lexers will be ignored, and any undefined terminals from the use of
``%declare`` will result in generation errors. To define strategies for
such terminals, pass a dictionary mapping their name to a corresponding
strategy as the ``explicit`` argument.
The :pypi:`hypothesmith` project includes a strategy for Python source,
based on a grammar and careful post-processing.
"""
check_type(lark.lark.Lark, grammar, "grammar")
if explicit is None:
explicit = {}
else:
check_type(dict, explicit, "explicit")
explicit = {
k: v.map(check_explicit(f"explicit[{k!r}]={v!r}"))
for k, v in explicit.items()
}
return LarkStrategy(grammar, start, explicit)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from hypothesis.extra.pandas.impl import (
column,
columns,
data_frames,
indexes,
range_indexes,
series,
)
__all__ = ["indexes", "range_indexes", "series", "column", "columns", "data_frames"]

View File

@@ -0,0 +1,756 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from collections import OrderedDict, abc
from copy import copy
from datetime import datetime, timedelta
from typing import Any, List, Optional, Sequence, Set, Union
import attr
import numpy as np
import pandas
from hypothesis import strategies as st
from hypothesis._settings import note_deprecation
from hypothesis.control import reject
from hypothesis.errors import InvalidArgument
from hypothesis.extra import numpy as npst
from hypothesis.internal.conjecture import utils as cu
from hypothesis.internal.coverage import check, check_function
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.internal.validation import (
check_type,
check_valid_interval,
check_valid_size,
try_convert,
)
from hypothesis.strategies._internal.strategies import Ex, check_strategy
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
try:
from pandas.core.arrays.integer import IntegerDtype
except ImportError:
IntegerDtype = ()
def dtype_for_elements_strategy(s):
return st.shared(
s.map(lambda x: pandas.Series([x]).dtype),
key=("hypothesis.extra.pandas.dtype_for_elements_strategy", s),
)
def infer_dtype_if_necessary(dtype, values, elements, draw):
if dtype is None and not values:
return draw(dtype_for_elements_strategy(elements))
return dtype
@check_function
def elements_and_dtype(elements, dtype, source=None):
if source is None:
prefix = ""
else:
prefix = f"{source}."
if elements is not None:
check_strategy(elements, f"{prefix}elements")
else:
with check("dtype is not None"):
if dtype is None:
raise InvalidArgument(
f"At least one of {prefix}elements or {prefix}dtype must be provided."
)
with check("isinstance(dtype, CategoricalDtype)"):
if pandas.api.types.CategoricalDtype.is_dtype(dtype):
raise InvalidArgument(
f"{prefix}dtype is categorical, which is currently unsupported"
)
if isinstance(dtype, type) and issubclass(dtype, IntegerDtype):
raise InvalidArgument(
f"Passed {dtype=} is a dtype class, please pass in an instance of this class."
"Otherwise it would be treated as dtype=object"
)
if isinstance(dtype, type) and np.dtype(dtype).kind == "O" and dtype is not object:
err_msg = f"Passed {dtype=} is not a valid Pandas dtype."
if issubclass(dtype, datetime):
err_msg += ' To generate valid datetimes, pass `dtype="datetime64[ns]"`'
raise InvalidArgument(err_msg)
elif issubclass(dtype, timedelta):
err_msg += ' To generate valid timedeltas, pass `dtype="timedelta64[ns]"`'
raise InvalidArgument(err_msg)
note_deprecation(
f"{err_msg} We'll treat it as "
"dtype=object for now, but this will be an error in a future version.",
since="2021-12-31",
has_codemod=False,
stacklevel=1,
)
if isinstance(dtype, st.SearchStrategy):
raise InvalidArgument(
f"Passed {dtype=} is a strategy, but we require a concrete dtype "
"here. See https://stackoverflow.com/q/74355937 for workaround patterns."
)
_get_subclasses = getattr(IntegerDtype, "__subclasses__", list)
dtype = {t.name: t() for t in _get_subclasses()}.get(dtype, dtype)
if isinstance(dtype, IntegerDtype):
is_na_dtype = True
dtype = np.dtype(dtype.name.lower())
elif dtype is not None:
is_na_dtype = False
dtype = try_convert(np.dtype, dtype, "dtype")
else:
is_na_dtype = False
if elements is None:
elements = npst.from_dtype(dtype)
if is_na_dtype:
elements = st.none() | elements
elif dtype is not None:
def convert_element(value):
if is_na_dtype and value is None:
return None
name = f"draw({prefix}elements)"
try:
return np.array([value], dtype=dtype)[0]
except (TypeError, ValueError):
raise InvalidArgument(
"Cannot convert %s=%r of type %s to dtype %s"
% (name, value, type(value).__name__, dtype.str)
) from None
elements = elements.map(convert_element)
assert elements is not None
return elements, dtype
class ValueIndexStrategy(st.SearchStrategy):
def __init__(self, elements, dtype, min_size, max_size, unique, name):
super().__init__()
self.elements = elements
self.dtype = dtype
self.min_size = min_size
self.max_size = max_size
self.unique = unique
self.name = name
def do_draw(self, data):
result = []
seen = set()
iterator = cu.many(
data,
min_size=self.min_size,
max_size=self.max_size,
average_size=(self.min_size + self.max_size) / 2,
)
while iterator.more():
elt = data.draw(self.elements)
if self.unique:
if elt in seen:
iterator.reject()
continue
seen.add(elt)
result.append(elt)
dtype = infer_dtype_if_necessary(
dtype=self.dtype, values=result, elements=self.elements, draw=data.draw
)
return pandas.Index(
result, dtype=dtype, tupleize_cols=False, name=data.draw(self.name)
)
DEFAULT_MAX_SIZE = 10
@cacheable
@defines_strategy()
def range_indexes(
min_size: int = 0,
max_size: Optional[int] = None,
name: st.SearchStrategy[Optional[str]] = st.none(),
) -> st.SearchStrategy[pandas.RangeIndex]:
"""Provides a strategy which generates an :class:`~pandas.Index` whose
values are 0, 1, ..., n for some n.
Arguments:
* min_size is the smallest number of elements the index can have.
* max_size is the largest number of elements the index can have. If None
it will default to some suitable value based on min_size.
* name is the name of the index. If st.none(), the index will have no name.
"""
check_valid_size(min_size, "min_size")
check_valid_size(max_size, "max_size")
if max_size is None:
max_size = min([min_size + DEFAULT_MAX_SIZE, 2**63 - 1])
check_valid_interval(min_size, max_size, "min_size", "max_size")
check_strategy(name)
return st.builds(pandas.RangeIndex, st.integers(min_size, max_size), name=name)
@cacheable
@defines_strategy()
def indexes(
*,
elements: Optional[st.SearchStrategy[Ex]] = None,
dtype: Any = None,
min_size: int = 0,
max_size: Optional[int] = None,
unique: bool = True,
name: st.SearchStrategy[Optional[str]] = st.none(),
) -> st.SearchStrategy[pandas.Index]:
"""Provides a strategy for producing a :class:`pandas.Index`.
Arguments:
* elements is a strategy which will be used to generate the individual
values of the index. If None, it will be inferred from the dtype. Note:
even if the elements strategy produces tuples, the generated value
will not be a MultiIndex, but instead be a normal index whose elements
are tuples.
* dtype is the dtype of the resulting index. If None, it will be inferred
from the elements strategy. At least one of dtype or elements must be
provided.
* min_size is the minimum number of elements in the index.
* max_size is the maximum number of elements in the index. If None then it
will default to a suitable small size. If you want larger indexes you
should pass a max_size explicitly.
* unique specifies whether all of the elements in the resulting index
should be distinct.
* name is a strategy for strings or ``None``, which will be passed to
the :class:`pandas.Index` constructor.
"""
check_valid_size(min_size, "min_size")
check_valid_size(max_size, "max_size")
check_valid_interval(min_size, max_size, "min_size", "max_size")
check_type(bool, unique, "unique")
elements, dtype = elements_and_dtype(elements, dtype)
if max_size is None:
max_size = min_size + DEFAULT_MAX_SIZE
return ValueIndexStrategy(elements, dtype, min_size, max_size, unique, name)
@defines_strategy()
def series(
*,
elements: Optional[st.SearchStrategy[Ex]] = None,
dtype: Any = None,
index: Optional[st.SearchStrategy[Union[Sequence, pandas.Index]]] = None,
fill: Optional[st.SearchStrategy[Ex]] = None,
unique: bool = False,
name: st.SearchStrategy[Optional[str]] = st.none(),
) -> st.SearchStrategy[pandas.Series]:
"""Provides a strategy for producing a :class:`pandas.Series`.
Arguments:
* elements: a strategy that will be used to generate the individual
values in the series. If None, we will attempt to infer a suitable
default from the dtype.
* dtype: the dtype of the resulting series and may be any value
that can be passed to :class:`numpy.dtype`. If None, will use
pandas's standard behaviour to infer it from the type of the elements
values. Note that if the type of values that comes out of your
elements strategy varies, then so will the resulting dtype of the
series.
* index: If not None, a strategy for generating indexes for the
resulting Series. This can generate either :class:`pandas.Index`
objects or any sequence of values (which will be passed to the
Index constructor).
You will probably find it most convenient to use the
:func:`~hypothesis.extra.pandas.indexes` or
:func:`~hypothesis.extra.pandas.range_indexes` function to produce
values for this argument.
* name: is a strategy for strings or ``None``, which will be passed to
the :class:`pandas.Series` constructor.
Usage:
.. code-block:: pycon
>>> series(dtype=int).example()
0 -2001747478
1 1153062837
"""
if index is None:
index = range_indexes()
else:
check_strategy(index, "index")
elements, np_dtype = elements_and_dtype(elements, dtype)
index_strategy = index
# if it is converted to an object, use object for series type
if (
np_dtype is not None
and np_dtype.kind == "O"
and not isinstance(dtype, IntegerDtype)
):
dtype = np_dtype
@st.composite
def result(draw):
index = draw(index_strategy)
if len(index) > 0:
if dtype is not None:
result_data = draw(
npst.arrays(
dtype=object,
elements=elements,
shape=len(index),
fill=fill,
unique=unique,
)
).tolist()
else:
result_data = list(
draw(
npst.arrays(
dtype=object,
elements=elements,
shape=len(index),
fill=fill,
unique=unique,
)
).tolist()
)
return pandas.Series(result_data, index=index, dtype=dtype, name=draw(name))
else:
return pandas.Series(
(),
index=index,
dtype=dtype
if dtype is not None
else draw(dtype_for_elements_strategy(elements)),
name=draw(name),
)
return result()
@attr.s(slots=True)
class column:
"""Data object for describing a column in a DataFrame.
Arguments:
* name: the column name, or None to default to the column position. Must
be hashable, but can otherwise be any value supported as a pandas column
name.
* elements: the strategy for generating values in this column, or None
to infer it from the dtype.
* dtype: the dtype of the column, or None to infer it from the element
strategy. At least one of dtype or elements must be provided.
* fill: A default value for elements of the column. See
:func:`~hypothesis.extra.numpy.arrays` for a full explanation.
* unique: If all values in this column should be distinct.
"""
name = attr.ib(default=None)
elements = attr.ib(default=None)
dtype = attr.ib(default=None, repr=get_pretty_function_description)
fill = attr.ib(default=None)
unique = attr.ib(default=False)
def columns(
names_or_number: Union[int, Sequence[str]],
*,
dtype: Any = None,
elements: Optional[st.SearchStrategy[Ex]] = None,
fill: Optional[st.SearchStrategy[Ex]] = None,
unique: bool = False,
) -> List[column]:
"""A convenience function for producing a list of :class:`column` objects
of the same general shape.
The names_or_number argument is either a sequence of values, the
elements of which will be used as the name for individual column
objects, or a number, in which case that many unnamed columns will
be created. All other arguments are passed through verbatim to
create the columns.
"""
if isinstance(names_or_number, (int, float)):
names: List[Union[int, str, None]] = [None] * names_or_number
else:
names = list(names_or_number)
return [
column(name=n, dtype=dtype, elements=elements, fill=fill, unique=unique)
for n in names
]
@defines_strategy()
def data_frames(
columns: Optional[Sequence[column]] = None,
*,
rows: Optional[st.SearchStrategy[Union[dict, Sequence[Any]]]] = None,
index: Optional[st.SearchStrategy[Ex]] = None,
) -> st.SearchStrategy[pandas.DataFrame]:
"""Provides a strategy for producing a :class:`pandas.DataFrame`.
Arguments:
* columns: An iterable of :class:`column` objects describing the shape
of the generated DataFrame.
* rows: A strategy for generating a row object. Should generate
either dicts mapping column names to values or a sequence mapping
column position to the value in that position (note that unlike the
:class:`pandas.DataFrame` constructor, single values are not allowed
here. Passing e.g. an integer is an error, even if there is only one
column).
At least one of rows and columns must be provided. If both are
provided then the generated rows will be validated against the
columns and an error will be raised if they don't match.
Caveats on using rows:
* In general you should prefer using columns to rows, and only use
rows if the columns interface is insufficiently flexible to
describe what you need - you will get better performance and
example quality that way.
* If you provide rows and not columns, then the shape and dtype of
the resulting DataFrame may vary. e.g. if you have a mix of int
and float in the values for one column in your row entries, the
column will sometimes have an integral dtype and sometimes a float.
* index: If not None, a strategy for generating indexes for the
resulting DataFrame. This can generate either :class:`pandas.Index`
objects or any sequence of values (which will be passed to the
Index constructor).
You will probably find it most convenient to use the
:func:`~hypothesis.extra.pandas.indexes` or
:func:`~hypothesis.extra.pandas.range_indexes` function to produce
values for this argument.
Usage:
The expected usage pattern is that you use :class:`column` and
:func:`columns` to specify a fixed shape of the DataFrame you want as
follows. For example the following gives a two column data frame:
.. code-block:: pycon
>>> from hypothesis.extra.pandas import column, data_frames
>>> data_frames([
... column('A', dtype=int), column('B', dtype=float)]).example()
A B
0 2021915903 1.793898e+232
1 1146643993 inf
2 -2096165693 1.000000e+07
If you want the values in different columns to interact in some way you
can use the rows argument. For example the following gives a two column
DataFrame where the value in the first column is always at most the value
in the second:
.. code-block:: pycon
>>> from hypothesis.extra.pandas import column, data_frames
>>> import hypothesis.strategies as st
>>> data_frames(
... rows=st.tuples(st.floats(allow_nan=False),
... st.floats(allow_nan=False)).map(sorted)
... ).example()
0 1
0 -3.402823e+38 9.007199e+15
1 -1.562796e-298 5.000000e-01
You can also combine the two:
.. code-block:: pycon
>>> from hypothesis.extra.pandas import columns, data_frames
>>> import hypothesis.strategies as st
>>> data_frames(
... columns=columns(["lo", "hi"], dtype=float),
... rows=st.tuples(st.floats(allow_nan=False),
... st.floats(allow_nan=False)).map(sorted)
... ).example()
lo hi
0 9.314723e-49 4.353037e+45
1 -9.999900e-01 1.000000e+07
2 -2.152861e+134 -1.069317e-73
(Note that the column dtype must still be specified and will not be
inferred from the rows. This restriction may be lifted in future).
Combining rows and columns has the following behaviour:
* The column names and dtypes will be used.
* If the column is required to be unique, this will be enforced.
* Any values missing from the generated rows will be provided using the
column's fill.
* Any values in the row not present in the column specification (if
dicts are passed, if there are keys with no corresponding column name,
if sequences are passed if there are too many items) will result in
InvalidArgument being raised.
"""
if index is None:
index = range_indexes()
else:
check_strategy(index, "index")
index_strategy = index
if columns is None:
if rows is None:
raise InvalidArgument("At least one of rows and columns must be provided")
else:
@st.composite
def rows_only(draw):
index = draw(index_strategy)
@check_function
def row():
result = draw(rows)
check_type(abc.Iterable, result, "draw(row)")
return result
if len(index) > 0:
return pandas.DataFrame([row() for _ in index], index=index)
else:
# If we haven't drawn any rows we need to draw one row and
# then discard it so that we get a consistent shape for the
# DataFrame.
base = pandas.DataFrame([row()])
return base.drop(0)
return rows_only()
assert columns is not None
cols = try_convert(tuple, columns, "columns")
rewritten_columns = []
column_names: Set[str] = set()
for i, c in enumerate(cols):
check_type(column, c, f"columns[{i}]")
c = copy(c)
if c.name is None:
label = f"columns[{i}]"
c.name = i
else:
label = c.name
try:
hash(c.name)
except TypeError:
raise InvalidArgument(
f"Column names must be hashable, but columns[{i}].name was "
f"{c.name!r} of type {type(c.name).__name__}, which cannot be hashed."
) from None
if c.name in column_names:
raise InvalidArgument(f"duplicate definition of column name {c.name!r}")
column_names.add(c.name)
c.elements, _ = elements_and_dtype(c.elements, c.dtype, label)
if c.dtype is None and rows is not None:
raise InvalidArgument(
"Must specify a dtype for all columns when combining rows with columns."
)
c.fill = npst.fill_for(
fill=c.fill, elements=c.elements, unique=c.unique, name=label
)
rewritten_columns.append(c)
if rows is None:
@st.composite
def just_draw_columns(draw):
index = draw(index_strategy)
local_index_strategy = st.just(index)
data = OrderedDict((c.name, None) for c in rewritten_columns)
# Depending on how the columns are going to be generated we group
# them differently to get better shrinking. For columns with fill
# enabled, the elements can be shrunk independently of the size,
# so we can just shrink by shrinking the index then shrinking the
# length and are generally much more free to move data around.
# For columns with no filling the problem is harder, and drawing
# them like that would result in rows being very far apart from
# each other in the underlying data stream, which gets in the way
# of shrinking. So what we do is reorder and draw those columns
# row wise, so that the values of each row are next to each other.
# This makes life easier for the shrinker when deleting blocks of
# data.
columns_without_fill = [c for c in rewritten_columns if c.fill.is_empty]
if columns_without_fill:
for c in columns_without_fill:
data[c.name] = pandas.Series(
np.zeros(shape=len(index), dtype=object),
index=index,
dtype=c.dtype,
)
seen = {c.name: set() for c in columns_without_fill if c.unique}
for i in range(len(index)):
for c in columns_without_fill:
if c.unique:
for _ in range(5):
value = draw(c.elements)
if value not in seen[c.name]:
seen[c.name].add(value)
break
else:
reject()
else:
value = draw(c.elements)
try:
data[c.name][i] = value
except ValueError as err: # pragma: no cover
# This just works in Pandas 1.4 and later, but gives
# a confusing error on previous versions.
if c.dtype is None and not isinstance(
value, (float, int, str, bool, datetime, timedelta)
):
raise ValueError(
f"Failed to add {value=} to column "
f"{c.name} with dtype=None. Maybe passing "
"dtype=object would help?"
) from err
# Unclear how this could happen, but users find a way...
raise
for c in rewritten_columns:
if not c.fill.is_empty:
data[c.name] = draw(
series(
index=local_index_strategy,
dtype=c.dtype,
elements=c.elements,
fill=c.fill,
unique=c.unique,
)
)
return pandas.DataFrame(data, index=index)
return just_draw_columns()
else:
@st.composite
def assign_rows(draw):
index = draw(index_strategy)
result = pandas.DataFrame(
OrderedDict(
(
c.name,
pandas.Series(
np.zeros(dtype=c.dtype, shape=len(index)), dtype=c.dtype
),
)
for c in rewritten_columns
),
index=index,
)
fills = {}
any_unique = any(c.unique for c in rewritten_columns)
if any_unique:
all_seen = [set() if c.unique else None for c in rewritten_columns]
while all_seen[-1] is None:
all_seen.pop()
for row_index in range(len(index)):
for _ in range(5):
original_row = draw(rows)
row = original_row
if isinstance(row, dict):
as_list = [None] * len(rewritten_columns)
for i, c in enumerate(rewritten_columns):
try:
as_list[i] = row[c.name]
except KeyError:
try:
as_list[i] = fills[i]
except KeyError:
if c.fill.is_empty:
raise InvalidArgument(
f"Empty fill strategy in {c!r} cannot "
f"complete row {original_row!r}"
) from None
fills[i] = draw(c.fill)
as_list[i] = fills[i]
for k in row:
if k not in column_names:
raise InvalidArgument(
"Row %r contains column %r not in columns %r)"
% (row, k, [c.name for c in rewritten_columns])
)
row = as_list
if any_unique:
has_duplicate = False
for seen, value in zip(all_seen, row):
if seen is None:
continue
if value in seen:
has_duplicate = True
break
seen.add(value)
if has_duplicate:
continue
row = list(try_convert(tuple, row, "draw(rows)"))
if len(row) > len(rewritten_columns):
raise InvalidArgument(
f"Row {original_row!r} contains too many entries. Has "
f"{len(row)} but expected at most {len(rewritten_columns)}"
)
while len(row) < len(rewritten_columns):
c = rewritten_columns[len(row)]
if c.fill.is_empty:
raise InvalidArgument(
f"Empty fill strategy in {c!r} cannot "
f"complete row {original_row!r}"
)
row.append(draw(c.fill))
result.iloc[row_index] = row
break
else:
reject()
return result
return assign_rows()

View File

@@ -0,0 +1,19 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
Stub for users who manually load our pytest plugin.
The plugin implementation is now located in a top-level module outside the main
hypothesis tree, so that Pytest can load the plugin without thereby triggering
the import of Hypothesis itself (and thus loading our own plugins).
"""
from _hypothesis_pytestplugin import * # noqa

View File

@@ -0,0 +1,54 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
----------------
hypothesis[pytz]
----------------
This module provides :pypi:`pytz` timezones.
You can use this strategy to make
:py:func:`hypothesis.strategies.datetimes` and
:py:func:`hypothesis.strategies.times` produce timezone-aware values.
"""
import datetime as dt
import pytz
from pytz.tzfile import StaticTzInfo # type: ignore # considered private by typeshed
from hypothesis import strategies as st
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
__all__ = ["timezones"]
@cacheable
@defines_strategy()
def timezones() -> st.SearchStrategy[dt.tzinfo]:
"""Any timezone in the Olsen database, as a pytz tzinfo object.
This strategy minimises to UTC, or the smallest possible fixed
offset, and is designed for use with
:py:func:`hypothesis.strategies.datetimes`.
"""
all_timezones = [pytz.timezone(tz) for tz in pytz.all_timezones]
# Some timezones have always had a constant offset from UTC. This makes
# them simpler than timezones with daylight savings, and the smaller the
# absolute offset the simpler they are. Of course, UTC is even simpler!
static: list = [pytz.UTC]
static += sorted(
(t for t in all_timezones if isinstance(t, StaticTzInfo)),
key=lambda tz: abs(tz.utcoffset(dt.datetime(2000, 1, 1))),
)
# Timezones which have changed UTC offset; best ordered by name.
dynamic = [tz for tz in all_timezones if tz not in static]
return st.sampled_from(static + dynamic)

View File

@@ -0,0 +1,78 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from contextlib import contextmanager
from datetime import timedelta
from typing import Iterable
from redis import Redis
from hypothesis.database import ExampleDatabase
from hypothesis.internal.validation import check_type
class RedisExampleDatabase(ExampleDatabase):
"""Store Hypothesis examples as sets in the given :class:`~redis.Redis` datastore.
This is particularly useful for shared databases, as per the recipe
for a :class:`~hypothesis.database.MultiplexedDatabase`.
.. note::
If a test has not been run for ``expire_after``, those examples will be allowed
to expire. The default time-to-live persists examples between weekly runs.
"""
def __init__(
self,
redis: Redis,
*,
expire_after: timedelta = timedelta(days=8),
key_prefix: bytes = b"hypothesis-example:",
):
check_type(Redis, redis, "redis")
check_type(timedelta, expire_after, "expire_after")
check_type(bytes, key_prefix, "key_prefix")
self.redis = redis
self._expire_after = expire_after
self._prefix = key_prefix
def __repr__(self) -> str:
return (
f"RedisExampleDatabase({self.redis!r}, expire_after={self._expire_after!r})"
)
@contextmanager
def _pipeline(self, *reset_expire_keys, transaction=False, auto_execute=True):
# Context manager to batch updates and expiry reset, reducing TCP roundtrips
pipe = self.redis.pipeline(transaction=transaction)
yield pipe
for key in reset_expire_keys:
pipe.expire(self._prefix + key, self._expire_after)
if auto_execute:
pipe.execute()
def fetch(self, key: bytes) -> Iterable[bytes]:
with self._pipeline(key, auto_execute=False) as pipe:
pipe.smembers(self._prefix + key)
yield from pipe.execute()[0]
def save(self, key: bytes, value: bytes) -> None:
with self._pipeline(key) as pipe:
pipe.sadd(self._prefix + key, value)
def delete(self, key: bytes, value: bytes) -> None:
with self._pipeline(key) as pipe:
pipe.srem(self._prefix + key, value)
def move(self, src: bytes, dest: bytes, value: bytes) -> None:
with self._pipeline(src, dest) as pipe:
pipe.srem(self._prefix + src, value)
pipe.sadd(self._prefix + dest, value)

View File

@@ -0,0 +1,9 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

View File

@@ -0,0 +1,277 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import attr
@attr.s(slots=True)
class Entry:
key = attr.ib()
value = attr.ib()
score = attr.ib()
pins = attr.ib(default=0)
@property
def sort_key(self):
if self.pins == 0:
# Unpinned entries are sorted by score.
return (0, self.score)
else:
# Pinned entries sort after unpinned ones. Beyond that, we don't
# worry about their relative order.
return (1,)
class GenericCache:
"""Generic supertype for cache implementations.
Defines a dict-like mapping with a maximum size, where as well as mapping
to a value, each key also maps to a score. When a write would cause the
dict to exceed its maximum size, it first evicts the existing key with
the smallest score, then adds the new key to the map.
A key has the following lifecycle:
1. key is written for the first time, the key is given the score
self.new_entry(key, value)
2. whenever an existing key is read or written, self.on_access(key, value,
score) is called. This returns a new score for the key.
3. When a key is evicted, self.on_evict(key, value, score) is called.
The cache will be in a valid state in all of these cases.
Implementations are expected to implement new_entry and optionally
on_access and on_evict to implement a specific scoring strategy.
"""
__slots__ = ("keys_to_indices", "data", "max_size", "__pinned_entry_count")
def __init__(self, max_size):
self.max_size = max_size
# Implementation: We store a binary heap of Entry objects in self.data,
# with the heap property requiring that a parent's score is <= that of
# its children. keys_to_index then maps keys to their index in the
# heap. We keep these two in sync automatically - the heap is never
# reordered without updating the index.
self.keys_to_indices = {}
self.data = []
self.__pinned_entry_count = 0
def __len__(self):
assert len(self.keys_to_indices) == len(self.data)
return len(self.data)
def __contains__(self, key):
return key in self.keys_to_indices
def __getitem__(self, key):
i = self.keys_to_indices[key]
result = self.data[i]
self.on_access(result.key, result.value, result.score)
self.__balance(i)
return result.value
def __setitem__(self, key, value):
if self.max_size == 0:
return
evicted = None
try:
i = self.keys_to_indices[key]
except KeyError:
if self.max_size == self.__pinned_entry_count:
raise ValueError(
"Cannot increase size of cache where all keys have been pinned."
) from None
entry = Entry(key, value, self.new_entry(key, value))
if len(self.data) >= self.max_size:
evicted = self.data[0]
assert evicted.pins == 0
del self.keys_to_indices[evicted.key]
i = 0
self.data[0] = entry
else:
i = len(self.data)
self.data.append(entry)
self.keys_to_indices[key] = i
else:
entry = self.data[i]
assert entry.key == key
entry.value = value
entry.score = self.on_access(entry.key, entry.value, entry.score)
self.__balance(i)
if evicted is not None:
if self.data[0] is not entry:
assert evicted.score <= self.data[0].score
self.on_evict(evicted.key, evicted.value, evicted.score)
def __iter__(self):
return iter(self.keys_to_indices)
def pin(self, key):
"""Mark ``key`` as pinned. That is, it may not be evicted until
``unpin(key)`` has been called. The same key may be pinned multiple
times and will not be unpinned until the same number of calls to
unpin have been made."""
i = self.keys_to_indices[key]
entry = self.data[i]
entry.pins += 1
if entry.pins == 1:
self.__pinned_entry_count += 1
assert self.__pinned_entry_count <= self.max_size
self.__balance(i)
def unpin(self, key):
"""Undo one previous call to ``pin(key)``. Once all calls are
undone this key may be evicted as normal."""
i = self.keys_to_indices[key]
entry = self.data[i]
if entry.pins == 0:
raise ValueError(f"Key {key!r} has not been pinned")
entry.pins -= 1
if entry.pins == 0:
self.__pinned_entry_count -= 1
self.__balance(i)
def is_pinned(self, key):
"""Returns True if the key is currently pinned."""
i = self.keys_to_indices[key]
return self.data[i].pins > 0
def clear(self):
"""Remove all keys, clearing their pinned status."""
del self.data[:]
self.keys_to_indices.clear()
self.__pinned_entry_count = 0
def __repr__(self):
return "{" + ", ".join(f"{e.key!r}: {e.value!r}" for e in self.data) + "}"
def new_entry(self, key, value):
"""Called when a key is written that does not currently appear in the
map.
Returns the score to associate with the key.
"""
raise NotImplementedError
def on_access(self, key, value, score):
"""Called every time a key that is already in the map is read or
written.
Returns the new score for the key.
"""
return score
def on_evict(self, key, value, score):
"""Called after a key has been evicted, with the score it had had at
the point of eviction."""
def check_valid(self):
"""Debugging method for use in tests.
Asserts that all of the cache's invariants hold. When everything
is working correctly this should be an expensive no-op.
"""
for i, e in enumerate(self.data):
assert self.keys_to_indices[e.key] == i
for j in [i * 2 + 1, i * 2 + 2]:
if j < len(self.data):
assert e.score <= self.data[j].score, self.data
def __swap(self, i, j):
assert i < j
assert self.data[j].sort_key < self.data[i].sort_key
self.data[i], self.data[j] = self.data[j], self.data[i]
self.keys_to_indices[self.data[i].key] = i
self.keys_to_indices[self.data[j].key] = j
def __balance(self, i):
"""When we have made a modification to the heap such that means that
the heap property has been violated locally around i but previously
held for all other indexes (and no other values have been modified),
this fixes the heap so that the heap property holds everywhere."""
while i > 0:
parent = (i - 1) // 2
if self.__out_of_order(parent, i):
self.__swap(parent, i)
i = parent
else:
# This branch is never taken on versions of Python where dicts
# preserve their insertion order (pypy or cpython >= 3.7)
break # pragma: no cover
while True:
children = [j for j in (2 * i + 1, 2 * i + 2) if j < len(self.data)]
if len(children) == 2:
children.sort(key=lambda j: self.data[j].score)
for j in children:
if self.__out_of_order(i, j):
self.__swap(i, j)
i = j
break
else:
break
def __out_of_order(self, i, j):
"""Returns True if the indices i, j are in the wrong order.
i must be the parent of j.
"""
assert i == (j - 1) // 2
return self.data[j].sort_key < self.data[i].sort_key
class LRUReusedCache(GenericCache):
"""The only concrete implementation of GenericCache we use outside of tests
currently.
Adopts a modified least-frequently used eviction policy: It evicts the key
that has been used least recently, but it will always preferentially evict
keys that have only ever been accessed once. Among keys that have been
accessed more than once, it ignores the number of accesses.
This retains most of the benefits of an LRU cache, but adds an element of
scan-resistance to the process: If we end up scanning through a large
number of keys without reusing them, this does not evict the existing
entries in preference for the new ones.
"""
__slots__ = ("__tick",)
def __init__(self, max_size):
super().__init__(max_size)
self.__tick = 0
def tick(self):
self.__tick += 1
return self.__tick
def new_entry(self, key, value):
return [1, self.tick()]
def on_access(self, key, value, score):
score[0] = 2
score[1] = self.tick()
return score
def pin(self, key):
try:
super().pin(key)
except KeyError:
# The whole point of an LRU cache is that it might drop things for you
assert key not in self.keys_to_indices
def unpin(self, key):
try:
super().unpin(key)
except KeyError:
assert key not in self.keys_to_indices

View File

@@ -0,0 +1,62 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from math import fabs, inf, isinf, isnan, nan, sqrt
from sys import float_info
def cathetus(h, a):
"""Given the lengths of the hypotenuse and a side of a right triangle,
return the length of the other side.
A companion to the C99 hypot() function. Some care is needed to avoid
underflow in the case of small arguments, and overflow in the case of
large arguments as would occur for the naive implementation as
sqrt(h*h - a*a). The behaviour with respect the non-finite arguments
(NaNs and infinities) is designed to be as consistent as possible with
the C99 hypot() specifications.
This function relies on the system ``sqrt`` function and so, like it,
may be inaccurate up to a relative error of (around) floating-point
epsilon.
Based on the C99 implementation https://github.com/jjgreen/cathetus
"""
if isnan(h):
return nan
if isinf(h):
if isinf(a):
return nan
else:
# Deliberately includes the case when isnan(a), because the
# C99 standard mandates that hypot(inf, nan) == inf
return inf
h = fabs(h)
a = fabs(a)
if h < a:
return nan
# Thanks to floating-point precision issues when performing multiple
# operations on extremely large or small values, we may rarely calculate
# a side length that is longer than the hypotenuse. This is clearly an
# error, so we clip to the hypotenuse as the best available estimate.
if h > sqrt(float_info.max):
if h > float_info.max / 2:
b = sqrt(h - a) * sqrt(h / 2 + a / 2) * sqrt(2)
else:
b = sqrt(h - a) * sqrt(h + a)
elif h < sqrt(float_info.min):
b = sqrt(h - a) * sqrt(h + a)
else:
b = sqrt((h - a) * (h + a))
return min(b, h)

View File

@@ -0,0 +1,282 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import codecs
import gzip
import json
import os
import sys
import tempfile
import unicodedata
from functools import lru_cache
from typing import Dict, Tuple
from hypothesis.configuration import storage_directory
from hypothesis.errors import InvalidArgument
from hypothesis.internal.intervalsets import IntervalSet
intervals = Tuple[Tuple[int, int], ...]
cache_type = Dict[Tuple[Tuple[str, ...], int, int, intervals], IntervalSet]
def charmap_file(fname="charmap"):
return storage_directory(
"unicode_data", unicodedata.unidata_version, f"{fname}.json.gz"
)
_charmap = None
def charmap():
"""Return a dict that maps a Unicode category, to a tuple of 2-tuples
covering the codepoint intervals for characters in that category.
>>> charmap()['Co']
((57344, 63743), (983040, 1048573), (1048576, 1114109))
"""
global _charmap
# Best-effort caching in the face of missing files and/or unwritable
# filesystems is fairly simple: check if loaded, else try loading,
# else calculate and try writing the cache.
if _charmap is None:
f = charmap_file()
try:
with gzip.GzipFile(f, "rb") as i:
tmp_charmap = dict(json.load(i))
except Exception:
# This loop is reduced to using only local variables for performance;
# indexing and updating containers is a ~3x slowdown. This doesn't fix
# https://github.com/HypothesisWorks/hypothesis/issues/2108 but it helps.
category = unicodedata.category # Local variable -> ~20% speedup!
tmp_charmap = {}
last_cat = category(chr(0))
last_start = 0
for i in range(1, sys.maxunicode + 1):
cat = category(chr(i))
if cat != last_cat:
tmp_charmap.setdefault(last_cat, []).append([last_start, i - 1])
last_cat, last_start = cat, i
tmp_charmap.setdefault(last_cat, []).append([last_start, sys.maxunicode])
try:
# Write the Unicode table atomically
tmpdir = storage_directory("tmp")
tmpdir.mkdir(exist_ok=True, parents=True)
fd, tmpfile = tempfile.mkstemp(dir=tmpdir)
os.close(fd)
# Explicitly set the mtime to get reproducible output
with gzip.GzipFile(tmpfile, "wb", mtime=1) as o:
result = json.dumps(sorted(tmp_charmap.items()))
o.write(result.encode())
os.renames(tmpfile, f)
except Exception:
pass
# convert between lists and tuples
_charmap = {
k: tuple(tuple(pair) for pair in pairs) for k, pairs in tmp_charmap.items()
}
# each value is a tuple of 2-tuples (that is, tuples of length 2)
# and that both elements of that tuple are integers.
for vs in _charmap.values():
ints = list(sum(vs, ()))
assert all(isinstance(x, int) for x in ints)
assert ints == sorted(ints)
assert all(len(tup) == 2 for tup in vs)
assert _charmap is not None
return _charmap
@lru_cache(maxsize=None)
def intervals_from_codec(codec_name: str) -> IntervalSet: # pragma: no cover
"""Return an IntervalSet of characters which are part of this codec."""
assert codec_name == codecs.lookup(codec_name).name
fname = charmap_file(f"codec-{codec_name}")
try:
with gzip.GzipFile(fname) as gzf:
encodable_intervals = json.load(gzf)
except Exception:
# This loop is kinda slow, but hopefully we don't need to do it very often!
encodable_intervals = []
for i in range(sys.maxunicode + 1):
try:
chr(i).encode(codec_name)
except Exception: # usually _but not always_ UnicodeEncodeError
pass
else:
encodable_intervals.append((i, i))
res = IntervalSet(encodable_intervals)
res = res.union(res)
try:
# Write the Unicode table atomically
tmpdir = storage_directory("tmp")
tmpdir.mkdir(exist_ok=True, parents=True)
fd, tmpfile = tempfile.mkstemp(dir=tmpdir)
os.close(fd)
# Explicitly set the mtime to get reproducible output
with gzip.GzipFile(tmpfile, "wb", mtime=1) as o:
o.write(json.dumps(res.intervals).encode())
os.renames(tmpfile, fname)
except Exception:
pass
return res
_categories = None
def categories():
"""Return a tuple of Unicode categories in a normalised order.
>>> categories() # doctest: +ELLIPSIS
('Zl', 'Zp', 'Co', 'Me', 'Pc', ..., 'Cc', 'Cs')
"""
global _categories
if _categories is None:
cm = charmap()
_categories = sorted(cm.keys(), key=lambda c: len(cm[c]))
_categories.remove("Cc") # Other, Control
_categories.remove("Cs") # Other, Surrogate
_categories.append("Cc")
_categories.append("Cs")
return tuple(_categories)
def as_general_categories(cats, name="cats"):
"""Return a tuple of Unicode categories in a normalised order.
This function expands one-letter designations of a major class to include
all subclasses:
>>> as_general_categories(['N'])
('Nd', 'Nl', 'No')
See section 4.5 of the Unicode standard for more on classes:
https://www.unicode.org/versions/Unicode10.0.0/ch04.pdf
If the collection ``cats`` includes any elements that do not represent a
major class or a class with subclass, a deprecation warning is raised.
"""
if cats is None:
return None
major_classes = ("L", "M", "N", "P", "S", "Z", "C")
cs = categories()
out = set(cats)
for c in cats:
if c in major_classes:
out.discard(c)
out.update(x for x in cs if x.startswith(c))
elif c not in cs:
raise InvalidArgument(
f"In {name}={cats!r}, {c!r} is not a valid Unicode category."
)
return tuple(c for c in cs if c in out)
category_index_cache = {(): ()}
def _category_key(cats):
"""Return a normalised tuple of all Unicode categories that are in
`include`, but not in `exclude`.
If include is None then default to including all categories.
Any item in include that is not a unicode character will be excluded.
>>> _category_key(exclude=['So'], include=['Lu', 'Me', 'Cs', 'So'])
('Me', 'Lu', 'Cs')
"""
cs = categories()
if cats is None:
cats = set(cs)
return tuple(c for c in cs if c in cats)
def _query_for_key(key):
"""Return a tuple of codepoint intervals covering characters that match one
or more categories in the tuple of categories `key`.
>>> _query_for_key(categories())
((0, 1114111),)
>>> _query_for_key(('Zl', 'Zp', 'Co'))
((8232, 8233), (57344, 63743), (983040, 1048573), (1048576, 1114109))
"""
try:
return category_index_cache[key]
except KeyError:
pass
assert key
if set(key) == set(categories()):
result = IntervalSet([(0, sys.maxunicode)])
else:
result = IntervalSet(_query_for_key(key[:-1])).union(
IntervalSet(charmap()[key[-1]])
)
assert isinstance(result, IntervalSet)
category_index_cache[key] = result.intervals
return result.intervals
limited_category_index_cache: cache_type = {}
def query(
*,
categories=None,
min_codepoint=None,
max_codepoint=None,
include_characters="",
exclude_characters="",
):
"""Return a tuple of intervals covering the codepoints for all characters
that meet the criteria.
>>> query()
((0, 1114111),)
>>> query(min_codepoint=0, max_codepoint=128)
((0, 128),)
>>> query(min_codepoint=0, max_codepoint=128, categories=['Lu'])
((65, 90),)
>>> query(min_codepoint=0, max_codepoint=128, categories=['Lu'],
... include_characters='')
((65, 90), (9731, 9731))
"""
if min_codepoint is None:
min_codepoint = 0
if max_codepoint is None:
max_codepoint = sys.maxunicode
catkey = _category_key(categories)
character_intervals = IntervalSet.from_string(include_characters or "")
exclude_intervals = IntervalSet.from_string(exclude_characters or "")
qkey = (
catkey,
min_codepoint,
max_codepoint,
character_intervals.intervals,
exclude_intervals.intervals,
)
try:
return limited_category_index_cache[qkey]
except KeyError:
pass
base = _query_for_key(catkey)
result = []
for u, v in base:
if v >= min_codepoint and u <= max_codepoint:
result.append((max(u, min_codepoint), min(v, max_codepoint)))
result = (IntervalSet(result) | character_intervals) - exclude_intervals
limited_category_index_cache[qkey] = result
return result

View File

@@ -0,0 +1,235 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import codecs
import copy
import dataclasses
import inspect
import platform
import sys
import typing
from functools import partial
from typing import Any, ForwardRef, get_args
try:
BaseExceptionGroup = BaseExceptionGroup
ExceptionGroup = ExceptionGroup # pragma: no cover
except NameError:
from exceptiongroup import (
BaseExceptionGroup as BaseExceptionGroup,
ExceptionGroup as ExceptionGroup,
)
if typing.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Concatenate as Concatenate, ParamSpec as ParamSpec
else:
try:
from typing import Concatenate as Concatenate, ParamSpec as ParamSpec
except ImportError:
try:
from typing_extensions import (
Concatenate as Concatenate,
ParamSpec as ParamSpec,
)
except ImportError:
Concatenate, ParamSpec = None, None
PYPY = platform.python_implementation() == "PyPy"
GRAALPY = platform.python_implementation() == "GraalVM"
WINDOWS = platform.system() == "Windows"
def add_note(exc, note):
try:
exc.add_note(note)
except AttributeError:
if not hasattr(exc, "__notes__"):
exc.__notes__ = []
exc.__notes__.append(note)
def escape_unicode_characters(s: str) -> str:
return codecs.encode(s, "unicode_escape").decode("ascii")
def int_from_bytes(data: typing.Union[bytes, bytearray]) -> int:
return int.from_bytes(data, "big")
def int_to_bytes(i: int, size: int) -> bytes:
return i.to_bytes(size, "big")
def int_to_byte(i: int) -> bytes:
return bytes([i])
def is_typed_named_tuple(cls):
"""Return True if cls is probably a subtype of `typing.NamedTuple`.
Unfortunately types created with `class T(NamedTuple):` actually
subclass `tuple` directly rather than NamedTuple. This is annoying,
and means we just have to hope that nobody defines a different tuple
subclass with similar attributes.
"""
return (
issubclass(cls, tuple)
and hasattr(cls, "_fields")
and (hasattr(cls, "_field_types") or hasattr(cls, "__annotations__"))
)
def _hint_and_args(x):
return (x, *get_args(x))
def get_type_hints(thing):
"""Like the typing version, but tries harder and never errors.
Tries harder: if the thing to inspect is a class but typing.get_type_hints
raises an error or returns no hints, then this function will try calling it
on the __init__ method. This second step often helps with user-defined
classes on older versions of Python. The third step we take is trying
to fetch types from the __signature__ property.
They override any other ones we found earlier.
Never errors: instead of raising TypeError for uninspectable objects, or
NameError for unresolvable forward references, just return an empty dict.
"""
if isinstance(thing, partial):
from hypothesis.internal.reflection import get_signature
bound = set(get_signature(thing.func).parameters).difference(
get_signature(thing).parameters
)
return {k: v for k, v in get_type_hints(thing.func).items() if k not in bound}
kwargs = {} if sys.version_info[:2] < (3, 9) else {"include_extras": True}
try:
hints = typing.get_type_hints(thing, **kwargs)
except (AttributeError, TypeError, NameError): # pragma: no cover
hints = {}
if inspect.isclass(thing):
try:
hints.update(typing.get_type_hints(thing.__init__, **kwargs))
except (TypeError, NameError, AttributeError):
pass
try:
if hasattr(thing, "__signature__"):
# It is possible for the signature and annotations attributes to
# differ on an object due to renamed arguments.
from hypothesis.internal.reflection import get_signature
from hypothesis.strategies._internal.types import is_a_type
vkinds = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
for p in get_signature(thing).parameters.values():
if (
p.kind not in vkinds
and is_a_type(p.annotation)
and p.annotation is not p.empty
):
p_hint = p.annotation
# Defer to `get_type_hints` if signature annotation is, or
# contains, a forward reference that is otherwise resolved.
if any(
isinstance(sig_hint, ForwardRef)
and not isinstance(hint, ForwardRef)
for sig_hint, hint in zip(
_hint_and_args(p.annotation),
_hint_and_args(hints.get(p.name, Any)),
)
):
p_hint = hints[p.name]
if p.default is None:
hints[p.name] = typing.Optional[p_hint]
else:
hints[p.name] = p_hint
except (AttributeError, TypeError, NameError): # pragma: no cover
pass
return hints
# Under Python 2, math.floor and math.ceil returned floats, which cannot
# represent large integers - eg `float(2**53) == float(2**53 + 1)`.
# We therefore implement them entirely in (long) integer operations.
# We still use the same trick on Python 3, because Numpy values and other
# custom __floor__ or __ceil__ methods may convert via floats.
# See issue #1667, Numpy issue 9068.
def floor(x):
y = int(x)
if y != x and x < 0:
return y - 1
return y
def ceil(x):
y = int(x)
if y != x and x > 0:
return y + 1
return y
def bad_django_TestCase(runner):
if runner is None or "django.test" not in sys.modules:
return False
else: # pragma: no cover
if not isinstance(runner, sys.modules["django.test"].TransactionTestCase):
return False
from hypothesis.extra.django._impl import HypothesisTestCase
return not isinstance(runner, HypothesisTestCase)
# see issue #3812
if sys.version_info[:2] < (3, 12):
def dataclass_asdict(obj, *, dict_factory=dict):
"""
A vendored variant of dataclasses.asdict. Includes the bugfix for
defaultdicts (cpython/32056) for all versions. See also issues/3812.
This should be removed whenever we drop support for 3.11. We can use the
standard dataclasses.asdict after that point.
"""
if not dataclasses._is_dataclass_instance(obj): # pragma: no cover
raise TypeError("asdict() should be called on dataclass instances")
return _asdict_inner(obj, dict_factory)
else: # pragma: no cover
dataclass_asdict = dataclasses.asdict
def _asdict_inner(obj, dict_factory):
if dataclasses._is_dataclass_instance(obj):
return dict_factory(
(f.name, _asdict_inner(getattr(obj, f.name), dict_factory))
for f in dataclasses.fields(obj)
)
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj])
elif isinstance(obj, (list, tuple)):
return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
elif isinstance(obj, dict):
if hasattr(type(obj), "default_factory"):
result = type(obj)(obj.default_factory)
for k, v in obj.items():
result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, dict_factory)
return result
return type(obj)(
(_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
for k, v in obj.items()
)
else:
return copy.deepcopy(obj)

View File

@@ -0,0 +1,9 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

View File

@@ -0,0 +1,160 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from collections import defaultdict
from random import Random
from typing import Callable, Dict, Iterable, List, Optional, Sequence
from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy, pop_random
def prefix_selection_order(
prefix: Sequence[int],
) -> Callable[[int, int], Iterable[int]]:
"""Select choices starting from ``prefix```,
preferring to move left then wrapping around
to the right."""
def selection_order(depth: int, n: int) -> Iterable[int]:
if depth < len(prefix):
i = prefix[depth]
if i >= n:
i = n - 1
yield from range(i, -1, -1)
yield from range(n - 1, i, -1)
else:
yield from range(n - 1, -1, -1)
return selection_order
def random_selection_order(random: Random) -> Callable[[int, int], Iterable[int]]:
"""Select choices uniformly at random."""
def selection_order(depth: int, n: int) -> Iterable[int]:
pending = LazySequenceCopy(range(n))
while pending:
yield pop_random(random, pending)
return selection_order
class Chooser:
"""A source of nondeterminism for use in shrink passes."""
def __init__(
self,
tree: "ChoiceTree",
selection_order: Callable[[int, int], Iterable[int]],
):
self.__selection_order = selection_order
self.__node_trail = [tree.root]
self.__choices: "List[int]" = []
self.__finished = False
def choose(
self,
values: Sequence[int],
condition: Callable[[int], bool] = lambda x: True,
) -> int:
"""Return some element of values satisfying the condition
that will not lead to an exhausted branch, or raise DeadBranch
if no such element exist".
"""
assert not self.__finished
node = self.__node_trail[-1]
if node.live_child_count is None:
node.live_child_count = len(values)
node.n = len(values)
assert node.live_child_count > 0 or len(values) == 0
for i in self.__selection_order(len(self.__choices), len(values)):
if node.live_child_count == 0:
break
if not node.children[i].exhausted:
v = values[i]
if condition(v):
self.__choices.append(i)
self.__node_trail.append(node.children[i])
return v
else:
node.children[i] = DeadNode
node.live_child_count -= 1
assert node.live_child_count == 0
raise DeadBranch
def finish(self) -> Sequence[int]:
"""Record the decisions made in the underlying tree and return
a prefix that can be used for the next Chooser to be used."""
self.__finished = True
assert len(self.__node_trail) == len(self.__choices) + 1
result = tuple(self.__choices)
self.__node_trail[-1].live_child_count = 0
while len(self.__node_trail) > 1 and self.__node_trail[-1].exhausted:
self.__node_trail.pop()
assert len(self.__node_trail) == len(self.__choices)
i = self.__choices.pop()
target = self.__node_trail[-1]
target.children[i] = DeadNode
assert target.live_child_count is not None
target.live_child_count -= 1
return result
class ChoiceTree:
"""Records sequences of choices made during shrinking so that we
can track what parts of a pass has run. Used to create Chooser
objects that are the main interface that a pass uses to make
decisions about what to do.
"""
def __init__(self) -> None:
self.root = TreeNode()
@property
def exhausted(self) -> bool:
return self.root.exhausted
def step(
self,
selection_order: Callable[[int, int], Iterable[int]],
f: Callable[[Chooser], None],
) -> Sequence[int]:
assert not self.exhausted
chooser = Chooser(self, selection_order)
try:
f(chooser)
except DeadBranch:
pass
return chooser.finish()
class TreeNode:
def __init__(self) -> None:
self.children: Dict[int, TreeNode] = defaultdict(TreeNode)
self.live_child_count: "Optional[int]" = None
self.n: "Optional[int]" = None
@property
def exhausted(self) -> bool:
return self.live_child_count == 0
DeadNode = TreeNode()
DeadNode.live_child_count = 0
class DeadBranch(Exception):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,427 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import attr
from hypothesis.errors import Flaky, HypothesisException, StopTest
from hypothesis.internal.compat import int_to_bytes
from hypothesis.internal.conjecture.data import (
ConjectureData,
DataObserver,
Status,
bits_to_bytes,
)
from hypothesis.internal.conjecture.junkdrawer import IntList
class PreviouslyUnseenBehaviour(HypothesisException):
pass
def inconsistent_generation():
raise Flaky(
"Inconsistent data generation! Data generation behaved differently "
"between different runs. Is your data generation depending on external "
"state?"
)
EMPTY: frozenset = frozenset()
@attr.s(slots=True)
class Killed:
"""Represents a transition to part of the tree which has been marked as
"killed", meaning we want to treat it as not worth exploring, so it will
be treated as if it were completely explored for the purposes of
exhaustion."""
next_node = attr.ib()
@attr.s(slots=True)
class Branch:
"""Represents a transition where multiple choices can be made as to what
to drawn."""
bit_length = attr.ib()
children = attr.ib(repr=False)
@property
def max_children(self):
return 1 << self.bit_length
@attr.s(slots=True, frozen=True)
class Conclusion:
"""Represents a transition to a finished state."""
status = attr.ib()
interesting_origin = attr.ib()
@attr.s(slots=True)
class TreeNode:
"""Node in a tree that corresponds to previous interactions with
a ``ConjectureData`` object according to some fixed test function.
This is functionally a variant patricia trie.
See https://en.wikipedia.org/wiki/Radix_tree for the general idea,
but what this means in particular here is that we have a very deep
but very lightly branching tree and rather than store this as a fully
recursive structure we flatten prefixes and long branches into
lists. This significantly compacts the storage requirements.
A single ``TreeNode`` corresponds to a previously seen sequence
of calls to ``ConjectureData`` which we have never seen branch,
followed by a ``transition`` which describes what happens next.
"""
# Records the previous sequence of calls to ``data.draw_bits``,
# with the ``n_bits`` argument going in ``bit_lengths`` and the
# values seen in ``values``. These should always have the same
# length.
bit_lengths = attr.ib(factory=IntList)
values = attr.ib(factory=IntList)
# The indices of of the calls to ``draw_bits`` that we have stored
# where ``forced`` is not None. Stored as None if no indices
# have been forced, purely for space saving reasons (we force
# quite rarely).
__forced = attr.ib(default=None, init=False)
# What happens next after observing this sequence of calls.
# Either:
#
# * ``None``, indicating we don't know yet.
# * A ``Branch`` object indicating that there is a ``draw_bits``
# call that we have seen take multiple outcomes there.
# * A ``Conclusion`` object indicating that ``conclude_test``
# was called here.
transition = attr.ib(default=None)
# A tree node is exhausted if every possible sequence of
# draws below it has been explored. We store this information
# on a field and update it when performing operations that
# could change the answer.
#
# A node may start exhausted, e.g. because it it leads
# immediately to a conclusion, but can only go from
# non-exhausted to exhausted when one of its children
# becomes exhausted or it is marked as a conclusion.
#
# Therefore we only need to check whether we need to update
# this field when the node is first created in ``split_at``
# or when we have walked a path through this node to a
# conclusion in ``TreeRecordingObserver``.
is_exhausted = attr.ib(default=False, init=False)
@property
def forced(self):
if not self.__forced:
return EMPTY
return self.__forced
def mark_forced(self, i):
"""Note that the value at index ``i`` was forced."""
assert 0 <= i < len(self.values)
if self.__forced is None:
self.__forced = set()
self.__forced.add(i)
def split_at(self, i):
"""Splits the tree so that it can incorporate
a decision at the ``draw_bits`` call corresponding
to position ``i``, or raises ``Flaky`` if that was
meant to be a forced node."""
if i in self.forced:
inconsistent_generation()
assert not self.is_exhausted
key = self.values[i]
child = TreeNode(
bit_lengths=self.bit_lengths[i + 1 :],
values=self.values[i + 1 :],
transition=self.transition,
)
self.transition = Branch(bit_length=self.bit_lengths[i], children={key: child})
if self.__forced is not None:
child.__forced = {j - i - 1 for j in self.__forced if j > i}
self.__forced = {j for j in self.__forced if j < i}
child.check_exhausted()
del self.values[i:]
del self.bit_lengths[i:]
assert len(self.values) == len(self.bit_lengths) == i
def check_exhausted(self):
"""Recalculates ``self.is_exhausted`` if necessary then returns
it."""
if (
not self.is_exhausted
and len(self.forced) == len(self.values)
and self.transition is not None
):
if isinstance(self.transition, (Conclusion, Killed)):
self.is_exhausted = True
elif len(self.transition.children) == self.transition.max_children:
self.is_exhausted = all(
v.is_exhausted for v in self.transition.children.values()
)
return self.is_exhausted
class DataTree:
"""Tracks the tree structure of a collection of ConjectureData
objects, for use in ConjectureRunner."""
def __init__(self):
self.root = TreeNode()
@property
def is_exhausted(self):
"""Returns True if every possible node is dead and thus the language
described must have been fully explored."""
return self.root.is_exhausted
def generate_novel_prefix(self, random):
"""Generate a short random string that (after rewriting) is not
a prefix of any buffer previously added to the tree.
The resulting prefix is essentially arbitrary - it would be nice
for it to be uniform at random, but previous attempts to do that
have proven too expensive.
"""
assert not self.is_exhausted
novel_prefix = bytearray()
def append_int(n_bits, value):
novel_prefix.extend(int_to_bytes(value, bits_to_bytes(n_bits)))
current_node = self.root
while True:
assert not current_node.is_exhausted
for i, (n_bits, value) in enumerate(
zip(current_node.bit_lengths, current_node.values)
):
if i in current_node.forced:
append_int(n_bits, value)
else:
while True:
k = random.getrandbits(n_bits)
if k != value:
append_int(n_bits, k)
break
# We've now found a value that is allowed to
# vary, so what follows is not fixed.
return bytes(novel_prefix)
else:
assert not isinstance(current_node.transition, (Conclusion, Killed))
if current_node.transition is None:
return bytes(novel_prefix)
branch = current_node.transition
assert isinstance(branch, Branch)
n_bits = branch.bit_length
check_counter = 0
while True:
k = random.getrandbits(n_bits)
try:
child = branch.children[k]
except KeyError:
append_int(n_bits, k)
return bytes(novel_prefix)
if not child.is_exhausted:
append_int(n_bits, k)
current_node = child
break
check_counter += 1
# We don't expect this assertion to ever fire, but coverage
# wants the loop inside to run if you have branch checking
# on, hence the pragma.
assert ( # pragma: no cover
check_counter != 1000
or len(branch.children) < (2**n_bits)
or any(not v.is_exhausted for v in branch.children.values())
)
def rewrite(self, buffer):
"""Use previously seen ConjectureData objects to return a tuple of
the rewritten buffer and the status we would get from running that
buffer with the test function. If the status cannot be predicted
from the existing values it will be None."""
buffer = bytes(buffer)
data = ConjectureData.for_buffer(buffer)
try:
self.simulate_test_function(data)
return (data.buffer, data.status)
except PreviouslyUnseenBehaviour:
return (buffer, None)
def simulate_test_function(self, data):
"""Run a simulated version of the test function recorded by
this tree. Note that this does not currently call ``stop_example``
or ``start_example`` as these are not currently recorded in the
tree. This will likely change in future."""
node = self.root
try:
while True:
for i, (n_bits, previous) in enumerate(
zip(node.bit_lengths, node.values)
):
v = data.draw_bits(
n_bits, forced=node.values[i] if i in node.forced else None
)
if v != previous:
raise PreviouslyUnseenBehaviour
if isinstance(node.transition, Conclusion):
t = node.transition
data.conclude_test(t.status, t.interesting_origin)
elif node.transition is None:
raise PreviouslyUnseenBehaviour
elif isinstance(node.transition, Branch):
v = data.draw_bits(node.transition.bit_length)
try:
node = node.transition.children[v]
except KeyError as err:
raise PreviouslyUnseenBehaviour from err
else:
assert isinstance(node.transition, Killed)
data.observer.kill_branch()
node = node.transition.next_node
except StopTest:
pass
def new_observer(self):
return TreeRecordingObserver(self)
class TreeRecordingObserver(DataObserver):
def __init__(self, tree):
self.__current_node = tree.root
self.__index_in_current_node = 0
self.__trail = [self.__current_node]
self.killed = False
def draw_bits(self, n_bits, forced, value):
i = self.__index_in_current_node
self.__index_in_current_node += 1
node = self.__current_node
assert len(node.bit_lengths) == len(node.values)
if i < len(node.bit_lengths):
if n_bits != node.bit_lengths[i]:
inconsistent_generation()
# Note that we don't check whether a previously
# forced value is now free. That will be caught
# if we ever split the node there, but otherwise
# may pass silently. This is acceptable because it
# means we skip a hash set lookup on every
# draw and that's a pretty niche failure mode.
if forced and i not in node.forced:
inconsistent_generation()
if value != node.values[i]:
node.split_at(i)
assert i == len(node.values)
new_node = TreeNode()
branch = node.transition
branch.children[value] = new_node
self.__current_node = new_node
self.__index_in_current_node = 0
else:
trans = node.transition
if trans is None:
node.bit_lengths.append(n_bits)
node.values.append(value)
if forced:
node.mark_forced(i)
elif isinstance(trans, Conclusion):
assert trans.status != Status.OVERRUN
# We tried to draw where history says we should have
# stopped
inconsistent_generation()
else:
assert isinstance(trans, Branch), trans
if n_bits != trans.bit_length:
inconsistent_generation()
try:
self.__current_node = trans.children[value]
except KeyError:
self.__current_node = trans.children.setdefault(value, TreeNode())
self.__index_in_current_node = 0
if self.__trail[-1] is not self.__current_node:
self.__trail.append(self.__current_node)
def kill_branch(self):
"""Mark this part of the tree as not worth re-exploring."""
if self.killed:
return
self.killed = True
if self.__index_in_current_node < len(self.__current_node.values) or (
self.__current_node.transition is not None
and not isinstance(self.__current_node.transition, Killed)
):
inconsistent_generation()
if self.__current_node.transition is None:
self.__current_node.transition = Killed(TreeNode())
self.__update_exhausted()
self.__current_node = self.__current_node.transition.next_node
self.__index_in_current_node = 0
self.__trail.append(self.__current_node)
def conclude_test(self, status, interesting_origin):
"""Says that ``status`` occurred at node ``node``. This updates the
node if necessary and checks for consistency."""
if status == Status.OVERRUN:
return
i = self.__index_in_current_node
node = self.__current_node
if i < len(node.values) or isinstance(node.transition, Branch):
inconsistent_generation()
new_transition = Conclusion(status, interesting_origin)
if node.transition is not None and node.transition != new_transition:
# As an, I'm afraid, horrible bodge, we deliberately ignore flakiness
# where tests go from interesting to valid, because it's much easier
# to produce good error messages for these further up the stack.
if isinstance(node.transition, Conclusion) and (
node.transition.status != Status.INTERESTING
or new_transition.status != Status.VALID
):
raise Flaky(
f"Inconsistent test results! Test case was {node.transition!r} "
f"on first run but {new_transition!r} on second"
)
else:
node.transition = new_transition
assert node is self.__trail[-1]
node.check_exhausted()
assert len(node.values) > 0 or node.check_exhausted()
if not self.killed:
self.__update_exhausted()
def __update_exhausted(self):
for t in reversed(self.__trail):
# Any node we've traversed might have now become exhausted.
# We check from the right. As soon as we hit a node that
# isn't exhausted, this automatically implies that all of
# its parents are not exhausted, so we stop.
if not t.check_exhausted():
break

View File

@@ -0,0 +1,674 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import threading
from collections import Counter, defaultdict, deque
from math import inf
from hypothesis.internal.reflection import proxies
def cached(fn):
@proxies(fn)
def wrapped(self, *args):
cache = self._DFA__cache(fn.__name__)
try:
return cache[args]
except KeyError:
return cache.setdefault(args, fn(self, *args))
return wrapped
class DFA:
"""Base class for implementations of deterministic finite
automata.
This is abstract to allow for the possibility of states
being calculated lazily as we traverse the DFA (which
we make heavy use of in our L* implementation - see
lstar.py for details).
States can be of any hashable type.
"""
def __init__(self):
self.__caches = threading.local()
def __cache(self, name):
try:
cache = getattr(self.__caches, name)
except AttributeError:
cache = {}
setattr(self.__caches, name, cache)
return cache
@property
def start(self):
"""Returns the starting state."""
raise NotImplementedError
def is_accepting(self, i):
"""Returns if state ``i`` is an accepting one."""
raise NotImplementedError
def transition(self, i, c):
"""Returns the state that i transitions to on reading
character c from a string."""
raise NotImplementedError
@property
def alphabet(self):
return range(256)
def transitions(self, i):
"""Iterates over all pairs (byte, state) of transitions
which do not lead to dead states."""
for c, j in self.raw_transitions(i):
if not self.is_dead(j):
yield c, j
@cached
def transition_counts(self, state):
counts = Counter()
for _, j in self.transitions(state):
counts[j] += 1
return list(counts.items())
def matches(self, s):
"""Returns whether the string ``s`` is accepted
by this automaton."""
i = self.start
for c in s:
i = self.transition(i, c)
return self.is_accepting(i)
def all_matching_regions(self, string):
"""Return all pairs ``(u, v)`` such that ``self.matches(string[u:v])``."""
# Stack format: (k, state, indices). After reading ``k`` characters
# starting from any i in ``indices`` the DFA would be at ``state``.
stack = [(0, self.start, range(len(string)))]
results = []
while stack:
k, state, indices = stack.pop()
# If the state is dead, abort early - no point continuing on
# from here where there will be no more matches.
if self.is_dead(state):
continue
# If the state is accepting, then every one of these indices
# has a matching region of length ``k`` starting from it.
if self.is_accepting(state):
results.extend([(i, i + k) for i in indices])
next_by_state = defaultdict(list)
for i in indices:
if i + k < len(string):
c = string[i + k]
next_by_state[self.transition(state, c)].append(i)
for next_state, next_indices in next_by_state.items():
stack.append((k + 1, next_state, next_indices))
return results
def max_length(self, i):
"""Returns the maximum length of a string that is
accepted when starting from i."""
if self.is_dead(i):
return 0
cache = self.__cache("max_length")
try:
return cache[i]
except KeyError:
pass
# Naively we can calculate this as 1 longer than the
# max length of the non-dead states this can immediately
# transition to, but a) We don't want unbounded recursion
# because that's how you get RecursionErrors and b) This
# makes it hard to look for cycles. So we basically do
# the recursion explicitly with a stack, but we maintain
# a parallel set that tracks what's already on the stack
# so that when we encounter a loop we can immediately
# determine that the max length here is infinite.
stack = [i]
stack_set = {i}
def pop():
"""Remove the top element from the stack, maintaining
the stack set appropriately."""
assert len(stack) == len(stack_set)
j = stack.pop()
stack_set.remove(j)
assert len(stack) == len(stack_set)
while stack:
j = stack[-1]
assert not self.is_dead(j)
# If any of the children have infinite max_length we don't
# need to check all of them to know that this state does
# too.
if any(cache.get(k) == inf for k in self.successor_states(j)):
cache[j] = inf
pop()
continue
# Recurse to the first child node that we have not yet
# calculated max_length for.
for k in self.successor_states(j):
if k in stack_set:
# k is part of a loop and is known to be live
# (since we never push dead states on the stack),
# so it can reach strings of unbounded length.
assert not self.is_dead(k)
cache[k] = inf
break
elif k not in cache and not self.is_dead(k):
stack.append(k)
stack_set.add(k)
break
else:
# All of j's successors have a known max_length or are dead,
# so we can now compute a max_length for j itself.
cache[j] = max(
(
1 + cache[k]
for k in self.successor_states(j)
if not self.is_dead(k)
),
default=0,
)
# j is live so it must either be accepting or have a live child.
assert self.is_accepting(j) or cache[j] > 0
pop()
return cache[i]
@cached
def has_strings(self, state, length):
"""Returns if any strings of length ``length`` are accepted when
starting from state ``state``."""
assert length >= 0
cache = self.__cache("has_strings")
try:
return cache[state, length]
except KeyError:
pass
pending = [(state, length)]
seen = set()
i = 0
while i < len(pending):
s, n = pending[i]
i += 1
if n > 0:
for t in self.successor_states(s):
key = (t, n - 1)
if key not in cache and key not in seen:
pending.append(key)
seen.add(key)
while pending:
s, n = pending.pop()
if n == 0:
cache[s, n] = self.is_accepting(s)
else:
cache[s, n] = any(
cache.get((t, n - 1)) for t in self.successor_states(s)
)
return cache[state, length]
def count_strings(self, state, length):
"""Returns the number of strings of length ``length``
that are accepted when starting from state ``state``."""
assert length >= 0
cache = self.__cache("count_strings")
try:
return cache[state, length]
except KeyError:
pass
pending = [(state, length)]
seen = set()
i = 0
while i < len(pending):
s, n = pending[i]
i += 1
if n > 0:
for t in self.successor_states(s):
key = (t, n - 1)
if key not in cache and key not in seen:
pending.append(key)
seen.add(key)
while pending:
s, n = pending.pop()
if n == 0:
cache[s, n] = int(self.is_accepting(s))
else:
cache[s, n] = sum(
cache[t, n - 1] * k for t, k in self.transition_counts(s)
)
return cache[state, length]
@cached
def successor_states(self, state):
"""Returns all of the distinct states that can be reached via one
transition from ``state``, in the lexicographic order of the
smallest character that reaches them."""
seen = set()
result = []
for _, j in self.raw_transitions(state):
if j not in seen:
seen.add(j)
result.append(j)
return tuple(result)
def is_dead(self, state):
"""Returns True if no strings can be accepted
when starting from ``state``."""
return not self.is_live(state)
def is_live(self, state):
"""Returns True if any strings can be accepted
when starting from ``state``."""
if self.is_accepting(state):
return True
# We work this out by calculating is_live for all nodes
# reachable from state which have not already had it calculated.
cache = self.__cache("is_live")
try:
return cache[state]
except KeyError:
pass
# roots are states that we know already must be live,
# either because we have previously calculated them to
# be or because they are an accepting state.
roots = set()
# We maintain a backwards graph where ``j in backwards_graph[k]``
# if there is a transition from j to k. Thus if a key in this
# graph is live, so must all its values be.
backwards_graph = defaultdict(set)
# First we find all reachable nodes from i which have not
# already been cached, noting any which are roots and
# populating the backwards graph.
explored = set()
queue = deque([state])
while queue:
j = queue.popleft()
if cache.get(j, self.is_accepting(j)):
# If j can be immediately determined to be live
# then there is no point in exploring beneath it,
# because any effect of states below it is screened
# off by the known answer for j.
roots.add(j)
continue
if j in cache:
# Likewise if j is known to be dead then there is
# no point exploring beneath it because we know
# that all nodes reachable from it must be dead.
continue
if j in explored:
continue
explored.add(j)
for k in self.successor_states(j):
backwards_graph[k].add(j)
queue.append(k)
marked_live = set()
queue = deque(roots)
while queue:
j = queue.popleft()
if j in marked_live:
continue
marked_live.add(j)
for k in backwards_graph[j]:
queue.append(k)
for j in explored:
cache[j] = j in marked_live
return cache[state]
def all_matching_strings_of_length(self, k):
"""Yields all matching strings whose length is ``k``, in ascending
lexicographic order."""
if k == 0:
if self.is_accepting(self.start):
yield b""
return
if not self.has_strings(self.start, k):
return
# This tracks a path through the DFA. We alternate between growing
# it until it has length ``k`` and is in an accepting state, then
# yielding that as a result, then modifying it so that the next
# time we do that it will yield the lexicographically next matching
# string.
path = bytearray()
# Tracks the states that are visited by following ``path`` from the
# starting point.
states = [self.start]
while True:
# First we build up our current best prefix to the lexicographically
# first string starting with it.
while len(path) < k:
state = states[-1]
for c, j in self.transitions(state):
if self.has_strings(j, k - len(path) - 1):
states.append(j)
path.append(c)
break
else:
raise NotImplementedError("Should be unreachable")
assert self.is_accepting(states[-1])
assert len(states) == len(path) + 1
yield bytes(path)
# Now we want to replace this string with the prefix that will
# cause us to extend to its lexicographic successor. This can
# be thought of as just repeatedly moving to the next lexicographic
# successor until we find a matching string, but we're able to
# use our length counts to jump over long sequences where there
# cannot be a match.
while True:
# As long as we are in this loop we are trying to move to
# the successor of the current string.
# If we've removed the entire prefix then we're done - no
# successor is possible.
if not path:
return
if path[-1] == 255:
# If our last element is maximal then the we have to "carry
# the one" - our lexicographic successor must be incremented
# earlier than this.
path.pop()
states.pop()
else:
# Otherwise increment by one.
path[-1] += 1
states[-1] = self.transition(states[-2], path[-1])
# If there are no strings of the right length starting from
# this prefix we need to keep going. Otherwise, this is
# the right place to be and we break out of our loop of
# trying to find the successor because it starts here.
if self.count_strings(states[-1], k - len(path)) > 0:
break
def all_matching_strings(self, min_length=0):
"""Iterate over all strings matched by this automaton
in shortlex-ascending order."""
# max_length might be infinite, hence the while loop
max_length = self.max_length(self.start)
length = min_length
while length <= max_length:
yield from self.all_matching_strings_of_length(length)
length += 1
def raw_transitions(self, i):
for c in self.alphabet:
j = self.transition(i, c)
yield c, j
def canonicalise(self):
"""Return a canonical version of ``self`` as a ConcreteDFA.
The DFA is not minimized, but nodes are sorted and relabelled
and dead nodes are pruned, so two minimized DFAs for the same
language will end up with identical canonical representatives.
This is mildly important because it means that the output of
L* should produce the same canonical DFA regardless of what
order we happen to have run it in.
"""
# We map all states to their index of appearance in depth
# first search. This both is useful for canonicalising and
# also allows for states that aren't integers.
state_map = {}
reverse_state_map = []
accepting = set()
seen = set()
queue = deque([self.start])
while queue:
state = queue.popleft()
if state in state_map:
continue
i = len(reverse_state_map)
if self.is_accepting(state):
accepting.add(i)
reverse_state_map.append(state)
state_map[state] = i
for _, j in self.transitions(state):
if j in seen:
continue
seen.add(j)
queue.append(j)
transitions = [
{c: state_map[s] for c, s in self.transitions(t)} for t in reverse_state_map
]
result = ConcreteDFA(transitions, accepting)
assert self.equivalent(result)
return result
def equivalent(self, other):
"""Checks whether this DFA and other match precisely the same
language.
Uses the classic algorithm of Hopcroft and Karp (more or less):
Hopcroft, John E. A linear algorithm for testing equivalence
of finite automata. Vol. 114. Defense Technical Information Center, 1971.
"""
# The basic idea of this algorithm is that we repeatedly
# merge states that would be equivalent if the two start
# states were. This starts by merging the two start states,
# and whenever we merge two states merging all pairs of
# states that are reachable by following the same character
# from that point.
#
# Whenever we merge two states, we check if one of them
# is accepting and the other non-accepting. If so, we have
# obtained a contradiction and have made a bad merge, so
# the two start states must not have been equivalent in the
# first place and we return False.
#
# If the languages matched are different then some string
# is contained in one but not the other. By looking at
# the pairs of states visited by traversing the string in
# each automaton in parallel, we eventually come to a pair
# of states that would have to be merged by this algorithm
# where one is accepting and the other is not. Thus this
# algorithm always returns False as a result of a bad merge
# if the two languages are not the same.
#
# If we successfully complete all merges without a contradiction
# we can thus safely return True.
# We maintain a union/find table for tracking merges of states.
table = {}
def find(s):
trail = [s]
while trail[-1] in table and table[trail[-1]] != trail[-1]:
trail.append(table[trail[-1]])
for t in trail:
table[t] = trail[-1]
return trail[-1]
def union(s, t):
s = find(s)
t = find(t)
table[s] = t
alphabet = sorted(set(self.alphabet) | set(other.alphabet))
queue = deque([((self.start, other.start))])
while queue:
self_state, other_state = queue.popleft()
# We use a DFA/state pair for keys because the same value
# may represent a different state in each DFA.
self_key = (self, self_state)
other_key = (other, other_state)
# We have already merged these, no need to remerge.
if find(self_key) == find(other_key):
continue
# We have found a contradiction, therefore the two DFAs must
# not be equivalent.
if self.is_accepting(self_state) != other.is_accepting(other_state):
return False
# Merge the two states
union(self_key, other_key)
# And also queue any logical consequences of merging those
# two states for merging.
for c in alphabet:
queue.append(
(self.transition(self_state, c), other.transition(other_state, c))
)
return True
DEAD = "DEAD"
class ConcreteDFA(DFA):
"""A concrete representation of a DFA in terms of an explicit list
of states."""
def __init__(self, transitions, accepting, start=0):
"""
* ``transitions`` is a list where transitions[i] represents the
valid transitions out of state ``i``. Elements may be either dicts
(in which case they map characters to other states) or lists. If they
are a list they may contain tuples of length 2 or 3. A tuple ``(c, j)``
indicates that this state transitions to state ``j`` given ``c``. A
tuple ``(u, v, j)`` indicates this state transitions to state ``j``
given any ``c`` with ``u <= c <= v``.
* ``accepting`` is a set containing the integer labels of accepting
states.
* ``start`` is the integer label of the starting state.
"""
super().__init__()
self.__start = start
self.__accepting = accepting
self.__transitions = list(transitions)
def __repr__(self):
transitions = []
# Particularly for including in source code it's nice to have the more
# compact repr, so where possible we convert to the tuple based representation
# which can represent ranges more compactly.
for i in range(len(self.__transitions)):
table = []
for c, j in self.transitions(i):
if not table or j != table[-1][-1] or c != table[-1][1] + 1:
table.append([c, c, j])
else:
table[-1][1] = c
transitions.append([(u, j) if u == v else (u, v, j) for u, v, j in table])
start = "" if self.__start == 0 else f", start={self.__start!r}"
return f"ConcreteDFA({transitions!r}, {self.__accepting!r}{start})"
@property
def start(self):
return self.__start
def is_accepting(self, i):
return i in self.__accepting
def transition(self, state, char):
"""Returns the state that i transitions to on reading
character c from a string."""
if state == DEAD:
return DEAD
table = self.__transitions[state]
# Given long transition tables we convert them to
# dictionaries for more efficient lookup.
if not isinstance(table, dict) and len(table) >= 5:
new_table = {}
for t in table:
if len(t) == 2:
new_table[t[0]] = t[1]
else:
u, v, j = t
for c in range(u, v + 1):
new_table[c] = j
self.__transitions[state] = new_table
table = new_table
if isinstance(table, dict):
try:
return self.__transitions[state][char]
except KeyError:
return DEAD
else:
for t in table:
if len(t) == 2:
if t[0] == char:
return t[1]
else:
u, v, j = t
if u <= char <= v:
return j
return DEAD
def raw_transitions(self, i):
if i == DEAD:
return
transitions = self.__transitions[i]
if isinstance(transitions, dict):
yield from sorted(transitions.items())
else:
for t in transitions:
if len(t) == 2:
yield t
else:
u, v, j = t
for c in range(u, v + 1):
yield c, j

View File

@@ -0,0 +1,498 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from bisect import bisect_right, insort
from collections import Counter
import attr
from hypothesis.errors import InvalidState
from hypothesis.internal.conjecture.dfa import DFA, cached
from hypothesis.internal.conjecture.junkdrawer import (
IntList,
NotFound,
SelfOrganisingList,
find_integer,
)
"""
This module contains an implementation of the L* algorithm
for learning a deterministic finite automaton based on an
unknown membership function and a series of examples of
strings that may or may not satisfy it.
The two relevant papers for understanding this are:
* Angluin, Dana. "Learning regular sets from queries and counterexamples."
Information and computation 75.2 (1987): 87-106.
* Rivest, Ronald L., and Robert E. Schapire. "Inference of finite automata
using homing sequences." Information and Computation 103.2 (1993): 299-347.
Note that we only use the material from section 4.5 "Improving Angluin's L*
algorithm" (page 318), and all of the rest of the material on homing
sequences can be skipped.
The former explains the core algorithm, the latter a modification
we use (which we have further modified) which allows it to
be implemented more efficiently.
Although we continue to call this L*, we in fact depart heavily from it to the
point where honestly this is an entirely different algorithm and we should come
up with a better name.
We have several major departures from the papers:
1. We learn the automaton lazily as we traverse it. This is particularly
valuable because if we make many corrections on the same string we only
have to learn the transitions that correspond to the string we are
correcting on.
2. We make use of our ``find_integer`` method rather than a binary search
as proposed in the Rivest and Schapire paper, as we expect that
usually most strings will be mispredicted near the beginning.
3. We try to learn a smaller alphabet of "interestingly distinct"
values. e.g. if all bytes larger than two result in an invalid
string, there is no point in distinguishing those bytes. In aid
of this we learn a single canonicalisation table which maps integers
to smaller integers that we currently think are equivalent, and learn
their inequivalence where necessary. This may require more learning
steps, as at each stage in the process we might learn either an
inequivalent pair of integers or a new experiment, but it may greatly
reduce the number of membership queries we have to make.
In addition, we have a totally different approach for mapping a string to its
canonical representative, which will be explained below inline. The general gist
is that our implementation is much more willing to make mistakes: It will often
create a DFA that is demonstrably wrong, based on information that it already
has, but where it is too expensive to discover that before it causes us to
make a mistake.
A note on performance: This code is not really fast enough for
us to ever want to run in production on large strings, and this
is somewhat intrinsic. We should only use it in testing or for
learning languages offline that we can record for later use.
"""
@attr.s(slots=True)
class DistinguishedState:
"""Relevant information for a state that we have witnessed as definitely
distinct from ones we have previously seen so far."""
# Index of this state in the learner's list of states
index: int = attr.ib()
# A string that witnesses this state (i.e. when starting from the origin
# and following this string you will end up in this state).
label: str = attr.ib()
# A boolean as to whether this is an accepting state.
accepting: bool = attr.ib()
# A list of experiments that it is necessary to run to determine whether
# a string is in this state. This is stored as a dict mapping experiments
# to their expected result. A string is only considered to lead to this
# state if ``all(learner.member(s + experiment) == result for experiment,
# result in self.experiments.items())``.
experiments: dict = attr.ib()
# A cache of transitions out of this state, mapping bytes to the states
# that they lead to.
transitions: dict = attr.ib(factory=dict)
class LStar:
"""This class holds the state for learning a DFA. The current DFA can be
accessed as the ``dfa`` member of this class. Such a DFA becomes invalid
as soon as ``learn`` has been called, and should only be used until the
next call to ``learn``.
Note that many of the DFA methods are on this class, but it is not itself
a DFA. The reason for this is that it stores mutable state which can cause
the structure of the learned DFA to change in potentially arbitrary ways,
making all cached properties become nonsense.
"""
def __init__(self, member):
self.experiments = []
self.__experiment_set = set()
self.normalizer = IntegerNormalizer()
self.__member_cache = {}
self.__member = member
self.__generation = 0
# A list of all state objects that correspond to strings we have
# seen and can demonstrate map to unique states.
self.__states = [
DistinguishedState(
index=0,
label=b"",
accepting=self.member(b""),
experiments={b"": self.member(b"")},
)
]
# When we're trying to figure out what state a string leads to we will
# end up searching to find a suitable candidate. By putting states in
# a self-organising list we ideally minimise the number of lookups.
self.__self_organising_states = SelfOrganisingList(self.__states)
self.start = 0
self.__dfa_changed()
def __dfa_changed(self):
"""Note that something has changed, updating the generation
and resetting any cached state."""
self.__generation += 1
self.dfa = LearnedDFA(self)
def is_accepting(self, i):
"""Equivalent to ``self.dfa.is_accepting(i)``"""
return self.__states[i].accepting
def label(self, i):
"""Returns the string label for state ``i``."""
return self.__states[i].label
def transition(self, i, c):
"""Equivalent to ``self.dfa.transition(i, c)```"""
c = self.normalizer.normalize(c)
state = self.__states[i]
try:
return state.transitions[c]
except KeyError:
pass
# The state that we transition to when reading ``c`` is reached by
# this string, because this state is reached by state.label. We thus
# want our candidate for the transition to be some state with a label
# equivalent to this string.
#
# We find such a state by looking for one such that all of its listed
# experiments agree on the result for its state label and this string.
string = state.label + bytes([c])
# We keep track of some useful experiments for distinguishing this
# string from other states, as this both allows us to more accurately
# select the state to map to and, if necessary, create the new state
# that this string corresponds to with a decent set of starting
# experiments.
accumulated = {}
counts = Counter()
def equivalent(t):
"""Checks if ``string`` could possibly lead to state ``t``."""
for e, expected in accumulated.items():
if self.member(t.label + e) != expected:
counts[e] += 1
return False
for e, expected in t.experiments.items():
result = self.member(string + e)
if result != expected:
# We expect most experiments to return False so if we add
# only True ones to our collection of essential experiments
# we keep the size way down and select only ones that are
# likely to provide useful information in future.
if result:
accumulated[e] = result
return False
return True
try:
destination = self.__self_organising_states.find(equivalent)
except NotFound:
i = len(self.__states)
destination = DistinguishedState(
index=i,
label=string,
experiments=accumulated,
accepting=self.member(string),
)
self.__states.append(destination)
self.__self_organising_states.add(destination)
state.transitions[c] = destination.index
return destination.index
def member(self, s):
"""Check whether this string is a member of the language
to be learned."""
try:
return self.__member_cache[s]
except KeyError:
result = self.__member(s)
self.__member_cache[s] = result
return result
@property
def generation(self):
"""Return an integer value that will be incremented
every time the DFA we predict changes."""
return self.__generation
def learn(self, string):
"""Learn to give the correct answer on this string.
That is, after this method completes we will have
``self.dfa.matches(s) == self.member(s)``.
Note that we do not guarantee that this will remain
true in the event that learn is called again with
a different string. It is in principle possible that
future learning will cause us to make a mistake on
this string. However, repeatedly calling learn on
each of a set of strings until the generation stops
changing is guaranteed to terminate.
"""
string = bytes(string)
correct_outcome = self.member(string)
# We don't want to check this inside the loop because it potentially
# causes us to evaluate more of the states than we actually need to,
# but if our model is mostly correct then this will be faster because
# we only need to evaluate strings that are of the form
# ``state + experiment``, which will generally be cached and/or needed
# later.
if self.dfa.matches(string) == correct_outcome:
return
# In the papers they assume that we only run this process
# once, but this is silly - often when you've got a messy
# string it will be wrong for many different reasons.
#
# Thus we iterate this to a fixed point where we repair
# the DFA by repeatedly adding experiments until the DFA
# agrees with the membership function on this string.
# First we make sure that normalization is not the source of the
# failure to match.
while True:
normalized = bytes(self.normalizer.normalize(c) for c in string)
# We can correctly replace the string with its normalized version
# so normalization is not the problem here.
if self.member(normalized) == correct_outcome:
string = normalized
break
alphabet = sorted(set(string), reverse=True)
target = string
for a in alphabet:
def replace(b):
if a == b:
return target
return bytes(b if c == a else c for c in target)
self.normalizer.distinguish(a, lambda x: self.member(replace(x)))
target = replace(self.normalizer.normalize(a))
assert self.member(target) == correct_outcome
assert target != normalized
self.__dfa_changed()
if self.dfa.matches(string) == correct_outcome:
return
# Now we know normalization is correct we can attempt to determine if
# any of our transitions are wrong.
while True:
dfa = self.dfa
states = [dfa.start]
def seems_right(n):
"""After reading n characters from s, do we seem to be
in the right state?
We determine this by replacing the first n characters
of s with the label of the state we expect to be in.
If we are in the right state, that will replace a substring
with an equivalent one so must produce the same answer.
"""
if n > len(string):
return False
# Populate enough of the states list to know where we are.
while n >= len(states):
states.append(dfa.transition(states[-1], string[len(states) - 1]))
return self.member(dfa.label(states[n]) + string[n:]) == correct_outcome
assert seems_right(0)
n = find_integer(seems_right)
# We got to the end without ever finding ourself in a bad
# state, so we must correctly match this string.
if n == len(string):
assert dfa.matches(string) == correct_outcome
break
# Reading n characters does not put us in a bad state but
# reading n + 1 does. This means that the remainder of
# the string that we have not read yet is an experiment
# that allows us to distinguish the state that we ended
# up in from the state that we should have ended up in.
source = states[n]
character = string[n]
wrong_destination = states[n + 1]
# We've made an error in transitioning from ``source`` to
# ``wrong_destination`` via ``character``. We now need to update
# the DFA so that this transition no longer occurs. Note that we
# do not guarantee that the transition is *correct* after this,
# only that we don't make this particular error.
assert self.transition(source, character) == wrong_destination
labels_wrong_destination = self.dfa.label(wrong_destination)
labels_correct_destination = self.dfa.label(source) + bytes([character])
ex = string[n + 1 :]
assert self.member(labels_wrong_destination + ex) != self.member(
labels_correct_destination + ex
)
# Adding this experiment causes us to distinguish the wrong
# destination from the correct one.
self.__states[wrong_destination].experiments[ex] = self.member(
labels_wrong_destination + ex
)
# We now clear the cached details that caused us to make this error
# so that when we recalculate this transition we get to a
# (hopefully now correct) different state.
del self.__states[source].transitions[character]
self.__dfa_changed()
# We immediately recalculate the transition so that we can check
# that it has changed as we expect it to have.
new_destination = self.transition(source, string[n])
assert new_destination != wrong_destination
class LearnedDFA(DFA):
"""This implements a lazily calculated DFA where states
are labelled by some string that reaches them, and are
distinguished by a membership test and a set of experiments."""
def __init__(self, lstar):
super().__init__()
self.__lstar = lstar
self.__generation = lstar.generation
def __check_changed(self):
if self.__generation != self.__lstar.generation:
raise InvalidState(
"The underlying L* model has changed, so this DFA is no longer valid. "
"If you want to preserve a previously learned DFA for posterity, call "
"canonicalise() on it first."
)
def label(self, i):
self.__check_changed()
return self.__lstar.label(i)
@property
def start(self):
self.__check_changed()
return self.__lstar.start
def is_accepting(self, i):
self.__check_changed()
return self.__lstar.is_accepting(i)
def transition(self, i, c):
self.__check_changed()
return self.__lstar.transition(i, c)
@cached
def successor_states(self, state):
"""Returns all of the distinct states that can be reached via one
transition from ``state``, in the lexicographic order of the
smallest character that reaches them."""
seen = set()
result = []
for c in self.__lstar.normalizer.representatives():
j = self.transition(state, c)
if j not in seen:
seen.add(j)
result.append(j)
return tuple(result)
class IntegerNormalizer:
"""A class for replacing non-negative integers with a
"canonical" value that is equivalent for all relevant
purposes."""
def __init__(self):
# We store canonical values as a sorted list of integers
# with each value being treated as equivalent to the largest
# integer in the list that is below it.
self.__values = IntList([0])
self.__cache = {}
def __repr__(self):
return f"IntegerNormalizer({list(self.__values)!r})"
def __copy__(self):
result = IntegerNormalizer()
result.__values = IntList(self.__values)
return result
def representatives(self):
yield from self.__values
def normalize(self, value):
"""Return the canonical integer considered equivalent
to ``value``."""
try:
return self.__cache[value]
except KeyError:
pass
i = bisect_right(self.__values, value) - 1
assert i >= 0
return self.__cache.setdefault(value, self.__values[i])
def distinguish(self, value, test):
"""Checks whether ``test`` gives the same answer for
``value`` and ``self.normalize(value)``. If it does
not, updates the list of canonical values so that
it does.
Returns True if and only if this makes a change to
the underlying canonical values."""
canonical = self.normalize(value)
if canonical == value:
return False
value_test = test(value)
if test(canonical) == value_test:
return False
self.__cache.clear()
def can_lower(k):
new_canon = value - k
if new_canon <= canonical:
return False
return test(new_canon) == value_test
new_canon = value - find_integer(can_lower)
assert new_canon not in self.__values
insort(self.__values, new_canon)
assert self.normalize(value) == new_canon
return True

Some files were not shown because too many files have changed in this diff Show More