# testing/fixtures.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

import contextlib
import itertools
import re
import sys

import sqlalchemy as sa
from . import assertions
from . import config
from . import schema
from .assertions import eq_
from .assertions import ne_
from .entities import BasicEntity
from .entities import ComparableEntity
from .entities import ComparableMixin  # noqa
from .util import adict
from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
from ..orm import declarative_base
from ..orm import registry
from ..orm.decl_api import DeclarativeMeta
from ..schema import sort_tables_and_constraints
from ..sql import visitors
from ..sql.elements import ClauseElement


@config.mark_base_test_class()
class TestBase(object):
    # A sequence of requirement names matching testing.requires decorators
    __requires__ = ()

    # A sequence of dialect names to exclude from the test class.
    __unsupported_on__ = ()

    # If present, test class is only runnable for the *single* specified
    # dialect.  If you need multiple, use __unsupported_on__ and invert.
    __only_on__ = None

    # A sequence of no-arg callables. If any are True, the entire testcase is
    # skipped.
    __skip_if__ = None

    # if True, the testing reaper will not attempt to touch connection
    # state after a test is completed and before the outer teardown
    # starts
    __leave_connections_for_teardown__ = False

    def assert_(self, val, msg=None):
        assert val, msg

    @config.fixture()
    def nocache(self):
        _cache = config.db._compiled_cache
        config.db._compiled_cache = None
        yield
        config.db._compiled_cache = _cache

    @config.fixture()
    def connection_no_trans(self):
        eng = getattr(self, "bind", None) or config.db

        with eng.connect() as conn:
            yield conn

    @config.fixture()
    def connection(self):
        global _connection_fixture_connection

        eng = getattr(self, "bind", None) or config.db

        conn = eng.connect()
        trans = conn.begin()

        _connection_fixture_connection = conn
        yield conn

        _connection_fixture_connection = None

        if trans.is_active:
            trans.rollback()
        # trans would not be active here if the test is using
        # the legacy @provide_metadata decorator still, as it will
        # run a close all connections.
        conn.close()

    @config.fixture()
    def close_result_when_finished(self):
        to_close = []
        to_consume = []

        def go(result, consume=False):
            to_close.append(result)
            if consume:
                to_consume.append(result)

        yield go
        for r in to_consume:
            try:
                r.all()
            except:
                pass
        for r in to_close:
            try:
                r.close()
            except:
                pass

    @config.fixture()
    def registry(self, metadata):
        reg = registry(metadata=metadata)
        yield reg
        reg.dispose()

    @config.fixture
    def decl_base(self, registry):
        return registry.generate_base()

    @config.fixture()
    def future_connection(self, future_engine, connection):
        # integrate the future_engine and connection fixtures so
        # that users of the "connection" fixture will get at the
        # "future" connection
        yield connection

    @config.fixture()
    def future_engine(self):
        eng = getattr(self, "bind", None) or config.db
        with _push_future_engine(eng):
            yield

    @config.fixture()
    def testing_engine(self):
        from . import engines

        def gen_testing_engine(
            url=None,
            options=None,
            future=None,
            asyncio=False,
            transfer_staticpool=False,
        ):
            if options is None:
                options = {}
            options["scope"] = "fixture"
            return engines.testing_engine(
                url=url,
                options=options,
                future=future,
                asyncio=asyncio,
                transfer_staticpool=transfer_staticpool,
            )

        yield gen_testing_engine

        engines.testing_reaper._drop_testing_engines("fixture")

    @config.fixture()
    def async_testing_engine(self, testing_engine):
        def go(**kw):
            kw["asyncio"] = True
            return testing_engine(**kw)

        return go

    @config.fixture
    def fixture_session(self):
        return fixture_session()

    @config.fixture()
    def metadata(self, request):
        """Provide bound MetaData for a single test, dropping afterwards."""

        from ..sql import schema

        metadata = schema.MetaData()
        request.instance.metadata = metadata
        yield metadata
        del request.instance.metadata

        if (
            _connection_fixture_connection
            and _connection_fixture_connection.in_transaction()
        ):
            trans = _connection_fixture_connection.get_transaction()
            trans.rollback()
            with _connection_fixture_connection.begin():
                drop_all_tables_from_metadata(
                    metadata, _connection_fixture_connection
                )
        else:
            drop_all_tables_from_metadata(metadata, config.db)

    @config.fixture(
        params=[
            (rollback, second_operation, begin_nested)
            for rollback in (True, False)
            for second_operation in ("none", "execute", "begin")
            for begin_nested in (
                True,
                False,
            )
        ]
    )
    def trans_ctx_manager_fixture(self, request, metadata):
        rollback, second_operation, begin_nested = request.param

        from sqlalchemy import Table, Column, Integer, func, select
        from . import eq_

        t = Table("test", metadata, Column("data", Integer))
        eng = getattr(self, "bind", None) or config.db

        t.create(eng)

        def run_test(subject, trans_on_subject, execute_on_subject):
            with subject.begin() as trans:

                if begin_nested:
                    if not config.requirements.savepoints.enabled:
                        config.skip_test("savepoints not enabled")
                    if execute_on_subject:
                        nested_trans = subject.begin_nested()
                    else:
                        nested_trans = trans.begin_nested()

                    with nested_trans:
                        if execute_on_subject:
                            subject.execute(t.insert(), {"data": 10})
                        else:
                            trans.execute(t.insert(), {"data": 10})

                        # for nested trans, we always commit/rollback on the
                        # "nested trans" object itself.
                        # only Session(future=False) will affect savepoint
                        # transaction for session.commit/rollback

                        if rollback:
                            nested_trans.rollback()
                        else:
                            nested_trans.commit()

                        if second_operation != "none":
                            with assertions.expect_raises_message(
                                sa.exc.InvalidRequestError,
                                "Can't operate on closed transaction "
                                "inside context "
                                "manager.  Please complete the context "
                                "manager "
                                "before emitting further commands.",
                            ):
                                if second_operation == "execute":
                                    if execute_on_subject:
                                        subject.execute(
                                            t.insert(), {"data": 12}
                                        )
                                    else:
                                        trans.execute(t.insert(), {"data": 12})
                                elif second_operation == "begin":
                                    if execute_on_subject:
                                        subject.begin_nested()
                                    else:
                                        trans.begin_nested()

                    # outside the nested trans block, but still inside the
                    # transaction block, we can run SQL, and it will be
                    # committed
                    if execute_on_subject:
                        subject.execute(t.insert(), {"data": 14})
                    else:
                        trans.execute(t.insert(), {"data": 14})

                else:
                    if execute_on_subject:
                        subject.execute(t.insert(), {"data": 10})
                    else:
                        trans.execute(t.insert(), {"data": 10})

                    if trans_on_subject:
                        if rollback:
                            subject.rollback()
                        else:
                            subject.commit()
                    else:
                        if rollback:
                            trans.rollback()
                        else:
                            trans.commit()

                    if second_operation != "none":
                        with assertions.expect_raises_message(
                            sa.exc.InvalidRequestError,
                            "Can't operate on closed transaction inside "
                            "context "
                            "manager.  Please complete the context manager "
                            "before emitting further commands.",
                        ):
                            if second_operation == "execute":
                                if execute_on_subject:
                                    subject.execute(t.insert(), {"data": 12})
                                else:
                                    trans.execute(t.insert(), {"data": 12})
                            elif second_operation == "begin":
                                if hasattr(trans, "begin"):
                                    trans.begin()
                                else:
                                    subject.begin()
                            elif second_operation == "begin_nested":
                                if execute_on_subject:
                                    subject.begin_nested()
                                else:
                                    trans.begin_nested()

            expected_committed = 0
            if begin_nested:
                # begin_nested variant, we inserted a row after the nested
                # block
                expected_committed += 1
            if not rollback:
                # not rollback variant, our row inserted in the target
                # block itself would be committed
                expected_committed += 1

            if execute_on_subject:
                eq_(
                    subject.scalar(select(func.count()).select_from(t)),
                    expected_committed,
                )
            else:
                with subject.connect() as conn:
                    eq_(
                        conn.scalar(select(func.count()).select_from(t)),
                        expected_committed,
                    )

        return run_test


