package quipu.maxent;

import cern.colt.function.DoubleFunction;
import cern.colt.function.IntDoubleProcedure;
import cern.colt.list.IntArrayList;
import cern.colt.map.OpenIntDoubleHashMap;
import cern.colt.map.OpenIntIntHashMap;
import java.io.BufferedWriter;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:quipu/maxent/GIS.class */
public class GIS {
    private static int numTokens;
    private static int numPreds;
    private static int numOutcomes;
    private static int TID;
    private static int PID;
    private static int OID;
    private static double PABISUM;
    private static int[][] contexts;
    private static int[] numTimesEventsSeen;
    private static String[] outcomeLabels;
    private static String[] predLabels;
    private static OpenIntDoubleHashMap[] observedExpects;
    private static OpenIntDoubleHashMap[] params;
    private static OpenIntDoubleHashMap[] modifiers;
    private static IntArrayList predkeys;
    private static double constantInverse;
    private static double cfObservedExpect;
    private static double CFMOD;
    private static OpenIntIntHashMap[] cfvals;
    private static OpenIntDoubleHashMap[] pabi;
    private static int constant = 1;
    private static double correctionParam = 0.0d;
    private static DoubleFunction backToZeros = new DoubleFunction() { // from class: quipu.maxent.GIS.1
        public double apply(double d) {
            return 0.0d;
        }
    };
    private static DoubleFunction normalizePABI = new DoubleFunction() { // from class: quipu.maxent.GIS.2
        public double apply(double d) {
            return d / GIS.PABISUM;
        }
    };
    private static IntDoubleProcedure addParamsToPABI = new IntDoubleProcedure() { // from class: quipu.maxent.GIS.3
        public boolean apply(int i, double d) {
            GIS.pabi[GIS.TID].put(i, GIS.pabi[GIS.TID].get(i) + d);
            return true;
        }
    };
    private static IntDoubleProcedure addCorrectionToPABIandExponentiate = new IntDoubleProcedure() { // from class: quipu.maxent.GIS.4
        public boolean apply(int i, double d) {
            double exp = Math.exp(d + (GIS.correctionParam * GIS.cfvals[GIS.TID].get(i)));
            GIS.PABISUM += exp;
            GIS.pabi[GIS.TID].put(i, exp);
            return true;
        }
    };
    private static IntDoubleProcedure updateModifiers = new IntDoubleProcedure() { // from class: quipu.maxent.GIS.5
        public boolean apply(int i, double d) {
            GIS.modifiers[GIS.PID].put(i, d + (GIS.pabi[GIS.TID].get(i) * GIS.numTimesEventsSeen[GIS.TID]));
            return true;
        }
    };
    private static IntDoubleProcedure updateParams = new IntDoubleProcedure() { // from class: quipu.maxent.GIS.6
        public boolean apply(int i, double d) {
            GIS.params[GIS.PID].put(i, d + (GIS.constantInverse * (GIS.observedExpects[GIS.PID].get(i) - Math.log(GIS.modifiers[GIS.PID].get(i)))));
            return true;
        }
    };
    private static IntDoubleProcedure updateCorrectionFeatureModifier = new IntDoubleProcedure() { // from class: quipu.maxent.GIS.7
        public boolean apply(int i, double d) {
            GIS.CFMOD += d * GIS.cfvals[GIS.TID].get(i) * GIS.numTimesEventsSeen[GIS.TID];
            return true;
        }
    };

    public static void trainModel(String str, DataIndexer dataIndexer, int i) {
        trainModel("", str, dataIndexer, i);
    }

