package io.bidmachine.ml;

import ai.catboost.CatBoostError;
import ai.catboost.CatBoostModel;
import io.bidmachine.Const;
import io.bidmachine.config.data.SchemaConfig;
import io.bidmachine.config.exceptions.ConfigException;
import io.bidmachine.config.ml.ModelFeaturesConfig;
import io.bidmachine.config.ml.ModelFeaturesConfigReader;
import io.bidmachine.config.providers.ConfigProvider;
import io.bidmachine.config.providers.SevenZConfigProvider;
import io.bidmachine.config.providers.ZipConfigProvider;
import io.bidmachine.data.DataFrame;
import io.bidmachine.data.FeatureRecord;
import io.bidmachine.data.InMemoryDataFrameReader;
import io.bidmachine.mutators.CommonMutator;
import io.bidmachine.mutators.EcpmMutator;
import io.bidmachine.mutators.LookupMutator;
import io.bidmachine.mutators.Mutator;
import io.bidmachine.mutators.MutatorException;
import io.bidmachine.mutators.TypeMutator;
import io.bidmachine.utils.CheckedHashMap;
import io.bidmachine.utils.CollectionUtils;
import io.bidmachine.utils.PathUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.NotImplementedException;

/* loaded from: input_file:io/bidmachine/ml/FloorPredictor.class */
public class FloorPredictor {
    private HashSet<String> inputAllFields;
    private Map<String, Integer> indexOfFloorByModel;
    private Map<String, CatBoostModel> catBoostModels;
    private Map<String, ModelFeaturesConfig> modelFeaturesConfigs;
    private List<Mutator> mutators;
    float[] cuts;
    private MLParams mlParams;

    public FloorPredictor(String str) throws FloorPredictorException {
        this(str, MLParamsBuilder.getDefault());
    }

    public FloorPredictor(ConfigProvider configProvider) throws FloorPredictorException {
        this(configProvider, MLParamsBuilder.getDefault());
    }

    public FloorPredictor(String str, String str2) throws FloorPredictorException {
        this(str, MLParamsBuilder.get(str2));
    }

    public FloorPredictor(ConfigProvider configProvider, String str) throws FloorPredictorException {
        this(configProvider, MLParamsBuilder.get(str));
    }

    public FloorPredictor(String str, MLParams mLParams) throws FloorPredictorException {
        if (str.endsWith(".7z")) {
            init(new SevenZConfigProvider(str), mLParams);
        } else if (str.endsWith(".zip")) {
            init(new ZipConfigProvider(str), mLParams);
        }
        throw new IllegalArgumentException("BundlePath must be either 7z or zip archive");
    }

    public FloorPredictor(ConfigProvider configProvider, MLParams mLParams) throws FloorPredictorException {
        init(configProvider, mLParams);
    }