_connection_fixture_connection = None


@contextlib.contextmanager
def _push_future_engine(engine):

    from ..future.engine import Engine
    from sqlalchemy import testing

    facade = Engine._future_facade(engine)
    config._current.push_engine(facade, testing)

    yield facade

    config._current.pop(testing)


class FutureEngineMixin(object):
    @config.fixture(autouse=True, scope="class")
    def _push_future_engine(self):
        eng = getattr(self, "bind", None) or config.db
        with _push_future_engine(eng):
            yield


class TablesTest(TestBase):

    # 'once', None
    run_setup_bind = "once"

    # 'once', 'each', None
    run_define_tables = "once"

    # 'once', 'each', None
    run_create_tables = "once"

    # 'once', 'each', None
    run_inserts = "each"

    # 'each', None
    run_deletes = "each"

    # 'once', None
    run_dispose_bind = None

    bind = None
    _tables_metadata = None
    tables = None
    other = None
    sequences = None

    @config.fixture(autouse=True, scope="class")
    def _setup_tables_test_class(self):
        cls = self.__class__
        cls._init_class()

        cls._setup_once_tables()

        cls._setup_once_inserts()

        yield

        cls._teardown_once_metadata_bind()

    @config.fixture(autouse=True, scope="function")
    def _setup_tables_test_instance(self):
        self._setup_each_tables()
        self._setup_each_inserts()

        yield

        self._teardown_each_tables()

    @property
    def tables_test_metadata(self):
        return self._tables_metadata

    @classmethod
    def _init_class(cls):
        if cls.run_define_tables == "each":
            if cls.run_create_tables == "once":
                cls.run_create_tables = "each"
            assert cls.run_inserts in ("each", None)

        cls.other = adict()
        cls.tables = adict()
        cls.sequences = adict()

        cls.bind = cls.setup_bind()
        cls._tables_metadata = sa.MetaData()

    @classmethod
    def _setup_once_inserts(cls):
        if cls.run_inserts == "once":
            cls._load_fixtures()
            with cls.bind.begin() as conn:
                cls.insert_data(conn)

    @classmethod
    def _setup_once_tables(cls):
        if cls.run_define_tables == "once":
            cls.define_tables(cls._tables_metadata)
            if cls.run_create_tables == "once":
                cls._tables_metadata.create_all(cls.bind)
            cls.tables.update(cls._tables_metadata.tables)
            cls.sequences.update(cls._tables_metadata._sequences)

    def _setup_each_tables(self):
        if self.run_define_tables == "each":
            self.define_tables(self._tables_metadata)
            if self.run_create_tables == "each":
                self._tables_metadata.create_all(self.bind)
            self.tables.update(self._tables_metadata.tables)
            self.sequences.update(self._tables_metadata._sequences)
        elif self.run_create_tables == "each":
            self._tables_metadata.create_all(self.bind)

    def _setup_each_inserts(self):
        if self.run_inserts == "each":
            self._load_fixtures()
            with self.bind.begin() as conn:
                self.insert_data(conn)

    def _teardown_each_tables(self):
        if self.run_define_tables == "each":
            self.tables.clear()
            if self.run_create_tables == "each":
                drop_all_tables_from_metadata(self._tables_metadata, self.bind)
            self._tables_metadata.clear()
        elif self.run_create_tables == "each":
            drop_all_tables_from_metadata(self._tables_metadata, self.bind)

        savepoints = getattr(config.requirements, "savepoints", False)
        if savepoints:
            savepoints = savepoints.enabled

        # no need to run deletes if tables are recreated on setup
        if (
            self.run_define_tables != "each"
            and self.run_create_tables != "each"
            and self.run_deletes == "each"
        ):
            with self.bind.begin() as conn:
                for table in reversed(
                    [
                        t
                        for (t, fks) in sort_tables_and_constraints(
                            self._tables_metadata.tables.values()
                        )
                        if t is not None
                    ]
                ):
                    try:
                        if savepoints:
                            with conn.begin_nested():
                                conn.execute(table.delete())
                        else:
                            conn.execute(table.delete())
                    except sa.exc.DBAPIError as ex:
                        util.print_(
                            ("Error emptying table %s: %r" % (table, ex)),
                            file=sys.stderr,
                        )

    @classmethod
    def _teardown_once_metadata_bind(cls):
        if cls.run_create_tables:
            drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)

        if cls.run_dispose_bind == "once":
            cls.dispose_bind(cls.bind)

        cls._tables_metadata.bind = None

        if cls.run_setup_bind is not None:
            cls.bind = None

    @classmethod
    def setup_bind(cls):
        return config.db

    @classmethod
    def dispose_bind(cls, bind):
        if hasattr(bind, "dispose"):
            bind.dispose()
        elif hasattr(bind, "close"):
            bind.close()

    @classmethod
    def define_tables(cls, metadata):
        pass

    @classmethod
    def fixtures(cls):
        return {}

    @classmethod
    def insert_data(cls, connection):
        pass

    def sql_count_(self, count, fn):
        self.assert_sql_count(self.bind, fn, count)

    def sql_eq_(self, callable_, statements):
        self.assert_sql(self.bind, callable_, statements)

    @classmethod
    def _load_fixtures(cls):
        """Insert rows as represented by the fixtures() method."""
        headers, rows = {}, {}
        for table, data in cls.fixtures().items():
            if len(data) < 2:
                continue
            if isinstance(table, util.string_types):
                table = cls.tables[table]
            headers[table] = data[0]
            rows[table] = data[1:]
        for table, fks in sort_tables_and_constraints(
            cls._tables_metadata.tables.values()
        ):
            if table is None:
                continue
            if table not in headers:
                continue
            with cls.bind.begin() as conn:
                conn.execute(
                    table.insert(),
                    [
                        dict(zip(headers[table], column_values))
                        for column_values in rows[table]
                    ],
                )


