/*
 * Copyright (c) 2015 Villu Ruusmann
 *
 * This file is part of JPMML-SkLearn
 *
 * JPMML-SkLearn is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SkLearn is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SkLearn.  If not, see <http://www.gnu.org/licenses/>.
 */
package sklearn.linear_model;

import java.util.ArrayList;
import java.util.List;

import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.LoggerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RegressionModelUtil {

	private RegressionModelUtil(){
	}

	static
	public RegressionModel encodeRegressionModel(List<? extends Number> coefficients, Number intercept, Schema schema){
		RegressionTable regressionTable = encodeRegressionTable(coefficients, intercept, schema);

		RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema), null)
			.addRegressionTables(regressionTable);

		return regressionModel;
	}

	static
	public RegressionTable encodeRegressionTable(List<? extends Number> coefficients, Number intercept, Schema schema){
		List<Feature> features = schema.getFeatures();

		if(coefficients.size() != features.size()){
			throw new IllegalArgumentException();
		}

		RegressionTable regressionTable = new RegressionTable(ValueUtil.asDouble(intercept));

		List<Feature> unusedFeatures = new ArrayList<>();

		for(int i = 0; i < coefficients.size(); i++){
			Number coefficient = coefficients.get(i);
			Feature feature = features.get(i);

			if(ValueUtil.isZero(coefficient)){
				unusedFeatures.add(feature);

				continue;
			} // End if

			if(feature instanceof ContinuousFeature){
				ContinuousFeature continuousFeature = (ContinuousFeature)feature;

				NumericPredictor numericPredictor = new NumericPredictor(continuousFeature.getName(), ValueUtil.asDouble(coefficient));

				regressionTable.addNumericPredictors(numericPredictor);
			} else

			if(feature instanceof BinaryFeature){
				BinaryFeature binaryFeature = (BinaryFeature)feature;

				CategoricalPredictor categoricalPredictor = new CategoricalPredictor(binaryFeature.getName(), binaryFeature.getValue(), ValueUtil.asDouble(coefficient));

				regressionTable.addCategoricalPredictors(categoricalPredictor);
			} else

			{
				throw new IllegalArgumentException();
			}
		}

		if(!unusedFeatures.isEmpty()){
			logger.info("Skipped {} feature(s): {}", unusedFeatures.size(), LoggerUtil.formatNameList(unusedFeatures));
		}

		return regressionTable;
	}

	static
	public RegressionTable encodeRegressionTable(NumericPredictor numericPredictor, Number intercept){
		RegressionTable regressionTable = new RegressionTable(ValueUtil.asDouble(intercept))
			.addNumericPredictors(numericPredictor);

		return regressionTable;
	}

	private static final Logger logger = LoggerFactory.getLogger(RegressionModelUtil.class);
}