    private void init(ConfigProvider configProvider, MLParams mLParams) throws FloorPredictorException {
        this.mlParams = mLParams;
        try {
            Object[] fieldValues = InMemoryDataFrameReader.readCsv(configProvider, PathUtils.joinPathGeneric("other", "cuts.csv")).getFieldValues(Const.FLR);
            this.cuts = new float[fieldValues.length];
            for (int i = 0; i < fieldValues.length; i++) {
                this.cuts[i] = ((Float) fieldValues[i]).floatValue();
            }
            HashSet hashSet = new HashSet();
            HashSet hashSet2 = new HashSet();
            this.catBoostModels = new CheckedHashMap();
            this.modelFeaturesConfigs = new CheckedHashMap();
            Pattern compile = Pattern.compile("model_(.*)\\.cbm");
            this.indexOfFloorByModel = new CheckedHashMap();
            for (String str : configProvider.listFiles("models", "*.cbm")) {
                try {
                    CatBoostModel loadModel = CatBoostModel.loadModel(configProvider.getInputStream(str));
                    Matcher matcher = compile.matcher(str);
                    if (!matcher.find()) {
                        throw new IllegalArgumentException(String.format("model path %s is of bad format", str));
                    }
                    String group = matcher.group(1);
                    this.catBoostModels.put(group, loadModel);
                    String str2 = "info_" + group + ".json";
                    try {
                        ModelFeaturesConfig read = ModelFeaturesConfigReader.read(configProvider, PathUtils.joinPathGeneric("models", str2));
                        List<String> contVars = read.getContVars();
                        if (!contVars.contains(Const.FLR)) {
                            throw new IllegalArgumentException(String.format("info_%s.json should contain flr feature", group));
                        }
                        int i2 = 0;
                        while (true) {
                            if (i2 >= contVars.size()) {
                                break;
                            }
                            if (Const.FLR.equals(contVars.get(i2))) {
                                this.indexOfFloorByModel.put(group, Integer.valueOf(i2));
                                break;
                            }
                            i2++;
                        }
                        hashSet.addAll(read.getCatVars());
                        hashSet2.addAll(read.getContVars());
                        this.modelFeaturesConfigs.put(group, read);
                    } catch (ConfigException e) {
                        throw new FloorPredictorException(String.format("Info file %s is bad", str2), FloorPredictorProblemType.BAD_BUNDLE_BAD_RESORCE, e);
                    }
                } catch (CatBoostError e2) {
                    throw new FloorPredictorException(String.format("Catboost cannot load model from %s", str), FloorPredictorProblemType.CATBOOST_FAILED_TO_LOAD, e2);
                } catch (IOException e3) {
                    throw new FloorPredictorException(String.format("bad %s: no resource %s inside", configProvider, str), FloorPredictorProblemType.BAD_BUNDLE_SOMETHING_MISSING, e3);
                }
            }
            HashSet union = CollectionUtils.union(hashSet, hashSet2);
            if (union.size() != hashSet.size() + hashSet2.size()) {
                throw new IllegalArgumentException(String.format("Categorical and numerical fields intersect: %s", CollectionUtils.toStr(CollectionUtils.intersect(hashSet, hashSet2))));
            }
            HashMap hashMap = new HashMap();
            for (String str3 : configProvider.listFiles("lookups", "*.csv")) {
                try {
                    hashMap.put(str3, InMemoryDataFrameReader.readCsv(configProvider, str3));
                } catch (ConfigException e4) {
                    throw new FloorPredictorException(String.format("Lookup file %s is bad", str3), FloorPredictorProblemType.BAD_BUNDLE_BAD_RESORCE, e4);
                } catch (IOException e5) {
                    throw new FloorPredictorException(String.format("bad %s: lookup file %s is missing", configProvider, str3), FloorPredictorProblemType.BAD_BUNDLE_SOMETHING_MISSING, e5);
                }
            }
            this.mutators = new ArrayList();
            this.mutators.add(new CommonMutator("common"));
            this.mutators.add(new EcpmMutator("ecpm"));
            HashSet hashSet3 = new HashSet();
            HashSet hashSet4 = new HashSet();
            for (Map.Entry entry : hashMap.entrySet()) {
                String str4 = (String) entry.getKey();
                DataFrame dataFrame = (DataFrame) entry.getValue();
                SchemaConfig schema = dataFrame.getSchema();
                String[] key = schema.getKey();
                HashSet intersect = CollectionUtils.intersect(union, Arrays.asList(schema.getValueFields()));
                if (!intersect.isEmpty()) {
                    if (!union.containsAll(Arrays.asList(key))) {
                        HashSet hashSet5 = new HashSet(Arrays.asList(key));
                        hashSet5.removeAll(union);
                        hashSet3.addAll(hashSet5);
                    }
                    union.removeAll(intersect);
                    hashSet4.addAll(intersect);
                    this.mutators.add(new LookupMutator(str4, dataFrame, Arrays.asList(key), new ArrayList(intersect), null, true));
                }
            }
            raiseErrorOnIntersection(union, hashSet3, "inputFields and inputMissingKeyFields");
            raiseErrorOnIntersection(hashSet3, hashSet4, "inputMissingKeyFields and lookupedFields");
            raiseErrorOnIntersection(union, hashSet4, "inputFields and lookupedFields");
            this.inputAllFields = CollectionUtils.union(union, hashSet3);
            for (String str5 : EcpmMutator.ECPM_AD_TYPES) {
                boolean z = false;
                String[] strArr = EcpmMutator.ECPM_METRICS;
                int length = strArr.length;
                int i3 = 0;
                while (true) {
                    if (i3 >= length) {
                        break;
                    }
                    if (this.inputAllFields.contains(strArr[i3] + "_" + str5)) {
                        z = true;
                        break;
                    }
                    i3++;
                }
                if (z) {
                    for (String str6 : EcpmMutator.ECPM_METRICS) {
                        this.inputAllFields.add(str6 + "_" + str5);
                    }
                }
            }
            for (String[] strArr2 : CommonMutator.BID_DENSITY_SUM_FLOAT) {
                if (this.inputAllFields.contains(strArr2[0])) {
                    this.inputAllFields.add(strArr2[1]);
                    this.inputAllFields.add(strArr2[2]);
                }
            }
            for (String[] strArr3 : CommonMutator.BID_DENSITY_SUM_INT) {
                if (this.inputAllFields.contains(strArr3[0])) {
                    this.inputAllFields.add(strArr3[1]);
                    this.inputAllFields.add(strArr3[2]);
                }
            }
            for (String[] strArr4 : CommonMutator.BID_DENSITY_DIV) {
                if (this.inputAllFields.contains(strArr4[0])) {
                    this.inputAllFields.add(strArr4[1]);
                    this.inputAllFields.add(strArr4[2]);
                }
            }
            HashSet hashSet6 = new HashSet(hashSet2);
            hashSet6.remove(Const.FLR);
            this.mutators.add(new TypeMutator("fix_types", hashSet, hashSet6));
        } catch (ConfigException e6) {
            throw new FloorPredictorException(String.format("bad %s: corrupted other/cuts.csv", configProvider), FloorPredictorProblemType.BAD_BUNDLE_BAD_RESORCE, e6);
        } catch (IOException e7) {
            throw new FloorPredictorException(String.format("bad %s: no resource other/cuts.csv inside", configProvider), FloorPredictorProblemType.BAD_BUNDLE_SOMETHING_MISSING, e7);
        }
    }

