package aprove.InputModules.Programs.fp;

import aprove.Framework.Algebra.Orders.Utility.POLO.Interpretation;
import aprove.Framework.Algebra.Terms.FunctionApplication;
import aprove.Framework.Algebra.Terms.Position;
import aprove.Framework.Algebra.Terms.Substitution;
import aprove.Framework.Algebra.Terms.Term;
import aprove.Framework.Algebra.Terms.UnificationException;
import aprove.Framework.Rewriting.Rule;
import aprove.Framework.Syntax.ConstructorSymbol;
import aprove.Framework.Syntax.Symbol;
import aprove.Framework.Syntax.VariableSymbol;
import aprove.Framework.Typing.TypeContext;
import aprove.Framework.Typing.TypeTools;
import aprove.Framework.Utility.FreshVarGenerator;
import aprove.InputModules.Programs.Predef.IntegerPredef.IntegerTools;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
import java.util.logging.Logger;

/* loaded from: input_file:aprove/InputModules/Programs/fp/PatternDisjunctor.class */
public class PatternDisjunctor {
    private static Vector<Vector<Rule>> ifBlocks;
    private static Vector<Rule> curRulesVec;
    private static Vector<Rule> conRulesVec;
    private static Logger logger = Logger.getLogger("aprove.InputModules.Programs.fp.PatternDisjunctor");
    private static boolean containsInts;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:aprove/InputModules/Programs/fp/PatternDisjunctor$RulesIndexed.class */
    public static class RulesIndexed extends LinkedHashMap<Rule, Integer> {
        public RulesIndexed(LinkedHashSet<Rule> linkedHashSet) {
            int i = 1;
            Iterator<Rule> it = linkedHashSet.iterator();
            while (it.hasNext()) {
                if (add(it.next(), Integer.valueOf(i))) {
                    i++;
                }
            }
        }

        public RulesIndexed(RulesIndexed rulesIndexed) {
            for (Map.Entry<Rule, Integer> entry : rulesIndexed.entrySet()) {
                add(entry.getKey(), entry.getValue());
            }
        }

        public boolean add(Rule rule, Integer num) {
            Integer num2 = get(rule);
            if (num2 != null && num2.intValue() < num.intValue()) {
                return false;
            }
            super.put(rule, num);
            return true;
        }

        public void addAll(Set<Rule> set, Integer num) {
            Iterator<Rule> it = set.iterator();
            while (it.hasNext()) {
                add(it.next(), num);
            }
        }

        public LinkedHashSet<Rule> toLinkedSet() {
            LinkedHashSet<Rule> linkedHashSet = new LinkedHashSet<>();
            int i = 0;
            while (linkedHashSet.size() < size()) {
                i++;
                for (Map.Entry<Rule, Integer> entry : entrySet()) {
                    if (entry.getValue().intValue() == i) {
                        linkedHashSet.add(entry.getKey());
                    }
                }
            }
            return linkedHashSet;
        }

        public Set<Integer> getNumbers() {
            HashSet hashSet = new HashSet();
            Iterator<Map.Entry<Rule, Integer>> it = entrySet().iterator();
            while (it.hasNext()) {
                hashSet.add(it.next().getValue());
            }
            return hashSet;
        }
    }

    public static LinkedHashSet<Rule> makePatternsNonOverlapping(LinkedHashSet<Rule> linkedHashSet, TypeContext typeContext, boolean z) {
        if (linkedHashSet == null || linkedHashSet.isEmpty() || linkedHashSet.size() == 1) {
            return linkedHashSet;
        }
        containsInts = z;
        RulesIndexed rulesIndexed = new RulesIndexed(linkedHashSet);
        buildIfBlocks(linkedHashSet);
        boolean z2 = true;
        while (z2) {
            z2 = false;
            Vector vector = new Vector();
            Iterator it = rulesIndexed.keySet().iterator();
            while (true) {
                if (it.hasNext()) {
                    Rule rule = (Rule) it.next();
                    Term left = rule.getLeft();
                    Rule rule2 = null;
                    int i = 0;
                    Iterator it2 = vector.iterator();
                    while (true) {
                        if (!it2.hasNext()) {
                            break;
                        }
                        Rule rule3 = (Rule) it2.next();
                        if (rule3.getLeft() != rule.getLeft()) {
                            i = ((Integer) rulesIndexed.get(rule3)).intValue();
                            Rule replaceVariables = rule3.replaceVariables(new FreshVarGenerator(left));
                            if (replaceVariables.getLeft().isUnifiable(left) && !sameIfBlock(left, rule3.getLeft())) {
                                rule2 = replaceVariables;
                                break;
                            }
                        }
                    }
                    if (rule2 != null) {
                        processConflict(rulesIndexed, rule, rule2, i, typeContext);
                        z2 = true;
                        break;
                    }
                    vector.add(rule);
                }
            }
        }
        return rulesIndexed.toLinkedSet();
    }