class NoCache(object):
    @config.fixture(autouse=True, scope="function")
    def _disable_cache(self):
        _cache = config.db._compiled_cache
        config.db._compiled_cache = None
        yield
        config.db._compiled_cache = _cache


class RemovesEvents(object):
    @util.memoized_property
    def _event_fns(self):
        return set()

    def event_listen(self, target, name, fn, **kw):
        self._event_fns.add((target, name, fn))
        event.listen(target, name, fn, **kw)

    @config.fixture(autouse=True, scope="function")
    def _remove_events(self):
        yield
        for key in self._event_fns:
            event.remove(*key)


_fixture_sessions = set()


def fixture_session(**kw):
    kw.setdefault("autoflush", True)
    kw.setdefault("expire_on_commit", True)

    bind = kw.pop("bind", config.db)

    sess = sa.orm.Session(bind, **kw)
    _fixture_sessions.add(sess)
    return sess


def _close_all_sessions():
    # will close all still-referenced sessions
    sa.orm.session.close_all_sessions()
    _fixture_sessions.clear()


def stop_test_class_inside_fixtures(cls):
    _close_all_sessions()
    sa.orm.clear_mappers()


def after_test():
    if _fixture_sessions:
        _close_all_sessions()


class ORMTest(TestBase):
    pass