    private void raiseErrorOnIntersection(Set<String> set, Set<String> set2, String str) {
        HashSet intersect = CollectionUtils.intersect(set, set2);
        if (!intersect.isEmpty()) {
            throw new IllegalArgumentException(String.format("%s are intersecting, common fields: %s", str, CollectionUtils.toStr(intersect)));
        }
    }

    private void raiseOnMissingFieldsInFeatureRecord(FeatureRecord featureRecord) throws FloorPredictorException {
        HashSet minus = CollectionUtils.minus(getInputAllFields(), featureRecord.getKeys());
        if (!minus.isEmpty()) {
            throw new FloorPredictorException(String.format("Wrong input. Missing keys: %s", minus), FloorPredictorProblemType.FEATURE_MISSING);
        }
    }

    public double predictBestFloor(Map<String, Object> map) throws FloorPredictorException {
        return predictBestFloor(map, this.mlParams);
    }

    public double predictBestFloor(Map<String, Object> map, String str) throws FloorPredictorException {
        return predictBestFloor(map, MLParamsBuilder.get(str));
    }

    public double predictBestFloor(Map<String, Object> map, MLParams mLParams) throws FloorPredictorException {
        if (map.containsKey(Const.FLR)) {
            throw new IllegalArgumentException("flr must not be in the keys when calling predictBestFloor()");
        }
        FeatureRecord featureRecord = new FeatureRecord(map);
        for (Mutator mutator : this.mutators) {
            try {
                featureRecord = mutator.mutate(featureRecord);
            } catch (MutatorException e) {
                throw new FloorPredictorException(String.format("Mutator %s failed", mutator.getId()), FloorPredictorProblemType.MUTATOR_FAILURE, e);
            }
        }
        if (featureRecord.isStopped()) {
            throw new FloorPredictorException(String.format("One of mutators stopped: " + featureRecord.getStopInfo(), new Object[0]), FloorPredictorProblemType.MUTATOR_STOPPED);
        }
        featureRecord.put(Const.FLR, Float.valueOf(0.0f));
        raiseOnMissingFieldsInFeatureRecord(featureRecord);
        double[] sigmoidArray = sigmoidArray(predictArrayByModel(featureRecord, Const.HAS_GOOD_BIDS));
        double[] sigmoidArray2 = sigmoidArray(predictArrayByModel(featureRecord, Const.IS_SPEND_NURL));
        double[] predictArrayByModel = predictArrayByModel(featureRecord, Const.SPEND_NURL);
        FloorSelectionStrategy floorSelectionStrategy = mLParams.floorSelectionStrategy();
        PostprocessingParams postprocessingParams = mLParams.postprocessingParams();
        switch (floorSelectionStrategy) {
            case SMART_WIN:
                return Postprocessing.calcSmartWin(this.cuts, sigmoidArray, sigmoidArray2, predictArrayByModel, postprocessingParams);
            case SMART_SPEND:
                return Postprocessing.calcSmartSpend(this.cuts, sigmoidArray, sigmoidArray2, predictArrayByModel, postprocessingParams);
            case REACH_WIN_PROB:
                return Postprocessing.calcReachWinProb(this.cuts, sigmoidArray2, postprocessingParams.medProbThreshold());
            case PLAIN_WIN:
                return Postprocessing.calcPlainWin(this.cuts, sigmoidArray, sigmoidArray2);
            case PLAIN_SPEND:
                return Postprocessing.calcPlainSpend(this.cuts, sigmoidArray, sigmoidArray2, predictArrayByModel);
            case MIDDLE_SMART_WIN_SMART_SPEND:
                return Postprocessing.calcMiddleSmartSpendSmartWin(this.cuts, sigmoidArray, sigmoidArray2, predictArrayByModel, postprocessingParams);
            default:
                throw new NotImplementedException(String.format("%s is not implemented yet", floorSelectionStrategy));
        }
    }

