/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.parsimony;

import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.Patterns;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.parsimony.Parsimony;
import dr.evolution.parsimony.ParsimonyCriterion;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;

public class FitchParsimony
implements ParsimonyCriterion {
    private final int stateCount;
    private final boolean gapsAreStates;
    private boolean[][][] stateSets;
    private int[][] states;
    private Tree tree = null;
    private final PatternList patterns;
    private boolean hasCalculatedSteps = false;
    private boolean hasRecontructedStates = false;
    private final double[] siteScores;

    public PatternList getPatterns() {
        return this.patterns;
    }

    public FitchParsimony(PatternList patternList, boolean bl) {
        if (patternList == null) {
            throw new IllegalArgumentException("The patterns cannot be null");
        }
        this.gapsAreStates = bl;
        this.stateCount = bl ? patternList.getDataType().getStateCount() + 1 : patternList.getDataType().getStateCount();
        this.patterns = patternList;
        this.siteScores = new double[patternList.getPatternCount()];
    }

    @Override
    public double[] getSiteScores(Tree tree) {
        if (tree == null) {
            throw new IllegalArgumentException("The tree cannot be null");
        }
        if (this.tree == null || this.tree != tree) {
            this.initialize(tree);
        }
        if (!this.hasCalculatedSteps) {
            for (int i = 0; i < this.siteScores.length; ++i) {
                this.siteScores[i] = 0.0;
            }
            this.calculateSteps(tree, tree.getRoot(), this.patterns);
            this.hasCalculatedSteps = true;
        }
        return this.siteScores;
    }

    @Override
    public double getScore(Tree tree) {
        this.getSiteScores(tree);
        double d = 0.0;
        for (int i = 0; i < this.patterns.getPatternCount(); ++i) {
            d += this.siteScores[i] * this.patterns.getPatternWeight(i);
        }
        return d;
    }

    @Override
    public int[] getStates(Tree tree, NodeRef nodeRef) {
        if (!TreeUtils.isBinary(tree)) {
            throw new IllegalArgumentException("The Fitch algorithm can only reconstruct ancestral states on binary trees");
        }
        this.getSiteScores(tree);
        if (!this.hasRecontructedStates) {
            this.reconstructStates(tree, tree.getRoot(), null);
            this.hasRecontructedStates = true;
        }
        return this.states[nodeRef.getNumber()];
    }

    public void initialize(Tree tree) {
        this.tree = tree;
        this.hasCalculatedSteps = false;
        this.hasRecontructedStates = false;
        this.stateSets = new boolean[tree.getNodeCount()][this.patterns.getPatternCount()][];
        this.states = new int[tree.getNodeCount()][this.patterns.getPatternCount()];
        for (int i = 0; i < this.patterns.getPatternCount(); ++i) {
            int[] nArray = this.patterns.getPattern(i);
            for (int j = 0; j < tree.getExternalNodeCount(); ++j) {
                NodeRef nodeRef = tree.getExternalNode(j);
                int n = nArray[this.patterns.getTaxonIndex(tree.getNodeTaxon(nodeRef).getId())];
                if (this.gapsAreStates) {
                    this.stateSets[j][i] = new boolean[this.stateCount];
                    if (this.patterns.getDataType().isGapState(n)) {
                        this.stateSets[j][i][this.stateCount - 1] = true;
                        continue;
                    }
                    boolean[] blArray = this.patterns.getDataType().getStateSet(n);
                    for (int k = 0; k < blArray.length; ++k) {
                        this.stateSets[j][i][k] = blArray[k];
                    }
                    continue;
                }
                this.stateSets[j][i] = this.patterns.getDataType().getStateSet(n);
            }
        }
    }

    private void calculateSteps(Tree tree, NodeRef nodeRef, PatternList patternList) {
        if (!tree.isExternal(nodeRef)) {
            int n;
            for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
                this.calculateSteps(tree, tree.getChild(nodeRef, n), patternList);
            }
            for (n = 0; n < patternList.getPatternCount(); ++n) {
                boolean[] blArray = this.stateSets[tree.getChild(nodeRef, 0).getNumber()][n];
                boolean[] blArray2 = this.stateSets[tree.getChild(nodeRef, 0).getNumber()][n];
                for (int i = 1; i < tree.getChildCount(nodeRef); ++i) {
                    blArray = FitchParsimony.union(blArray, this.stateSets[tree.getChild(nodeRef, i).getNumber()][n]);
                    blArray2 = FitchParsimony.intersection(blArray2, this.stateSets[tree.getChild(nodeRef, i).getNumber()][n]);
                }
                if (FitchParsimony.size(blArray2) > 0) {
                    this.stateSets[nodeRef.getNumber()][n] = blArray2;
                    continue;
                }
                this.stateSets[nodeRef.getNumber()][n] = blArray;
                int n2 = n;
                this.siteScores[n2] = this.siteScores[n2] + 1.0;
            }
        }
    }

    private void reconstructStates(Tree tree, NodeRef nodeRef, int[] nArray) {
        int n;
        for (n = 0; n < this.patterns.getPatternCount(); ++n) {
            this.states[nodeRef.getNumber()][n] = nArray != null && this.stateSets[nodeRef.getNumber()][n][nArray[n]] ? nArray[n] : FitchParsimony.firstIndex(this.stateSets[nodeRef.getNumber()][n]);
        }
        for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
            this.reconstructStates(tree, tree.getChild(nodeRef, n), this.states[nodeRef.getNumber()]);
        }
    }

    private static boolean[] union(boolean[] blArray, boolean[] blArray2) {
        boolean[] blArray3 = new boolean[blArray.length];
        for (int i = 0; i < blArray3.length; ++i) {
            blArray3[i] = blArray[i] || blArray2[i];
        }
        return blArray3;
    }

    private static boolean[] intersection(boolean[] blArray, boolean[] blArray2) {
        boolean[] blArray3 = new boolean[blArray.length];
        for (int i = 0; i < blArray3.length; ++i) {
            blArray3[i] = blArray[i] && blArray2[i];
        }
        return blArray3;
    }

    private static int firstIndex(boolean[] blArray) {
        for (int i = 0; i < blArray.length; ++i) {
            if (!blArray[i]) continue;
            return i;
        }
        return -1;
    }

    private static int size(boolean[] blArray) {
        int n = 0;
        for (int i = 0; i < blArray.length; ++i) {
            if (!blArray[i]) continue;
            ++n;
        }
        return n;
    }

    public static void main(String[] stringArray) {
        FlexibleNode flexibleNode = new FlexibleNode(new Taxon("tip1"));
        FlexibleNode flexibleNode2 = new FlexibleNode(new Taxon("tip2"));
        FlexibleNode flexibleNode3 = new FlexibleNode(new Taxon("tip3"));
        FlexibleNode flexibleNode4 = new FlexibleNode(new Taxon("tip4"));
        FlexibleNode flexibleNode5 = new FlexibleNode(new Taxon("tip5"));
        FlexibleNode flexibleNode6 = new FlexibleNode();
        flexibleNode6.addChild(flexibleNode);
        flexibleNode6.addChild(flexibleNode2);
        FlexibleNode flexibleNode7 = new FlexibleNode();
        flexibleNode7.addChild(flexibleNode4);
        flexibleNode7.addChild(flexibleNode5);
        FlexibleNode flexibleNode8 = new FlexibleNode();
        flexibleNode8.addChild(flexibleNode3);
        flexibleNode8.addChild(flexibleNode7);
        FlexibleNode flexibleNode9 = new FlexibleNode();
        flexibleNode9.addChild(flexibleNode6);
        flexibleNode9.addChild(flexibleNode8);
        FlexibleTree flexibleTree = new FlexibleTree(flexibleNode9);
        Patterns patterns = new Patterns(Nucleotides.INSTANCE, flexibleTree);
        patterns.addPattern(new int[]{1, 0, 1, 2, 2});
        FitchParsimony fitchParsimony = new FitchParsimony(patterns, false);
        System.out.println("No. Steps = " + fitchParsimony.getScore(flexibleTree));
        System.out.println(" state(node1) = " + fitchParsimony.getStates(flexibleTree, flexibleNode6)[0]);
        System.out.println(" state(node2) = " + fitchParsimony.getStates(flexibleTree, flexibleNode7)[0]);
        System.out.println(" state(node3) = " + fitchParsimony.getStates(flexibleTree, flexibleNode8)[0]);
        System.out.println(" state(root) = " + fitchParsimony.getStates(flexibleTree, flexibleNode9)[0]);
        System.out.println("\nParsimony static methods:");
        System.out.println("No. Steps = " + Parsimony.getParsimonySteps(flexibleTree, patterns));
        Parsimony.reconstructParsimonyStates(flexibleTree, patterns);
        System.out.println(" state(node1) = " + flexibleTree.getNodeAttribute(flexibleNode6, "rstate1"));
        System.out.println(" state(node2) = " + flexibleTree.getNodeAttribute(flexibleNode7, "rstate1"));
        System.out.println(" state(node3) = " + flexibleTree.getNodeAttribute(flexibleNode8, "rstate1"));
        System.out.println(" state(root) = " + flexibleTree.getNodeAttribute(flexibleNode9, "rstate1"));
    }
}

