/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.substmodel;

import dr.evomodel.substmodel.BaseSubstitutionModel;
import dr.evomodel.substmodel.ComplexSubstitutionModel;
import dr.evomodel.substmodel.DifferentiableSubstitutionModel;
import dr.evomodel.substmodel.DifferentialMassProvider;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.WrappedMatrix;

public class DifferentiableSubstitutionModelUtil {
    static final double threshold = 1.0E-10;
    private static final boolean CHECK_COMMUTABILITY = false;
    private static final double COMMUTABILITY_CHECK_THRESHOLD = 0.01;

    static double[] getApproximateDifferentialMassMatrix(double d, WrappedMatrix wrappedMatrix) {
        int n = wrappedMatrix.getDim();
        double[] dArray = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray[i] = d * wrappedMatrix.get(i);
        }
        return dArray;
    }

    static double[] getExactDifferentialMassMatrix(double d, WrappedMatrix wrappedMatrix, EigenDecomposition eigenDecomposition) {
        int n;
        int n2 = wrappedMatrix.getMajorDim();
        double[] dArray = eigenDecomposition.getEigenValues();
        WrappedMatrix.Raw raw = new WrappedMatrix.Raw(eigenDecomposition.getEigenVectors(), 0, n2, n2);
        WrappedMatrix.Raw raw2 = new WrappedMatrix.Raw(eigenDecomposition.getInverseEigenVectors(), 0, n2, n2);
        DifferentiableSubstitutionModelUtil.getTripleMatrixMultiplication(n2, raw2, wrappedMatrix, raw);
        DifferentiableSubstitutionModelUtil.setZeros(wrappedMatrix);
        for (int i = 0; i < n2; ++i) {
            for (n = 0; n < n2; ++n) {
                if (i == n || dArray[i] == dArray[n]) {
                    wrappedMatrix.set(i, n, wrappedMatrix.get(i, n) * d);
                    continue;
                }
                wrappedMatrix.set(i, n, wrappedMatrix.get(i, n) == 0.0 ? 0.0 : wrappedMatrix.get(i, n) * (1.0 - Math.exp((dArray[n] - dArray[i]) * d)) / (dArray[i] - dArray[n]));
            }
        }
        DifferentiableSubstitutionModelUtil.getTripleMatrixMultiplication(n2, raw, wrappedMatrix, raw2);
        double[] dArray2 = new double[n2 * n2];
        int n3 = n2 * n2;
        for (n = 0; n < n3; ++n) {
            dArray2[n] = wrappedMatrix.get(n);
        }
        return dArray2;
    }

    static double[] getAffineDifferentialMassMatrix(double d, WrappedMatrix wrappedMatrix, EigenDecomposition eigenDecomposition) {
        double[] dArray = DifferentiableSubstitutionModelUtil.getApproximateDifferentialMassMatrix(d, wrappedMatrix);
        int n = wrappedMatrix.getMajorDim();
        assert (n == wrappedMatrix.getMinorDim());
        double[] dArray2 = new double[n * n];
        int n2 = DifferentiableSubstitutionModelUtil.findZeroEigenvalueIndex(eigenDecomposition.getEigenValues(), n);
        double[] dArray3 = eigenDecomposition.getEigenVectors();
        double[] dArray4 = eigenDecomposition.getInverseEigenVectors();
        double[] dArray5 = DifferentiableSubstitutionModelUtil.getQQPlus(dArray3, dArray4, n2, n);
        double[] dArray6 = DifferentiableSubstitutionModelUtil.getOneMinusQPlusQ(dArray5, n);
        double[] dArray7 = new double[n * n];
        DifferentiableSubstitutionModelUtil.multiply(dArray7, dArray6, dArray, 1.0, n);
        DifferentiableSubstitutionModelUtil.multiply(dArray2, dArray7, dArray5, 1.0, n);
        for (int i = 0; i < dArray.length; ++i) {
            int n3 = i;
            dArray[n3] = dArray[n3] - dArray2[i];
        }
        return dArray;
    }

    private static void multiply(double[] dArray, double[] dArray2, double[] dArray3, double d, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                double d2 = 0.0;
                for (int k = 0; k < n; ++k) {
                    d2 += dArray2[i * n + k] * dArray3[k * n + j];
                }
                dArray[i * n + j] = d * d2;
            }
        }
    }

    private static void setZeros(WrappedMatrix wrappedMatrix) {
        for (int i = 0; i < wrappedMatrix.getMinorDim(); ++i) {
            for (int j = 0; j < wrappedMatrix.getMinorDim(); ++j) {
                if (!(Math.abs(wrappedMatrix.get(i, j)) < 1.0E-10)) continue;
                wrappedMatrix.set(i, j, 0.0);
            }
        }
    }

    private static void getTripleMatrixMultiplication(int n, ReadableMatrix readableMatrix, WrappedMatrix wrappedMatrix, ReadableMatrix readableMatrix2) {
        int n2;
        int n3;
        double[][] dArray = new double[n][n];
        for (n3 = 0; n3 < n; ++n3) {
            for (n2 = 0; n2 < n; ++n2) {
                for (int i = 0; i < n; ++i) {
                    double[] dArray2 = dArray[n3];
                    int n4 = n2;
                    dArray2[n4] = dArray2[n4] + wrappedMatrix.get(n3, i) * readableMatrix2.get(i, n2);
                }
            }
        }
        for (n3 = 0; n3 < n; ++n3) {
            for (n2 = 0; n2 < n; ++n2) {
                double d = 0.0;
                for (int i = 0; i < n; ++i) {
                    d += readableMatrix.get(n3, i) * dArray[i][n2];
                }
                wrappedMatrix.set(n3, n2, d);
            }
        }
    }

    public static WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, BaseSubstitutionModel baseSubstitutionModel) {
        if (!(baseSubstitutionModel instanceof DifferentiableSubstitutionModel)) {
            throw new RuntimeException("Not supported!");
        }
        double d = baseSubstitutionModel.setupMatrix();
        int n = baseSubstitutionModel.getDataType().getStateCount();
        int n2 = baseSubstitutionModel.getRateCount(n);
        double[] dArray = new double[n * n];
        baseSubstitutionModel.getInfinitesimalMatrix(dArray);
        double[] dArray2 = new double[n2];
        ((DifferentiableSubstitutionModel)((Object)baseSubstitutionModel)).setupDifferentialRates(wrtParameter, dArray2, d);
        double[] dArray3 = new double[n];
        ((DifferentiableSubstitutionModel)((Object)baseSubstitutionModel)).setupDifferentialFrequency(wrtParameter, dArray3);
        double[][] dArray4 = new double[n][n];
        DifferentiableSubstitutionModelUtil.setupQDerivative(baseSubstitutionModel, dArray2, dArray3, dArray4);
        baseSubstitutionModel.makeValid(dArray4, n);
        double d2 = ((DifferentiableSubstitutionModel)((Object)baseSubstitutionModel)).getWeightedNormalizationGradient(wrtParameter, dArray4, dArray3);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                double[] dArray5 = dArray4[i];
                int n3 = j;
                dArray5[n3] = dArray5[n3] - dArray[i * n + j] * d2;
            }
        }
        WrappedMatrix.ArrayOfArray arrayOfArray = new WrappedMatrix.ArrayOfArray(dArray4);
        return arrayOfArray;
    }

    private static void setupQDerivative(BaseSubstitutionModel baseSubstitutionModel, double[] dArray, double[] dArray2, double[][] dArray3) {
        if (baseSubstitutionModel instanceof ComplexSubstitutionModel) {
            double d;
            int n;
            int n2;
            int n3 = 0;
            int n4 = dArray2.length;
            for (n2 = 0; n2 < n4; ++n2) {
                for (n = n2 + 1; n < n4; ++n) {
                    d = dArray[n3++];
                    dArray3[n2][n] = d * dArray2[n];
                }
            }
            for (n = 0; n < n4; ++n) {
                for (n2 = n + 1; n2 < n4; ++n2) {
                    d = dArray[n3++];
                    dArray3[n2][n] = d * dArray2[n];
                }
            }
        } else {
            baseSubstitutionModel.setupQMatrix(dArray, dArray2, dArray3);
        }
    }

    public static boolean checkCommutability(WrappedMatrix wrappedMatrix, WrappedMatrix wrappedMatrix2) {
        WrappedMatrix wrappedMatrix3 = DifferentiableSubstitutionModelUtil.product(wrappedMatrix, wrappedMatrix2);
        WrappedMatrix wrappedMatrix4 = DifferentiableSubstitutionModelUtil.product(wrappedMatrix2, wrappedMatrix);
        boolean bl = true;
        for (int i = 0; i < wrappedMatrix3.getDim(); ++i) {
            if (!(Math.abs(2.0 * (wrappedMatrix3.get(i) - wrappedMatrix4.get(i)) / (wrappedMatrix3.get(i) + wrappedMatrix4.get(i))) > 0.01)) continue;
            bl = false;
        }
        return bl;
    }

    private static WrappedMatrix product(WrappedMatrix wrappedMatrix, WrappedMatrix wrappedMatrix2) {
        int n = wrappedMatrix.getMajorDim();
        int n2 = wrappedMatrix2.getMinorDim();
        int n3 = wrappedMatrix.getMinorDim();
        if (n3 != wrappedMatrix2.getMajorDim()) {
            return null;
        }
        WrappedMatrix.Raw raw = new WrappedMatrix.Raw(new double[n * n2], 0, n, n2);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n2; ++j) {
                double d = 0.0;
                for (int k = 0; k < n3; ++k) {
                    d += wrappedMatrix.get(i, k) * wrappedMatrix2.get(k, j);
                }
                raw.set(i, j, d);
            }
        }
        return raw;
    }

    private static int findZeroEigenvalueIndex(double[] dArray, int n) {
        for (int i = 0; i < n; ++i) {
            if (dArray[i] != 0.0) continue;
            return i;
        }
        return -1;
    }

    private static double[] getOneMinusQPlusQ(double[] dArray, int n) {
        double[] dArray2 = new double[n * n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                dArray2[DifferentiableSubstitutionModelUtil.index12((int)i, (int)j, (int)n)] = i == j ? 1.0 - dArray[DifferentiableSubstitutionModelUtil.index12(i, j, n)] : -dArray[DifferentiableSubstitutionModelUtil.index12(i, j, n)];
            }
        }
        return dArray2;
    }

    public static double[] getQQPlus(double[] dArray, double[] dArray2, int n, int n2) {
        double[] dArray3 = new double[n2 * n2];
        for (int i = 0; i < n2; ++i) {
            for (int j = 0; j < n2; ++j) {
                double d = 0.0;
                for (int k = 0; k < n2; ++k) {
                    if (k == n) continue;
                    d += dArray[i * n2 + k] * dArray2[k * n2 + j];
                }
                dArray3[i * n2 + j] = d;
            }
        }
        return dArray3;
    }

    public static double[] getQQPlus(double[] dArray, double[] dArray2, double[] dArray3, int n) {
        int n2;
        double d;
        int n3;
        double[] dArray4 = new double[n * n];
        for (int i = 0; i < n; ++i) {
            for (n3 = 0; n3 < n; ++n3) {
                d = 0.0;
                for (n2 = 0; n2 < n; ++n2) {
                    if (dArray3[n2] == 0.0) continue;
                    d += dArray[i * n + n2] * dArray2[n2 * n + n3];
                }
                dArray4[i * n + n3] = d;
            }
        }
        double[] dArray5 = new double[n];
        for (n3 = 0; n3 < n; ++n3) {
            d = 0.0;
            for (n2 = 0; n2 < n; ++n2) {
                if (dArray3[n2] == 0.0) continue;
                d += dArray[n2] * dArray2[n2 * n + n3];
            }
            dArray5[n3] = d;
        }
        dArray5[0] = dArray5[0] - 1.0;
        return dArray4;
    }

    private static int index12(int n, int n2, int n3) {
        return n * n3 + n2;
    }

    private static int index21(int n, int n2, int n3) {
        return n2 * n3 + n;
    }
}

