/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.DataTypesSession;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
import org.nd4j.autodiff.samediff.ops.SDCNN;
import org.nd4j.autodiff.samediff.ops.SDLoss;
import org.nd4j.autodiff.samediff.ops.SDMath;
import org.nd4j.autodiff.samediff.ops.SDNN;
import org.nd4j.autodiff.samediff.ops.SDRNN;
import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.autodiff.util.cloner.DataBufferFastCloner;
import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SameDiff
extends SDBaseOps {
    private static final Logger log;
    private final Map<String, Variable> variables = new LinkedHashMap<String, Variable>();
    private final Map<String, SameDiffOp> ops = new LinkedHashMap<String, SameDiffOp>();
    private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<Long, InferenceSession>();
    private final Map<String, DeviceLocalNDArray> constantArrays = new ConcurrentHashMap<String, DeviceLocalNDArray>();
    private final Map<String, DeviceLocalNDArray> variablesArrays = new ConcurrentHashMap<String, DeviceLocalNDArray>();
    private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<Long, Map<String, INDArray>>();
    private final List<String> lossVariables = new ArrayList<String>();
    private TrainingConfig trainingConfig;
    private boolean initializedTraining;
    private INDArray updaterState;
    private Map<String, INDArray> updaterViews;
    private Map<String, GradientUpdater> updaterMap;
    private Map<String, String> baseNameForFunctionInstanceId;
    private DifferentialFunctionFactory functionFactory;
    @Deprecated
    private Map<String, long[]> variableNameToShape;
    @Deprecated
    private Map<String, SDVariable> forwardVarForGrad;
    private int variableId = 0;
    public final SDMath math = new SDMath(this);
    public final SDRandom random = new SDRandom(this);
    public final SDNN nn = new SDNN(this);
    public final SDCNN cnn = new SDCNN(this);
    public final SDRNN rnn = new SDRNN(this);
    public final SDLoss loss = new SDLoss(this);
    private Map<String, List<String>> propertiesToResolve;
    private Map<String, Map<String, Object>> propertiesForFunction;
    @Deprecated
    private Map<String, long[]> placeHolderOriginalShapes;
    private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
    private Map<String, SameDiff> sameDiffFunctionInstances;
    private Set<String> placeHolderFunctions;
    private static Cloner cloner;
    private static Map<String, Method> opMethods;
    private Table<String, String, String> fieldVariableResolutionMapping;
    private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);
    private boolean debugMode;
    private Map<int[], Op> opsForResult;
    private boolean resolvedVariables = false;
    boolean logExecution = true;
    private SameDiff parent;
    private SameDiff child;
    public static final String TRAINING_CONFIG_JSON_ZIP_ENTRY_NAME = "trainingConfig.json";
    public static final String SAMEDIFF_FILE_ENTRY_NAME = "samediff.fb";

    public SDMath math() {
        return this.math;
    }

    public SDRandom random() {
        return this.random;
    }

    public SDNN nn() {
        return this.nn;
    }

    public SDCNN cnn() {
        return this.cnn;
    }

    public SDRNN rnn() {
        return this.rnn;
    }

    public SDLoss loss() {
        return this.loss;
    }

    public static Cloner newCloner() {
        Cloner cloner = new Cloner();
        INDArrayFastCloner fc = new INDArrayFastCloner();
        cloner.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), (IFastCloner)fc);
        DataBufferFastCloner fc2 = new DataBufferFastCloner();
        DataBufferFactory d = Nd4j.getDataBufferFactory();
        SameDiff.doReg(cloner, fc2, d.intBufferClass());
        SameDiff.doReg(cloner, fc2, d.longBufferClass());
        SameDiff.doReg(cloner, fc2, d.halfBufferClass());
        SameDiff.doReg(cloner, fc2, d.floatBufferClass());
        SameDiff.doReg(cloner, fc2, d.doubleBufferClass());
        SameDiff.doReg(cloner, fc2, CompressedDataBuffer.class);
        return cloner;
    }

    private static void doReg(Cloner cl, IFastCloner fc, Class<?> c) {
        if (c != null) {
            cl.registerFastCloner(c, fc);
        }
    }

    public void updateVariableName(String varName, String withName) {
        DifferentialFunction func;
        SDVariable oldVarNameRef = this.getVariable(varName);
        Variable v = this.variables.remove(varName);
        String oldVarName = varName;
        oldVarNameRef.setVarName(withName);
        v.setName(withName);
        this.variables.put(withName, v);
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            List<String> inputsToOp;
            List<String> outputsOfOp = sameDiffOp.getOutputsOfOp();
            if (outputsOfOp != null && !outputsOfOp.isEmpty()) {
                for (int i = 0; i < outputsOfOp.size(); ++i) {
                    if (!outputsOfOp.get(i).equals(oldVarName)) continue;
                    outputsOfOp.set(i, withName);
                }
            }
            if ((inputsToOp = sameDiffOp.getInputsToOp()) == null || inputsToOp.isEmpty()) continue;
            for (int i = 0; i < inputsToOp.size(); ++i) {
                if (!inputsToOp.get(i).equals(oldVarName)) continue;
                inputsToOp.set(i, withName);
            }
        }
        if (this.variableNameToShape.containsKey(oldVarName)) {
            long[] shape = this.variableNameToShape.remove(oldVarName);
            this.variableNameToShape.put(withName, shape);
        }
        if (this.forwardVarForGrad.containsKey(oldVarName)) {
            SDVariable forwardGrad = this.forwardVarForGrad.remove(oldVarName);
            this.forwardVarForGrad.put(withName, forwardGrad);
        }
        if (v.getInputsForOp() != null) {
            List<String> funcNames = v.getInputsForOp();
            for (String s : funcNames) {
                DifferentialFunction func2 = this.ops.get(s).getOp();
                if (!(func2 instanceof BaseOp)) continue;
                BaseOp baseOp = (BaseOp)func2;
                if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
                    baseOp.setXVertexId(withName);
                }
                if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
                    baseOp.setYVertexId(withName);
                }
                if (baseOp.getZVertexId() == null || !baseOp.getZVertexId().equals(oldVarName)) continue;
                baseOp.setZVertexId(withName);
            }
        }
        if (v.getOutputOfOp() != null && (func = this.ops.get(v.getOutputOfOp()).getOp()) instanceof BaseOp) {
            BaseOp baseOp = (BaseOp)func;
            if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
                baseOp.setXVertexId(withName);
            }
            if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
                baseOp.setYVertexId(withName);
            }
            if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
                baseOp.setZVertexId(withName);
            }
        }
    }

    public SameDiff disableDebugging() {
        this.debugMode = false;
        return this;
    }

    public SameDiff enableDebugMode() {
        this.debugMode = true;
        return this;
    }

    @Override
    public DifferentialFunctionFactory f() {
        return this.functionFactory;
    }

    public SDVariable invokeGraphOn(SameDiff sameDiff) {
        HashMap<Integer, Integer> thisVertexIdToNew = new HashMap<Integer, Integer>();
        int idx = 1;
        for (SDVariable var : this.variables()) {
            SDVariable clone = (SDVariable)cloner.deepCloneDontCloneInstances((Object)var, new Object[]{var.getSameDiff()});
            SDVariable newVar = sameDiff.var(clone);
            if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) {
                sameDiff.associateArrayWithVariable(var.getArr(), newVar);
            }
            thisVertexIdToNew.put(idx, idx);
            clone.setSameDiff(sameDiff);
            ++idx;
        }
        LinkedHashMap<String, DifferentialFunction> newFunctions = new LinkedHashMap<String, DifferentialFunction>();
        for (SameDiffOp op : this.ops.values()) {
            DifferentialFunction function = op.getOp();
            if (function instanceof SDVariable) continue;
            DifferentialFunction clone = (DifferentialFunction)cloner.deepCloneDontCloneInstances((Object)function, new Object[]{function.getSameDiff()});
            clone.setSameDiff(sameDiff);
            clone.setOwnName(function.getOwnName());
            if (sameDiff.functionExists(function.getOwnName())) {
                sameDiff.putFunctionForId(function.getOwnName(), function);
            }
            newFunctions.put(function.getOwnName(), clone);
            SDVariable[] argsForFunction = function.args();
            SDVariable[] outputsForFunction = function.outputVariables();
            sameDiff.addArgsFor(argsForFunction, clone);
            sameDiff.addOutgoingFor(outputsForFunction, function);
            for (SDVariable arg : clone.args()) {
                arg.setSameDiff(sameDiff);
            }
            for (SDVariable output : clone.outputVariables()) {
                output.setSameDiff(sameDiff);
            }
            sameDiff.ops.put(function.getOwnName(), op);
        }
        return sameDiff.variables().get(sameDiff.variables().size() - 1);
    }

    public boolean functionExists(String id) {
        return this.ops.containsKey(id);
    }

    public DifferentialFunction functionOutputFor(String varName) {
        if (this.variables.get(varName).getOutputOfOp() == null) {
            return null;
        }
        String outName = this.variables.get(varName).getOutputOfOp();
        if (outName == null) {
            return null;
        }
        return this.ops.get(outName).getOp();
    }

    public DifferentialFunction getFunctionById(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        if (!this.ops.containsKey(id)) {
            throw new ND4JIllegalStateException("No function with id " + id + " found!");
        }
        return this.ops.get(id).getOp();
    }

    public void putFunctionForId(String id, DifferentialFunction function) {
        if (this.ops.containsKey(id) && this.ops.get(id).getOp() == null) {
            throw new ND4JIllegalStateException("Function by id already exists!");
        }
        if (function instanceof SDVariable) {
            throw new ND4JIllegalStateException("Function must not be a variable!");
        }
        if (!this.ops.containsKey(id)) {
            this.ops.put(id, SameDiffOp.builder().name(id).op(function).build());
        }
    }

    public String[] getInputsForFunction(DifferentialFunction function) {
        if (!this.ops.containsKey(function.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
        }
        List<String> inputs = this.ops.get(function.getOwnName()).getInputsToOp();
        return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
    }

    public String[] getOutputsForFunction(DifferentialFunction function) {
        if (!this.ops.containsKey(function.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
        }
        List<String> outputs = this.ops.get(function.getOwnName()).getOutputsOfOp();
        return outputs == null ? null : outputs.toArray(new String[outputs.size()]);
    }

    public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) {
        String[] inputs = this.getOutputsForFunction(function);
        if (inputs == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + function);
        }
        SDVariable[] vars = new SDVariable[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            vars[i] = this.getVariable(inputs[i]);
        }
        return vars;
    }

    public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) {
        String[] inputs = this.getInputsForFunction(function);
        if (inputs == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + function);
        }
        SDVariable[] vars = new SDVariable[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            vars[i] = this.getVariable(inputs[i]);
            if (vars[i] != null) continue;
            throw new ND4JIllegalStateException("Found null variable at index " + i);
        }
        return vars;
    }

    public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) {
        if (varName == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        if (arr == null) {
            throw new NullPointerException("arr is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        if (v.isConstant()) {
            this.constantArrays.put(varName, new DeviceLocalNDArray(arr));
        } else if (v.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.put(varName, new DeviceLocalNDArray(arr));
        } else if (v.isPlaceHolder()) {
            long tid = Thread.currentThread().getId();
            if (!this.placeholdersPerThread.containsKey(tid)) {
                this.placeholdersPerThread.put(tid, new HashMap());
            }
            this.placeholdersPerThread.get(tid).put(varName, arr);
        } else {
            throw new UnsupportedOperationException("Cannot set variable of type " + (Object)((Object)v.getVariableType()) + " using this method");
        }
    }

    public long[] getShapeForVarName(String varName) {
        if (this.arrayAlreadyExistsForVarName(varName)) {
            return this.getVariable(varName).getArr().shape();
        }
        return this.variableNameToShape.get(varName);
    }

    public LongShapeDescriptor getShapeDescriptorForVarName(String varName) {
        if (this.getVariable(varName).getArr() != null) {
            return this.getVariable(varName).getArr().shapeDescriptor();
        }
        return LongShapeDescriptor.fromShape(this.variableNameToShape.get(varName), Nd4j.dataType());
    }

    @Deprecated
    public void putShapeForVarName(String varName, long[] shape) {
        if (shape == null) {
            throw new ND4JIllegalStateException("Shape must not be null!");
        }
        if (this.variableNameToShape.containsKey(varName)) {
            throw new ND4JIllegalStateException("Shape for " + varName + " already exists!");
        }
        this.variableNameToShape.put(varName, shape);
    }

    public void putShapeForVarName(String varName, LongShapeDescriptor shape) {
        SDVariable v = this.getVariable(varName);
        this.putShapeForVarName(varName, shape.getShape());
        v.setDataType(shape.dataType());
    }

    @Deprecated
    public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch) {
        Preconditions.checkNotNull((Object)shape, (String)"Cannot put null shape for variable: %s", (Object)varName);
        if (!this.variableNameToShape.containsKey(varName)) {
            this.putShapeForVarName(varName, shape);
        }
    }

    public boolean shapeAlreadyExistsForVarName(String varName) {
        return this.variableNameToShape.containsKey(varName) || this.arrayAlreadyExistsForVarName(varName);
    }

    public boolean arrayAlreadyExistsForVarName(String varName) {
        SDVariable var = this.getVariable(varName);
        switch (var.getVariableType()) {
            case VARIABLE: {
                return this.variablesArrays.containsKey(varName);
            }
            case ARRAY: {
                long tid = Thread.currentThread().getId();
                return this.sessions.containsKey(tid) && this.sessions.get(tid).contains(varName, "main", 0, null);
            }
            case CONSTANT: {
                return this.constantArrays.containsKey(varName);
            }
            case PLACEHOLDER: {
                return this.placeholdersPerThread.containsKey(Thread.currentThread().getId()) && this.placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
            }
        }
        throw new RuntimeException("Unknown variable type: " + (Object)((Object)var.getVariableType()));
    }

    public INDArray getArrForVarName(@NonNull String varName) {
        if (varName == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable found with name \"%s\"", (Object)varName);
        SDVariable v = this.variables.get(varName).getVariable();
        switch (v.getVariableType()) {
            case VARIABLE: {
                if (!this.variablesArrays.containsKey(varName)) {
                    v.storeAndAllocateNewArray();
                }
                return (INDArray)this.variablesArrays.get(varName).get();
            }
            case CONSTANT: {
                if (!this.constantArrays.containsKey(varName)) {
                    return null;
                }
                return (INDArray)this.constantArrays.get(varName).get();
            }
            case ARRAY: {
                InferenceSession s = this.sessions.get(Thread.currentThread().getId());
                if (s == null) {
                    return null;
                }
                return (INDArray)s.get(varName, "main", 0, null, false);
            }
            case PLACEHOLDER: {
                long tid = Thread.currentThread().getId();
                if (this.placeholdersPerThread.get(tid) == null || !this.placeholdersPerThread.get(tid).containsKey(varName)) {
                    return null;
                }
                return this.placeholdersPerThread.get(tid).get(varName);
            }
        }
        throw new RuntimeException("Unknown variable type: " + (Object)((Object)v.getVariableType()));
    }

    public void associateArrayWithVariable(INDArray arr, @NonNull String variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.variables.containsKey(variable), (String)"Cannot associate array with variable \"%s\": variable \"%s\" does not exist in this SameDiff instance", (Object)variable, (Object)variable);
        this.associateArrayWithVariable(arr, this.getVariable(variable));
    }

    public void associateArrayWithVariable(INDArray arr, SDVariable variable) {
        if (variable == null) {
            throw new ND4JIllegalArgumentException("Variable must not be null!");
        }
        if (arr == null) {
            throw new ND4JIllegalArgumentException("Array must not be null");
        }
        if (variable.dataType() != arr.dataType()) {
            arr = arr.castTo(variable.dataType());
        }
        Preconditions.checkState((variable.dataType() == arr.dataType() ? 1 : 0) != 0, (String)"Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", (Object)variable.getVarName(), (Object)variable.dataType(), (Object)arr.dataType());
        if (this.sessions.get(Thread.currentThread().getId()) == null) {
            this.sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
        }
        boolean duped = false;
        if (arr.isAttached()) {
            arr = arr.detach();
            duped = true;
        }
        if (arr.isView()) {
            arr = arr.dup();
            duped = true;
        }
        if (!duped && variable.getVariableType() == VariableType.VARIABLE) {
            for (DeviceLocalNDArray deviceLocalNDArray : this.variablesArrays.values()) {
                if (deviceLocalNDArray.get() != arr) continue;
                arr = arr.dup();
                break;
            }
        }
        switch (variable.getVariableType()) {
            case VARIABLE: {
                this.variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
                break;
            }
            case CONSTANT: {
                this.constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
                break;
            }
            case ARRAY: {
                InferenceSession session = this.sessions.get(Thread.currentThread().getId());
                AbstractSession.VarId varId = session.newVarId(variable.getVarName(), "main", 0, null);
                session.getNodeOutputs().put(varId, arr);
            }
            case PLACEHOLDER: {
                long tid = Thread.currentThread().getId();
                if (!this.placeholdersPerThread.containsKey(tid)) {
                    this.placeholdersPerThread.put(tid, new HashMap());
                }
                this.placeholdersPerThread.get(tid).put(variable.getVarName(), arr);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown variable type: " + (Object)((Object)variable.getVariableType()));
            }
        }
        if (this.sameDiffFunctionInstances != null && this.sameDiffFunctionInstances.size() > 0) {
            for (Map.Entry entry : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff sd = (SameDiff)entry.getValue();
                SDVariable v = sd.getVariable(variable.getVarName());
                if (v == null) continue;
                sd.associateArrayWithVariable(arr, v);
            }
        }
    }

    public void putSubFunction(String name, SameDiff nameSpace) {
        if (this.sameDiffFunctionInstances.containsKey(name) && this.sameDiffFunctionInstances.get(name) != nameSpace) {
            throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
        }
        this.sameDiffFunctionInstances.put(name, nameSpace);
    }

    public Map<String, SDVariable> variableMap() {
        LinkedHashMap<String, SDVariable> ret = new LinkedHashMap<String, SDVariable>();
        for (Variable v : this.variables.values()) {
            ret.put(v.getName(), v.getVariable());
        }
        return ret;
    }

    @Deprecated
    public SDVariable invoke(Op op, SDVariable x, SDVariable y) {
        if (!opMethods.containsKey(op.opName())) {
            throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
        }
        if (x != null && y != null) {
            try {
                return (SDVariable)opMethods.get(op.opName()).invoke((Object)this, x, y);
            }
            catch (Exception exception) {
            }
        } else {
            try {
                return (SDVariable)opMethods.get(op.opName()).invoke((Object)this, x);
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
    }

    public Collection<String> definedFunctionNames() {
        return this.sameDiffFunctionInstances.keySet();
    }

    public SDVariable invoke(Op op, SDVariable x) {
        return this.invoke(op, x, null);
    }

    private SameDiff() {
        this.functionFactory = new DifferentialFunctionFactory(this);
        this.sameDiffFunctionDefinitionMap = new LinkedHashMap<String, SameDiffFunctionDefinition>();
        this.sameDiffFunctionInstances = new LinkedHashMap<String, SameDiff>();
        this.forwardVarForGrad = new LinkedHashMap<String, SDVariable>();
        this.opsForResult = new IntArrayKeyMap();
        this.variableNameToShape = new LinkedHashMap<String, long[]>();
        this.placeHolderOriginalShapes = new LinkedHashMap<String, long[]>();
        this.placeHolderFunctions = new LinkedHashSet<String>();
        this.baseNameForFunctionInstanceId = new LinkedHashMap<String, String>();
        this.propertiesToResolve = new LinkedHashMap<String, List<String>>();
        this.propertiesForFunction = new LinkedHashMap<String, Map<String, Object>>();
        this.fieldVariableResolutionMapping = HashBasedTable.create();
    }

    public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) {
        if (!this.propertiesToResolve.containsKey(forFunction.getOwnName())) {
            ArrayList<String> newVal = new ArrayList<String>();
            newVal.add(arrayName);
            this.propertiesToResolve.put(forFunction.getOwnName(), newVal);
        } else {
            List<String> newVal = this.propertiesToResolve.get(forFunction.getOwnName());
            newVal.add(arrayName);
        }
    }

    public List<String> propertiesToResolveForFunction(DifferentialFunction function) {
        if (!this.propertiesToResolve.containsKey(function.getOwnName())) {
            return Collections.emptyList();
        }
        return this.propertiesToResolve.get(function.getOwnName());
    }

    private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) {
        if (!this.propertiesForFunction.containsKey(functionFor.getOwnName())) {
            LinkedHashMap<String, Object> fields = new LinkedHashMap<String, Object>();
            fields.put(propertyName, propertyValue);
            this.propertiesForFunction.put(functionFor.getOwnName(), fields);
        } else {
            Map<String, Object> fieldMap = this.propertiesForFunction.get(functionFor.getOwnName());
            if (fieldMap.containsKey(propertyName)) {
                throw new ND4JIllegalStateException("Attempting to override property " + propertyName);
            }
            fieldMap.put(propertyName, propertyValue);
        }
    }

    public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) {
        this.fieldVariableResolutionMapping.put((Object)function.getOwnName(), (Object)fieldName, (Object)varName);
    }

    public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) {
        return (String)this.fieldVariableResolutionMapping.get((Object)function.getOwnName(), (Object)fieldName);
    }

    public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) {
        this.baseNameForFunctionInstanceId.put(function.getOwnName(), baseName);
    }

    public String getBaseNameForFunction(DifferentialFunction function) {
        return this.baseNameForFunctionInstanceId.get(function.getOwnName());
    }

    public <X extends SDVariable> X setupFunction(X function) {
        Preconditions.checkNotNull(function, (String)"Passed in function must not be null!");
        if (function instanceof SDVariable) {
            if (function.getSameDiff() != this) {
                function.setSameDiff(this);
            }
            return function;
        }
        return function;
    }

    public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) {
        String[] varNames = new String[variables.length];
        for (int i = 0; i < varNames.length; ++i) {
            varNames[i] = variables[i].getVarName();
        }
        this.addOutgoingFor(varNames, function);
    }

    public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
        if (function.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (this.ops.get(function.getOwnName()).getOutputsOfOp() != null && !this.ops.get(function.getOwnName()).getOutputsOfOp().isEmpty()) {
            throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
        }
        if (varNames == null) {
            throw new ND4JIllegalStateException("Var names can not be null!");
        }
        for (int i = 0; i < varNames.length; ++i) {
            if (varNames[i] != null) continue;
            throw new ND4JIllegalStateException("Variable name elements can not be null!");
        }
        this.ops.get(function.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
        for (String resultName : varNames) {
            this.variables.get(resultName).setOutputOfOp(function.getOwnName());
        }
    }

    public void addArgsFor(String[] variables, DifferentialFunction function) {
        if (function.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        for (String varName : variables) {
            if (!this.isPlaceHolder(varName)) continue;
            this.placeHolderFunctions.add(function.getOwnName());
        }
        if (!this.ops.containsKey(function.getOwnName())) {
            this.ops.put(function.getOwnName(), SameDiffOp.builder().name(function.getOwnName()).op(function).build());
        }
        this.ops.get(function.getOwnName()).setInputsToOp(Arrays.asList(variables));
        for (String variableName : variables) {
            List<String> funcs = this.variables.get(variableName).getInputsForOp();
            if (funcs == null) {
                funcs = new ArrayList<String>();
                this.variables.get(variableName).setInputsForOp(funcs);
            }
            if (funcs.contains(function.getOwnName())) continue;
            funcs.add(function.getOwnName());
        }
    }

    public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
        String[] varNames = new String[variables.length];
        for (int i = 0; i < varNames.length; ++i) {
            if (variables[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
            varNames[i] = variables[i].getVarName();
        }
        this.addArgsFor(varNames, function);
    }

    public DifferentialFunction getVariableOutputFunction(String variableName) {
        Preconditions.checkState((boolean)this.variables.containsKey(variableName), (String)"No variable with name \"%s\" found in graph", (Object)variableName);
        if (this.variables.get(variableName).getOutputOfOp() == null) {
            return null;
        }
        return this.ops.get(this.variables.get(variableName).getOutputOfOp()).getOp();
    }

    public boolean hasArgs(DifferentialFunction function) {
        List<String> vertexIdArgs = this.ops.get(function.getOwnName()).getInputsToOp();
        return vertexIdArgs != null && vertexIdArgs.size() > 0;
    }

    public DifferentialFunction[] functions() {
        ArrayList<DifferentialFunction> out = new ArrayList<DifferentialFunction>(this.ops.size());
        for (SameDiffOp op : this.ops.values()) {
            out.add(op.getOp());
        }
        return out.toArray(new DifferentialFunction[out.size()]);
    }

    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.variables != null ? this.variables.hashCode() : 0);
        return result;
    }

    public static SameDiff create(SameDiff originalSameDiff) {
        DifferentialFunctionFactory differentialFunctionFactory;
        SameDiff ret = SameDiff.builder().sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances).build();
        ret.variables.putAll(originalSameDiff.variables);
        ret.functionFactory = differentialFunctionFactory = new DifferentialFunctionFactory(ret);
        return ret;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        SameDiff sameDiff = (SameDiff)o;
        if (this.variables != null ? !this.variables.equals(sameDiff.variables) : sameDiff.variables != null) {
            return false;
        }
        if (this.sameDiffFunctionDefinitionMap != null ? !this.sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null) {
            return false;
        }
        return this.sameDiffFunctionInstances != null ? this.sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
    }

    public static SameDiff create() {
        return new SameDiff();
    }

    public SameDiff dup() {
        Cloner cloner = SameDiff.newCloner();
        SameDiff clone = (SameDiff)cloner.deepClone((Object)this);
        clone.sessions.clear();
        return clone;
    }

    public long numElements() {
        long ret = 0L;
        for (SDVariable variable : this.variables()) {
            long[] shape = variable.getShape();
            if (shape == null) continue;
            ret += (long)ArrayUtil.prod((long[])shape);
        }
        return ret;
    }

    public List<String> inputs() {
        ArrayList<String> out = new ArrayList<String>();
        for (String s : this.variables.keySet()) {
            if (!this.isPlaceHolder(s)) continue;
            out.add(s);
        }
        return out;
    }

    public List<String> outputs() {
        ArrayList<String> out = new ArrayList<String>();
        for (Variable v : this.variables.values()) {
            String opName;
            SameDiffOp o;
            if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || v.getInputsForOp() != null && !v.getInputsForOp().isEmpty() || v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty() || v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty() || v.getOutputOfOp() != null && ((o = this.ops.get(opName = v.getOutputOfOp())).getOp() instanceof Assert || o.getOp() instanceof Switch)) continue;
            out.add(v.getName());
        }
        return out;
    }

    public List<SDVariable> variables() {
        return new ArrayList<SDVariable>(this.variableMap().values());
    }

    public List<String> getLossVariables() {
        return Collections.unmodifiableList(this.lossVariables);
    }

    public void setLossVariables(String ... lossVariableNames) {
        this.lossVariables.clear();
        for (String s : lossVariableNames) {
            this.addLossVariable(s);
        }
        this.sameDiffFunctionInstances.remove("grad");
    }

    public void addLossVariable(@NonNull String variableName) {
        if (variableName == null) {
            throw new NullPointerException("variableName is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.hasVariable(variableName), (String)"No variable with name \"%s\" exists", (Object)variableName);
        SDVariable v = this.getVariable(variableName);
        Preconditions.checkState((boolean)v.dataType().isFPType(), (String)"Only floating point type variables can be marked as losses to be minimized. SDVariable \"%s\" has datatype %s", (Object)variableName, (Object)v.dataType());
        Preconditions.checkState((v.getVariableType() == VariableType.ARRAY ? 1 : 0) != 0, (String)"Only ARRAY type SDVariables can be marked as losses to be minimized. SDVariable \"%s\" has variable type %s", (Object)variableName, (Object)((Object)v.getVariableType()));
        if (!this.lossVariables.contains(variableName)) {
            this.lossVariables.add(variableName);
        }
    }

    public void setTrainingConfig(TrainingConfig trainingConfig) {
        this.trainingConfig = trainingConfig;
    }

    public void fit(DataSet dataSet) {
        this.fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false);
    }

    public void fit(MultiDataSet dataSet) {
        this.fit(new SingletonMultiDataSetIterator(dataSet), 1, false);
    }

    public void fit(DataSetIterator iter, int numEpochs) {
        this.fit(new MultiDataSetIteratorAdapter(iter), numEpochs, true);
    }

    public void fit(MultiDataSetIterator iter, int numEpochs) {
        this.fit(iter, numEpochs, true);
    }

    protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount) {
        Preconditions.checkNotNull((Object)iter, (String)"Iterator must not be null");
        Preconditions.checkState((numEpochs > 0 ? 1 : 0) != 0, (String)"Number of training epochs must be a positive number. Got: %s", (int)numEpochs);
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"No training configuration has been set. A training configuration must be set before training. Use setTrainingConfig(TrainingConfig)");
        Preconditions.checkState((numEpochs == 1 || iter.resetSupported() ? 1 : 0) != 0, (String)"Cannot train for multiple epochs on an iterator that does not support resetting");
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        boolean performedValidation = false;
        for (int i = 0; i < numEpochs; ++i) {
            while (iter.hasNext()) {
                Map<String, INDArray> placeholders;
                MultiDataSet ds = (MultiDataSet)iter.next();
                if (!performedValidation) {
                    Preconditions.checkState((this.trainingConfig.getDataSetFeatureMapping().size() == ds.numFeatureArrays() ? 1 : 0) != 0, (String)"The number of dataset feature mapping variables set in the training configuration (%s) must match the number of dataset feature arrays (%s)", (int)this.trainingConfig.getDataSetFeatureMapping().size(), (int)ds.numFeatureArrays());
                    List<String> labelMapping = this.trainingConfig.getDataSetLabelMapping();
                    int lblSize = labelMapping == null ? 0 : labelMapping.size();
                    Preconditions.checkState((lblSize == ds.numLabelsArrays() ? 1 : 0) != 0, (String)"The number of dataset label mapping variables set in the training configuration (%s) must match the number of dataset label arrays (%s)", (int)lblSize, (int)ds.numLabelsArrays());
                    performedValidation = true;
                }
                Preconditions.checkState(((placeholders = this.toPlaceholderMap(ds)).size() > 0 ? 1 : 0) != 0, (String)"No placeholder variables were set for training");
                this.resolveVariablesWith(placeholders);
                this.execBackwards(placeholders);
                if (!this.initializedTraining) {
                    this.initializeTraining();
                }
                int iteration = this.trainingConfig.getIterationCount();
                int e = this.trainingConfig.getEpochCount();
                for (String s : this.trainingConfig.getTrainableParams()) {
                    double lr;
                    INDArray param = this.variables.get(s).getVariable().getArr();
                    SDVariable gradVar = this.variables.get(s).getVariable().getGradient();
                    if (gradVar == null) continue;
                    INDArray grad = gradVar.getArr();
                    List<Regularization> r = this.trainingConfig.getRegularization();
                    int iterCount = this.trainingConfig.getIterationCount();
                    int epochCount = this.trainingConfig.getEpochCount();
                    double d = lr = this.trainingConfig.getUpdater().hasLearningRate() ? this.trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
                    if (r != null && r.size() > 0) {
                        for (Regularization reg : r) {
                            if (reg.applyStep() != Regularization.ApplyStep.BEFORE_UPDATER) continue;
                            reg.apply(param, grad, lr, iterCount, epochCount);
                        }
                    }
                    INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1L, grad.length()}, grad.ordering() == 'f');
                    Preconditions.checkState((reshapedView != null ? 1 : 0) != 0, (String)"Error reshaping array for parameter \"%s\": array is a view?", (Object)s);
                    GradientUpdater u = this.updaterMap.get(s);
                    try {
                        u.applyUpdater(reshapedView, iteration, e);
                    }
                    catch (Throwable t) {
                        throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + s + "\": either parameter size is inconsistent between iterations, or \"" + s + "\" should not be a trainable parameter?", t);
                    }
                    if (r != null && r.size() > 0) {
                        for (Regularization reg : r) {
                            if (reg.applyStep() != Regularization.ApplyStep.POST_UPDATER) continue;
                            reg.apply(param, grad, lr, iterCount, epochCount);
                        }
                    }
                    if (this.trainingConfig.isMinimize()) {
                        param.subi(grad);
                        continue;
                    }
                    param.addi(grad);
                }
                this.trainingConfig.incrementIterationCount();
            }
            if (i < numEpochs - 1) {
                iter.reset();
            }
            if (!incrementEpochCount) continue;
            this.trainingConfig.incrementEpochCount();
        }
    }

    public double calcRegularizationScore() {
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"No training configuration has been set. A training configuration must be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");
        if (this.trainingConfig.getRegularization() == null || this.trainingConfig.getRegularization().isEmpty()) {
            return 0.0;
        }
        if (this.trainingConfig.getTrainableParams() == null || this.trainingConfig.getTrainableParams().isEmpty()) {
            this.initializeTraining();
        }
        List<Regularization> l = this.trainingConfig.getRegularization();
        double loss = 0.0;
        for (String s : this.trainingConfig.getTrainableParams()) {
            for (Regularization r : l) {
                INDArray arr = this.getVariable(s).getArr();
                loss += r.score(arr, this.trainingConfig.getIterationCount(), this.trainingConfig.getEpochCount());
            }
        }
        return loss;
    }

    protected void initializeTraining() {
        if (!this.initializedTraining) {
            if (this.trainingConfig == null) {
                throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig");
            }
            if (this.trainingConfig.getTrainableParams() == null || this.trainingConfig.getTrainableParams().size() == 0) {
                ArrayList<String> trainVarList = new ArrayList<String>();
                for (Variable var : this.variables.values()) {
                    SDVariable v = var.getVariable();
                    String n = v.getVarName();
                    if (this.variables.get(n).getOutputOfOp() != null || this.isPlaceHolder(n) || this.variables.get(n).getVariable().isConstant() || this.trainingConfig.getDataSetFeatureMapping() != null && this.trainingConfig.getDataSetFeatureMapping().contains(n) || this.trainingConfig.getDataSetLabelMapping() != null && this.trainingConfig.getDataSetLabelMapping().contains(n) || this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().contains(n) || this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().contains(n)) continue;
                    trainVarList.add(n);
                }
                this.trainingConfig.setTrainableParams(trainVarList);
                log.info("Inferred trainable variables: {}", trainVarList);
            }
            long numTrainableParams = 0L;
            DataType dt = null;
            for (String s : this.trainingConfig.getTrainableParams()) {
                SDVariable v = this.variables.get(s).getVariable();
                Preconditions.checkState((v != null ? 1 : 0) != 0, (String)"No variable found for trainable parameter name \"%s\"", (Object)s);
                INDArray arr = v.getArr();
                Preconditions.checkState((arr != null ? 1 : 0) != 0, (String)"No array found for trainable parameter \"%s\"", (Object)s);
                numTrainableParams += arr.length();
                if (dt != null) continue;
                dt = arr.dataType();
            }
            long updaterStateSize = this.trainingConfig.getUpdater().stateSize(numTrainableParams);
            if (updaterStateSize > 0L) {
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.updaterState = Nd4j.createUninitialized(dt, 1L, updaterStateSize);
                }
            }
            long viewSoFar = 0L;
            this.updaterViews = new HashMap<String, INDArray>();
            this.updaterMap = new HashMap<String, GradientUpdater>();
            for (String s : this.trainingConfig.getTrainableParams()) {
                long thisSize = this.trainingConfig.getUpdater().stateSize(this.variables.get(s).getVariable().getArr().length());
                INDArray view = updaterStateSize == 0L || thisSize == 0L ? null : this.updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(viewSoFar, viewSoFar + thisSize));
                this.updaterViews.put(s, view);
                this.updaterMap.put(s, this.trainingConfig.getUpdater().instantiate(view, true));
                viewSoFar += thisSize;
            }
            this.initializedTraining = true;
        }
    }

    private Map<String, INDArray> toPlaceholderMap(MultiDataSet ds) {
        HashMap<String, INDArray> placeholders = new HashMap<String, INDArray>();
        int count = 0;
        for (String s : this.trainingConfig.getDataSetFeatureMapping()) {
            placeholders.put(s, ds.getFeatures(count++));
        }
        count = 0;
        if (this.trainingConfig.getDataSetLabelMapping() != null) {
            for (String s : this.trainingConfig.getDataSetLabelMapping()) {
                placeholders.put(s, ds.getLabels(count++));
            }
        }
        if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().size() > 0) {
            count = 0;
            for (String s : this.trainingConfig.getDataSetFeatureMaskMapping()) {
                if (s == null) {
                    ++count;
                    continue;
                }
                placeholders.put(s, ds.getFeaturesMaskArray(count++));
            }
        }
        if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().size() > 0) {
            count = 0;
            for (String s : this.trainingConfig.getDataSetLabelMaskMapping()) {
                if (s == null) {
                    ++count;
                    continue;
                }
                placeholders.put(s, ds.getLabelsMaskArray(count++));
            }
        }
        return placeholders;
    }

    public void evaluate(DataSetIterator iterator, String outputVariable, IEvaluation ... evaluations) {
        Preconditions.checkArgument((evaluations != null && evaluations.length > 0 ? 1 : 0) != 0, (String)"No evaluations were passed to the evaluate method");
        this.evaluate(new MultiDataSetIteratorAdapter(iterator), Collections.singletonMap(outputVariable, Arrays.asList(evaluations)), Collections.singletonMap(outputVariable, 0));
    }

    public void evaluate(DataSetIterator iterator, Map<String, IEvaluation> variableEvals) {
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        HashMap<String, List<IEvaluation>> variableEvalsList = new HashMap<String, List<IEvaluation>>();
        for (String s : variableEvals.keySet()) {
            map.put(s, 0);
            variableEvalsList.put(s, Collections.singletonList(variableEvals.get(s)));
        }
        this.evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map);
    }

    public void evaluateMultiple(DataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals) {
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        for (String s : variableEvals.keySet()) {
            map.put(s, 0);
        }
        this.evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvals, map);
    }

    public void evaluate(MultiDataSetIterator iterator, String outputVariable, int labelIndex, IEvaluation ... evaluations) {
        Preconditions.checkArgument((evaluations != null && evaluations.length > 0 ? 1 : 0) != 0, (String)"No evaluations were passed to the evaluate method");
        this.evaluate(iterator, Collections.singletonMap(outputVariable, Arrays.asList(evaluations)), Collections.singletonMap(outputVariable, labelIndex));
    }

    public void evaluate(MultiDataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals, Map<String, Integer> predictionLabelMapping) {
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"Training config has not been set");
        Preconditions.checkState((boolean)variableEvals.keySet().equals(predictionLabelMapping.keySet()), (String)"Keysets for variable evaluations and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", variableEvals.keySet(), predictionLabelMapping.keySet());
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        ArrayList<String> reqVars = new ArrayList<String>(variableEvals.keySet());
        while (iterator.hasNext()) {
            MultiDataSet ds = (MultiDataSet)iterator.next();
            Map<String, INDArray> placeholderMap = this.toPlaceholderMap(ds);
            Map<String, INDArray> m = this.exec(placeholderMap, reqVars);
            for (Map.Entry<String, List<IEvaluation>> e : variableEvals.entrySet()) {
                INDArray prediction = m.get(e.getKey());
                for (IEvaluation eval : e.getValue()) {
                    INDArray label = ds.getLabels(predictionLabelMapping.get(e.getKey()));
                    eval.eval(label, prediction);
                }
            }
        }
    }

    public Map<String, INDArray> output(DataSet dataSet, String ... outputs) {
        return this.output(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
    }

    public List<Map<String, INDArray>> output(DataSetIterator iterator, String ... outputs) {
        return this.output(new MultiDataSetIteratorAdapter(iterator), outputs);
    }

    public List<Map<String, INDArray>> output(MultiDataSetIterator iterator, String ... outputs) {
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"Training config has not been set");
        List<String> reqVars = outputs != null ? Arrays.asList(outputs) : this.outputs();
        ArrayList<Map<String, INDArray>> predictions = new ArrayList<Map<String, INDArray>>();
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        while (iterator.hasNext()) {
            MultiDataSet ds = (MultiDataSet)iterator.next();
            Map<String, INDArray> placeholderMap = this.toPlaceholderMap(ds);
            predictions.add(this.exec(placeholderMap, reqVars));
        }
        return predictions;
    }

    public SDVariable one(String name, int ... shape) {
        return this.one(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable one(String name, long ... shape) {
        return this.one(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable one(String name, DataType dataType, int ... shape) {
        return this.var(name, new ConstantInitScheme('f', 1.0), dataType, ArrayUtil.toLongArray((int[])shape));
    }

    public SDVariable one(String name, DataType dataType, long ... shape) {
        return this.var(name, new ConstantInitScheme('f', 1.0), dataType, shape);
    }

    public SDVariable zero(String name, long ... shape) {
        return this.zero(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable zero(String name, int ... shape) {
        return this.zero(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable zero(String name, DataType dataType, long ... shape) {
        return this.var(name, new ZeroInitScheme(), dataType, shape);
    }

    public SDVariable zero(String name, DataType dataType, int ... shape) {
        return this.var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray((int[])shape));
    }

    public SDVariable constant(@NonNull INDArray constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        return this.constant(this.getNewVarName(), constant);
    }

    public SDVariable constant(String name, @NonNull INDArray constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        Preconditions.checkState((!this.variables.containsKey(name) ? 1 : 0) != 0, (String)"Variable with name \"%s\" already exists", (Object)name);
        if (name == null || name.length() < 1) {
            name = this.getNewVarName();
        }
        SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
        this.variables.put(name, Variable.builder().name(name).variable(v).build());
        this.constantArrays.put(name, new DeviceLocalNDArray(constant));
        return v;
    }

    @Deprecated
    public SDVariable constant(SDVariable value, long ... shape) {
        return this.constant(null, value, shape);
    }

    @Deprecated
    public SDVariable constant(String name, SDVariable value, long ... shape) {
        SDVariable ret = this.f().constant(value, shape);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable placeHolder(String name, DataType dataType, long ... shape) {
        SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType, null);
        this.variables.put(name, Variable.builder().name(name).variable(ret).build());
        return ret;
    }

    public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (weightInitScheme == null) {
            throw new NullPointerException("weightInitScheme is marked @NonNull but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked @NonNull but is null");
        }
        if (shape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return this.var(name, VariableType.VARIABLE, weightInitScheme, dataType, shape);
    }

    public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (variableType == null) {
            throw new NullPointerException("variableType is marked @NonNull but is null");
        }
        if (this.variables.containsKey(name) && this.variables.get(name).getVariable().getArr() != null) {
            throw new IllegalArgumentException("Another variable with the name " + name + " already exists.");
        }
        if (name == null || name.length() < 1) {
            name = this.getNewVarName();
        }
        SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme);
        this.addVariable(ret);
        if (variableType == VariableType.PLACEHOLDER) {
            this.setOriginalPlaceHolderShape(name, shape);
            this.putShapeForVarName(name, shape);
        }
        return ret;
    }

    public SDVariable var(@NonNull String name, @NonNull LongShapeDescriptor shape, WeightInitScheme weightInitScheme) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (shape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return this.var(name, weightInitScheme, shape.dataType(), shape.getShape());
    }

    public SDVariable var(String name, DataType dataType, long ... shape) {
        Preconditions.checkNotNull((Object)(shape != null ? 1 : 0), (String)"Invalid shape: shape may not be null");
        if (Shape.isPlaceholderShape(shape)) {
            return this.placeHolder(name, dataType, shape);
        }
        return this.var(name, new ZeroInitScheme(), dataType, shape);
    }

    public SDVariable var(String name, LongShapeDescriptor shapeDesc) {
        Preconditions.checkNotNull((Object)(shapeDesc != null ? 1 : 0), (String)"Invalid shape: shape may not be null");
        return this.var(name, shapeDesc, new ZeroInitScheme());
    }

    public SDVariable var(String name, int ... shape) {
        return this.var(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable var(String name, long ... shape) {
        return this.var(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable var(String name, DataType dataType, int ... shape) {
        Preconditions.checkNotNull((Object)shape, (String)"Invalid shape: shape may not be null");
        if (Shape.isPlaceholderShape(shape)) {
            return this.placeHolder(name, dataType, ArrayUtil.toLongArray((int[])shape));
        }
        return this.var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray((int[])shape));
    }

    public SDVariable var(@NonNull SDVariable v) {
        if (v == null) {
            throw new NullPointerException("v is marked @NonNull but is null");
        }
        if (this.variables.containsKey(v.getVarName()) && this.variables.get(v.getVarName()).getVariable().getArr() != null) {
            return this.variables.get(v.getVarName()).getVariable();
        }
        if (v.getVarName() == null || v.getVarName().length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        VariableType vt = v.getVariableType();
        NDArraySupplierInitScheme s = null;
        switch (vt) {
            case VARIABLE: {
                s = new NDArraySupplierInitScheme(v.getArr());
            }
            case ARRAY: {
                SDVariable ret = new SDVariable(v.getVarName(), v.getVariableType(), this, v.getShape(), v.dataType(), s);
                return this.addVariable(ret);
            }
            case CONSTANT: {
                return this.constant(v.getVarName(), v.getArr());
            }
            case PLACEHOLDER: {
                return this.placeHolder(v.getVarName(), v.dataType(), v.placeholderShape());
            }
        }
        throw new RuntimeException("Unknown/not supported variable type: " + (Object)((Object)vt));
    }

    private String getNewVarName() {
        String varName = "sd_var_" + String.valueOf(this.variableId);
        while (this.variables.containsKey(varName)) {
            ++this.variableId;
            varName = "sd_var_" + String.valueOf(this.variableId);
        }
        return varName;
    }

    public SDVariable var(DataType dataType, int ... shape) {
        return this.var(this.getNewVarName(), dataType, shape);
    }

    public SDVariable var(DataType dataType, long ... shape) {
        return this.var(this.getNewVarName(), dataType, shape);
    }

    public SDVariable var(WeightInitScheme weightInitScheme, DataType dataType, long ... shape) {
        return this.var(this.getNewVarName(), weightInitScheme, dataType, shape);
    }

    public SDVariable var(INDArray arr) {
        return this.var(this.getNewVarName(), arr);
    }

    public SDVariable var(String name, @NonNull INDArray arr) {
        if (arr == null) {
            throw new NullPointerException("arr is marked @NonNull but is null");
        }
        if (this.variables.containsKey(name) && this.variables.get(name).getVariable().getArr() != null) {
            throw new IllegalArgumentException("Another variable with the name " + name + " already exists.");
        }
        Preconditions.checkState((boolean)arr.dataType().isFPType(), (String)"Cannot create variable with non-floating point type: provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\nFor non floating point types, these should be created as placeholders or constants instead.", (Object)arr.dataType());
        if (name == null || name.length() < 1) {
            name = this.getNewVarName();
        }
        boolean duped = false;
        if (arr.isAttached()) {
            arr = arr.detach();
            duped = true;
        }
        if (arr.isView()) {
            arr = arr.dup();
            duped = true;
        }
        if (!duped) {
            for (DeviceLocalNDArray otherArr : this.variablesArrays.values()) {
                if (otherArr.get() != arr) continue;
                arr = arr.dup();
                break;
            }
        }
        SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType(), new NDArraySupplierInitScheme(arr));
        this.associateArrayWithVariable(arr, ret);
        if (ArrayUtil.prod((long[])arr.shape()) == 1) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                ret.setScalarValue(Nd4j.scalar(arr.getDouble(0L)));
            }
        }
        this.addVariable(ret);
        if (this.getShapeForVarName(name) == null) {
            this.putShapeForVarName(name, arr.shape());
        }
        return ret;
    }

    public SDVariable convertToConstant(@NonNull SDVariable variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        this.convertToConstants(Collections.singletonList(variable));
        return variable;
    }

    public void convertToConstants(List<SDVariable> variables) {
        if (variables.size() == 0) {
            return;
        }
        boolean allConst = true;
        for (SDVariable variable : variables) {
            if (variable.getVariableType() == VariableType.CONSTANT) continue;
            allConst = false;
            Preconditions.checkState((variable.getVariableType() != VariableType.ARRAY ? 1 : 0) != 0, (String)"Cannot convert variable of type ARRAY to a constant: %s", (Object)variable);
        }
        if (allConst) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove("grad");
        for (SDVariable variable : variables) {
            String n = variable.getVarName();
            INDArray arr = variable.getArr();
            Preconditions.checkNotNull((Object)arr, (String)"Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", (Object)variable);
            this.constantArrays.put(n, new DeviceLocalNDArray(arr));
            this.variablesArrays.remove(n);
            if (!this.placeholdersPerThread.isEmpty()) {
                for (Map<String, INDArray> m : this.placeholdersPerThread.values()) {
                    m.remove(n);
                }
            }
            variable.setVariableType(VariableType.CONSTANT);
        }
        if (this.trainingConfig != null) {
            HashSet<String> toRemove = new HashSet<String>();
            boolean anyTrainableParmsModified = false;
            List<String> origTrainableParams = this.trainingConfig.getTrainableParams();
            for (SDVariable sDVariable : variables) {
                toRemove.add(sDVariable.getVarName());
                if (anyTrainableParmsModified || !origTrainableParams.contains(sDVariable.getVarName())) continue;
                anyTrainableParmsModified = true;
            }
            if (anyTrainableParmsModified) {
                ArrayList<String> newTrainableParams = new ArrayList<String>();
                for (String s : origTrainableParams) {
                    if (toRemove.contains(s)) continue;
                    newTrainableParams.add(s);
                }
                this.trainingConfig.setTrainableParams(newTrainableParams);
            }
            if (this.initializedTraining) {
                ArrayList<INDArray> newUpdaterState = new ArrayList<INDArray>();
                for (String s : origTrainableParams) {
                    INDArray stateArr = this.updaterViews.get(s);
                    if (toRemove.contains(s)) continue;
                    newUpdaterState.add(stateArr);
                }
                this.updaterState = newUpdaterState.isEmpty() ? null : Nd4j.concat(0, newUpdaterState.toArray(new INDArray[newUpdaterState.size()]));
                long l = 0L;
                this.updaterViews = new HashMap<String, INDArray>();
                this.updaterMap = new HashMap<String, GradientUpdater>();
                for (String s : this.trainingConfig.getTrainableParams()) {
                    long thisSize = this.trainingConfig.getUpdater().stateSize(this.variables.get(s).getVariable().getArr().length());
                    INDArray view = this.updaterState == null || thisSize == 0L ? null : this.updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(l, l + thisSize));
                    this.updaterViews.put(s, view);
                    this.updaterMap.put(s, this.trainingConfig.getUpdater().instantiate(view, false));
                    l += thisSize;
                }
            }
        }
    }

    public SDVariable convertToVariable(@NonNull SDVariable constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)constant.dataType().isFPType(), (String)"Only floating point SDVariables can be converted to variables, datatype of %s is %s", (Object)constant.getVarName(), (Object)constant.dataType());
        this.convertToVariables(Collections.singletonList(constant));
        return constant;
    }

    public void convertToVariables(@NonNull List<SDVariable> constants) {
        if (constants == null) {
            throw new NullPointerException("constants is marked @NonNull but is null");
        }
        if (constants.size() == 0) {
            return;
        }
        boolean allConst = true;
        for (SDVariable variable : constants) {
            if (variable.getVariableType() != VariableType.VARIABLE) {
                allConst = false;
            }
            Preconditions.checkState((variable.getVariableType() != VariableType.ARRAY ? 1 : 0) != 0, (String)"Cannot convert variable of type ARRAY to a variable: %s", (Object)variable);
        }
        if (allConst) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove("grad");
        for (SDVariable variable : constants) {
            String n = variable.getVarName();
            INDArray arr = variable.getArr();
            Preconditions.checkNotNull((Object)arr, (String)"Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", (Object)variable);
            this.variablesArrays.put(n, new DeviceLocalNDArray(arr));
            this.constantArrays.remove(n);
            if (!this.placeholdersPerThread.isEmpty()) {
                for (Map map : this.placeholdersPerThread.values()) {
                    map.remove(n);
                }
            }
            variable.setVariableType(VariableType.VARIABLE);
        }
        if (this.trainingConfig != null) {
            ArrayList<String> newTrainableParams = new ArrayList<String>(this.trainingConfig.getTrainableParams());
            ArrayList<String> convertedToVars = new ArrayList<String>();
            for (SDVariable v : constants) {
                newTrainableParams.add(v.getVarName());
                convertedToVars.add(v.getVarName());
            }
            this.trainingConfig.setTrainableParams(newTrainableParams);
            if (this.initializedTraining) {
                long extraStateSize = 0L;
                for (String string : convertedToVars) {
                    INDArray arr = this.getVariable(string).getArr();
                    long stateSize = this.trainingConfig.getUpdater().stateSize(arr.length());
                    extraStateSize += stateSize;
                }
                if (extraStateSize > 0L) {
                    INDArray newState = Nd4j.createUninitialized(this.updaterState.dataType(), 1L, extraStateSize);
                    this.updaterState = this.updaterState == null ? newState : Nd4j.concat(1, this.updaterState, newState);
                    long l = 0L;
                    this.updaterViews = new HashMap<String, INDArray>();
                    this.updaterMap = new HashMap<String, GradientUpdater>();
                    for (String s : this.trainingConfig.getTrainableParams()) {
                        long thisSize = this.trainingConfig.getUpdater().stateSize(this.variables.get(s).getVariable().getArr().length());
                        INDArray view = this.updaterState == null || thisSize == 0L ? null : this.updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(l, l + thisSize));
                        this.updaterViews.put(s, view);
                        boolean init = convertedToVars.contains(s);
                        this.updaterMap.put(s, this.trainingConfig.getUpdater().instantiate(view, init));
                        l += thisSize;
                    }
                }
            }
        }
    }

    public void removeArgFromFunction(String varName, DifferentialFunction function) {
        SDVariable[] args = function.args();
        for (int i = 0; i < args.length; ++i) {
            if (!args[i].getVarName().equals(varName)) continue;
            List<String> reverseArgs = this.ops.get(function.getOwnName()).getInputsToOp();
            ArrayList<String> newArgs = new ArrayList<String>(args.length - 1);
            for (int arg = 0; arg < args.length; ++arg) {
                if (reverseArgs.get(arg).equals(varName)) continue;
                newArgs.add(reverseArgs.get(arg));
            }
            this.ops.get(function.getOwnName()).setInputsToOp(newArgs);
            break;
        }
    }

    public SDVariable getVariable(String name) {
        Variable v = this.variables.get(name);
        return v == null ? null : v.getVariable();
    }

    public boolean hasVariable(String name) {
        return this.variables.containsKey(name);
    }

    public SDVariable getGradForVariable(String varName) {
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        Preconditions.checkState((boolean)v.dataType().isFPType(), (String)"Cannot get gradient of %s variable \"%s\": only floating point variables have gradients", (Object)varName, (Object)v.dataType());
        if (this.variables.containsKey(varName) && this.variables.get(varName).getGradient() != null) {
            return this.variables.get(varName).getGradient();
        }
        if (this.sameDiffFunctionInstances.containsKey("grad") && this.sameDiffFunctionInstances.get((Object)"grad").variables.containsKey(varName)) {
            return this.sameDiffFunctionInstances.get((Object)"grad").variables.get(varName).getGradient();
        }
        return null;
    }

    public boolean variableHasGradient(String varName) {
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        if (!v.dataType().isFPType() || v.isConstant()) {
            return false;
        }
        return this.getGradForVariable(varName) != null;
    }

    public void setGradientForVariableName(String variableName, SDVariable variable) {
        Preconditions.checkState((boolean)this.variables.containsKey(variableName), (String)"No variable exists with name \"%s\"", (Object)variableName);
        if (variable == null) {
            throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + variableName);
        }
        this.variables.get(variableName).setGradient(variable);
    }

    public void setForwardVariableForVarName(String varName, SDVariable forwardVariable) {
        this.forwardVarForGrad.put(varName, forwardVariable);
    }

    public SDVariable grad(String varName) {
        if (!this.sameDiffFunctionInstances.containsKey("grad")) {
            throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first.");
        }
        SameDiff grad = this.getFunction("grad");
        SDVariable var = grad.getVariable(varName);
        return this.getFunction("grad").getGradForVariable(var.getVarName());
    }

    public SDVariable scalar(String name, double value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, float value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, int value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, long value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, DataType dataType, Number value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(dataType, value));
            return sDVariable;
        }
    }

    public SDVariable constant(double value) {
        return this.constant(null, value);
    }

    public SDVariable constant(String name, double value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(float value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, float value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(int value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, int value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(long value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, long value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(String name, DataType dataType, Number value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(dataType, value));
            return sDVariable;
        }
    }

    public SDVariable addVariable(SDVariable variable) {
        Preconditions.checkState((variable.getSameDiff() == this ? 1 : 0) != 0, (String)"Samediff instance must be the same.");
        if (this.variables.containsKey(variable.getVarName()) && !this.variables.get(variable.getVarName()).getVariable().equals(variable)) {
            throw new IllegalArgumentException("Variable already found with variable opName " + variable.getVarName());
        }
        Preconditions.checkState((variable.getSameDiff() == this ? 1 : 0) != 0, (String)"Same diff instance for variable must be the same!");
        this.variables.put(variable.getVarName(), Variable.builder().name(variable.getVarName()).variable(variable).build());
        return variable;
    }

    @Override
    public String generateNewVarName(String baseName, int argIndex) {
        if (!this.variables.containsKey(baseName) && argIndex == 0) {
            return baseName;
        }
        int count = 0;
        String name = baseName + (count == 0 ? "" : "_" + count) + (argIndex > 0 ? ":" + argIndex : "");
        while (this.getVariable(name) != null) {
            name = baseName + "_" + ++count + (argIndex > 0 ? ":" + argIndex : "");
        }
        if (this.getVariable(name) != null) {
            throw new ND4JIllegalStateException("Converged on already generated variable!");
        }
        return name;
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport) {
        List<LongShapeDescriptor> outputShape;
        if (baseName == null || baseName.isEmpty() && this.getBaseNameForFunction(function) != null) {
            baseName = this.getBaseNameForFunction(function);
        }
        if (baseName == null) {
            baseName = function.opName();
        }
        List<DataType> outputDataTypes = null;
        if (!isImport) {
            ArrayList<DataType> inputDataTypes = new ArrayList<DataType>();
            List<String> fnInputs = this.ops.get(function.getOwnName()).getInputsToOp();
            if (fnInputs != null) {
                for (String var : fnInputs) {
                    inputDataTypes.add(this.variables.get(var).getVariable().dataType());
                }
            }
            outputDataTypes = function.calculateOutputDataTypes(inputDataTypes);
        }
        if ((outputShape = function.calculateOutputShape()) == null || outputShape.isEmpty()) {
            if (function instanceof CustomOp) {
                CustomOp customOp = (CustomOp)((Object)function);
                int num_outputs = function.getNumOutputs();
                if (num_outputs <= 0) {
                    CustomOpDescriptor descriptor = customOp.getDescriptor();
                    if (descriptor != null) {
                        num_outputs = descriptor.getNumOutputs();
                    }
                    if (num_outputs <= 0) {
                        throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override getNumOutputs() to specify number of outputs if required");
                    }
                }
                int ordering = 99;
                SDVariable[] args = function.args();
                if (args != null && args.length > 0 && args[0].getArr() != null) {
                    ordering = function.args()[0].getArr().ordering();
                }
                SDVariable[] ret = new SDVariable[num_outputs];
                Preconditions.checkState((isImport || num_outputs == 0 || outputDataTypes != null && outputDataTypes.size() == num_outputs ? 1 : 0) != 0, (String)"Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s)", (Object)(outputDataTypes == null ? null : Integer.valueOf(outputDataTypes.size())), (Object)num_outputs, outputDataTypes, (Object)function.getClass().getSimpleName());
                for (int i = 0; i < ret.length; ++i) {
                    SDVariable var;
                    SDVariable sDVariable = var = i == 0 ? this.getVariable(baseName) : this.getVariable(baseName + ":" + i);
                    if (var == null) {
                        DataType dataType = isImport ? null : outputDataTypes.get(i);
                        var = this.var(this.generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[])null);
                    }
                    var.setOutputIndex(i);
                    var.setCreator(function);
                    ret[i] = var;
                }
                if (this.getOutputsForFunction(function) == null) {
                    this.addOutgoingFor(ret, function);
                }
                return ret;
            }
            if (function instanceof BaseOp && outputShape.isEmpty()) {
                DataType dataType;
                SDVariable[] ret = new SDVariable[1];
                SDVariable checkGet = this.getVariable(baseName);
                int ordering = 99;
                SDVariable[] args = function.args();
                if (args != null && args.length > 0 && function.args()[0].getArr() != null) {
                    ordering = function.args()[0].getArr().ordering();
                }
                if (checkGet == null) {
                    dataType = outputDataTypes.get(0);
                    checkGet = this.var(baseName, VariableType.ARRAY, null, dataType, (long[])null);
                }
                if (checkGet == null) {
                    dataType = outputDataTypes.get(0);
                    checkGet = this.var(baseName, VariableType.ARRAY, null, dataType, (long[])null);
                }
                checkGet.setOutputIndex(0);
                checkGet.setCreator(function);
                ret[0] = checkGet;
                if (this.getOutputsForFunction(function) == null) {
                    this.addOutgoingFor(ret, function);
                }
                return ret;
            }
        }
        if (!isImport) {
            for (int i = 0; i < outputShape.size(); ++i) {
                DataType shapeDataType = outputShape.get(i).dataType();
                DataType calcType = outputDataTypes.get(i);
                Preconditions.checkState((calcType == shapeDataType ? 1 : 0) != 0, (String)"Calculated output data types do not match for shape calculation vs. datatype calculation: %s vs %s for op %s output %s", (Object)shapeDataType, (Object)calcType, (Object)function.getClass().getName(), (Object)i);
            }
        }
        char ordering = 'c';
        if (function.args() != null && function.args().length > 0 && function.args()[0].getArr() != null) {
            ordering = function.args()[0].getArr().ordering();
        }
        SDVariable[] ret = new SDVariable[outputShape.size()];
        String ownName = function.getOwnName();
        String rootName = baseName;
        for (int i = 0; i < ret.length; ++i) {
            LongShapeDescriptor shape = outputShape.get(i);
            baseName = rootName + (i > 0 ? ":" + i : "");
            SDVariable checkGet = this.getVariable(baseName);
            if (checkGet == null) {
                checkGet = this.var(baseName, VariableType.ARRAY, null, shape.dataType(), shape.getShape());
            } else if (shape != null && !this.shapeAlreadyExistsForVarName(checkGet.getVarName())) {
                this.putShapeForVarName(checkGet.getVarName(), shape);
            } else if (shape == null || this.shapeAlreadyExistsForVarName(checkGet.getVarName())) {
                // empty if block
            }
            if (checkGet == null) {
                DataType dataType = DataType.FLOAT;
                checkGet = this.var(baseName + (i > 0 ? ":" + i : ""), new ZeroInitScheme(ordering), dataType, shape.getShape());
            }
            checkGet.setOutputIndex(i);
            checkGet.setCreator(function);
            ret[i] = checkGet;
        }
        return ret;
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
        return this.generateOutputVariableForOp(function, function.opName(), false);
    }

    public SameDiff getFunction(String functionName) {
        return this.sameDiffFunctionInstances.get(functionName);
    }

    public While whileStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition loopBody, SDVariable[] inputVars) {
        return While.builder().inputVars(inputVars).condition(conditionBody).predicate(sameDiffConditional).trueBody(loopBody).parent(this).blockName("while-" + UUID.randomUUID().toString()).build();
    }

    public If ifStatement(SameDiffConditional conditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition trueBody, SameDiffFunctionDefinition falseBody, SDVariable[] inputVars) {
        return If.builder().conditionBody(conditionBody).falseBody(falseBody).trueBody(trueBody).predicate(conditional).inputVars(inputVars).parent(this).blockName("if-" + UUID.randomUUID().toString()).build();
    }

    public TensorArray tensorArray(DataType dataType) {
        TensorArray ta = new TensorArray(this, dataType);
        SDVariable[] outVars = ta.outputVariables();
        return ta;
    }

    public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
        SameDiff instance = this.sameDiffFunctionInstances.get(functionName);
        SDVariable ret = instance.invokeGraphOn(with);
        return ret;
    }

    public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables) {
        if (!this.sameDiffFunctionInstances.containsKey(function)) {
            SameDiff sub;
            this.child = sub = SameDiff.create();
            sub.parent = this;
            SDVariable[] ret = new SDVariable[variables.length];
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = sub.var(variables[i]);
            }
            functionDefinition.define(sub, null, ret);
            this.sameDiffFunctionInstances.put(function, sub);
        }
        this.child = null;
        return this.sameDiffFunctionInstances.get(function);
    }

    public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition) {
        this.defineFunction(function, functionDefinition, new LinkedHashMap<String, INDArray>());
    }

    public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map<String, INDArray> inputs) {
        if (!this.sameDiffFunctionInstances.containsKey(function)) {
            SameDiff sub = SameDiff.create();
            functionDefinition.define(sub, inputs, null);
            this.sameDiffFunctionInstances.put(function, sub);
        }
    }

    @Deprecated
    public INDArray execAndEndResult() {
        List<String> outputs = this.outputs();
        Preconditions.checkState((outputs.size() == 1 ? 1 : 0) != 0, (String)"Method can only be used with SameDiff instances with a single output");
        long tid = Thread.currentThread().getId();
        Map<String, INDArray> placeholders = this.placeholdersPerThread.get(tid);
        return this.execSingle(placeholders, outputs.get(0));
    }

    public void execBackwards(Map<String, INDArray> placeholders) {
        if (this.getFunction("grad") == null) {
            this.createGradFunction();
        }
        HashSet<String> varGradNames = new HashSet<String>();
        for (Variable v : this.variables.values()) {
            SDVariable g;
            if (v.getVariable().getVariableType() != VariableType.VARIABLE || (g = v.getVariable().gradient()) == null) continue;
            varGradNames.add(g.getVarName());
        }
        if (varGradNames.isEmpty()) {
            log.warn("Skipping gradient execution (backward pass) - no variables to be calculated (graph does not contain any VARIABLE type SDVariables).\nIf gradients for other variables (such as placeholders) are required, use execBackwards(Map, List) instead");
            return;
        }
        ArrayList<String> vargradNamesList = new ArrayList<String>(varGradNames);
        this.execBackwards(placeholders, vargradNamesList);
    }

    public void execBackwards(Map<String, INDArray> placeholders, String ... variableGradNamesList) {
        this.execBackwards(placeholders, Arrays.asList(variableGradNamesList));
    }

    public void execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList) {
        if (this.getFunction("grad") == null) {
            this.createGradFunction();
        }
        log.trace("About to execute backward function");
        if (variableGradNamesList.isEmpty()) {
            log.warn("Skipping gradient calculation (backward pass) - no variables to be calculated (variableGradNamesList is empty)");
            return;
        }
        this.sameDiffFunctionInstances.get("grad").exec(placeholders, variableGradNamesList);
    }

    public void createGradFunction() {
        if (this.lossVariables.isEmpty()) {
            if (this.trainingConfig != null && this.trainingConfig.getLossVariables() != null && !this.trainingConfig.getLossVariables().isEmpty()) {
                this.lossVariables.addAll(this.trainingConfig.getLossVariables());
            } else {
                List<String> outputs = this.outputs();
                if (outputs.size() == 1) {
                    String outName = outputs.get(0);
                    String opName = this.variables.get(outName).getOutputOfOp();
                    if (opName == null || !(this.ops.get(opName).getOp() instanceof ExternalErrorsFunction)) {
                        log.info("Inferring output \"{}\" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override", (Object)outputs.get(0));
                    }
                    this.lossVariables.add(outputs.get(0));
                }
            }
        }
        Preconditions.checkState((!this.lossVariables.isEmpty() ? 1 : 0) != 0, (String)"Cannot create gradient function: No loss variables (variables to minimize) have been specified. Loss variables are the variables that represent the loss/cost/score to be minimized during training, and that all gradients are calculated with respect to.\n Losses can be specified either in TrainingConfiguration (Builder.minimize(...)) or via SameDiff.setLossVariables()/addLossVariable()");
        if (log.isTraceEnabled()) {
            log.trace("Defining function \"grad\"");
        }
        final SameDiff outer = this;
        this.defineFunction("grad", new SameDiffFunctionDefinition(){

            @Override
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
                List<String> inputsToOp;
                ArrayList allFunctions;
                if (SameDiff.this.debugMode) {
                    sameDiff.enableDebugMode();
                }
                outer.invokeGraphOn(sameDiff);
                if (SameDiff.this.debugMode) {
                    Preconditions.checkState((boolean)sameDiff.ops.keySet().equals(SameDiff.this.ops.keySet()), (String)"ops keysets not equal");
                }
                if ((allFunctions = new ArrayList(sameDiff.ops.values())).isEmpty()) {
                    throw new ND4JIllegalStateException("No ops found!");
                }
                for (SameDiffOp op : allFunctions) {
                    SDVariable[] outputs;
                    SDVariable[] args;
                    DifferentialFunction func = op.getOp();
                    if (func instanceof SDVariable) continue;
                    for (SDVariable arg : args = func.args()) {
                        arg.setSameDiff(sameDiff);
                    }
                    for (SDVariable output : outputs = func.outputVariables()) {
                        output.setSameDiff(sameDiff);
                    }
                    func.setSameDiff(sameDiff);
                }
                ArrayList<Object> finalOutputs = new ArrayList<Object>(SameDiff.this.lossVariables.size());
                SDVariable initialGrad = sameDiff.var("one-var", Nd4j.scalar(1.0f));
                for (String s : SameDiff.this.lossVariables) {
                    Preconditions.checkNotNull((Object)s, (String)"Encountered null value in loss variables. Null loss variables are not allowed. Use SameDiff.setLossVariables with non-null array names to fix");
                    Preconditions.checkState((boolean)SameDiff.this.variables.containsKey(s), (String)"Specified loss function variable \"%s\" does not exist", (Object)s);
                    Object v = ((Variable)SameDiff.this.variables.get(s)).getVariable();
                    Preconditions.checkState((boolean)((SDVariable)v).dataType().isFPType(), (String)"Specified loss function variable \"%s\" is not a floatingpoint variable (datatype: %s). Only floating point variables may be used as loss function variable", (Object)s, (Object)((SDVariable)v).dataType());
                    v = ((SDVariable)v).sum(new int[0]);
                    if (((SDVariable)v).dataType() == initialGrad.dataType()) {
                        sameDiff.setGradientForVariableName(((SDVariable)v).getVarName(), initialGrad);
                    } else {
                        sameDiff.setGradientForVariableName(((SDVariable)v).getVarName(), initialGrad.castTo(((SDVariable)v).dataType()));
                    }
                    if (finalOutputs.contains(v)) {
                        log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", (Object)s);
                        continue;
                    }
                    finalOutputs.add(v);
                }
                if (log.isTraceEnabled()) {
                    String s;
                    Object[] initialOutputsStr = ((SameDiffOp)allFunctions.get(allFunctions.size() - 1)).getOp().outputVariablesNames();
                    s = initialOutputsStr == null ? "null" : Arrays.toString(initialOutputsStr);
                    log.trace("Defining backward function: initial outputs {}", (Object)s);
                }
                HashSet<String> allFpVarsConnectedToLoss = new HashSet<String>();
                LinkedList<String> toProcess = new LinkedList<String>();
                for (String s : SameDiff.this.lossVariables) {
                    if (toProcess.contains(s)) continue;
                    toProcess.add(s);
                }
                while (!toProcess.isEmpty()) {
                    Variable v;
                    String next = (String)toProcess.remove();
                    if (allFpVarsConnectedToLoss.contains(next) || !(v = (Variable)SameDiff.this.variables.get(next)).getVariable().dataType().isFPType()) continue;
                    allFpVarsConnectedToLoss.add(v.getName());
                    if (v.getOutputOfOp() == null) continue;
                    String opName = v.getOutputOfOp();
                    SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                    List<String> opInputs = op.getInputsToOp();
                    if (opInputs == null) continue;
                    for (String s : opInputs) {
                        Variable inputVar = (Variable)SameDiff.this.variables.get(s);
                        if (!inputVar.getVariable().dataType().isFPType()) continue;
                        toProcess.add(s);
                    }
                }
                HashSet minimalSubgraphVars = new HashSet(allFpVarsConnectedToLoss);
                LinkedList<String> leafFPVars = new LinkedList<String>();
                for (String s : allFpVarsConnectedToLoss) {
                    Variable v = (Variable)SameDiff.this.variables.get(s);
                    if (v.getVariable().getVariableType() == VariableType.ARRAY) {
                        String opName = v.getOutputOfOp();
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        inputsToOp = op.getInputsToOp();
                        boolean anyInputsInSubgraph = false;
                        if (inputsToOp != null) {
                            for (String string : inputsToOp) {
                                if (!allFpVarsConnectedToLoss.contains(string)) continue;
                                anyInputsInSubgraph = true;
                                break;
                            }
                        }
                        if (!anyInputsInSubgraph) {
                            leafFPVars.add(s);
                        }
                    }
                    if (v.getVariable().getVariableType() != VariableType.CONSTANT && v.getVariable().getVariableType() != VariableType.PLACEHOLDER) continue;
                    leafFPVars.add(s);
                }
                while (!leafFPVars.isEmpty()) {
                    String nextLeaf = (String)leafFPVars.remove();
                    Variable v = (Variable)SameDiff.this.variables.get(nextLeaf);
                    minimalSubgraphVars.remove(nextLeaf);
                    List<String> inputsTo = v.getInputsForOp();
                    if (inputsTo == null || inputsTo.isEmpty()) continue;
                    for (String opName : inputsTo) {
                        List<String> list;
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        List<String> inputsToOp2 = op.getInputsToOp();
                        boolean anyPresent = false;
                        for (String string : inputsToOp2) {
                            if (!minimalSubgraphVars.contains(string)) continue;
                            anyPresent = true;
                            break;
                        }
                        if (anyPresent || (list = op.getOutputsOfOp()) == null) continue;
                        for (String s3 : list) {
                            if (leafFPVars.contains(s3)) continue;
                            leafFPVars.add(s3);
                        }
                    }
                }
                Preconditions.checkState((!minimalSubgraphVars.isEmpty() ? 1 : 0) != 0, (String)"Cannot differentiate graph relative to the specified loss function variables %s: graph does not contain any trainable SDVariables (floating point VARIABLE type SDVariables) that the loss function depend on.", (Object)SameDiff.this.lossVariables);
                LinkedList<String> availableForDiff = new LinkedList<String>();
                for (Object lossVar : finalOutputs) {
                    String opName;
                    Variable v = (Variable)sameDiff.variables.get(((SDVariable)lossVar).getVarName());
                    if (v.getOutputOfOp() == null) continue;
                    opName = v.getOutputOfOp();
                    availableForDiff.add(opName);
                }
                HashMap prerequisites = new HashMap();
                for (String var : minimalSubgraphVars) {
                    Variable variable = (Variable)SameDiff.this.variables.get(var);
                    List<String> inputsForOp = variable.getInputsForOp();
                    if (inputsForOp == null) continue;
                    ArrayList<String> req = new ArrayList<String>();
                    for (String string : inputsForOp) {
                        SameDiffOp sameDiffOp = (SameDiffOp)SameDiff.this.ops.get(string);
                        List<String> opOutputs = sameDiffOp.getOutputsOfOp();
                        boolean anyOpOutputsRequired = false;
                        if (opOutputs != null) {
                            for (String s : opOutputs) {
                                if (!minimalSubgraphVars.contains(s)) continue;
                                anyOpOutputsRequired = true;
                                break;
                            }
                        }
                        if (!anyOpOutputsRequired) continue;
                        req.add(string);
                    }
                    prerequisites.put(variable.getName(), req);
                }
                HashSet<String> differentiatedOps = new HashSet<String>();
                while (!availableForDiff.isEmpty()) {
                    List<Object> outputsOfOp;
                    String dfName = (String)availableForDiff.remove();
                    DifferentialFunction df = ((SameDiffOp)sameDiff.ops.get(dfName)).getOp();
                    if (df instanceof GradientBackwardsMarker) {
                        SameDiffOp op = (SameDiffOp)sameDiff.ops.get(df.getOwnName());
                        inputsToOp = op.getInputsToOp();
                        outputsOfOp = Collections.emptyList();
                    } else {
                        inputsToOp = ((SameDiffOp)sameDiff.ops.get(df.getOwnName())).getInputsToOp();
                        outputsOfOp = ((SameDiffOp)sameDiff.ops.get(df.getOwnName())).getOutputsOfOp();
                    }
                    ArrayList<SDVariable> grads = new ArrayList<SDVariable>();
                    for (String string : outputsOfOp) {
                        SDVariable g;
                        SDVariable v = sameDiff.getVariable(string);
                        SDVariable sDVariable = g = v.hasGradient() ? v.gradient() : null;
                        if (g == null) {
                            if (!v.dataType().isFPType()) {
                                grads.add(null);
                                continue;
                            }
                            SDVariable gTemp = sameDiff.zerosLike(v);
                            grads.add(gTemp);
                            continue;
                        }
                        grads.add(g);
                    }
                    List<SDVariable> list = df.diff(grads);
                    differentiatedOps.add(df.getOwnName());
                    for (String s : inputsToOp) {
                        Variable v = (Variable)sameDiff.variables.get(s);
                        String opName = v.getOutputOfOp();
                        if (opName == null || differentiatedOps.contains(opName)) continue;
                        boolean isRequiredOp = false;
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        if (op.getInputsToOp() != null) {
                            List<String> opInputs = op.getInputsToOp();
                            boolean anyInputsRequired = false;
                            for (String s2 : opInputs) {
                                if (!minimalSubgraphVars.contains(s2)) continue;
                                anyInputsRequired = true;
                                break;
                            }
                            if (anyInputsRequired && !differentiatedOps.contains(op.getName())) {
                                isRequiredOp = true;
                            }
                        }
                        if (!isRequiredOp) continue;
                        boolean allAvailable = true;
                        SameDiffOp o = (SameDiffOp)sameDiff.ops.get(opName);
                        for (String opOutput : o.getOutputsOfOp()) {
                            Variable outVar = (Variable)SameDiff.this.variables.get(opOutput);
                            if (!outVar.getVariable().dataType().isFPType() || !minimalSubgraphVars.contains(outVar.getName())) continue;
                            if (outVar.getVariable().gradient() == null) {
                                allAvailable = false;
                                break;
                            }
                            List prereqs = (List)prerequisites.get(outVar.getName());
                            if (prereqs == null || (allAvailable &= differentiatedOps.containsAll(prereqs))) continue;
                            break;
                        }
                        if (!allAvailable || availableForDiff.contains(o.getOp().getOwnName())) continue;
                        availableForDiff.add(o.getOp().getOwnName());
                    }
                }
                for (String s : minimalSubgraphVars) {
                    SDVariable g;
                    if (SameDiff.this.lossVariables.contains(s) || (g = ((Variable)SameDiff.this.variables.get(s)).getVariable().gradient()) != null) continue;
                    throw new IllegalStateException("Error encountered during differentiation: no gradient for required variable \"" + s + "\" was calculated");
                }
                return new SDVariable[]{sameDiff.var("grad", DataType.FLOAT, 1)};
            }
        });
        this.associateSameDiffWithOpsAndVariables();
    }

    public void setOriginalPlaceHolderShape(String variableName, long[] shape) {
        if (!this.isPlaceHolder(variableName)) {
            throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
        }
        if (shape == null) {
            throw new ND4JIllegalStateException("Null and 0 length shape arrays not allowed");
        }
        if (this.placeHolderOriginalShapes.containsKey(variableName) && !Arrays.equals(this.placeHolderOriginalShapes.get(variableName), shape)) {
            throw new ND4JIllegalStateException("Unable to add a new shape for vertex id " + variableName);
        }
        this.placeHolderOriginalShapes.put(variableName, shape);
    }

    @Deprecated
    public long[] getOriginalShapeForPlaceHolder(String varName) {
        return this.placeHolderOriginalShapes.get(varName);
    }

    public boolean isPlaceHolder(String varName) {
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable present in SameDiff instance with name \"%s\"", (Object)varName);
        return this.variables.get(varName).getVariable().isPlaceHolder();
    }

    public void resolveVariablesWith(Map<String, INDArray> arrays) {
        for (Map.Entry<String, INDArray> e : arrays.entrySet()) {
            long[] newShape;
            SDVariable varForName = this.getVariable(e.getKey());
            if (varForName == null) {
                throw new ND4JIllegalStateException("No variable name found for " + e.getKey());
            }
            Variable v = this.variables.get(e.getKey());
            if (varForName.getVariableType() != VariableType.PLACEHOLDER) continue;
            long[] shape = varForName.placeholderShape();
            Preconditions.checkState((shape.length == (newShape = e.getValue().shape()).length ? 1 : 0) != 0, (String)"Placeholder shape not compatible (mismatched rank): placeholder \"%s\" shape %s, got incompatible shape %s", (Object)e.getKey(), (Object)shape, (Object)newShape);
        }
        for (Map.Entry<String, INDArray> entry : arrays.entrySet()) {
            if (!this.variables.get(entry.getKey()).getVariable().isPlaceHolder()) {
                throw new ND4JIllegalStateException("Illegal variable " + entry.getKey() + " passed in. Variable found not to be a place holder variable");
            }
            long[] specifiedShape = this.getOriginalShapeForPlaceHolder(entry.getKey());
            if (!Shape.isPlaceholderShape(specifiedShape) && !Shape.shapeEquals(specifiedShape, entry.getValue().shape())) {
                throw new ND4JIllegalStateException("Place holder shape specified was " + Arrays.toString(specifiedShape) + " but array shape was " + Arrays.toString(entry.getValue().shape()));
            }
            this.associateArrayWithVariable(entry.getValue(), this.getVariable(entry.getKey()));
            this.setArrayForVariable(entry.getKey(), entry.getValue());
        }
        this.resolvedVariables = true;
    }

    @Override
    public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
        if (varToUpdate == null) {
            throw new NullPointerException("Null input: No variable found for updating!");
        }
        if (newVarName != null && this.variables.containsKey(newVarName) && varToUpdate != this.variables.get(newVarName).getVariable()) {
            throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable");
        }
        if (newVarName == null && this.variables.containsKey(varToUpdate.getVarName())) {
            newVarName = this.generateNewVarName(varToUpdate.getVarName(), 0);
        }
        if (newVarName == null || varToUpdate.getVarName().equals(newVarName)) {
            return varToUpdate;
        }
        String oldVarName = varToUpdate.getVarName();
        varToUpdate.setVarName(newVarName);
        this.updateVariableName(oldVarName, newVarName);
        return varToUpdate;
    }

    @Override
    protected SameDiff sd() {
        return this;
    }

    @Override
    public SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames) {
        int numVariables = variablesToUpdate.length;
        SDVariable[] updatedVariables = new SDVariable[numVariables];
        for (int i = 0; i < numVariables; ++i) {
            SDVariable varToUpdate = variablesToUpdate[i];
            String name = newVariableNames == null ? null : newVariableNames[i];
            updatedVariables[i] = this.updateVariableNameAndReference(varToUpdate, name);
        }
        return updatedVariables;
    }

    protected void associateSameDiffWithOpsAndVariables() {
        for (SDVariable var : this.variableMap().values()) {
            var.setSameDiff(this);
        }
        for (SameDiffOp op : this.ops.values()) {
            SDVariable[] outputs;
            DifferentialFunction df = op.getOp();
            df.setSameDiff(this);
            SDVariable[] args = df.args();
            if (args != null) {
                for (SDVariable arg : args) {
                    arg.setSameDiff(this);
                }
            }
            if ((outputs = df.outputVariables()) == null) continue;
            for (SDVariable out : outputs) {
                out.setSameDiff(this);
            }
        }
    }

    public Map<String, INDArray> execAll(Map<String, INDArray> placeholders) {
        ArrayList<String> allVars = new ArrayList<String>();
        for (Variable v : this.variables.values()) {
            allVars.add(v.getName());
        }
        return this.exec(placeholders, allVars.toArray(new String[allVars.size()]));
    }

    public INDArray execSingle(Map<String, INDArray> placeholders, String output) {
        return this.exec(placeholders, output).get(output);
    }

    public Map<String, INDArray> exec(Map<String, INDArray> placeholders, List<String> outputs) {
        return this.exec(placeholders, outputs.toArray(new String[outputs.size()]));
    }

    public Map<String, INDArray> exec(Map<String, INDArray> placeholders, String ... outputs) {
        Preconditions.checkState((outputs != null && outputs.length > 0 ? 1 : 0) != 0, (String)"No outputs were specified");
        long threadId = Thread.currentThread().getId();
        if (!this.sessions.containsKey(threadId)) {
            log.info("Creating new InferenceSession for thread {}", (Object)threadId);
            this.sessions.put(threadId, new InferenceSession(this));
        }
        List<String> phNames = this.inputs();
        if (placeholders == null && phNames != null) {
            placeholders = this.placeholdersPerThread.get(Thread.currentThread().getId());
        }
        if (phNames != null && phNames.size() > 0) {
            Preconditions.checkNotNull(placeholders, (String)"No placeholders were provided. Network has placeholders: %s", phNames);
            for (String s : phNames) {
                Preconditions.checkState((boolean)placeholders.containsKey(s), (String)"No placeholder variable was provided for variable \"%s\". Cannot execute without all placeholders set", (Object)s);
            }
        }
        InferenceSession is = this.sessions.get(threadId);
        Map<String, INDArray> ret = is.output(Arrays.asList(outputs), placeholders);
        return ret;
    }

    protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) {
        if (scope == null) {
            throw new NullPointerException("scope is marked @NonNull but is null");
        }
        if (bufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked @NonNull but is null");
        }
        int scopeName = bufferBuilder.createString((CharSequence)name);
        int flatNode = FlatNode.createFlatNode(bufferBuilder, scopeName, scopeName, (byte)119, 10L, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0);
        return flatNode;
    }

    public static Pair<String, Integer> parseVariable(@NonNull String varName) {
        if (varName == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        if (!varName.contains(":")) {
            return Pair.pairOf((Object)varName, (Object)0);
        }
        String[] split = varName.split(":");
        Integer index = Integer.valueOf(split[split.length - 1]);
        if (split.length == 2) {
            return Pair.pairOf((Object)split[0], (Object)index);
        }
        StringBuilder builder = new StringBuilder();
        for (int e = 0; e < split.length - 1; ++e) {
            builder.append(split[e]);
            if (e >= split.length - 2) continue;
            builder.append(":");
        }
        return Pair.pairOf((Object)builder.toString(), (Object)index);
    }

    protected int asFlatNode(@NonNull DifferentialFunction node, @NonNull FlatBufferBuilder bufferBuilder, List<SDVariable> variables, Map<String, Integer> reverseMap, Map<String, Integer> forwardMap, Map<String, Integer> framesMap, AtomicInteger idCounter, Integer id) {
        List<String> outVarNames;
        ScalarOp sOp;
        INDArray s;
        int[] dims;
        String[] outNames;
        SDVariable[] inputs;
        Op op;
        double[] extras;
        if (node == null) {
            throw new NullPointerException("node is marked @NonNull but is null");
        }
        if (bufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked @NonNull but is null");
        }
        String opName = node.opName();
        long hash = FlatBuffersMapper.getOpNum(node.opName(), node.opType());
        if (node.opType() == Op.Type.CUSTOM) {
            CustomOp op2 = (CustomOp)((Object)node);
            extras = op2.tArgs();
        } else {
            extras = node.getExtraArgs() != null ? new double[node.getExtraArgs().length] : new double[]{};
            for (int e = 0; e < extras.length; ++e) {
                extras[e] = ((Number)node.getExtraArgs()[e]).doubleValue();
            }
        }
        boolean[] boolArgs = null;
        long[] extraBits = null;
        if (node.opType() == Op.Type.CUSTOM) {
            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp)node;
            extraBits = dynamicCustomOp.iArgs();
            boolArgs = dynamicCustomOp.bArgs();
        } else if (node instanceof Enter) {
            String frameName = ((Enter)node).getFrameName();
            if (!framesMap.containsKey(frameName)) {
                framesMap.put(frameName, idCounter.incrementAndGet());
            }
            extraBits = new long[]{framesMap.get(frameName).intValue()};
        } else {
            extraBits = new long[]{};
        }
        if (node.opType() == Op.Type.REDUCE_BOOL || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_LONG) {
            op = (ReduceOp)((Object)node);
            boolArgs = new boolean[]{op.isKeepDims(), true};
        } else if (node.opType() == Op.Type.INDEXREDUCE) {
            op = (IndexAccumulation)((Object)node);
            boolArgs = new boolean[]{op.isKeepDims(), true};
        }
        ArrayList<Integer> inPaired = new ArrayList<Integer>();
        int[] outputIds = null;
        SDVariable[] outputVertexId = null;
        try {
            outputVertexId = node.outputVariables();
            outputIds = new int[outputVertexId.length];
            for (int i = 0; i < outputIds.length; ++i) {
                outputIds[i] = variables.indexOf(outputVertexId[i]);
            }
        }
        catch (ND4UnresolvedOutputVariables e) {
            outputIds = new int[]{};
            outputVertexId = null;
        }
        catch (Exception e) {
            throw new ND4JIllegalStateException(e);
        }
        for (SDVariable input : inputs = node.args()) {
            int outIdx;
            String varName = input.getVarName();
            if (this.variables.get(varName).getOutputOfOp() != null) {
                DifferentialFunction df = this.ops.get(this.variables.get(varName).getOutputOfOp()).getOp();
                outIdx = this.ops.get(df.getOwnName()).getOutputsOfOp().indexOf(varName);
            } else {
                outIdx = 0;
            }
            if (!reverseMap.containsKey(varName)) {
                if (varName.contains("NextIteration")) {
                    int fwdNodeId = idCounter.incrementAndGet();
                    forwardMap.put(varName, fwdNodeId);
                    reverseMap.put(varName, fwdNodeId);
                } else {
                    throw new ND4JIllegalStateException("Unknown variable used in input: [" + varName + "]");
                }
            }
            int nodeId = reverseMap.get(varName);
            inPaired.add(IntPair.createIntPair(bufferBuilder, nodeId, outIdx));
        }
        log.trace("Own Name: {}", (Object)node.getOwnName());
        int ownId = id != null ? id.intValue() : idCounter.incrementAndGet();
        for (String s2 : outNames = node.outputVariablesNames()) {
            if (reverseMap.containsKey(s2)) continue;
            reverseMap.put(s2, ownId);
        }
        if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL || node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) {
            dims = node.getDimensions();
            if (dims == null) {
                dims = new int[]{};
            }
        } else {
            dims = new int[]{};
        }
        Map<String, Object> fnProps = node.propertiesForFunction();
        int[] flatProperties = FlatBuffersMapper.mapFunctionPropertiesToFlatProperties(bufferBuilder, fnProps);
        int propIdx = FlatNode.createPropertiesVector(bufferBuilder, flatProperties);
        int nodesIn = FlatNode.createInputVector(bufferBuilder, new int[0]);
        int nodesInPaired = FlatNode.createInputPairedVector(bufferBuilder, Ints.toArray(inPaired));
        int nodesOut = FlatNode.createOutputVector(bufferBuilder, outputIds);
        int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras);
        int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits);
        int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[]{});
        int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims);
        int fname = bufferBuilder.createString((CharSequence)node.getOwnName());
        int scopeName = bufferBuilder.createString((CharSequence)"");
        int scalar = 0;
        if (node instanceof ScalarOp && (s = (sOp = (ScalarOp)((Object)node)).scalar()) != null) {
            scalar = s.toFlatArray(bufferBuilder);
        }
        if (node.opType() == null) {
            log.warn("Null-op node: {}", (Object)node);
        }
        int[] outVarNamesStringsOffsets = new int[(outVarNames = node.getSameDiff().ops.get(node.getOwnName()).getOutputsOfOp()) == null ? 0 : outVarNames.size()];
        for (int i = 0; i < outVarNamesStringsOffsets.length; ++i) {
            outVarNamesStringsOffsets[i] = bufferBuilder.createString((CharSequence)outVarNames.get(i));
        }
        int outVarNamesOffset = FlatNode.createOutputNamesVector(bufferBuilder, outVarNamesStringsOffsets);
        int opNameOffset = bufferBuilder.createString((CharSequence)opName);
        byte[] outTypes = new byte[outVarNames.size()];
        int i = 0;
        for (String s3 : outVarNames) {
            SDVariable v = this.getVariable(s3);
            outTypes[i++] = FlatBuffersMapper.getDataTypeAsByte(v.dataType());
        }
        int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes);
        int flatNode = FlatNode.createFlatNode(bufferBuilder, ownId, fname, FlatBuffersMapper.getFlatOpType(node.opType()), hash, propIdx, nodesIn, nodesInPaired, nodesOut, extraz, integerArgs, bArgs, dimensions, -1, 0, scopeName, outVarNamesOffset, opNameOffset, outTypesOffset, scalar);
        return flatNode;
    }

    public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration configuration) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        return this.asFlatBuffers(0L, configuration);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration configuration) {
        int flatVariable;
        byte varType;
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        Nd4j.getExecutioner().commit();
        FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(1024);
        AtomicInteger idCounter = new AtomicInteger(0);
        ArrayList<Integer> flatVariables = new ArrayList<Integer>();
        ArrayList flatOffsets = new ArrayList();
        ArrayList<Integer> flatNodes = new ArrayList<Integer>();
        ArrayList<SDVariable> variableList = new ArrayList<SDVariable>(this.variables());
        LinkedHashMap<String, Integer> reverseMap = new LinkedHashMap<String, Integer>();
        LinkedHashMap<String, Integer> forwardMap = new LinkedHashMap<String, Integer>();
        LinkedHashMap<String, Integer> framesMap = new LinkedHashMap<String, Integer>();
        int idx = 0;
        IdentityHashMap<DifferentialFunction, Integer> idxForOps = new IdentityHashMap<DifferentialFunction, Integer>();
        List<SDVariable> allVars = this.variables();
        for (SDVariable sDVariable : allVars) {
            int outputNum;
            int varIdx;
            INDArray arr = sDVariable.getArr();
            log.trace("Exporting variable: [{}]", (Object)sDVariable.getVarName());
            String varName = sDVariable.getVarName();
            if (this.variables.get(varName).getOutputOfOp() != null) {
                DifferentialFunction df = this.ops.get(this.variables.get(varName).getOutputOfOp()).getOp();
                if (!idxForOps.containsKey(df)) {
                    varIdx = idCounter.incrementAndGet();
                    idxForOps.put(df, varIdx);
                } else {
                    varIdx = (Integer)idxForOps.get(df);
                }
                Object[] outNames = df.outputVariablesNames();
                outputNum = ArrayUtils.indexOf((Object[])outNames, (Object)varName);
                Preconditions.checkState((outputNum >= 0 ? 1 : 0) != 0, (String)"Variable name \"%s\" not found in list of outputs: %s", (Object)varName, (Object)outNames);
            } else {
                varIdx = idCounter.incrementAndGet();
                outputNum = 0;
            }
            reverseMap.put(sDVariable.getVarName(), varIdx);
            log.trace("Adding [{}] as [{}]", (Object)sDVariable.getVarName(), (Object)varIdx);
            int shape = 0;
            int name = bufferBuilder.createString((CharSequence)sDVariable.getVarName());
            int array = arr == null ? 0 : arr.toFlatArray(bufferBuilder);
            int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum);
            varType = (byte)sDVariable.getVariableType().ordinal();
            if (sDVariable.getVariableType() == VariableType.PLACEHOLDER) {
                long[] shp = sDVariable.getShape();
                shape = FlatVariable.createShapeVector(bufferBuilder, shp);
            }
            flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(sDVariable.dataType()), shape, array, -1, varType);
            flatVariables.add(flatVariable);
        }
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            DifferentialFunction func = sameDiffOp.getOp();
            Iterator<SameDiffOp> fnId = (Integer)idxForOps.get(func);
            flatNodes.add(this.asFlatNode(func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, (Integer)((Object)fnId)));
        }
        for (Map.Entry entry : this.sameDiffFunctionInstances.entrySet()) {
            if (((String)entry.getKey()).equalsIgnoreCase("grad")) continue;
            flatNodes.add(this.asFlatNode((String)entry.getKey(), (SameDiff)entry.getValue(), bufferBuilder));
            ArrayList<SDVariable> currVarList = new ArrayList<SDVariable>(((SameDiff)entry.getValue()).variables());
            for (SDVariable node : ((SameDiff)entry.getValue()).variables()) {
                INDArray arr = node.getArr();
                if (arr == null) continue;
                int name = bufferBuilder.createString((CharSequence)node.getVarName());
                int array = arr.toFlatArray(bufferBuilder);
                int id = IntPair.createIntPair(bufferBuilder, ++idx, 0);
                Pair<String, Integer> pair = SameDiff.parseVariable(node.getVarName());
                reverseMap.put((String)pair.getFirst(), idx);
                log.trace("Adding [{}] as [{}]", pair.getFirst(), (Object)idx);
                varType = (byte)node.getVariableType().ordinal();
                flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType);
                flatVariables.add(flatVariable);
            }
            for (SameDiffOp op : ((SameDiff)entry.getValue()).ops.values()) {
                DifferentialFunction func = op.getOp();
                flatNodes.add(this.asFlatNode(func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null));
            }
        }
        int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets));
        int n = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables));
        int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes));
        int numPlaceholders = 0;
        for (SDVariable v : this.variables()) {
            if (!v.isPlaceHolder()) continue;
            ++numPlaceholders;
        }
        int[] placeholderOffsets = new int[numPlaceholders];
        if (numPlaceholders > 0) {
            int i = 0;
            for (SDVariable v : this.variables()) {
                if (!v.isPlaceHolder()) continue;
                placeholderOffsets[i++] = bufferBuilder.createString((CharSequence)v.getVarName());
            }
        }
        int placeholdersOffset = FlatGraph.createPlaceholdersVector(bufferBuilder, placeholderOffsets);
        List<String> lossVars = this.getLossVariables();
        int[] lossVarOffsets = new int[lossVars == null ? 0 : lossVars.size()];
        for (int i = 0; i < lossVarOffsets.length; ++i) {
            lossVarOffsets[i] = bufferBuilder.createString((CharSequence)lossVars.get(i));
        }
        int lossVarOffset = FlatGraph.createLossVariablesVector(bufferBuilder, lossVarOffsets);
        int fg = FlatGraph.createFlatGraph(bufferBuilder, graphId, n, nodesOffset, outputsOffset, configuration.getFlatConfiguration(bufferBuilder), placeholdersOffset, lossVarOffset);
        bufferBuilder.finish(fg);
        SameDiff sameDiff = this;
        synchronized (sameDiff) {
            for (Map.Entry<String, Integer> e : reverseMap.entrySet()) {
                this.variables.get(e.getKey()).setVariableIndex(e.getValue());
            }
        }
        return bufferBuilder.dataBuffer();
    }

    public FlatGraph asFlatGraph() {
        return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers());
    }

    public FlatGraph asFlatGraph(long graphId, ExecutorConfiguration configuration) {
        return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(graphId, configuration));
    }

    public ByteBuffer asFlatBuffers() {
        ExecutorConfiguration configuration = ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).gatherTimings(true).build();
        return this.asFlatBuffers(configuration);
    }

    public void saveWithTrainingConfig(OutputStream outputStream) throws IOException {
        if (this.trainingConfig == null) {
            throw new IllegalStateException("No training configuration found!");
        }
        this.saveWithTrainingConfig(this.trainingConfig, outputStream);
    }

    public void saveWithTrainingConfig(File outputFile) throws IOException {
        if (this.trainingConfig == null) {
            throw new IllegalStateException("No training configuration found!");
        }
        try (BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(outputFile));){
            this.saveWithTrainingConfig(this.trainingConfig, bufferedOutputStream);
            bufferedOutputStream.flush();
        }
    }

    public void saveWithTrainingConfig(TrainingConfig trainingConfig, OutputStream outputStream) throws IOException {
        ObjectMapper objectMapper = ObjectMapperHolder.getJsonMapper();
        String configJson = objectMapper.writeValueAsString((Object)trainingConfig);
        ZipOutputStream zipfile = new ZipOutputStream((OutputStream)new CloseShieldOutputStream(outputStream));
        ZipEntry config = new ZipEntry(TRAINING_CONFIG_JSON_ZIP_ENTRY_NAME);
        zipfile.putNextEntry(config);
        zipfile.write(configJson.getBytes());
        ZipEntry sameDiff = new ZipEntry(SAMEDIFF_FILE_ENTRY_NAME);
        zipfile.putNextEntry(sameDiff);
        ByteBuffer fb = this.asFlatBuffers();
        int offset = fb.position();
        byte[] array = fb.array();
        try (BufferedOutputStream zipFileOutputStream = new BufferedOutputStream(zipfile);
             DataOutputStream dos = new DataOutputStream(zipFileOutputStream);){
            dos.write(array, offset, array.length - offset);
        }
    }

    public static SameDiff restoreFromTrainingConfigZip(File file) throws IOException {
        ZipFile zipFile = new ZipFile(file);
        ZipEntry config = zipFile.getEntry(TRAINING_CONFIG_JSON_ZIP_ENTRY_NAME);
        TrainingConfig trainingConfig = null;
        try (InputStream stream = zipFile.getInputStream(config);){
            byte[] read = IOUtils.toByteArray((InputStream)stream);
            trainingConfig = (TrainingConfig)ObjectMapperHolder.getJsonMapper().readValue(read, TrainingConfig.class);
        }
        SameDiff ret = null;
        ZipEntry sameDiffFile = zipFile.getEntry(SAMEDIFF_FILE_ENTRY_NAME);
        try (InputStream stream = zipFile.getInputStream(sameDiffFile);){
            byte[] read = IOUtils.toByteArray((InputStream)stream);
            ret = SameDiff.fromFlatBuffers(ByteBuffer.wrap(read));
        }
        ret.setTrainingConfig(trainingConfig);
        ret.initializeTraining();
        return ret;
    }

    public void asFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        ByteBuffer fb = this.asFlatBuffers();
        int offset = fb.position();
        byte[] array = fb.array();
        try (FileOutputStream fos = new FileOutputStream(file);
             BufferedOutputStream bos = new BufferedOutputStream(fos);
             DataOutputStream dos = new DataOutputStream(bos);){
            dos.write(array, offset, array.length - offset);
        }
    }

    public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        ByteBuffer fb = this.asFlatBuffers(configuration);
        int offset = fb.position();
        byte[] array = fb.array();
        try (FileOutputStream fos = new FileOutputStream(file);
             BufferedOutputStream bos = new BufferedOutputStream(fos);
             DataOutputStream dos = new DataOutputStream(bos);){
            dos.write(array, offset, array.length - offset);
        }
    }

    public static SameDiff fromFlatFile(@NonNull File file) throws IOException {
        byte[] bytes;
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        try (BufferedInputStream is = new BufferedInputStream(new FileInputStream(file));){
            bytes = IOUtils.toByteArray((InputStream)is);
        }
        ByteBuffer bbIn = ByteBuffer.wrap(bytes);
        return SameDiff.fromFlatBuffers(bbIn);
    }

    public static SameDiff fromFlatBuffers(ByteBuffer bbIn) throws IOException {
        FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn);
        int numOps = fg.nodesLength();
        int numVars = fg.variablesLength();
        ArrayList<FlatNode> ops = new ArrayList<FlatNode>(numOps);
        for (int i = 0; i < numOps; ++i) {
            ops.add(fg.nodes(i));
        }
        ArrayList<FlatVariable> vars = new ArrayList<FlatVariable>(numVars);
        for (int i = 0; i < numVars; ++i) {
            vars.add(fg.variables(i));
        }
        FlatConfiguration conf = fg.configuration();
        SameDiff sd = SameDiff.create();
        int numPlaceholders = fg.placeholdersLength();
        LinkedHashSet<String> ph = new LinkedHashSet<String>();
        for (int i = 0; i < numPlaceholders; ++i) {
            ph.add(fg.placeholders(i));
        }
        HashMap varNodeIds = new HashMap();
        HashMap<Pair, SDVariable> variablesByNodeAndOutNum = new HashMap<Pair, SDVariable>();
        HashMap variablesByName = new HashMap();
        for (FlatVariable v : vars) {
            int shapeLength = v.shapeLength();
            long[] shape = new long[shapeLength];
            for (int i = 0; i < shapeLength; ++i) {
                shape[i] = v.shape(i);
            }
            String n = v.name();
            byte dtypeByte = v.dtype();
            DataType dtype = FlatBuffersMapper.getDataTypeFromByte(dtypeByte);
            VariableType vt = VariableType.values()[v.variabletype()];
            SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null);
            sd.variables.put(n, Variable.builder().name(n).variable(var).build());
            sd.variableNameToShape.put(n, shape);
            FlatArray fa = v.ndarray();
            if (fa != null && vt != VariableType.ARRAY) {
                INDArray arr;
                try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                    arr = Nd4j.createFromFlatArray(fa);
                }
                sd.setArrayForVariable(n, arr);
            }
            IntPair id = v.id();
            variablesByNodeAndOutNum.put(new Pair((Object)id.first(), (Object)id.second()), var);
            if (!variablesByName.containsKey(n)) {
                variablesByName.put(n, new ArrayList());
            }
            List list = (List)variablesByName.get(n);
            list.add(var);
        }
        for (FlatNode fn : ops) {
            int i;
            DifferentialFunction df = FlatBuffersMapper.fromFlatNode(fn);
            String name = fn.name();
            df.setSameDiff(sd);
            df.setOwnName(name);
            if (sd.ops.containsKey(name)) {
                sd.ops.get(name).setOp(df);
            } else {
                sd.ops.put(name, SameDiffOp.builder().name(name).op(df).build());
            }
            int outLength = fn.outputLength();
            int[] outs = new int[outLength];
            for (int i2 = 0; i2 < outLength; ++i2) {
                outs[i2] = fn.output(i2);
            }
            int opId = fn.id();
            int[] output = new int[fn.outputLength()];
            for (int i3 = 0; i3 < output.length; ++i3) {
                output[i3] = fn.output(i3);
            }
            int[] input = new int[fn.inputLength()];
            for (int i4 = 0; i4 < input.length; ++i4) {
                input[i4] = fn.input(i4);
            }
            IntPair[] inputPaired = new IntPair[fn.inputPairedLength()];
            ArrayList<Pair> intPairList = new ArrayList<Pair>();
            for (int i5 = 0; i5 < inputPaired.length; ++i5) {
                inputPaired[i5] = fn.inputPaired(i5);
                intPairList.add(new Pair((Object)inputPaired[i5].first(), (Object)inputPaired[i5].second()));
            }
            String[] inputNames = new String[inputPaired.length];
            for (int i6 = 0; i6 < inputPaired.length; ++i6) {
                int nodeId = inputPaired[i6].first();
                int nodeOutNum = inputPaired[i6].second();
                SDVariable varIn = (SDVariable)variablesByNodeAndOutNum.get(new Pair((Object)nodeId, (Object)nodeOutNum));
                if (varIn == null) {
                    // empty if block
                }
                inputNames[i6] = varIn.getVarName();
            }
            sd.ops.get(df.getOwnName()).setInputsToOp(Arrays.asList(inputNames));
            for (String inName : inputNames) {
                Variable v = sd.getVariables().get(inName);
                if (v.getInputsForOp() == null) {
                    v.setInputsForOp(new ArrayList<String>());
                }
                if (v.getInputsForOp().contains(df.getOwnName())) continue;
                v.getInputsForOp().add(df.getOwnName());
            }
            List varsForOp = (List)variablesByName.get(name);
            int numOutputs = df.getNumOutputs();
            if (numOutputs <= 0) {
                numOutputs = fn.outputLength();
            }
            String[] varNames = null;
            if (varsForOp != null && varsForOp.size() == numOutputs) {
                varNames = new String[varsForOp.size()];
                for (i = 0; i < varNames.length; ++i) {
                    varNames[i] = ((SDVariable)varsForOp.get(i)).getVarName();
                    sd.getVariables().get(varNames[i]).setOutputOfOp(df.getOwnName());
                }
                sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
            } else {
                int outputNamesLength = fn.outputNamesLength();
                varNames = new String[outputNamesLength];
                for (int i7 = 0; i7 < outputNamesLength; ++i7) {
                    String n;
                    varNames[i7] = n = fn.outputNames(i7);
                    if (!sd.variables.containsKey(n)) {
                        SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null, null);
                        sd.variables.put(n, Variable.builder().name(n).variable(var).build());
                        variablesByNodeAndOutNum.put(new Pair((Object)opId, (Object)i7), var);
                    }
                    sd.getVariables().get(varNames[i7]).setOutputOfOp(df.getOwnName());
                }
                sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
            }
            for (i = 0; i < varNames.length; ++i) {
                Pair p = new Pair((Object)opId, (Object)i);
                if (variablesByNodeAndOutNum.containsKey(p)) continue;
                variablesByNodeAndOutNum.put(p, sd.getVariable(varNames[i]));
            }
        }
        if (fg.lossVariablesLength() > 0) {
            for (int i = 0; i < fg.lossVariablesLength(); ++i) {
                sd.addLossVariable(fg.lossVariables(i));
            }
        }
        return sd;
    }

    public String asFlatPrint() {
        StringBuilder sb = new StringBuilder();
        ByteBuffer fb = this.asFlatBuffers();
        FlatGraph graph = FlatGraph.getRootAsFlatGraph(fb);
        sb.append("\nExternal variables:\n\n");
        for (int e = 0; e < graph.variablesLength(); ++e) {
            FlatVariable var = graph.variables(e);
            INDArray ndarray = null;
            try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                FlatArray fa = var.ndarray();
                if (fa != null) {
                    ndarray = Nd4j.createFromFlatArray(fa);
                }
            }
            sb.append(var.id().first()).append(":<").append(var.name()).append("> ");
            if (ndarray == null) {
                sb.append("<no array>").append("; Values: ").append("<no array>").append(";\n");
                continue;
            }
            sb.append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: ");
            if (ndarray.data() == null) {
                sb.append("<empty array>");
            } else if (ndarray.dataType() == DataType.UTF8) {
                sb.append("<string array>");
            } else if (ndarray.length() < 50L) {
                sb.append(Arrays.toString(ndarray.data().asFloat()).replaceAll(" ", ""));
            } else {
                sb.append("[");
                for (int i = 0; i < 50; ++i) {
                    if (i > 0) {
                        sb.append(",");
                    }
                    sb.append(ndarray.data().getFloat((long)i));
                }
                sb.append("]");
            }
            sb.append(";\n");
        }
        Map<String, CustomOpDescriptor> map = Nd4j.getExecutioner().getCustomOperations();
        sb.append("\nOps sequence:\n\n");
        for (int e = 0; e < graph.nodesLength(); ++e) {
            FlatNode node = graph.nodes(e);
            log.info("{}:<{}>", (Object)node.id(), (Object)node.name());
            sb.append(node.id()).append(":<").append(node.name()).append("> ").append((Object)FlatBuffersMapper.getTypeFromByte(node.opType()));
            if (FlatBuffersMapper.getTypeFromByte(node.opType()) != Op.Type.CUSTOM) {
                sb.append(": ").append(node.opNum());
            } else {
                Set<String> keys = map.keySet();
                String opName = null;
                for (String k : keys) {
                    CustomOpDescriptor d = map.get(k);
                    if (d.getHash() != node.opNum()) continue;
                    opName = k;
                }
                if (opName == null) {
                    opName = "unknown";
                }
                sb.append(": ").append(opName);
            }
            sb.append("; Inputs: {");
            for (int i = 0; i < node.inputPairedLength(); ++i) {
                IntPair pair = node.inputPaired(i);
                sb.append("[").append(pair.first()).append(":").append(pair.second()).append("]");
                if (i >= node.inputPairedLength() - 1) continue;
                sb.append(", ");
            }
            sb.append("};");
            sb.append(" OpNum: {").append(node.opNum()).append("};");
            sb.append("\n");
        }
        return sb.toString();
    }

    public String summary() {
        Map<String, SDVariable> varMap = this.variableMap();
        DifferentialFunction[] functions = this.functions();
        int countVarsWithArrays = 0;
        for (String s : varMap.keySet()) {
            if (this.getArrForVarName(s) == null) continue;
            ++countVarsWithArrays;
        }
        StringBuilder sb = new StringBuilder();
        String format = "%-25s%-20s";
        sb.append("--- Summary ---\n");
        sb.append(String.format(format, "Variables:", varMap.size())).append(" (").append(countVarsWithArrays).append(" with arrays)").append("\n").append(String.format(format, "Functions:", functions.length)).append("\n").append(String.format(format, "SameDiff Function Defs:", this.sameDiffFunctionInstances.size())).append("\n").append("Loss function variables: ").append(this.getLossVariables()).append("\n\n");
        sb.append("--- Variables ---\n");
        HashMap<String, String> outputOfFn = new HashMap<String, String>();
        int maxLengthOutputOf = 22;
        int maxLengthOfName = 8;
        for (String s : varMap.keySet()) {
            String outputOf = null;
            for (SameDiffOp op : this.ops.values()) {
                List<String> outputsOfOp = op.getOutputsOfOp();
                if (outputsOfOp == null || !outputsOfOp.contains(s)) continue;
                outputOf = op.getName();
                break;
            }
            if (outputOf == null) {
                outputOf = "<none>";
            } else {
                DifferentialFunction d = this.getFunctionById(outputOf);
                outputOf = d.getOwnName() + "(" + d.opName() + ")";
            }
            outputOfFn.put(s, outputOf);
            maxLengthOutputOf = Math.max(maxLengthOutputOf, outputOf.length());
            maxLengthOfName = Math.max(maxLengthOfName, s.length());
        }
        format = "%-" + (maxLengthOfName += 2) + "s%-20s%-20s%-20s%-" + (maxLengthOutputOf += 2) + "s%-20s";
        sb.append(String.format(format, "- Name -", "- Array Shape -", "- Variable Type -", "- Data Type-", "- Output Of Function -", "- Inputs To Functions -")).append("\n");
        for (String s : varMap.keySet()) {
            INDArray arr = this.getArrForVarName(s);
            String arrayShape = "-";
            if (arr != null) {
                arrayShape = Arrays.toString(arr.shape());
            }
            String varType = this.getVariable(s).getVariableType().toString();
            String dtype = this.getVariable(s).dataType().toString();
            List<String> argNames = this.variables.get(s).getInputsForOp();
            String dfArrStr = "";
            if (argNames != null) {
                dfArrStr = argNames.toString();
            }
            String outputOfStr = (String)outputOfFn.get(s);
            sb.append(String.format(format, s, arrayShape, varType, dtype, outputOfStr, dfArrStr)).append("\n");
        }
        sb.append("\n\n--- Functions ---\n");
        ArrayList<String> dfInputStr = new ArrayList<String>();
        ArrayList<String> dfOutputStr = new ArrayList<String>();
        int maxInLength = 10;
        int maxOutLength = 11;
        int maxOpNameLength = 17;
        int maxDfClassNameLength = 10;
        for (DifferentialFunction df : functions) {
            Object[] argNames = df.argNames();
            Object[] outNames = df.outputVariablesNames();
            String argStr = Arrays.toString(argNames);
            String outStr = Arrays.toString(outNames);
            maxInLength = Math.max(maxInLength, argStr.length());
            maxOutLength = Math.max(maxOutLength, outStr.length());
            dfInputStr.add(argStr);
            dfOutputStr.add(outStr);
            String name = df.getOwnName() == null ? df.opName() : df.getOwnName();
            maxOpNameLength = Math.max(maxOpNameLength, name.length());
            maxDfClassNameLength = Math.max(maxDfClassNameLength, df.getClass().getSimpleName().length());
        }
        format = "%-5s%-" + (maxOpNameLength += 2) + "s%-" + (maxDfClassNameLength += 2) + "s%-" + (maxInLength += 2) + "s%-" + (maxOutLength += 2) + "s";
        sb.append(String.format(format, "", "- Function Name -", "- Op -", "- Inputs -", "- Outputs -")).append("\n");
        for (int i = 0; i < functions.length; ++i) {
            DifferentialFunction df = functions[i];
            String fnName = df.getOwnName() == null ? df.opName() : df.getOwnName();
            sb.append(String.format(format, String.valueOf(i), fnName, df.getClass().getSimpleName(), dfInputStr.get(i), dfOutputStr.get(i))).append("\n");
        }
        if (this.sameDiffFunctionInstances.size() > 0) {
            sb.append("\n\n--- SameDiff Defined Functions ---\n");
            format = "%-20s%-15s%-15s%-15s";
            sb.append(String.format(format, "- Name -", "- Variables -", "- Functions -", "- Fn Defs -")).append("\n");
            for (Map.Entry<String, SameDiff> e : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff sd = e.getValue();
                int vars = sd.variableMap().size();
                int fns = sd.functions() == null ? 0 : sd.functions().length;
                int defFns = sd.definedFunctionNames().size();
                sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n");
            }
        }
        return sb.toString();
    }

    public Map<String, DataType> calculateOutputDataTypes() {
        ArrayList<String> allVars = new ArrayList<String>(this.variables.keySet());
        DataTypesSession session = new DataTypesSession(this);
        HashMap<String, DataType> phValues = new HashMap<String, DataType>();
        for (Variable v : this.variables.values()) {
            if (!v.getVariable().isPlaceHolder()) continue;
            DataType dt = v.getVariable().dataType();
            Preconditions.checkNotNull((Object)dt, (String)"Placeholder variable %s has null datatype", (Object)v.getName());
            phValues.put(v.getName(), dt);
        }
        Map<String, DataType> out = session.output(allVars, phValues);
        return out;
    }

    public static SameDiffBuilder builder() {
        return new SameDiffBuilder();
    }

    public SameDiff(TrainingConfig trainingConfig, boolean initializedTraining, INDArray updaterState, Map<String, INDArray> updaterViews, Map<String, GradientUpdater> updaterMap, Map<String, String> baseNameForFunctionInstanceId, DifferentialFunctionFactory functionFactory, Map<String, long[]> variableNameToShape, Map<String, SDVariable> forwardVarForGrad, int variableId, Map<String, List<String>> propertiesToResolve, Map<String, Map<String, Object>> propertiesForFunction, Map<String, long[]> placeHolderOriginalShapes, Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap, Map<String, SameDiff> sameDiffFunctionInstances, Set<String> placeHolderFunctions, Table<String, String, String> fieldVariableResolutionMapping, AtomicBoolean wasRegistered, boolean debugMode, Map<int[], Op> opsForResult, boolean resolvedVariables, boolean logExecution, SameDiff parent, SameDiff child) {
        this.trainingConfig = trainingConfig;
        this.initializedTraining = initializedTraining;
        this.updaterState = updaterState;
        this.updaterViews = updaterViews;
        this.updaterMap = updaterMap;
        this.baseNameForFunctionInstanceId = baseNameForFunctionInstanceId;
        this.functionFactory = functionFactory;
        this.variableNameToShape = variableNameToShape;
        this.forwardVarForGrad = forwardVarForGrad;
        this.variableId = variableId;
        this.propertiesToResolve = propertiesToResolve;
        this.propertiesForFunction = propertiesForFunction;
        this.placeHolderOriginalShapes = placeHolderOriginalShapes;
        this.sameDiffFunctionDefinitionMap = sameDiffFunctionDefinitionMap;
        this.sameDiffFunctionInstances = sameDiffFunctionInstances;
        this.placeHolderFunctions = placeHolderFunctions;
        this.fieldVariableResolutionMapping = fieldVariableResolutionMapping;
        this.wasRegistered = wasRegistered;
        this.debugMode = debugMode;
        this.opsForResult = opsForResult;
        this.resolvedVariables = resolvedVariables;
        this.logExecution = logExecution;
        this.parent = parent;
        this.child = child;
    }

    public Map<String, Variable> getVariables() {
        return this.variables;
    }

    public Map<String, SameDiffOp> getOps() {
        return this.ops;
    }

    public Map<Long, InferenceSession> getSessions() {
        return this.sessions;
    }

    public TrainingConfig getTrainingConfig() {
        return this.trainingConfig;
    }

    public boolean isInitializedTraining() {
        return this.initializedTraining;
    }

    public INDArray getUpdaterState() {
        return this.updaterState;
    }

    public Map<String, INDArray> getUpdaterViews() {
        return this.updaterViews;
    }

    public Map<String, GradientUpdater> getUpdaterMap() {
        return this.updaterMap;
    }

    public boolean isDebugMode() {
        return this.debugMode;
    }

    public boolean isLogExecution() {
        return this.logExecution;
    }

    public void setLogExecution(boolean logExecution) {
        this.logExecution = logExecution;
    }

    public SameDiff getParent() {
        return this.parent;
    }

    public SameDiff getChild() {
        return this.child;
    }

    static {
        Method[] methods;
        log = LoggerFactory.getLogger(SameDiff.class);
        cloner = SameDiff.newCloner();
        opMethods = new HashMap<String, Method>();
        for (Method method : methods = SameDiff.class.getDeclaredMethods()) {
            if (!method.getReturnType().equals(SDVariable.class)) continue;
            opMethods.put(method.getName(), method);
        }
    }

    public static class SameDiffBuilder {
        private TrainingConfig trainingConfig;
        private boolean initializedTraining;
        private INDArray updaterState;
        private Map<String, INDArray> updaterViews;
        private Map<String, GradientUpdater> updaterMap;
        private Map<String, String> baseNameForFunctionInstanceId;
        private DifferentialFunctionFactory functionFactory;
        private Map<String, long[]> variableNameToShape;
        private Map<String, SDVariable> forwardVarForGrad;
        private int variableId;
        private Map<String, List<String>> propertiesToResolve;
        private Map<String, Map<String, Object>> propertiesForFunction;
        private Map<String, long[]> placeHolderOriginalShapes;
        private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
        private Map<String, SameDiff> sameDiffFunctionInstances;
        private Set<String> placeHolderFunctions;
        private Table<String, String, String> fieldVariableResolutionMapping;
        private AtomicBoolean wasRegistered;
        private boolean debugMode;
        private Map<int[], Op> opsForResult;
        private boolean resolvedVariables;
        private boolean logExecution;
        private SameDiff parent;
        private SameDiff child;

        SameDiffBuilder() {
        }

        public SameDiffBuilder trainingConfig(TrainingConfig trainingConfig) {
            this.trainingConfig = trainingConfig;
            return this;
        }

        public SameDiffBuilder initializedTraining(boolean initializedTraining) {
            this.initializedTraining = initializedTraining;
            return this;
        }

        public SameDiffBuilder updaterState(INDArray updaterState) {
            this.updaterState = updaterState;
            return this;
        }

        public SameDiffBuilder updaterViews(Map<String, INDArray> updaterViews) {
            this.updaterViews = updaterViews;
            return this;
        }

        public SameDiffBuilder updaterMap(Map<String, GradientUpdater> updaterMap) {
            this.updaterMap = updaterMap;
            return this;
        }

        public SameDiffBuilder baseNameForFunctionInstanceId(Map<String, String> baseNameForFunctionInstanceId) {
            this.baseNameForFunctionInstanceId = baseNameForFunctionInstanceId;
            return this;
        }

        public SameDiffBuilder functionFactory(DifferentialFunctionFactory functionFactory) {
            this.functionFactory = functionFactory;
            return this;
        }

        @Deprecated
        public SameDiffBuilder variableNameToShape(Map<String, long[]> variableNameToShape) {
            this.variableNameToShape = variableNameToShape;
            return this;
        }

        @Deprecated
        public SameDiffBuilder forwardVarForGrad(Map<String, SDVariable> forwardVarForGrad) {
            this.forwardVarForGrad = forwardVarForGrad;
            return this;
        }

        public SameDiffBuilder variableId(int variableId) {
            this.variableId = variableId;
            return this;
        }

        public SameDiffBuilder propertiesToResolve(Map<String, List<String>> propertiesToResolve) {
            this.propertiesToResolve = propertiesToResolve;
            return this;
        }

        public SameDiffBuilder propertiesForFunction(Map<String, Map<String, Object>> propertiesForFunction) {
            this.propertiesForFunction = propertiesForFunction;
            return this;
        }

        @Deprecated
        public SameDiffBuilder placeHolderOriginalShapes(Map<String, long[]> placeHolderOriginalShapes) {
            this.placeHolderOriginalShapes = placeHolderOriginalShapes;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionDefinitionMap(Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap) {
            this.sameDiffFunctionDefinitionMap = sameDiffFunctionDefinitionMap;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionInstances(Map<String, SameDiff> sameDiffFunctionInstances) {
            this.sameDiffFunctionInstances = sameDiffFunctionInstances;
            return this;
        }

        public SameDiffBuilder placeHolderFunctions(Set<String> placeHolderFunctions) {
            this.placeHolderFunctions = placeHolderFunctions;
            return this;
        }

        public SameDiffBuilder fieldVariableResolutionMapping(Table<String, String, String> fieldVariableResolutionMapping) {
            this.fieldVariableResolutionMapping = fieldVariableResolutionMapping;
            return this;
        }

        public SameDiffBuilder wasRegistered(AtomicBoolean wasRegistered) {
            this.wasRegistered = wasRegistered;
            return this;
        }

        public SameDiffBuilder debugMode(boolean debugMode) {
            this.debugMode = debugMode;
            return this;
        }

        public SameDiffBuilder opsForResult(Map<int[], Op> opsForResult) {
            this.opsForResult = opsForResult;
            return this;
        }

        public SameDiffBuilder resolvedVariables(boolean resolvedVariables) {
            this.resolvedVariables = resolvedVariables;
            return this;
        }

        public SameDiffBuilder logExecution(boolean logExecution) {
            this.logExecution = logExecution;
            return this;
        }

        public SameDiffBuilder parent(SameDiff parent) {
            this.parent = parent;
            return this;
        }

        public SameDiffBuilder child(SameDiff child) {
            this.child = child;
            return this;
        }

        public SameDiff build() {
            return new SameDiff(this.trainingConfig, this.initializedTraining, this.updaterState, this.updaterViews, this.updaterMap, this.baseNameForFunctionInstanceId, this.functionFactory, this.variableNameToShape, this.forwardVarForGrad, this.variableId, this.propertiesToResolve, this.propertiesForFunction, this.placeHolderOriginalShapes, this.sameDiffFunctionDefinitionMap, this.sameDiffFunctionInstances, this.placeHolderFunctions, this.fieldVariableResolutionMapping, this.wasRegistered, this.debugMode, this.opsForResult, this.resolvedVariables, this.logExecution, this.parent, this.child);
        }

        public String toString() {
            return "SameDiff.SameDiffBuilder(trainingConfig=" + this.trainingConfig + ", initializedTraining=" + this.initializedTraining + ", updaterState=" + this.updaterState + ", updaterViews=" + this.updaterViews + ", updaterMap=" + this.updaterMap + ", baseNameForFunctionInstanceId=" + this.baseNameForFunctionInstanceId + ", functionFactory=" + this.functionFactory + ", variableNameToShape=" + this.variableNameToShape + ", forwardVarForGrad=" + this.forwardVarForGrad + ", variableId=" + this.variableId + ", propertiesToResolve=" + this.propertiesToResolve + ", propertiesForFunction=" + this.propertiesForFunction + ", placeHolderOriginalShapes=" + this.placeHolderOriginalShapes + ", sameDiffFunctionDefinitionMap=" + this.sameDiffFunctionDefinitionMap + ", sameDiffFunctionInstances=" + this.sameDiffFunctionInstances + ", placeHolderFunctions=" + this.placeHolderFunctions + ", fieldVariableResolutionMapping=" + this.fieldVariableResolutionMapping + ", wasRegistered=" + this.wasRegistered + ", debugMode=" + this.debugMode + ", opsForResult=" + this.opsForResult + ", resolvedVariables=" + this.resolvedVariables + ", logExecution=" + this.logExecution + ", parent=" + this.parent + ", child=" + this.child + ")";
        }
    }
}