class MappedTest(TablesTest, assertions.AssertsExecutionResults):
    # 'once', 'each', None
    run_setup_classes = "once"

    # 'once', 'each', None
    run_setup_mappers = "each"

    classes = None

    @config.fixture(autouse=True, scope="class")
    def _setup_tables_test_class(self):
        cls = self.__class__
        cls._init_class()

        if cls.classes is None:
            cls.classes = adict()

        cls._setup_once_tables()
        cls._setup_once_classes()
        cls._setup_once_mappers()
        cls._setup_once_inserts()

        yield

        cls._teardown_once_class()
        cls._teardown_once_metadata_bind()

    @config.fixture(autouse=True, scope="function")
    def _setup_tables_test_instance(self):
        self._setup_each_tables()
        self._setup_each_classes()
        self._setup_each_mappers()
        self._setup_each_inserts()

        yield

        sa.orm.session.close_all_sessions()
        self._teardown_each_mappers()
        self._teardown_each_classes()
        self._teardown_each_tables()

    @classmethod
    def _teardown_once_class(cls):
        cls.classes.clear()

    @classmethod
    def _setup_once_classes(cls):
        if cls.run_setup_classes == "once":
            cls._with_register_classes(cls.setup_classes)

    @classmethod
    def _setup_once_mappers(cls):
        if cls.run_setup_mappers == "once":
            cls.mapper_registry, cls.mapper = cls._generate_registry()
            cls._with_register_classes(cls.setup_mappers)

    def _setup_each_mappers(self):
        if self.run_setup_mappers != "once":
            (
                self.__class__.mapper_registry,
                self.__class__.mapper,
            ) = self._generate_registry()

        if self.run_setup_mappers == "each":
            self._with_register_classes(self.setup_mappers)

    def _setup_each_classes(self):
        if self.run_setup_classes == "each":
            self._with_register_classes(self.setup_classes)

    @classmethod
    def _generate_registry(cls):
        decl = registry(metadata=cls._tables_metadata)
        return decl, decl.map_imperatively

    @classmethod
    def _with_register_classes(cls, fn):
        """Run a setup method, framing the operation with a Base class
        that will catch new subclasses to be established within
        the "classes" registry.

        """
        cls_registry = cls.classes

        assert cls_registry is not None

        class FindFixture(type):
            def __init__(cls, classname, bases, dict_):
                cls_registry[classname] = cls
                type.__init__(cls, classname, bases, dict_)

        class _Base(util.with_metaclass(FindFixture, object)):
            pass

        class Basic(BasicEntity, _Base):
            pass

        class Comparable(ComparableEntity, _Base):
            pass

        cls.Basic = Basic
        cls.Comparable = Comparable
        fn()

    def _teardown_each_mappers(self):
        # some tests create mappers in the test bodies
        # and will define setup_mappers as None -
        # clear mappers in any case
        if self.run_setup_mappers != "once":
            sa.orm.clear_mappers()

    def _teardown_each_classes(self):
        if self.run_setup_classes != "once":
            self.classes.clear()

    @classmethod
    def setup_classes(cls):
        pass

    @classmethod
    def setup_mappers(cls):
        pass


