package com.github.dbunit.junit5;

import com.github.dbunit.rules.api.connection.ConnectionHolder;
import com.github.dbunit.rules.api.dataset.DataSet;
import com.github.dbunit.rules.api.dataset.DataSetExecutor;
import com.github.dbunit.rules.api.dataset.ExpectedDataSet;
import com.github.dbunit.rules.api.leak.LeakHunter;
import com.github.dbunit.rules.configuration.DBUnitConfig;
import com.github.dbunit.rules.configuration.DataSetConfig;
import com.github.dbunit.rules.dataset.DataSetExecutorImpl;
import com.github.dbunit.rules.leak.LeakHunterException;
import com.github.dbunit.rules.leak.LeakHunterFactory;
import com.github.dbunit.rules.util.EntityManagerProvider;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Optional;
import org.dbunit.DatabaseUnitException;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.TestExtensionContext;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/dbunit/junit5/DBUnitExtension.class */
public class DBUnitExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback {
    private static final String EXECUTOR_STORE = "executor";
    private static final String DATASET_CONFIG_STORE = "datasetConfig";
    private static final String LEAK_STORE = "leakHunter";
    private static final String CONNECTION_BEFORE_STORE = "openConnectionsBefore";

    public void beforeTestExecution(TestExtensionContext testExtensionContext) throws Exception {
        if (shouldCreateDataSet(testExtensionContext)) {
            ConnectionHolder findTestConnection = findTestConnection(testExtensionContext);
            if (EntityManagerProvider.isEntityManagerActive()) {
                EntityManagerProvider.em().clear();
            }
            DataSet annotation = ((Method) testExtensionContext.getTestMethod().get()).getAnnotation(DataSet.class);
            if (annotation == null) {
                annotation = (DataSet) ((Class) testExtensionContext.getTestClass().get()).getAnnotation(DataSet.class);
            }
            if (annotation == null) {
                throw new RuntimeException("Could not find DataSet annotation for test " + ((Method) testExtensionContext.getTestMethod().get()).getName());
            }
            DBUnitConfig from = DBUnitConfig.from((Method) testExtensionContext.getTestMethod().get());
            DataSetConfig from2 = new DataSetConfig().from(annotation);
            DataSetExecutorImpl instance = DataSetExecutorImpl.instance(from2.getExecutorId(), findTestConnection);
            instance.setDBUnitConfig(from);
            ExtensionContext.Namespace executorNamespace = getExecutorNamespace(testExtensionContext);
            testExtensionContext.getStore(executorNamespace).put(EXECUTOR_STORE, instance);
            testExtensionContext.getStore(executorNamespace).put(DATASET_CONFIG_STORE, from2);
            if (from.isLeakHunter()) {
                LeakHunter from3 = LeakHunterFactory.from(findTestConnection.getConnection());
                testExtensionContext.getStore(executorNamespace).put(LEAK_STORE, from3);
                testExtensionContext.getStore(executorNamespace).put(CONNECTION_BEFORE_STORE, Integer.valueOf(from3.openConnections()));
            }
            try {
                instance.createDataSet(from2);
                if (from2.isTransactional() && EntityManagerProvider.isEntityManagerActive()) {
                    EntityManagerProvider.em().getTransaction().begin();
                }
            } catch (Exception e) {
                throw new RuntimeException(String.format("Could not create dataset for test method %s due to following error " + e.getMessage(), ((Method) testExtensionContext.getTestMethod().get()).getName()), e);
            }
        }
    }

    private boolean shouldCreateDataSet(TestExtensionContext testExtensionContext) {
        return ((Method) testExtensionContext.getTestMethod().get()).isAnnotationPresent(DataSet.class) || ((Class) testExtensionContext.getTestClass().get()).isAnnotationPresent(DataSet.class);
    }