    /* JADX WARN: Type inference failed for: r0v19, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r0v28, types: [java.lang.Object[], java.lang.String[], java.lang.String[][]] */
    private double[] predictArrayByModel(FeatureRecord featureRecord, String str) throws FloorPredictorException {
        CatBoostModel catBoostModel = this.catBoostModels.get(str);
        ModelFeaturesConfig modelFeaturesConfig = this.modelFeaturesConfigs.get(str);
        try {
            float[] floatValues = featureRecord.getFloatValues(modelFeaturesConfig.getContVars());
            try {
                String[] stringValues = featureRecord.getStringValues(modelFeaturesConfig.getCatVars());
                int length = this.cuts.length;
                ?? r0 = new float[length];
                int intValue = this.indexOfFloorByModel.get(str).intValue();
                for (int i = 0; i < length; i++) {
                    float[] copyOf = Arrays.copyOf(floatValues, floatValues.length);
                    copyOf[intValue] = this.cuts[i];
                    r0[i] = copyOf;
                }
                ?? r02 = new String[length];
                Arrays.fill((Object[]) r02, stringValues);
                try {
                    double[] copyRowMajorPredictions = catBoostModel.predict((float[][]) r0, (String[][]) r02).copyRowMajorPredictions();
                    if (Const.SPEND_NURL.equals(str)) {
                        for (int i2 = 0; i2 < copyRowMajorPredictions.length; i2++) {
                            copyRowMajorPredictions[i2] = Math.exp(copyRowMajorPredictions[i2]);
                        }
                    }
                    return copyRowMajorPredictions;
                } catch (CatBoostError e) {
                    throw new FloorPredictorException(String.format("Failed to do CatBoost prediction for model %s", str), FloorPredictorProblemType.CATBOOST_FAILED_TO_PREDICT, e);
                }
            } catch (IllegalArgumentException e2) {
                throw new FloorPredictorException(String.format("Model %s requires feature, but it is missing in the input", str), FloorPredictorProblemType.FEATURE_MISSING, e2);
            }
        } catch (ClassCastException e3) {
            throw new FloorPredictorException(String.format("Model %s, input has a feature which can't be converted to float32", str), FloorPredictorProblemType.FEATURE_NUMERICAL_NOT_FLOAT32, e3);
        } catch (IllegalArgumentException e4) {
            throw new FloorPredictorException(String.format("Model %s requires feature, but it is missing in the input", str), FloorPredictorProblemType.FEATURE_MISSING, e4);
        }
    }

    public static double[] sigmoidArray(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = 1.0d / (1.0d + Math.exp(-dArr[i]));
        }
        return dArr2;
    }

    public HashSet<String> getInputAllFields() {
        return this.inputAllFields;
    }

    public List<String> getFieldsToBeSubmittedByBackend() {
        HashSet hashSet = new HashSet(this.inputAllFields);
        Iterator<Mutator> it = this.mutators.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().infoInputFeatures());
        }
        Iterator<Mutator> it2 = this.mutators.iterator();
        while (it2.hasNext()) {
            hashSet.removeAll(it2.next().infoCalcFeatures());
        }
        ArrayList arrayList = new ArrayList(hashSet);
        Collections.sort(arrayList);
        return arrayList;
    }

    public void setMLParams(MLParams mLParams) {
        this.mlParams = mLParams;
    }

    public void setMLParams(String str) {
        this.mlParams = MLParamsBuilder.get(str);
    }
}
