/*
 * Copyright 2015, The OpenNMS Group
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License. You may obtain
 * a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 *     
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.opennms.newts.persistence.cassandra;


import static com.codahale.metrics.MetricRegistry.name;
import static com.datastax.driver.core.querybuilder.QueryBuilder.bindMarker;
import static com.datastax.driver.core.querybuilder.QueryBuilder.eq;
import static com.datastax.driver.core.querybuilder.QueryBuilder.gte;
import static com.datastax.driver.core.querybuilder.QueryBuilder.insertInto;
import static com.datastax.driver.core.querybuilder.QueryBuilder.lte;
import static com.datastax.driver.core.querybuilder.QueryBuilder.ttl;
import static com.datastax.driver.core.querybuilder.QueryBuilder.unloggedBatch;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;

import javax.inject.Inject;
import javax.inject.Named;

import org.opennms.newts.aggregate.IntervalGenerator;
import org.opennms.newts.aggregate.ResultProcessor;
import org.opennms.newts.api.Duration;
import org.opennms.newts.api.Measurement;
import org.opennms.newts.api.Resource;
import org.opennms.newts.api.Results;
import org.opennms.newts.api.Results.Row;
import org.opennms.newts.api.Context;
import org.opennms.newts.api.Sample;
import org.opennms.newts.api.SampleProcessorService;
import org.opennms.newts.api.SampleRepository;
import org.opennms.newts.api.Timestamp;
import org.opennms.newts.api.ValueType;
import org.opennms.newts.api.query.ResultDescriptor;
import org.opennms.newts.cassandra.CassandraSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.datastax.driver.core.BoundStatement;
import com.datastax.driver.core.PreparedStatement;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.querybuilder.Batch;
import com.datastax.driver.core.querybuilder.QueryBuilder;
import com.datastax.driver.core.querybuilder.Select;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;


public class CassandraSampleRepository implements SampleRepository {

    private static final Logger LOG = LoggerFactory.getLogger(CassandraSampleRepository.class);

    // Used to calculate the duration when the duration is not specified
    private static final int TARGET_NUMBER_OF_STEPS = 10;

    private final CassandraSession m_session;
    private final int m_ttl;
    private final SampleProcessorService m_processorService;
    private final PreparedStatement m_selectStatement;

    private final Timer m_sampleSelectTimer;
    private final Timer m_measurementSelectTimer;
    private final Timer m_insertTimer;

    private final ContextConfigurations m_contextConfigurations;

    @Inject
    public CassandraSampleRepository(CassandraSession session, @Named("samples.cassandra.time-to-live") int ttl, MetricRegistry registry, SampleProcessorService processorService, ContextConfigurations contextConfigurations) {

        m_session = checkNotNull(session, "session argument");
        checkArgument(ttl >= 0, "Negative Cassandra column TTL");

        m_ttl = ttl;

        checkNotNull(registry, "metric registry argument");
        m_processorService = processorService;

        m_contextConfigurations = checkNotNull(contextConfigurations, "contextConfigurations argument");

        Select select = QueryBuilder.select().from(SchemaConstants.T_SAMPLES);
        select.where(eq(SchemaConstants.F_CONTEXT, bindMarker(SchemaConstants.F_CONTEXT)));
        select.where(eq(SchemaConstants.F_PARTITION, bindMarker(SchemaConstants.F_PARTITION)));
        select.where(eq(SchemaConstants.F_RESOURCE, bindMarker(SchemaConstants.F_RESOURCE)));

        select.where(gte(SchemaConstants.F_COLLECTED, bindMarker("start")));
        select.where(lte(SchemaConstants.F_COLLECTED, bindMarker("end")));

        m_selectStatement = m_session.prepare(select.toString());

        m_sampleSelectTimer = registry.timer(metricName("sample-select-timer"));
        m_measurementSelectTimer = registry.timer(metricName("measurement-select-timer"));
        m_insertTimer = registry.timer(metricName("insert-timer"));

    }

    @Override
    public Results<Measurement> select(Context context, Resource resource, Optional<Timestamp> start, Optional<Timestamp> end, ResultDescriptor descriptor, Optional<Duration> resolution) {

        Timer.Context timer = m_measurementSelectTimer.time();

        validateSelect(start, end);

        Timestamp upper = end.isPresent() ? end.get() : Timestamp.now();
        Timestamp lower = start.isPresent() ? start.get() : upper.minus(Duration.seconds(86400));
        Duration step;
        if (resolution.isPresent()) {
            step = resolution.get();
        } else {
            // Determine the ideal step size, splitting the interval evenly into N slices
            long stepMillis = upper.minus(lower).asMillis() / TARGET_NUMBER_OF_STEPS;

            // But every step must be a multiple of the interval
            long intervalMillis = descriptor.getInterval().asMillis();

            // If the interval is greater than the target step, use the 2 * interval as the step
            if (intervalMillis >= stepMillis) {
                step = descriptor.getInterval().times(2);
            } else {
                // Otherwise, round stepMillkeyis up to the closest multiple of intervalMillis
                long remainderMillis = stepMillis % intervalMillis;
                if (remainderMillis != 0) {
                    stepMillis = stepMillis + intervalMillis - remainderMillis;
                }

                step = Duration.millis(stepMillis);
            }
        }

        LOG.debug("Querying database for resource {}, from {} to {}", resource, lower.minus(step), upper);

        DriverAdapter driverAdapter = new DriverAdapter(cassandraSelect(context, resource,lower.minus(step), upper),
                descriptor.getSourceNames());
        Results<Measurement> results = new ResultProcessor(resource, lower, upper, descriptor, step).process(driverAdapter);

        LOG.debug("{} results returned from database", driverAdapter.getResultCount());

        try {
            return results;
        }
        finally {
            timer.stop();
        }

    }

    @Override
    public Results<Sample> select(Context context, Resource resource, Optional<Timestamp> start, Optional<Timestamp> end) {

        Timer.Context timer = m_sampleSelectTimer.time();

        validateSelect(start, end);

        Timestamp upper = end.isPresent() ? end.get() : Timestamp.now();
        Timestamp lower = start.isPresent() ? start.get() : upper.minus(Duration.seconds(86400));

        LOG.debug("Querying database for resource {}, from {} to {}", resource, lower, upper);

        Results<Sample> samples = new Results<Sample>();
        DriverAdapter driverAdapter = new DriverAdapter(cassandraSelect(context, resource, lower, upper));

        for (Row<Sample> row : driverAdapter) {
            samples.addRow(row);
        }

        LOG.debug("{} results returned from database", driverAdapter.getResultCount());

        try {
            return samples;
        }
        finally {
            timer.stop();
        }
    }

    @Override
    public void insert(Collection<Sample> samples) {
        insert(samples, false);
    }

    @Override
    public void insert(Collection<Sample> samples, boolean calculateTimeToLive) {

        Timer.Context timer = m_insertTimer.time();
        Timestamp now = Timestamp.now();

        Batch batch = unloggedBatch();

        for (Sample m : samples) {
            int ttl = m_ttl;
            if (calculateTimeToLive) {
                ttl -= (int)now.minus(m.getTimestamp()).asSeconds();
                if (ttl <= 0) {
                    LOG.debug("Skipping expired sample: {}", m);
                    continue;
                }
            }

            Duration resourceShard = m_contextConfigurations.getResourceShard(m.getContext());
            batch.add(
                    insertInto(SchemaConstants.T_SAMPLES)
                        .value(SchemaConstants.F_CONTEXT, m.getContext().getId())
                        .value(SchemaConstants.F_PARTITION, m.getTimestamp().stepFloor(resourceShard).asSeconds())
                        .value(SchemaConstants.F_RESOURCE, m.getResource().getId())
                        .value(SchemaConstants.F_COLLECTED, m.getTimestamp().asMillis())
                        .value(SchemaConstants.F_METRIC_NAME, m.getName())
                        .value(SchemaConstants.F_VALUE, ValueType.decompose(m.getValue()))
                        .value(SchemaConstants.F_ATTRIBUTES, m.getAttributes())
                        .using(ttl(ttl))
            );
        }

        try {
            m_session.execute(batch);

            if (m_processorService != null) {
                m_processorService.submit(samples);
            }
        }
        finally {
            timer.stop();
        }

    }

    private Iterator<com.datastax.driver.core.Row> cassandraSelect(Context context, Resource resource,
            Timestamp start, Timestamp end) {

        List<Future<ResultSet>> futures = Lists.newArrayList();

        Duration resourceShard = m_contextConfigurations.getResourceShard(context);
        Timestamp lower = start.stepFloor(resourceShard);
        Timestamp upper = end.stepFloor(resourceShard);

        for (Timestamp partition : new IntervalGenerator(lower, upper, resourceShard)) {
            BoundStatement bindStatement = m_selectStatement.bind();
            bindStatement.setString(SchemaConstants.F_CONTEXT, context.getId());
            bindStatement.setInt(SchemaConstants.F_PARTITION, (int) partition.asSeconds());
            bindStatement.setString(SchemaConstants.F_RESOURCE, resource.getId());
            bindStatement.setDate("start", start.asDate());
            bindStatement.setDate("end", end.asDate());

            futures.add(m_session.executeAsync(bindStatement));
        }

        return new ConcurrentResultWrapper(futures);
    }

    private void validateSelect(Optional<Timestamp> start, Optional<Timestamp> end) {
        if ((start.isPresent() && end.isPresent()) && start.get().gt(end.get())) {
            throw new IllegalArgumentException("start time must be less than end time");
        }
    }


    private String metricName(String suffix) {
        return name("repository", suffix);
    }

}
