/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.imports.converters;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.ImportClassMapping;
import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DTF;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.OpDef;

public class DifferentialFunctionClassHolder {
    private static final Logger log = LoggerFactory.getLogger(DifferentialFunctionClassHolder.class);
    private Map<String, DifferentialFunction> nodeConverters = ImportClassMapping.getOpNameMapping();
    private Map<String, DifferentialFunction> tensorFlowNames = ImportClassMapping.getTFOpMappingFunctions();
    private Map<String, DifferentialFunction> onnxNames = ImportClassMapping.getOnnxOpMappingFunctions();
    private Map<Long, Class<?>> customOpHashToClass = new HashMap();
    private Map<Long, Map<String, Class<?>>> customOpHashToClasses = new HashMap();
    private List<String> missingOps = new ArrayList<String>();
    private Map<String, OpDescriptor> onnxOpDescriptors;
    private Map<String, OpDef> tensorflowOpDescriptors;
    private Map<String, Map<String, Field>> fieldsForFunction = new LinkedHashMap<String, Map<String, Field>>();
    private static final Set<String> fieldNamesOpsIgnore = new LinkedHashSet<String>(){
        {
            this.add("extraArgs");
            this.add("arrayInitialized");
            this.add("log");
            this.add("inputArguments");
            this.add("outputArguments");
            this.add("outputShapes");
            this.add("outputVariables");
            this.add("tArguments");
            this.add("iArguments");
            this.add("hash");
            this.add("opName");
            this.add("sameDiff");
            this.add("ownName");
        }
    };
    private static final Set<String> classesWithConfig = new LinkedHashSet<String>(){
        {
            this.add(AvgPooling2D.class.getName());
            this.add(Conv2D.class.getName());
            this.add(Conv3D.class.getName());
            this.add(FullConv3D.class.getName());
            this.add(LocalResponseNormalization.class.getName());
            this.add(MaxPooling2D.class.getName());
            this.add(Pooling2D.class.getName());
            this.add(Pooling3D.class.getName());
            this.add(DepthwiseConv2D.class.getName());
            this.add(DeConv2DTF.class.getName());
        }
    };
    private static final Set<Class> classesToIgnore = new HashSet<Class>(Arrays.asList(Object.class));
    private static final Map<Class<?>, Set<String>> classFieldsToIgnore = new HashMap();
    private int countTotalTfOps;
    private int countTotalMappedOps;
    private static DifferentialFunctionClassHolder INSTANCE;

    public Map<String, Field> getFieldsForFunction(DifferentialFunction function) {
        return this.fieldsForFunction.get(function.opName());
    }

    public OpDef getOpDefByTensorflowName(String name) {
        if (!this.tensorflowOpDescriptors.containsKey(name)) {
            throw new ND4JIllegalStateException("No op found with name " + name);
        }
        return this.tensorflowOpDescriptors.get(name);
    }

    public OpDescriptor getOpDescriptorForOnnx(String name) {
        if (!this.onnxOpDescriptors.containsKey(name)) {
            throw new ND4JIllegalStateException("No op found with name " + name);
        }
        return this.onnxOpDescriptors.get(name);
    }

    public DifferentialFunction getOpWithTensorflowName(String tensorflowName) {
        return this.tensorFlowNames.get(tensorflowName);
    }

    public DifferentialFunction getOpWithOnnxName(String onnxName) {
        return this.onnxNames.get(onnxName);
    }

