package org.thryft.waf.server.controllers;

import java.io.IOException;
import java.net.InetAddress;
import java.util.HashSet;
import java.util.Set;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.thryft.waf.lib.logging.LoggingUtils;

import com.google.inject.Singleton;

@Singleton
public class LocalhostFilter implements Filter {
    @Override
    public void destroy() {
    }

    @Override
    public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
            throws IOException, ServletException {
        if (localAddresses.contains(request.getRemoteAddr())) {
            chain.doFilter(request, response);
        } else {
            logger.warn(logMarker, "denying request for {} from {}", ((HttpServletRequest) request).getRequestURI(),
                    request.getRemoteAddr());
            ((HttpServletResponse) response).sendError(403,
                    String.format("only localhost (%s) access allowed, not remote address %s",
                            StringUtils.join(localAddresses, ','), request.getRemoteAddr()));
        }
    }

    @Override
    public void init(final FilterConfig filterConfig) throws ServletException {
        try {
            localAddresses.add(InetAddress.getLocalHost().getHostAddress());
            for (final InetAddress inetAddress : InetAddress.getAllByName("localhost")) {
                localAddresses.add(inetAddress.getHostAddress());
            }
        } catch (final IOException e) {
            throw new ServletException("Unable to lookup local addresses");
        }
    }

    private final Set<String> localAddresses = new HashSet<String>();
    private final Logger logger = LoggerFactory.getLogger(LocalhostFilter.class);
    private final Marker logMarker = LoggingUtils.getMarker(LocalhostFilter.class);
}