class DeclarativeMappedTest(MappedTest):
    run_setup_classes = "once"
    run_setup_mappers = "once"

    @classmethod
    def _setup_once_tables(cls):
        pass

    @classmethod
    def _with_register_classes(cls, fn):
        cls_registry = cls.classes

        class FindFixtureDeclarative(DeclarativeMeta):
            def __init__(cls, classname, bases, dict_):
                cls_registry[classname] = cls
                DeclarativeMeta.__init__(cls, classname, bases, dict_)

        class DeclarativeBasic(object):
            __table_cls__ = schema.Table

        _DeclBase = declarative_base(
            metadata=cls._tables_metadata,
            metaclass=FindFixtureDeclarative,
            cls=DeclarativeBasic,
        )

        cls.DeclarativeBasic = _DeclBase

        # sets up cls.Basic which is helpful for things like composite
        # classes
        super(DeclarativeMappedTest, cls)._with_register_classes(fn)

        if cls._tables_metadata.tables and cls.run_create_tables:
            cls._tables_metadata.create_all(config.db)


class ComputedReflectionFixtureTest(TablesTest):
    run_inserts = run_deletes = None

    __backend__ = True
    __requires__ = ("computed_columns", "table_reflection")

    regexp = re.compile(r"[\[\]\(\)\s`'\"]*")

    def normalize(self, text):
        return self.regexp.sub("", text).lower()

    @classmethod
    def define_tables(cls, metadata):
        from .. import Integer
        from .. import testing
        from ..schema import Column
        from ..schema import Computed
        from ..schema import Table

        Table(
            "computed_default_table",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("normal", Integer),
            Column("computed_col", Integer, Computed("normal + 42")),
            Column("with_default", Integer, server_default="42"),
        )

        t = Table(
            "computed_column_table",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("normal", Integer),
            Column("computed_no_flag", Integer, Computed("normal + 42")),
        )

        if testing.requires.schemas.enabled:
            t2 = Table(
                "computed_column_table",
                metadata,
                Column("id", Integer, primary_key=True),
                Column("normal", Integer),
                Column("computed_no_flag", Integer, Computed("normal / 42")),
                schema=config.test_schema,
            )

        if testing.requires.computed_columns_virtual.enabled:
            t.append_column(
                Column(
                    "computed_virtual",
                    Integer,
                    Computed("normal + 2", persisted=False),
                )
            )
            if testing.requires.schemas.enabled:
                t2.append_column(
                    Column(
                        "computed_virtual",
                        Integer,
                        Computed("normal / 2", persisted=False),
                    )
                )
        if testing.requires.computed_columns_stored.enabled:
            t.append_column(
                Column(
                    "computed_stored",
                    Integer,
                    Computed("normal - 42", persisted=True),
                )
            )
            if testing.requires.schemas.enabled:
                t2.append_column(
                    Column(
                        "computed_stored",
                        Integer,
                        Computed("normal * 42", persisted=True),
                    )
                )