    private DifferentialFunctionClassHolder() {
        for (DifferentialFunction df : ImportClassMapping.getOpNameMapping().values()) {
            try {
                LinkedHashMap<String, Field> fieldNames = new LinkedHashMap<String, Field>();
                Class<?> current = df.getClass();
                ArrayList<Field> fields = new ArrayList<Field>();
                while (current.getSuperclass() != null && !classesToIgnore.contains(current.getSuperclass())) {
                    if (classesWithConfig.contains(current.getName())) {
                        Field[] fieldName = "config";
                        Field configField = current.getDeclaredField("config");
                        if (configField == null) continue;
                        Class<?> configFieldClass = configField.getType();
                        for (Field field : configFieldClass.getDeclaredFields()) {
                            if (Modifier.isStatic(field.getModifiers()) || fieldNamesOpsIgnore.contains(field.getName()) || classFieldsToIgnore.containsKey(current) && classFieldsToIgnore.get(current).contains(field.getName())) continue;
                            fields.add(field);
                            field.setAccessible(true);
                            if (fieldNames.containsKey(field.getName())) {
                                throw new IllegalStateException("Field with name " + field.getName() + " exists for multiple classes: " + ((Field)fieldNames.get(field.getName())).getDeclaringClass().getName() + " and " + field.getDeclaringClass().getName());
                            }
                            fieldNames.put(field.getName(), field);
                        }
                    } else {
                        for (Field field : current.getDeclaredFields()) {
                            if (Modifier.isStatic(field.getModifiers()) || fieldNamesOpsIgnore.contains(field.getName()) || classFieldsToIgnore.containsKey(current) && classFieldsToIgnore.get(current).contains(field.getName())) continue;
                            fields.add(field);
                            field.setAccessible(true);
                            if (fieldNames.containsKey(field.getName())) {
                                throw new IllegalStateException("Field with name " + field.getName() + " exists for multiple classes: " + ((Field)fieldNames.get(field.getName())).getDeclaringClass().getName() + " and " + field.getDeclaringClass().getName());
                            }
                            fieldNames.put(field.getName(), field);
                        }
                    }
                    current = current.getSuperclass();
                }
                this.fieldsForFunction.put(df.opName(), fieldNames);
            }
            catch (NoOpNameFoundException e) {
                log.trace("Skipping function  " + df.getClass());
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        try {
            this.tensorflowOpDescriptors = TensorflowDescriptorParser.opDescs();
            this.onnxOpDescriptors = OnnxDescriptorParser.onnxOpDescriptors();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        HashMap<String, CustomOpDescriptor> map = new HashMap<String, CustomOpDescriptor>(Nd4j.getExecutioner().getCustomOperations());
        Set<String> set = map.keySet();
        set.removeAll(this.nodeConverters.keySet());
        this.missingOps.addAll(set);
        Collections.sort(this.missingOps);
        this.countTotalTfOps = this.tensorflowOpDescriptors.size();
        this.countTotalMappedOps = this.nodeConverters.size();
        Map<String, CustomOpDescriptor> descriptorMap = Nd4j.getExecutioner().getCustomOperations();
        HashSet<Long> multiClassHashes = new HashSet<Long>();
        for (Map.Entry<String, CustomOpDescriptor> e : descriptorMap.entrySet()) {
            String name = e.getKey();
            DifferentialFunction df = this.getInstance(name);
            if (df == null || !CustomOp.class.isAssignableFrom(df.getClass())) continue;
            long h = e.getValue().getHash();
            if (this.customOpHashToClass.containsKey(h)) {
                multiClassHashes.add(h);
            }
            this.customOpHashToClass.put(e.getValue().getHash(), df.getClass());
        }
        for (Map.Entry<String, CustomOpDescriptor> e : descriptorMap.entrySet()) {
            long h = e.getValue().getHash();
            if (!multiClassHashes.contains(h)) continue;
            if (!this.customOpHashToClasses.containsKey(h)) {
                this.customOpHashToClasses.put(h, new HashMap());
            }
            Map<String, Class<?>> m = this.customOpHashToClasses.get(h);
            String name = e.getKey();
            DifferentialFunction df = this.getInstance(name);
            if (df == null) continue;
            m.put(e.getKey(), df.getClass());
        }
    }

    public Set<String> missingOnnxOps() {
        HashSet<String> copy = new HashSet<String>(this.onnxOpDescriptors.keySet());
        copy.removeAll(this.onnxNames.keySet());
        return copy;
    }

    public Set<String> missingTensorflowOps() {
        HashSet<String> copy = new HashSet<String>(this.tensorflowOpDescriptors.keySet());
        copy.removeAll(this.tensorFlowNames.keySet());
        return copy;
    }

    public List<String> missingOps() {
        return this.missingOps;
    }

    public boolean hasName(String name) {
        return this.nodeConverters.containsKey(name);
    }

    public Set<String> opNames() {
        return this.nodeConverters.keySet();
    }

    public DifferentialFunction getInstance(String name) {
        return this.nodeConverters.get(name);
    }

    public Class<?> customOpClassForHashAndName(long customOpHash, String name) {
        if (this.customOpHashToClasses.containsKey(customOpHash)) {
            return this.customOpHashToClasses.get(customOpHash).get(name);
        }
        if (this.customOpHashToClass.containsKey(customOpHash)) {
            return this.customOpHashToClass.get(customOpHash);
        }
        throw new IllegalStateException("No op known for hash: " + customOpHash);
    }

    public static DifferentialFunctionClassHolder getInstance() {
        return INSTANCE;
    }

    public Map<String, DifferentialFunction> getTensorFlowNames() {
        return Collections.unmodifiableMap(this.tensorFlowNames);
    }

    public int getCountTotalTfOps() {
        return this.countTotalTfOps;
    }

    public int getCountTotalMappedOps() {
        return this.countTotalMappedOps;
    }

    static {
        classFieldsToIgnore.put(BaseOp.class, new HashSet<String>(Arrays.asList("x", "y", "z", "n", "numProcessed", "xVertexId", "yVertexId", "zVertexId", "extraArgz")));
        INSTANCE = new DifferentialFunctionClassHolder();
    }
}

