package org.unitils.jbehave.modules;

import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.openqa.selenium.WebDriver;
import org.openqa.selenium.WebDriverException;
import org.unitils.selenium.BrowserChoice;
import org.unitils.selenium.WebDriverFactory;
import org.unitils.selenium.WebDriverModule;
import org.unitils.selenium.annotation.TestWebDriver;
import org.unitils.selenium.annotation.WebPage;
import org.unitils.util.AnnotationUtils;
import org.unitils.util.ReflectionUtils;


/**
 * After each scenario this webdriver is closed, but the different steps uses the same {@link WebDriver}.
 *
 * @author Willemijn Wouters
 *
 * @since 1.0.0
 *
 */
public class JBehaveWebdriverModule extends WebDriverModule {

    private static final Log LOGGER = LogFactory.getLog(JBehaveWebdriverModule.class);

    private WebDriver webdriver;

    private Map<Class<? extends Object>, Object> map;

    /**
     * Initialise the module.
     *
     * @param configuration : The {@link org.unitils.core.Unitils} configuration.
     * @see org.unitils.selenium.WebDriverModule#init(java.util.Properties)
     */
    @Override
    public void init(Properties configuration) {
        super.init(configuration);
        if (map == null) {
            map = new HashMap<Class<? extends Object>, Object>();
        }

    }

    /**
     * This method injects the webdriver into the step.
     *
     * @param testObject : the step.
     * @see org.unitils.selenium.WebDriverModule#initWebDriver(java.lang.Object)
     */
    @Override
    public void initWebDriver(Object testObject) {
        Set<Field> fields = AnnotationUtils.getFieldsAnnotatedWith(testObject.getClass(), TestWebDriver.class);
        if (!fields.isEmpty()) {
            if (webdriver == null) {
                webdriver = createWebdriver();
            }

            for (Field field : fields) {
                ReflectionUtils.setFieldValue(testObject, field, webdriver);
            }
        }
    }

    /**
     * Initialises all the {@link WebPage} elements in the step.
     *
     * @param testObject : the step.
     * @see org.unitils.selenium.WebDriverModule#initElements(java.lang.Object)
     */
    @Override
    public void initElements(Object testObject) {
        // find fields that has the @WebPage annotation
        Set<Field> fields = AnnotationUtils.getFieldsAnnotatedWith(testObject.getClass(), WebPage.class);
        // find the webdriver
        Set<Field> webdrivers = AnnotationUtils.getFieldsAnnotatedWith(testObject.getClass(), TestWebDriver.class);
        if (webdrivers.size() > 0) {
            // initialise the page and set the object in the correct field.
            WebDriver webdriver = ReflectionUtils.getFieldValue(testObject, webdrivers.iterator().next());
            for (Field field : fields) {
                Object page = null;
                if (!map.containsKey(field.getType())) {
                    page = getElement(webdriver, field.getType());
                    map.put(field.getType(), page);
                } else {
                    page = map.get(field.getType());
                }
                if (webdriver != null) {
                    ReflectionUtils.setFieldValue(testObject, field, page);
                }
            }

        }


    }

    /**
     * Kills the webdriver.
     * 
     * @param testObject : the step
     * @see org.unitils.selenium.WebDriverModule#killWebDriver(java.lang.Object)
     */
    @Override
    protected void killWebDriver(Object testObject) {
        map = new HashMap<Class<?>, Object>();
        if (webdriver != null) {
            LOGGER.debug("closing a driver that is on page : " + webdriver.getCurrentUrl());
            webdriver.close();
            webdriver.quit();
            nastyDoubleCheck(webdriver);
            nastyDoubleCheck(webdriver);
            webdriver = null;
        }
    }

    /**
     * Checks if the driver is killed.
     *
     * @param driver : the {@link WebDriver} that should be killed.
     */
    protected void nastyDoubleCheck(WebDriver driver) {
        try {

            Thread.sleep(500);
            driver.getTitle();
            driver.close();
            driver.quit();
        } catch (WebDriverException e) {
            // continue
        } catch (InterruptedException e) {
            // continue
        }


    }

    /**
     * Creates the correct {@link WebDriver}.
     *
     * @return {@link WebDriver}
     */
    protected WebDriver createWebdriver() {
        WebDriver driver;
        BrowserChoice browserChoice = getBrowserChoice();
        String downloadPath2 = getDownloadPath();
        String fileType2 = getFileType();
        if (StringUtils.isEmpty(getProxyUrl())) {
            driver = WebDriverFactory.createDriver(browserChoice, getAbsoluteDownloadPath(downloadPath2), fileType2);
        } else {
            driver = WebDriverFactory.createDriver(browserChoice, getProxyUrl(), getAbsoluteDownloadPath(downloadPath2), fileType2);
        }
        driver.manage().deleteAllCookies();

        return driver;

    }

    /**
     * getter for the proxyUrl.
     *
     * @return {@link String}
     */
    protected String getProxyUrl() {
        return (String) getFieldValue("proxyUrl");
    }

    /**
     * gettter for downloadPath.
     *
     * @return {@link String}
     */
    protected String getDownloadPath() {
        return (String) getFieldValue("downloadPath");
    }

    /**
     * getter for the browserChoice.
     *
     * @return {@link BrowserChoice}
     */
    protected BrowserChoice getBrowserChoice() {
        return (BrowserChoice) getFieldValue("browserChoice");
    }

    /**
     * Get a specific field from the {@link WebDriverModule}.
     *
     * @param nameField : the name of the field of the {@link WebDriverModule}.
     * @return {@link Object}
     */
    public Object getFieldValue(String nameField) {
        Field field = ReflectionUtils.getFieldWithName(WebDriverModule.class, nameField, false);
        return ReflectionUtils.getFieldValue(this, field);
    }

    /**
     * getter for the fileType.
     *
     * @return {@link String}
     */
    public String getFileType() {
        return (String) getFieldValue("fileType");
    }


    /**
     * getter for the webdriver.
     *
     * @return {@link WebDriver}
     */
    protected WebDriver getWebdriver() {
        return webdriver;
    }
}
