package io.bidmachine.ml;

import ai.catboost.CatBoostError;
import ai.catboost.CatBoostModel;
import io.bidmachine.config.data.SchemaConfig;
import io.bidmachine.config.exceptions.ConfigException;
import io.bidmachine.config.ml.ModelFeaturesConfig;
import io.bidmachine.config.ml.ModelFeaturesConfigReader;
import io.bidmachine.config.providers.ConfigProvider;
import io.bidmachine.config.providers.ZipConfigProvider;
import io.bidmachine.data.DataFrame;
import io.bidmachine.data.FeatureRecord;
import io.bidmachine.data.InMemoryDataFrameReader;
import io.bidmachine.mutators.LookupMutator;
import io.bidmachine.mutators.Mutator;
import io.bidmachine.mutators.MutatorException;
import io.bidmachine.mutators.TypeMutator;
import io.bidmachine.mutators.UppercaseMutatorForUS;
import io.bidmachine.utils.CheckedHashMap;
import io.bidmachine.utils.CollectionUtils;
import io.bidmachine.utils.PathUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.NotImplementedException;

/* loaded from: input_file:io/bidmachine/ml/FloorPredictor.class */
public class FloorPredictor {
    private HashSet<String> inputAllFields;
    private Map<String, Integer> indexOfFloorByModel;
    private Map<String, CatBoostModel> catBoostModels;
    private Map<String, ModelFeaturesConfig> modelFeaturesConfigs;
    private List<Mutator> mutators;
    float[] cuts;

    public FloorPredictor(String str) throws CatBoostError, IOException, ConfigException {
        init(new ZipConfigProvider(str));
    }

    public FloorPredictor(ConfigProvider configProvider) throws CatBoostError, IOException, ConfigException {
        init(configProvider);
    }

    private void init(ConfigProvider configProvider) throws CatBoostError, IOException, ConfigException {
        Object[] fieldValues = InMemoryDataFrameReader.readCsv(configProvider, PathUtils.joinPathGeneric("other", "cuts.csv")).getFieldValues("flr");
        this.cuts = new float[fieldValues.length];
        for (int i = 0; i < fieldValues.length; i++) {
            this.cuts[i] = ((Float) fieldValues[i]).floatValue();
        }
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        this.catBoostModels = new CheckedHashMap();
        this.modelFeaturesConfigs = new CheckedHashMap();
        Pattern compile = Pattern.compile("model_(.*)\\.cbm");
        this.indexOfFloorByModel = new CheckedHashMap();
        for (String str : configProvider.listFiles("models", "*.cbm")) {
            CatBoostModel loadModel = CatBoostModel.loadModel(configProvider.getInputStream(str));
            Matcher matcher = compile.matcher(str);
            if (!matcher.find()) {
                throw new IllegalArgumentException(String.format("model path %s is of bad format", str));
            }
            String group = matcher.group(1);
            this.catBoostModels.put(group, loadModel);
            ModelFeaturesConfig read = ModelFeaturesConfigReader.read(configProvider, PathUtils.joinPathGeneric("models", "info_" + group + ".json"));
            List<String> contVars = read.getContVars();
            if (!contVars.contains("flr")) {
                throw new IllegalArgumentException(String.format("info_%s.json should contain flr feature", group));
            }
            int i2 = 0;
            while (true) {
                if (i2 >= contVars.size()) {
                    break;
                }
                if ("flr".equals(contVars.get(i2))) {
                    this.indexOfFloorByModel.put(group, Integer.valueOf(i2));
                    break;
                }
                i2++;
            }
            hashSet.addAll(read.getCatVars());
            hashSet2.addAll(read.getContVars());
            this.modelFeaturesConfigs.put(group, read);
        }
        HashSet union = CollectionUtils.union(hashSet, hashSet2);
        if (union.size() != hashSet.size() + hashSet2.size()) {
            throw new IllegalArgumentException(String.format("Categorical and numerical fields intersect: %s", CollectionUtils.toStr(CollectionUtils.intersect(hashSet, hashSet2))));
        }
        HashMap hashMap = new HashMap();
        for (String str2 : configProvider.listFiles("lookups", "*.csv")) {
            hashMap.put(str2, InMemoryDataFrameReader.readCsv(configProvider, str2));
        }
        this.mutators = new ArrayList();
        this.mutators.add(new UppercaseMutatorForUS("uppercase_US_region"));
        HashSet hashSet3 = new HashSet();
        HashSet hashSet4 = new HashSet();
        for (Map.Entry entry : hashMap.entrySet()) {
            String str3 = (String) entry.getKey();
            DataFrame dataFrame = (DataFrame) entry.getValue();
            SchemaConfig schema = dataFrame.getSchema();
            String[] key = schema.getKey();
            HashSet intersect = CollectionUtils.intersect(union, Arrays.asList(schema.getValueFields()));
            if (!intersect.isEmpty()) {
                if (!union.containsAll(Arrays.asList(key))) {
                    HashSet hashSet5 = new HashSet(Arrays.asList(key));
                    hashSet5.removeAll(union);
                    hashSet3.addAll(hashSet5);
                }
                union.removeAll(intersect);
                hashSet4.addAll(intersect);
                this.mutators.add(new LookupMutator(str3, dataFrame, Arrays.asList(key), new ArrayList(intersect), null, true));
            }
        }
        raiseErrorOnIntersection(union, hashSet3, "inputFields and inputMissingKeyFields");
        raiseErrorOnIntersection(hashSet3, hashSet4, "inputMissingKeyFields and lookupedFields");
        raiseErrorOnIntersection(union, hashSet4, "inputFields and lookupedFields");
        this.inputAllFields = CollectionUtils.union(union, hashSet3);
        HashSet hashSet6 = new HashSet(hashSet2);
        hashSet6.remove("flr");
        this.mutators.add(new TypeMutator("fix_types", hashSet, hashSet6));
    }

