package com.ibm.avatar.algebra.util.pmml;

import com.ibm.avatar.algebra.datamodel.AbstractTupleSchema;
import com.ibm.avatar.algebra.datamodel.FieldCopier;
import com.ibm.avatar.algebra.datamodel.FieldGetter;
import com.ibm.avatar.algebra.datamodel.FieldSetter;
import com.ibm.avatar.algebra.datamodel.FieldType;
import com.ibm.avatar.algebra.datamodel.TLIter;
import com.ibm.avatar.algebra.datamodel.Tuple;
import com.ibm.avatar.algebra.datamodel.TupleList;
import com.ibm.avatar.algebra.datamodel.TupleSchema;
import com.ibm.avatar.api.exceptions.TableUDFException;
import com.ibm.avatar.api.exceptions.TextAnalyticsException;
import com.ibm.avatar.api.udf.TableUDFBase;
import com.ibm.avatar.logging.Log;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.IOUtil;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Computable;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.manager.PMMLManager;

/* loaded from: input_file:com/ibm/avatar/algebra/util/pmml/ScoringTableFuncBase.class */
public abstract class ScoringTableFuncBase extends TableUDFBase {
    protected PMML pmml;
    protected Evaluator evaluator;
    protected ArrayList<FieldName> activeFields;
    protected ArrayList<FieldGetter<?>> fieldGetters;
    protected ArrayList<FieldName> outputFields;
    protected ArrayList<FieldSetter<? extends Object>> outputFieldSetters;
    protected ArrayList<FieldName> predictedFields;
    protected ArrayList<FieldSetter<? extends Object>> predictedFieldSetters;
    protected FieldCopier fieldCopier = null;

    protected abstract InputStream getPMMLStream() throws TableUDFException;

    @Override // com.ibm.avatar.api.udf.TableUDFBase, com.ibm.avatar.algebra.function.base.TableUDFBaseImpl
    public void validateSchema(TupleSchema tupleSchema, TupleSchema tupleSchema2, TupleSchema tupleSchema3, Method method, boolean z) throws TableUDFException {
        initModel(tupleSchema2, tupleSchema3);
    }

    private void initModel(TupleSchema tupleSchema, TupleSchema tupleSchema2) throws TableUDFException {
        FieldType fieldTypeByName = tupleSchema.getFieldTypeByName(PMMLUtil.MODEL_PARAMS_ARG_NAME);
        if (false == fieldTypeByName.getIsLocator()) {
            throw new TableUDFException("Params argument of scoring function is not a record locator.  Ensure that the function is declarated properly in the 'create function' statement.", new Object[0]);
        }
        AbstractTupleSchema recordSchema = fieldTypeByName.getRecordSchema();
        try {
            this.pmml = IOUtil.unmarshal(getPMMLStream());
            this.evaluator = new PMMLManager(this.pmml).getModelManager((String) null, ModelEvaluatorFactory.getInstance());
            Log.debug("%s: Active fields are %s", this, this.evaluator.getActiveFields());
            Log.debug("%s: Output fields are %s", this, this.evaluator.getOutputFields());
            Log.debug("%s: Predicted fields are %s", this, this.evaluator.getPredictedFields());
            Log.debug("%s: Group fields are %s", this, this.evaluator.getGroupFields());
            this.activeFields = new ArrayList<>();
            this.fieldGetters = new ArrayList<>();
            for (FieldName fieldName : this.evaluator.getActiveFields()) {
                DataType dataType = this.evaluator.getDataField(fieldName).getDataType();
                if (recordSchema.containsField(fieldName.getValue())) {
                    FieldType fieldTypeByName2 = recordSchema.getFieldTypeByName(fieldName.getValue());
                    validateInputConversion(fieldName.getValue(), fieldTypeByName2, dataType);
                    this.activeFields.add(fieldName);
                    this.fieldGetters.add(recordSchema.genericGetter(fieldName.getValue(), fieldTypeByName2));
                }
            }
            this.outputFields = new ArrayList<>();
            this.outputFieldSetters = new ArrayList<>();
            for (FieldName fieldName2 : this.evaluator.getOutputFields()) {
                if (tupleSchema2.containsField(fieldName2.getValue())) {
                    this.outputFields.add(fieldName2);
                    this.outputFieldSetters.add(makeFieldSetter(fieldName2, false, tupleSchema2));
                }
            }
            this.predictedFields = new ArrayList<>();
            this.predictedFieldSetters = new ArrayList<>();
            for (FieldName fieldName3 : this.evaluator.getPredictedFields()) {
                if (tupleSchema2.containsField(fieldName3.getValue())) {
                    this.predictedFields.add(fieldName3);
                    this.predictedFieldSetters.add(makeFieldSetter(fieldName3, true, tupleSchema2));
                }
            }
            ArrayList arrayList = new ArrayList();
            for (String str : recordSchema.getFieldNames()) {
                if (tupleSchema2.containsField(str)) {
                    arrayList.add(str);
                }
            }
            if (arrayList.size() > 0) {
                String[] strArr = (String[]) arrayList.toArray(new String[arrayList.size()]);
                this.fieldCopier = tupleSchema2.fieldCopier(recordSchema, strArr, strArr);
            }
        } catch (Exception e) {
            throw new TableUDFException(e, "Error parsing PMML representation of model.  Ensure that the model file is in valid PMML format (PMML 4.1 or earlier).", new Object[0]);
        }
    }

