package org.unitils.jbehave.modules;

import java.lang.reflect.Field;
import java.util.Set;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.openqa.selenium.WebDriver;
import org.openqa.selenium.WebDriverException;
import org.unitils.selenium.BrowserChoice;
import org.unitils.selenium.WebDriverFactory;
import org.unitils.selenium.WebDriverModule;
import org.unitils.selenium.annotation.TestWebDriver;
import org.unitils.selenium.annotation.WebPage;
import org.unitils.util.AnnotationUtils;
import org.unitils.util.ReflectionUtils;


/**
 * After each scenario this webdriver is closed, but the different steps uses the same {@link WebDriver}.
 * 
 * @author Willemijn Wouters
 * 
 * @since 1.0.0
 * 
 */
public class JBehaveWebdriverModule extends WebDriverModule {
    private static final Log LOGGER = LogFactory.getLog(JBehaveWebdriverModule.class);
    private WebDriver webdriver;

    /**
     * @see org.unitils.selenium.WebDriverModule#initWebDriver(java.lang.Object)
     */
    @Override
    public void initWebDriver(Object testObject) {
        Set<Field> fields = AnnotationUtils.getFieldsAnnotatedWith(testObject.getClass(), TestWebDriver.class);
        if (!fields.isEmpty()) {
            if (webdriver == null) {
                webdriver = createWebdriver();
            }

            for (Field field : fields) {
                ReflectionUtils.setFieldValue(testObject, field, webdriver);
            }
        }
    }
    
    /**
     * @see org.unitils.selenium.WebDriverModule#initElements(java.lang.Object)
     */
    @Override
    public void initElements(Object testObject) {
      //find fields that has the @WebPage annotation
        Set<Field> fields = AnnotationUtils.getFieldsAnnotatedWith(testObject.getClass(), WebPage.class);
        //find the webdriver
        Set<Field> webdrivers = AnnotationUtils.getFieldsAnnotatedWith(testObject.getClass(), TestWebDriver.class);
        if (webdrivers.size() > 0) {
            //initialise the page and set the object in the correct field.
            WebDriver webdriver = ReflectionUtils.getFieldValue(testObject, webdrivers.iterator().next());
            for (Field field : fields) {
                if (webdriver != null) {
                    ReflectionUtils.setFieldValue(testObject, field, getElement(webdriver, field.getType()));
                }
            }

        } 
    }

    /**
     * @see org.unitils.selenium.WebDriverModule#killWebDriver(java.lang.Object)
     */
    @Override
    protected void killWebDriver(Object testObject) {
        if (webdriver != null ) {
            LOGGER.debug("closing a driver that is on page : " + webdriver.getCurrentUrl());
            webdriver.close();
            webdriver.quit();
            nastyDoubleCheck(webdriver);
            nastyDoubleCheck(webdriver);
            webdriver = null;
        }
    }

    protected void nastyDoubleCheck(WebDriver driver) {
        try {

            Thread.sleep(500);
            driver.getTitle();
            driver.close();
            driver.quit();
        } catch (WebDriverException e) {
            // continue
        } catch (InterruptedException e) {
            // continue
        }


    }

    protected WebDriver createWebdriver() {
        WebDriver driver;
        BrowserChoice browserChoice = getBrowserChoice();
        String downloadPath2 = getDownloadPath();
        String fileType2 = getFileType();
        if (StringUtils.isEmpty(getProxyUrl())) {
            driver = WebDriverFactory.createDriver(browserChoice, getAbsoluteDownloadPath(downloadPath2 ), fileType2 );
        } else {
            driver = WebDriverFactory.createDriver(browserChoice, getProxyUrl(), getAbsoluteDownloadPath(downloadPath2), fileType2);
        }
        driver.manage().deleteAllCookies();

        return driver;

    }

    protected String getProxyUrl() {
        return (String) getFieldValue("proxyUrl");
    }

    protected String getDownloadPath() {
        return (String) getFieldValue("downloadPath");
    }

    protected BrowserChoice getBrowserChoice() {
        return (BrowserChoice) getFieldValue("browserChoice");
    }

    public Object getFieldValue(String nameField) {
        Field field = ReflectionUtils.getFieldWithName(WebDriverModule.class, nameField, false);
        return ReflectionUtils.getFieldValue(this, field);
    }

    public String getFileType() {
        return (String) getFieldValue("fileType");
    }


    /**
     * @return the webdriver
     */
    protected WebDriver getWebdriver() {
        return webdriver;
    }
}