    public static void trainModel(String str, String str2, DataIndexer dataIndexer, int i) {
        System.out.println("Incorporating indexed data for training...  ");
        contexts = dataIndexer.contexts;
        numTimesEventsSeen = dataIndexer.numTimesEventsSeen;
        numTokens = contexts.length;
        TID = 0;
        while (TID < contexts.length) {
            if (contexts[TID].length > constant) {
                constant = contexts[TID].length;
            }
            TID++;
        }
        constantInverse = 1.0d / constant;
        outcomeLabels = dataIndexer.outcomeLabels;
        numOutcomes = outcomeLabels.length;
        predLabels = dataIndexer.predLabels;
        numPreds = predLabels.length;
        System.out.println(new StringBuffer("\tNumber of Event Tokens: ").append(numTokens).toString());
        System.out.println(new StringBuffer("\t    Number of Outcomes: ").append(numOutcomes).toString());
        System.out.println(new StringBuffer("\t  Number of Predicates: ").append(numPreds).toString());
        int[][] iArr = new int[numPreds][numOutcomes];
        TID = 0;
        while (TID < numTokens) {
            for (int i2 = 0; i2 < contexts[TID].length; i2++) {
                int[] iArr2 = iArr[contexts[TID][i2]];
                int i3 = dataIndexer.outcomeList[TID];
                iArr2[i3] = iArr2[i3] + numTimesEventsSeen[TID];
            }
            TID++;
        }
        params = new OpenIntDoubleHashMap[numPreds];
        modifiers = new OpenIntDoubleHashMap[numPreds];
        observedExpects = new OpenIntDoubleHashMap[numPreds];
        PID = 0;
        while (PID < numPreds) {
            params[PID] = new OpenIntDoubleHashMap();
            modifiers[PID] = new OpenIntDoubleHashMap();
            observedExpects[PID] = new OpenIntDoubleHashMap();
            OID = 0;
            while (OID < numOutcomes) {
                if (iArr[PID][OID] > 0) {
                    params[PID].put(OID, 0.0d);
                    modifiers[PID].put(OID, 0.0d);
                    observedExpects[PID].put(OID, Math.log(iArr[PID][OID]));
                }
                OID++;
            }
            params[PID].trimToSize();
            modifiers[PID].trimToSize();
            observedExpects[PID].trimToSize();
            PID++;
        }
        System.out.println("...done.");
        System.out.print("Computing correction feature matrix... ");
        cfvals = new OpenIntIntHashMap[numTokens];
        pabi = new OpenIntDoubleHashMap[numTokens];
        TID = 0;
        while (TID < numTokens) {
            cfvals[TID] = new OpenIntIntHashMap();
            pabi[TID] = new OpenIntDoubleHashMap();
            for (int i4 = 0; i4 < contexts[TID].length; i4++) {
                PID = contexts[TID][i4];
                predkeys = params[PID].keys();
                for (int i5 = 0; i5 < predkeys.size(); i5++) {
                    OID = predkeys.get(i5);
                    if (cfvals[TID].containsKey(OID)) {
                        cfvals[TID].put(OID, cfvals[TID].get(OID) + 1);
                    } else {
                        cfvals[TID].put(OID, 1);
                        pabi[TID].put(OID, 0.0d);
                    }
                }
            }
            cfvals[TID].trimToSize();
            pabi[TID].trimToSize();
            TID++;
        }
        TID = 0;
        while (TID < numTokens) {
            predkeys = cfvals[TID].keys();
            for (int i6 = 0; i6 < predkeys.size(); i6++) {
                OID = predkeys.get(i6);
                cfvals[TID].put(OID, constant - cfvals[TID].get(OID));
            }
            TID++;
        }
        int i7 = 0;
        TID = 0;
        while (TID < numTokens) {
            i7 += (constant - contexts[TID].length) * numTimesEventsSeen[TID];
            TID++;
        }
        cfObservedExpect = Math.log(i7);
        System.out.println("done.");
        System.out.println("Computing model parameters...");
        findParameters(i);
        System.out.println(new StringBuffer().append("Writing model to disk: ").append(str).append(str2).toString());
        try {
            writeModel(new StringBuffer().append(str).append(str2).toString());
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static void findParameters(int i) {
        System.out.println(new StringBuffer().append("Performing ").append(i).append(" iterations.").toString());
        for (int i2 = 1; i2 <= i; i2++) {
            if (i2 < 10) {
                System.out.print(new StringBuffer().append("  ").append(i2).append(":  ").toString());
            } else if (i2 < 100) {
                System.out.print(new StringBuffer().append(" ").append(i2).append(":  ").toString());
            } else {
                System.out.print(new StringBuffer().append(i2).append(":  ").toString());
            }
            nextIteration();
        }
        observedExpects = null;
        pabi = null;
        modifiers = null;
        cfvals = null;
        numTimesEventsSeen = null;
        contexts = null;
    }

    private static void nextIteration() {
        CFMOD = 0.0d;
        TID = 0;
        while (TID < numTokens) {
            pabi[TID].assign(backToZeros);
            for (int i = 0; i < contexts[TID].length; i++) {
                params[contexts[TID][i]].forEachPair(addParamsToPABI);
            }
            PABISUM = 0.0d;
            pabi[TID].forEachPair(addCorrectionToPABIandExponentiate);
            if (PABISUM > 0.0d) {
                pabi[TID].assign(normalizePABI);
            }
            pabi[TID].forEachPair(updateCorrectionFeatureModifier);
            TID++;
        }
        System.out.print(".");
        TID = 0;
        while (TID < numTokens) {
            for (int i2 = 0; i2 < contexts[TID].length; i2++) {
                PID = contexts[TID][i2];
                modifiers[PID].forEachPair(updateModifiers);
            }
            TID++;
        }
        System.out.print(".");
        PID = 0;
        while (PID < numPreds) {
            params[PID].forEachPair(updateParams);
            modifiers[PID].assign(backToZeros);
            PID++;
        }
        correctionParam += constantInverse * (cfObservedExpect - Math.log(CFMOD));
        System.out.println(".");
    }

    private static void writeModel(String str) throws IOException {
        ComparablePredicate[] comparablePredicateArr = new ComparablePredicate[numPreds];
        int i = 0;
        PID = 0;
        while (PID < numPreds) {
            predkeys = params[PID].keys();
            predkeys.sort();
            int size = predkeys.size();
            i += size;
            int[] iArr = new int[size];
            double[] dArr = new double[size];
            int i2 = 0;
            for (int i3 = 0; i3 < predkeys.size(); i3++) {
                OID = predkeys.get(i3);
                iArr[i2] = OID;
                dArr[i2] = params[PID].get(OID);
                i2++;
            }
            comparablePredicateArr[PID] = new ComparablePredicate(predLabels[PID], iArr, dArr);
            PID++;
        }
        Arrays.sort(comparablePredicateArr);
        ComparablePredicate comparablePredicate = comparablePredicateArr[0];
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i4 = 0; i4 < comparablePredicateArr.length; i4++) {
            if (comparablePredicate.compareTo(comparablePredicateArr[i4]) == 0) {
                arrayList2.add(comparablePredicateArr[i4]);
            } else {
                comparablePredicate = comparablePredicateArr[i4];
                arrayList.add(arrayList2);
                arrayList2 = new ArrayList();
                arrayList2.add(comparablePredicateArr[i4]);
            }
        }
        arrayList.add(arrayList2);
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(new StringBuffer().append(str).append(".mei.gz").toString()))));
        bufferedWriter.write("GIS");
        bufferedWriter.newLine();
        bufferedWriter.write(Integer.toString(constant));
        bufferedWriter.newLine();
        bufferedWriter.write(Double.toString(correctionParam));
        bufferedWriter.newLine();
        bufferedWriter.write(Integer.toString(numOutcomes));
        bufferedWriter.newLine();
        for (int i5 = 0; i5 < numOutcomes; i5++) {
            bufferedWriter.write(outcomeLabels[i5]);
            bufferedWriter.newLine();
        }
        bufferedWriter.write(Integer.toString(arrayList.size()));
        bufferedWriter.newLine();
        for (int i6 = 0; i6 < arrayList.size(); i6++) {
            ArrayList arrayList3 = (ArrayList) arrayList.get(i6);
            bufferedWriter.write(new StringBuffer().append(arrayList3.size()).append(((ComparablePredicate) arrayList3.get(0)).toString()).toString());
            bufferedWriter.newLine();
        }
        bufferedWriter.write(Integer.toString(numPreds));
        bufferedWriter.newLine();
        for (int i7 = 0; i7 < numPreds; i7++) {
            bufferedWriter.write(comparablePredicateArr[i7].name);
            bufferedWriter.newLine();
        }
        bufferedWriter.flush();
        bufferedWriter.close();
        DataOutputStream dataOutputStream = new DataOutputStream(new GZIPOutputStream(new FileOutputStream(new StringBuffer().append(str).append(".mep.gz").toString())));
        int i8 = 0;
        for (int i9 = 0; i9 < numPreds; i9++) {
            for (int i10 = 0; i10 < comparablePredicateArr[i9].params.length; i10++) {
                i8++;
                dataOutputStream.writeDouble(comparablePredicateArr[i9].params[i10]);
            }
        }
        dataOutputStream.flush();
        dataOutputStream.close();
    }
}