    private FieldSetter<? extends Object> makeFieldSetter(FieldName fieldName, boolean z, TupleSchema tupleSchema) throws TableUDFException {
        String value = fieldName.getValue();
        DataType dataType = z ? this.evaluator.getDataField(fieldName).getDataType() : this.evaluator.getOutputField(fieldName).getDataType();
        FieldType fieldTypeByName = tupleSchema.getFieldTypeByName(value);
        validateOutputConversion(value, dataType, fieldTypeByName);
        return fieldTypeByName.getIsText() ? tupleSchema.textSetter(value) : tupleSchema.genericSetter(value, fieldTypeByName);
    }

    private void validateInputConversion(String str, FieldType fieldType, DataType dataType) throws TableUDFException {
    }

    private void validateOutputConversion(String str, DataType dataType, FieldType fieldType) throws TableUDFException {
    }

    protected TupleList evalImpl(TupleList tupleList) throws TextAnalyticsException {
        TupleList tupleList2 = new TupleList(getReturnTupleSchema());
        TLIter it = tupleList.iterator();
        while (it.hasNext()) {
            Tuple next = it.next();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (int i = 0; i < this.activeFields.size(); i++) {
                FieldName fieldName = this.activeFields.get(i);
                Object val = this.fieldGetters.get(i).getVal(next);
                try {
                    linkedHashMap.put(fieldName, this.evaluator.prepare(fieldName, val));
                } catch (Exception e) {
                    throw new TextAnalyticsException(e, "Invalid argument value %s for field %s. PMML error type is %s.  Ensure that the value satisfies the constraints specified in the PMML file.", val, fieldName, e.toString());
                }
            }
            Map evaluate = this.evaluator.evaluate(linkedHashMap);
            Tuple createTup = getReturnTupleSchema().createTup();
            for (int i2 = 0; i2 < this.outputFields.size(); i2++) {
                this.outputFieldSetters.get(i2).setVal(createTup, resultToObject(evaluate.get(this.outputFields.get(i2))));
            }
            for (int i3 = 0; i3 < this.predictedFields.size(); i3++) {
                this.predictedFieldSetters.get(i3).setVal(createTup, resultToObject(evaluate.get(this.predictedFields.get(i3))));
            }
            if (null != this.fieldCopier) {
                this.fieldCopier.copyVals(next, createTup);
            }
            tupleList2.add(createTup);
        }
        return tupleList2;
    }

    private Object resultToObject(Object obj) {
        Object obj2 = obj;
        if (obj instanceof Computable) {
            obj2 = ((Computable) obj).getResult();
        }
        return obj2;
    }
}