    private static void buildIfBlocks(Set<Rule> set) {
        ifBlocks = new Vector<>();
        for (Rule rule : set) {
            if (rule.getConds().size() != 0) {
                boolean z = false;
                Iterator<Vector<Rule>> it = ifBlocks.iterator();
                while (it.hasNext()) {
                    Vector<Rule> next = it.next();
                    Iterator<Rule> it2 = next.iterator();
                    while (true) {
                        if (!it2.hasNext()) {
                            break;
                        }
                        if (rule.getLeft() == it2.next().getLeft()) {
                            next.add(rule);
                            z = true;
                            break;
                        }
                    }
                    if (z) {
                        break;
                    }
                }
                if (!z) {
                    ifBlocks.add(new Vector<>(Arrays.asList(rule)));
                }
            }
        }
    }

    private static boolean sameIfBlock(Term term, Term term2) {
        curRulesVec = null;
        conRulesVec = null;
        Iterator<Vector<Rule>> it = ifBlocks.iterator();
        while (it.hasNext()) {
            Vector<Rule> next = it.next();
            boolean z = false;
            boolean z2 = false;
            Iterator<Rule> it2 = next.iterator();
            while (it2.hasNext()) {
                Rule next2 = it2.next();
                if (next2.getLeft() == term) {
                    z = true;
                    curRulesVec = next;
                }
                if (next2.getLeft() == term2) {
                    z2 = true;
                    conRulesVec = next;
                }
                if (z && z2) {
                    return true;
                }
            }
        }
        return false;
    }

    private static void updateIfBlockRules(Set<Rule> set) {
        if (curRulesVec != null) {
            curRulesVec.addAll(set);
        }
    }

    private static void processConflict(RulesIndexed rulesIndexed, Rule rule, Rule rule2, int i, TypeContext typeContext) {
        Term left = rule.getLeft();
        int intValue = ((Integer) rulesIndexed.get(rule)).intValue();
        if (i > intValue) {
            logger.finer("switching: ruleIndex " + intValue + " with conflictingRuleIndex " + i + "\n");
            rule2 = rule;
            i = intValue;
            rule = rule2;
            intValue = i;
            left = rule.getLeft();
            curRulesVec = conRulesVec;
        }
        rulesIndexed.remove(rule);
        try {
            Term left2 = rule2.getLeft();
            Substitution unifies = left.unifies(left2);
            logger.finer("Conflict (mgu=" + unifies + "):\n");
            logger.finer("\t" + rule + "/" + intValue + " vs. " + rule2 + "/" + i + " varren: " + unifies.isVariableRenaming() + "\n");
            LinkedHashSet<Rule> resolveOverlap = resolveOverlap(rule, left2, typeContext);
            rulesIndexed.addAll(resolveOverlap, Integer.valueOf(intValue));
            updateIfBlockRules(resolveOverlap);
            logger.finer("\tnew rules:\n");
            Iterator<Rule> it = resolveOverlap.iterator();
            while (it.hasNext()) {
                logger.finer("\t\t" + it.next() + "\n");
            }
        } catch (UnificationException e) {
            throw new RuntimeException("Tried to resolve conflict on non-unifiable left hand sides.");
        }
    }