    private void raiseErrorOnIntersection(Set<String> set, Set<String> set2, String str) {
        HashSet intersect = CollectionUtils.intersect(set, set2);
        if (!intersect.isEmpty()) {
            throw new IllegalArgumentException(String.format("%s are intersecting, common fields: %s", str, CollectionUtils.toStr(intersect)));
        }
    }

    private void raiseErrorOnBadInput(Map<String, Object> map) {
        Set<String> keySet = map.keySet();
        if (!keySet.equals(this.inputAllFields)) {
            throw new IllegalArgumentException(String.format("Wrong input. Missing keys: %s, redundant keys: %s", CollectionUtils.minus(this.inputAllFields, keySet), CollectionUtils.minus(keySet, this.inputAllFields)));
        }
    }

    public double predictBestFloor(Map<String, Object> map) throws MutatorException, CatBoostError {
        return predictBestFloor(map, FloorSelectionStrategy.SMART_WIN);
    }

    public double predictBestFloor(Map<String, Object> map, FloorSelectionStrategy floorSelectionStrategy) throws MutatorException, CatBoostError {
        if (map.containsKey("flr")) {
            throw new IllegalArgumentException("flr must not be in the keys when calling predictBestFloor()");
        }
        FeatureRecord featureRecord = new FeatureRecord(map);
        Iterator<Mutator> it = this.mutators.iterator();
        while (it.hasNext()) {
            featureRecord = it.next().mutate(featureRecord);
        }
        featureRecord.put("flr", Float.valueOf(0.0f));
        double[] sigmoidArray = sigmoidArray(predictArrayByModel(featureRecord, "has_good_bids"));
        double[] sigmoidArray2 = sigmoidArray(predictArrayByModel(featureRecord, "is_spend_nurl"));
        double[] predictArrayByModel = predictArrayByModel(featureRecord, "spend_nurl");
        switch (floorSelectionStrategy) {
            case SMART_WIN:
                return Postprocessing.calcSmartWin(this.cuts, sigmoidArray, sigmoidArray2, predictArrayByModel);
            case SMART_SPEND:
                return Postprocessing.calcSmartSpend(this.cuts, sigmoidArray, sigmoidArray2, predictArrayByModel);
            case PLAIN_WIN:
                return Postprocessing.calcPlainWin(this.cuts, sigmoidArray, sigmoidArray2);
            case PLAIN_SPEND:
                return Postprocessing.calcPlainSpend(this.cuts, sigmoidArray, sigmoidArray2, predictArrayByModel);
            default:
                throw new NotImplementedException(String.format("%s is not implemented yet", floorSelectionStrategy));
        }
    }

    /* JADX WARN: Type inference failed for: r0v16, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r0v25, types: [java.lang.Object[], java.lang.String[], java.lang.String[][]] */
    private double[] predictArrayByModel(FeatureRecord featureRecord, String str) throws CatBoostError {
        CatBoostModel catBoostModel = this.catBoostModels.get(str);
        ModelFeaturesConfig modelFeaturesConfig = this.modelFeaturesConfigs.get(str);
        float[] floatValues = featureRecord.getFloatValues(modelFeaturesConfig.getContVars());
        String[] stringValues = featureRecord.getStringValues(modelFeaturesConfig.getCatVars());
        int length = this.cuts.length;
        ?? r0 = new float[length];
        int intValue = this.indexOfFloorByModel.get(str).intValue();
        for (int i = 0; i < length; i++) {
            float[] copyOf = Arrays.copyOf(floatValues, floatValues.length);
            copyOf[intValue] = this.cuts[i];
            r0[i] = copyOf;
        }
        ?? r02 = new String[length];
        Arrays.fill((Object[]) r02, stringValues);
        return catBoostModel.predict((float[][]) r0, (String[][]) r02).copyRowMajorPredictions();
    }

    public static double[] sigmoidArray(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = 1.0d / (1.0d + Math.exp(-dArr[i]));
        }
        return dArr2;
    }

    public HashSet<String> getInputAllFields() {
        return this.inputAllFields;
    }
}
