package natlab.tame.simplification;

import ast.ASTNode;
import ast.AssignStmt;
import ast.Expr;
import ast.Function;
import ast.FunctionList;
import ast.LambdaExpr;
import ast.List;
import ast.Name;
import ast.NameExpr;
import ast.ParameterizedExpr;
import com.google.common.base.Predicates;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.Iterables;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import natlab.toolkits.analysis.varorfun.VFPreorderAnalysis;
import natlab.toolkits.filehandling.FunctionOrScriptQuery;
import natlab.toolkits.rewrite.TempFactory;
import natlab.toolkits.rewrite.TempFunctionBuilderHelper;
import natlab.toolkits.rewrite.TransformedNode;
import natlab.toolkits.rewrite.simplification.AbstractSimplification;
import natlab.utils.NodeFinder;

/* loaded from: input_file:natlab/tame/simplification/LambdaSimplification.class */
public class LambdaSimplification extends AbstractSimplification {
    FunctionOrScriptQuery query;
    VFPreorderAnalysis kind;
    HashMap<String, Function> tempFunctions;

    public LambdaSimplification(ASTNode<?> aSTNode, VFPreorderAnalysis vFPreorderAnalysis) {
        super(aSTNode, vFPreorderAnalysis);
        this.tempFunctions = new HashMap<>();
        this.query = vFPreorderAnalysis.getQuery();
        this.kind = vFPreorderAnalysis;
    }

    public Map<String, Function> getTempFunctions() {
        return new HashMap(this.tempFunctions);
    }

    @Override // natlab.toolkits.rewrite.simplification.AbstractSimplification
    public Set<Class<? extends AbstractSimplification>> getDependencies() {
        return Collections.emptySet();
    }

    @Override // nodecases.natlab.NatlabAbstractNodeCaseHandler, nodecases.natlab.NatlabNodeCaseHandler
    public void caseLambdaExpr(LambdaExpr lambdaExpr) {
        HashSet hashSet = new HashSet();
        if (lambdaExpr.getBody() instanceof ParameterizedExpr) {
            ParameterizedExpr parameterizedExpr = (ParameterizedExpr) lambdaExpr.getBody();
            if (!isVar(parameterizedExpr) && Iterables.all(parameterizedExpr.getArgList(), Predicates.instanceOf(NameExpr.class))) {
                return;
            }
        }
        FluentIterable find = NodeFinder.find(LambdaExpr.class, lambdaExpr.getBody());
        if (!Iterables.isEmpty(find)) {
            rewriteChildren(lambdaExpr);
            Iterator<Name> it = ((LambdaExpr) Iterables.getFirst(find, null)).getInputParamList().iterator();
            while (it.hasNext()) {
                hashSet.add(it.next().getID());
            }
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        Iterator<Name> it2 = lambdaExpr.getInputParamList().iterator();
        while (it2.hasNext()) {
            linkedHashSet.add(it2.next().getID());
        }
        for (NameExpr nameExpr : lambdaExpr.getBody().getAllNameExpressions()) {
            if (!linkedHashSet.contains(nameExpr.getName().getID()) && !hashSet.contains(nameExpr.getName().getID()) && isVar((Expr) nameExpr)) {
                linkedHashSet2.add(nameExpr.getName().getID());
            }
        }
        List<Name> list = new List<>();
        List list2 = new List();
        Iterator it3 = linkedHashSet2.iterator();
        while (it3.hasNext()) {
            String str = (String) it3.next();
            list.add(new Name(str));
            list2.add(new NameExpr(new Name(str)));
        }
        Iterator it4 = linkedHashSet.iterator();
        while (it4.hasNext()) {
            String str2 = (String) it4.next();
            list.add(new Name(str2));
            list2.add(new NameExpr(new Name(str2)));
        }
        Expr body = lambdaExpr.getBody();
        String freshFunctionName = TempFunctionBuilderHelper.getFreshFunctionName(lambdaExpr, this.query, "lambda_", this.tempFunctions.keySet());
        Function function = new Function();
        TempFactory genFreshTempFactory = TempFactory.genFreshTempFactory();
        function.setName(freshFunctionName);
        function.setOutputParamList(new List().add(genFreshTempFactory.genName()));
        function.setStmtList(new List().add(new AssignStmt(genFreshTempFactory.genNameExpr(), body)));
        function.setInputParamList(list);
        this.tempFunctions.put(freshFunctionName, function);
        this.newNode = new TransformedNode(new LambdaExpr(lambdaExpr.getInputParamList(), new ParameterizedExpr(new NameExpr(new Name(freshFunctionName)), list2)));
    }

    @Override // nodecases.natlab.NatlabAbstractNodeCaseHandler, nodecases.natlab.NatlabNodeCaseHandler
    public void caseFunctionList(FunctionList functionList) {
        rewriteChildren(functionList);
        Iterator<String> it = this.tempFunctions.keySet().iterator();
        while (it.hasNext()) {
            functionList.getFunctionList().add(this.tempFunctions.get(it.next()));
        }
        this.newNode = new TransformedNode(functionList);
    }
}