    private boolean shouldCompareDataSet(TestExtensionContext testExtensionContext) {
        return ((Method) testExtensionContext.getTestMethod().get()).isAnnotationPresent(ExpectedDataSet.class) || ((Class) testExtensionContext.getTestClass().get()).isAnnotationPresent(ExpectedDataSet.class);
    }

    public void afterTestExecution(TestExtensionContext testExtensionContext) throws Exception {
        DBUnitConfig from = DBUnitConfig.from((Method) testExtensionContext.getTestMethod().get());
        ExtensionContext.Namespace executorNamespace = getExecutorNamespace(testExtensionContext);
        if (from != null) {
            try {
                if (from.isLeakHunter()) {
                    LeakHunter leakHunter = (LeakHunter) testExtensionContext.getStore(executorNamespace).get(LEAK_STORE, LeakHunter.class);
                    int intValue = ((Integer) testExtensionContext.getStore(executorNamespace).get(CONNECTION_BEFORE_STORE, Integer.class)).intValue();
                    int openConnections = leakHunter.openConnections();
                    if (openConnections > intValue) {
                        throw new LeakHunterException(((Method) testExtensionContext.getTestMethod().get()).getName(), openConnections - intValue);
                    }
                }
            } catch (Throwable th) {
                DataSetConfig dataSetConfig = (DataSetConfig) testExtensionContext.getStore(executorNamespace).get(DATASET_CONFIG_STORE, DataSetConfig.class);
                if (dataSetConfig == null) {
                    return;
                }
                DataSetExecutor dataSetExecutor = (DataSetExecutor) testExtensionContext.getStore(executorNamespace).get(EXECUTOR_STORE, DataSetExecutor.class);
                if (dataSetConfig.getExecuteStatementsAfter() != null && dataSetConfig.getExecuteStatementsAfter().length > 0) {
                    for (int i = 0; i < dataSetConfig.getExecuteScriptsAfter().length; i++) {
                        try {
                            dataSetExecutor.executeScript(dataSetConfig.getExecuteScriptsAfter()[i]);
                        } catch (Exception e) {
                            if (e instanceof DatabaseUnitException) {
                                throw e;
                            }
                            LoggerFactory.getLogger(getClass().getName()).error(((Method) testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsAfter:" + e.getMessage(), e);
                        }
                    }
                }
                if (dataSetConfig.getExecuteScriptsAfter() != null && dataSetConfig.getExecuteScriptsAfter().length > 0) {
                    for (int i2 = 0; i2 < dataSetConfig.getExecuteScriptsAfter().length; i2++) {
                        try {
                            dataSetExecutor.executeScript(dataSetConfig.getExecuteScriptsAfter()[i2]);
                        } catch (Exception e2) {
                            if (e2 instanceof DatabaseUnitException) {
                                throw e2;
                            }
                            LoggerFactory.getLogger(getClass().getName()).error(((Method) testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsAfter:" + e2.getMessage(), e2);
                        }
                    }
                }
                if (dataSetConfig.isCleanAfter()) {
                    dataSetExecutor.clearDatabase(dataSetConfig);
                }
                throw th;
            }
        }
        if (shouldCompareDataSet(testExtensionContext)) {
            ExpectedDataSet annotation = ((Method) testExtensionContext.getTestMethod().get()).getAnnotation(ExpectedDataSet.class);
            if (annotation == null) {
                annotation = (ExpectedDataSet) ((Class) testExtensionContext.getTestClass().get()).getAnnotation(ExpectedDataSet.class);
            }
            if (annotation != null) {
                ExtensionContext.Namespace executorNamespace2 = getExecutorNamespace(testExtensionContext);
                DataSetExecutor dataSetExecutor2 = (DataSetExecutor) testExtensionContext.getStore(executorNamespace2).get(EXECUTOR_STORE, DataSetExecutor.class);
                if (((DataSetConfig) testExtensionContext.getStore(executorNamespace2).get(DATASET_CONFIG_STORE, DataSetConfig.class)).isTransactional() && EntityManagerProvider.isEntityManagerActive()) {
                    EntityManagerProvider.em().getTransaction().commit();
                }
                dataSetExecutor2.compareCurrentDataSetWith(new DataSetConfig(annotation.value()).disableConstraints(true), annotation.ignoreCols());
            }
        }
        DataSetConfig dataSetConfig2 = (DataSetConfig) testExtensionContext.getStore(executorNamespace).get(DATASET_CONFIG_STORE, DataSetConfig.class);
        if (dataSetConfig2 == null) {
            return;
        }
        DataSetExecutor dataSetExecutor3 = (DataSetExecutor) testExtensionContext.getStore(executorNamespace).get(EXECUTOR_STORE, DataSetExecutor.class);
        if (dataSetConfig2.getExecuteStatementsAfter() != null && dataSetConfig2.getExecuteStatementsAfter().length > 0) {
            for (int i3 = 0; i3 < dataSetConfig2.getExecuteScriptsAfter().length; i3++) {
                try {
                    dataSetExecutor3.executeScript(dataSetConfig2.getExecuteScriptsAfter()[i3]);
                } catch (Exception e3) {
                    if (e3 instanceof DatabaseUnitException) {
                        throw e3;
                    }
                    LoggerFactory.getLogger(getClass().getName()).error(((Method) testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsAfter:" + e3.getMessage(), e3);
                }
            }
        }
        if (dataSetConfig2.getExecuteScriptsAfter() != null && dataSetConfig2.getExecuteScriptsAfter().length > 0) {
            for (int i4 = 0; i4 < dataSetConfig2.getExecuteScriptsAfter().length; i4++) {
                try {
                    dataSetExecutor3.executeScript(dataSetConfig2.getExecuteScriptsAfter()[i4]);
                } catch (Exception e4) {
                    if (e4 instanceof DatabaseUnitException) {
                        throw e4;
                    }
                    LoggerFactory.getLogger(getClass().getName()).error(((Method) testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsAfter:" + e4.getMessage(), e4);
                }
            }
        }
        if (dataSetConfig2.isCleanAfter()) {
            dataSetExecutor3.clearDatabase(dataSetConfig2);
        }
    }

    private ExtensionContext.Namespace getExecutorNamespace(TestExtensionContext testExtensionContext) {
        return ExtensionContext.Namespace.create(new Object[]{"DBUnitExtension-" + testExtensionContext.getTestClass().get()});
    }

    private ConnectionHolder findTestConnection(TestExtensionContext testExtensionContext) {
        Class cls = (Class) testExtensionContext.getTestClass().get();
        try {
            Optional findFirst = Arrays.stream(cls.getDeclaredFields()).filter(field -> {
                return field.getType() == ConnectionHolder.class;
            }).findFirst();
            if (findFirst.isPresent()) {
                Field field2 = (Field) findFirst.get();
                if (!field2.isAccessible()) {
                    field2.setAccessible(true);
                }
                ConnectionHolder connectionHolder = (ConnectionHolder) ConnectionHolder.class.cast(field2.get(testExtensionContext.getTestInstance()));
                if (connectionHolder == null || connectionHolder.getConnection() == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
            Optional findFirst2 = Arrays.stream(cls.getDeclaredMethods()).filter(method -> {
                return method.getReturnType() == ConnectionHolder.class;
            }).findFirst();
            if (!findFirst2.isPresent()) {
                return null;
            }
            Method method2 = (Method) findFirst2.get();
            if (!method2.isAccessible()) {
                method2.setAccessible(true);
            }
            ConnectionHolder connectionHolder2 = (ConnectionHolder) ConnectionHolder.class.cast(method2.invoke(testExtensionContext.getTestInstance(), new Object[0]));
            if (connectionHolder2 == null || connectionHolder2.getConnection() == null) {
                throw new RuntimeException("ConnectionHolder not initialized correctly");
            }
            return connectionHolder2;
        } catch (Exception e) {
            throw new RuntimeException("Could not get database connection for test " + cls, e);
        }
    }
}
