
/**
 * Copyright 2011-2012 Clint Combs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.regrest.http

import org.regrest.{RestRequest,RestResponse}
import org.apache.http.{HttpResponse,Header}
import org.apache.http.client.ResponseHandler
import org.apache.http.client.methods.{HttpGet,HttpPost,HttpPut,HttpDelete}
import org.apache.http.entity.StringEntity
import org.apache.http.impl.client.DefaultHttpClient
import com.weiglewilczek.slf4s.Logging
import scala.collection.JavaConversions._
import java.net.URI
import java.io.{BufferedReader,InputStreamReader}

class RequestProcessor(val overrideHost:Option[String] = None) extends Logging {

	def processRequest(req:RestRequest):(RestRequest,RestResponse) = {
		val resolvedReq = overrideHost match {
			case Some(host) =>
				RestRequest(req.httpVersion, calculateURI(req.url), req.method, req.headers, req.body)
			case None => req
		}

		req.method match {
			case "GET" => (resolvedReq, processGet(resolvedReq))
			case "POST" => (resolvedReq, processPost(resolvedReq))
			case "PUT" => (resolvedReq, processPut(resolvedReq))
			case "DELETE" => (resolvedReq, processDelete(resolvedReq))
			case "OPTIONS" => (resolvedReq, processOptions(resolvedReq))
			case "HEAD" => (resolvedReq, processHead(resolvedReq))
			case m => logger.error("unknown HTTP method: " + m); (RestRequest(), RestResponse())
		}
	}

	/**
	 * Override the host, if necessary.
	 */
	def calculateURI(url:String):String = {
		val uri = new URI(url)
		overrideHost match {
			case Some(host) =>
				new URI(uri.getScheme, uri.getUserInfo, host, uri.getPort,
						uri.getPath, uri.getQuery, uri.getFragment).toString
			case None => url
		}
	}

	def processGet(r:RestRequest):RestResponse = {
		require(r.method == "GET")
		logger.info(r.method + " " + r.url)

		val get = new HttpGet(r.url)
		r.headers.foreach { h => get.setHeader(h._1, h._2) }
		val rp = new SendRequestResponseParser
		val str = new DefaultHttpClient().execute(get, rp)

		RestResponse(
			protocolVersion = rp.protocolVersion,
			statusCode = rp.statusCode,
			reasonPhrase = rp.reasonPhrase,
			headerMap = rp.headerMap,
			body = str)
	}

	def processPost(r:RestRequest) = {
		require(r.method == "POST")
		logger.info(r.method + " " + r.url)

		val post = new HttpPost(r.url)
		r.headers.foreach { h => post.setHeader(h._1, h._2) }
		r.body match {
			case Some(b) =>
				if (b.contentType.length > 0) {
					post.setHeader("Content-Type", b.contentType + (if (b.charset.length > 0) "; " + b.charset))
				}
				post.setEntity(new StringEntity(r.body.get.text))
			case None =>
		}
		val rp = new SendRequestResponseParser
		val str = new DefaultHttpClient().execute(post, rp)

		RestResponse(
			protocolVersion = rp.protocolVersion,
			statusCode = rp.statusCode,
			reasonPhrase = rp.reasonPhrase,
			headerMap = rp.headerMap,
			body = str)
	}

	def processPut(r:RestRequest) = {
		require(r.method == "PUT")
		logger.info(r.method + " " + r.url)

		val put = new HttpPut(r.url)
		r.headers.foreach { h => put.setHeader(h._1, h._2) }
		r.body match {
			case Some(b) =>
				if (b.contentType.length > 0) {
					put.setHeader("Content-Type", b.contentType + (if (b.charset.length > 0) "; " + b.charset))
				}
				put.setEntity(new StringEntity(r.body.get.text))
			case None =>
		}
		val rp = new SendRequestResponseParser
		val str = new DefaultHttpClient().execute(put, rp)

		RestResponse(
			protocolVersion = rp.protocolVersion,
			statusCode = rp.statusCode,
			reasonPhrase = rp.reasonPhrase,
			headerMap = rp.headerMap,
			body = str)
	}

	def processDelete(r:RestRequest) = {
		require(r.method == "DELETE")
		logger.info(r.method + " " + r.url)

		val delete = new HttpDelete(r.url)
		r.headers.foreach { h => delete.setHeader(h._1, h._2) }
		val rp = new SendRequestResponseParser
		val str = new DefaultHttpClient().execute(delete, rp)

		RestResponse(
			protocolVersion = rp.protocolVersion,
			statusCode = rp.statusCode,
			reasonPhrase = rp.reasonPhrase,
			headerMap = rp.headerMap,
			body = str)
	}

	def processOptions(r:RestRequest) = {
		require(r.method == "OPTIONS")
		logger.info(r.method + " " + r.url)
		RestResponse()
	}

	def processHead(r:RestRequest) = {
		require(r.method == "HEAD")
		logger.info(r.method + " " + r.url)
		RestResponse()
	}
}

/**
 * Handle Apache HTTP client responses by printing the status line.
 */
class SendRequestResponseParser extends ResponseHandler[String] with Logging {

	var protocolVersion = ""
	var statusCode = 200
	var reasonPhrase = ""
	var headerMap = Map[String, String]()

	def handleResponse(res:HttpResponse):String = {
		logger.info(res.getStatusLine.toString)

		reasonPhrase = res.getStatusLine.getReasonPhrase
		statusCode = res.getStatusLine.getStatusCode
		protocolVersion = res.getStatusLine.getProtocolVersion.toString

		val headers = res.getAllHeaders
		for(h <- headers) headerMap += (h.getName -> h.getValue)

		val reader = new BufferedReader(new InputStreamReader(res.getEntity.getContent))
		val buf = new StringBuilder
		var s = reader.readLine
		while(s != null) { buf.append(s); s = reader.readLine }
		reader.close
		buf.toString
	}
}