    private static LinkedHashSet<Rule> resolveOverlap(Rule rule, Term term, TypeContext typeContext) {
        Term left = rule.getLeft();
        try {
            Substitution unifies = left.unifies(term);
            LinkedHashSet<Rule> linkedHashSet = new LinkedHashSet<>();
            FreshVarGenerator freshVarGenerator = new FreshVarGenerator(left);
            for (VariableSymbol variableSymbol : unifies.getDomain()) {
                if (!unifies.get(variableSymbol).isVariable()) {
                    Set<Position> positionsWithSymbol = getPositionsWithSymbol(left, variableSymbol);
                    if (!positionsWithSymbol.isEmpty()) {
                        Position next = positionsWithSymbol.iterator().next();
                        int last = next.getLast();
                        Symbol symbol = left.getSubterm(next.pred()).getSymbol();
                        for (Term term2 : computeNonOverlappingTerms(unifies.get(variableSymbol), TypeTools.getFunctionArgAt(typeContext.getSingleTypeOf(symbol).getTypeMatrix(), last), getDisallowedSymbols(symbol, typeContext), freshVarGenerator, typeContext)) {
                            Substitution create = Substitution.create();
                            create.put(variableSymbol, term2);
                            linkedHashSet.add(rule.deepcopy().apply(create));
                        }
                        Substitution create2 = Substitution.create();
                        create2.put(variableSymbol, unifies.get(variableSymbol));
                        rule = rule.deepcopy().apply(create2);
                        left = rule.getLeft();
                    }
                }
            }
            return linkedHashSet;
        } catch (UnificationException e) {
            throw new RuntimeException("resolveOverlap was called with non-unifiable terms " + left + " and " + term);
        }
    }

    private static Set<Term> computeNonOverlappingTerms(Term term, Term term2, Set<Symbol> set, FreshVarGenerator freshVarGenerator, TypeContext typeContext) {
        HashSet hashSet = new HashSet();
        if (term.isVariable()) {
            return hashSet;
        }
        HashSet<ConstructorSymbol> hashSet2 = new HashSet();
        for (Symbol symbol : typeContext.getTypeDefOf((ConstructorSymbol) term2.getSymbol()).getDeclaredSymbols()) {
            if (!set.contains(symbol)) {
                hashSet2.add((ConstructorSymbol) symbol);
            }
        }
        ConstructorSymbol constructorSymbol = (ConstructorSymbol) term.getSymbol();
        for (ConstructorSymbol constructorSymbol2 : hashSet2) {
            if (constructorSymbol2.equals(constructorSymbol)) {
                Term typeMatrix = typeContext.getSingleTypeOf(constructorSymbol2).getTypeMatrix();
                HashSet<List> hashSet3 = new HashSet();
                for (int i = 0; i < constructorSymbol2.getArity(); i++) {
                    Set<Term> computeNonOverlappingTerms = computeNonOverlappingTerms(term.getArgument(i), TypeTools.getFunctionArgAt(typeMatrix, i), getDisallowedSymbols(constructorSymbol, typeContext), freshVarGenerator, typeContext);
                    if (computeNonOverlappingTerms.isEmpty()) {
                        computeNonOverlappingTerms.add(freshVarGenerator.getFreshVariable(Interpretation.VARIABLE_PREFIX + (i + 1), constructorSymbol2.getArgSort(i), false));
                    }
                    for (Term term3 : computeNonOverlappingTerms) {
                        if (i == 0) {
                            Vector vector = new Vector();
                            vector.add(term3);
                            hashSet3.add(vector);
                        } else {
                            Iterator it = hashSet3.iterator();
                            while (it.hasNext()) {
                                ((List) it.next()).add(term3);
                            }
                        }
                    }
                }
                for (List list : hashSet3) {
                    boolean z = true;
                    Iterator it2 = list.iterator();
                    while (it2.hasNext()) {
                        z &= ((Term) it2.next()).isVariable();
                    }
                    if (!z) {
                        hashSet.add(FunctionApplication.create(constructorSymbol2, list));
                    }
                }
            } else {
                Vector vector2 = new Vector(constructorSymbol2.getArity());
                for (int i2 = 0; i2 < constructorSymbol2.getArity(); i2++) {
                    vector2.add(freshVarGenerator.getFreshVariable(Interpretation.VARIABLE_PREFIX + (i2 + 1), constructorSymbol2.getArgSort(i2), false));
                }
                hashSet.add(FunctionApplication.create(constructorSymbol2, vector2));
            }
        }
        return hashSet;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static Set<Symbol> getDisallowedSymbols(Symbol symbol, TypeContext typeContext) {
        return (containsInts && IntegerTools.isIntSymbol(symbol, typeContext)) ? IntegerTools.getDisallowedSymbols(symbol, typeContext) : new HashSet();
    }

    private static Set<Position> getPositionsWithSymbol(Term term, Symbol symbol) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (Position position : term.getPositions()) {
            if (term.getSubterm(position).getSymbol().equals(symbol)) {
                linkedHashSet.add(position);
            }
        }
        return linkedHashSet;
    }
}