class CacheKeyFixture(object):
    def _compare_equal(self, a, b, compare_values):
        a_key = a._generate_cache_key()
        b_key = b._generate_cache_key()

        if a_key is None:
            assert a._annotations.get("nocache")

            assert b_key is None
        else:

            eq_(a_key.key, b_key.key)
            eq_(hash(a_key.key), hash(b_key.key))

            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
                assert a_param.compare(b_param, compare_values=compare_values)
        return a_key, b_key

    def _run_cache_key_fixture(self, fixture, compare_values):
        case_a = fixture()
        case_b = fixture()

        for a, b in itertools.combinations_with_replacement(
            range(len(case_a)), 2
        ):
            if a == b:
                a_key, b_key = self._compare_equal(
                    case_a[a], case_b[b], compare_values
                )
                if a_key is None:
                    continue
            else:
                a_key = case_a[a]._generate_cache_key()
                b_key = case_b[b]._generate_cache_key()

                if a_key is None or b_key is None:
                    if a_key is None:
                        assert case_a[a]._annotations.get("nocache")
                    if b_key is None:
                        assert case_b[b]._annotations.get("nocache")
                    continue

                if a_key.key == b_key.key:
                    for a_param, b_param in zip(
                        a_key.bindparams, b_key.bindparams
                    ):
                        if not a_param.compare(
                            b_param, compare_values=compare_values
                        ):
                            break
                    else:
                        # this fails unconditionally since we could not
                        # find bound parameter values that differed.
                        # Usually we intended to get two distinct keys here
                        # so the failure will be more descriptive using the
                        # ne_() assertion.
                        ne_(a_key.key, b_key.key)
                else:
                    ne_(a_key.key, b_key.key)

            # ClauseElement-specific test to ensure the cache key
            # collected all the bound parameters that aren't marked
            # as "literal execute"
            if isinstance(case_a[a], ClauseElement) and isinstance(
                case_b[b], ClauseElement
            ):
                assert_a_params = []
                assert_b_params = []

                for elem in visitors.iterate(case_a[a]):
                    if elem.__visit_name__ == "bindparam":
                        assert_a_params.append(elem)

                for elem in visitors.iterate(case_b[b]):
                    if elem.__visit_name__ == "bindparam":
                        assert_b_params.append(elem)

                # note we're asserting the order of the params as well as
                # if there are dupes or not.  ordering has to be
                # deterministic and matches what a traversal would provide.
                eq_(
                    sorted(a_key.bindparams, key=lambda b: b.key),
                    sorted(
                        util.unique_list(assert_a_params), key=lambda b: b.key
                    ),
                )
                eq_(
                    sorted(b_key.bindparams, key=lambda b: b.key),
                    sorted(
                        util.unique_list(assert_b_params), key=lambda b: b.key
                    ),
                )

    def _run_cache_key_equal_fixture(self, fixture, compare_values):
        case_a = fixture()
        case_b = fixture()

        for a, b in itertools.combinations_with_replacement(
            range(len(case_a)), 2
        ):
            self._compare_equal(case_a[a], case_b[b], compare_values)
