/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.ops;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.ops.SDOps;
import org.nd4j.autodiff.samediff.ops.SDValidation;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.factory.Nd4j;

public class SDNN
extends SDOps {
    public SDNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, double epsilon, int ... axis) {
        return this.batchNorm(null, input, mean, variance, gamma, beta, true, true, epsilon, axis);
    }

    public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon, int ... axis) {
        SDValidation.validateFloatingPoint("batchNorm", "input", input);
        SDValidation.validateFloatingPoint("batchNorm", "mean", mean);
        SDValidation.validateFloatingPoint("batchNorm", "variance", variance);
        SDValidation.validateFloatingPoint("batchNorm", "gamma", gamma);
        SDValidation.validateFloatingPoint("batchNorm", "beta", beta);
        SDVariable res = this.f().batchNorm(input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon, axis);
        return this.updateVariableNameAndReference(res, name);
    }

    public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, double epsilon, int ... axis) {
        return this.batchNorm(name, input, mean, variance, gamma, beta, true, true, epsilon, axis);
    }

    public SDVariable biasAdd(SDVariable input, SDVariable bias) {
        return this.biasAdd(null, input, bias);
    }

    public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) {
        SDValidation.validateFloatingPoint("biasAdd", "input", input);
        SDValidation.validateFloatingPoint("biasAdd", "bias", bias);
        SDVariable ret = this.f().biasAdd(input, bias);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable dropout(SDVariable input, double inputRetainProbability) {
        return this.dropout(null, input, inputRetainProbability);
    }

    public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) {
        SDValidation.validateFloatingPoint("dropout", input);
        SDVariable res = this.f().dropout(input, inputRetainProbability);
        return this.updateVariableNameAndReference(res, name);
    }

    public SDVariable elu(SDVariable x) {
        return this.elu(null, x);
    }

    public SDVariable elu(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("elu", x);
        SDVariable result = this.f().elu(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable eluDerivative(SDVariable x) {
        return this.eluDerivative(null, x);
    }

    public SDVariable eluDerivative(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("eluDerivative", x);
        SDVariable result = this.f().eluDerivative(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable gelu(SDVariable x) {
        return this.gelu(null, x);
    }

    public SDVariable gelu(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("gelu", x);
        SDVariable ret = this.f().gelu(x, false);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable hardSigmoid(SDVariable in) {
        return this.hardSigmoid(null, in);
    }

    public SDVariable hardSigmoid(String name, SDVariable in) {
        SDValidation.validateFloatingPoint("hard sigmoid", in);
        SDVariable ret = this.f().hardSigmoid(in);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable hardTanh(SDVariable in) {
        return this.hardTanh(null, in);
    }

    public SDVariable hardTanh(String name, SDVariable in) {
        SDValidation.validateFloatingPoint("hard Tanh", in);
        SDVariable result = this.f().hardTanh(in);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable hardTanhDerivative(SDVariable x) {
        return this.hardTanhDerivative(null, x);
    }

    public SDVariable hardTanhDerivative(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("hard Tanh derivative", x);
        SDVariable result = this.f().hardTanhDerivative(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable leakyRelu(SDVariable x, double alpha) {
        return this.leakyRelu(null, x, alpha);
    }

    public SDVariable leakyRelu(String name, SDVariable x, double alpha) {
        SDValidation.validateFloatingPoint("leaky ReLU", x);
        SDVariable result = this.f().leakyRelu(x, alpha);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) {
        SDValidation.validateFloatingPoint("leaky ReLU derivative", x);
        SDVariable result = this.f().leakyReluDerivative(x, alpha);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) {
        return this.linear(null, input, weights, bias);
    }

    public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) {
        SDValidation.validateFloatingPoint("linear", "input", input);
        SDValidation.validateFloatingPoint("linear", "weights", weights);
        SDValidation.validateFloatingPoint("linear", "bias", bias);
        SDVariable res = this.f().xwPlusB(input, weights, bias);
        return this.updateVariableNameAndReference(res, name);
    }

    public SDVariable logSigmoid(SDVariable x) {
        return this.logSigmoid(null, x);
    }

    public SDVariable logSigmoid(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("log sigmoid", x);
        SDVariable ret = this.f().logSigmoid(x);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable logSoftmax(SDVariable x) {
        return this.logSoftmax(null, x);
    }

    public SDVariable logSoftmax(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("log softmax", x);
        SDVariable ret = this.f().logSoftmax(x);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable relu(SDVariable x, double cutoff) {
        return this.relu(null, x, cutoff);
    }

    public SDVariable relu(String name, SDVariable x, double cutoff) {
        SDValidation.validateFloatingPoint("ReLU", x);
        SDVariable result = this.f().relu(x, cutoff);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable relu6(SDVariable x, double cutoff) {
        return this.relu6(null, x, cutoff);
    }

    public SDVariable relu6(String name, SDVariable x, double cutoff) {
        SDValidation.validateFloatingPoint("ReLU6", x);
        SDVariable result = this.f().relu6(x, cutoff);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) {
        return this.reluLayer(null, input, weights, bias);
    }

    public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) {
        SDValidation.validateFloatingPoint("reluLayer", "input", input);
        SDValidation.validateFloatingPoint("reluLayer", "weights", weights);
        SDValidation.validateFloatingPoint("reluLayer", "bias", bias);
        SDVariable res = this.f().reluLayer(input, weights, bias);
        return this.updateVariableNameAndReference(res, name);
    }

    public SDVariable selu(SDVariable x) {
        return this.selu(null, x);
    }

    public SDVariable selu(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("selu", x);
        SDVariable ret = this.f().selu(x);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable sigmoid(SDVariable x) {
        return this.sigmoid(null, x);
    }

    public SDVariable sigmoid(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("sigmoid", x);
        SDVariable result = this.f().sigmoid(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) {
        return this.sigmoidDerivative(null, x, wrt);
    }

    public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) {
        SDValidation.validateFloatingPoint("sigmoidDerivative", x);
        SDVariable result = this.f().sigmoidDerivative(x, wrt);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable softmax(SDVariable x) {
        return this.softmax(null, x);
    }

    public SDVariable softmax(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("softmax", x);
        SDVariable result = this.f().softmax(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt) {
        return this.softmaxDerivative(name, x, wrt, null);
    }

    public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, Integer dimension) {
        SDValidation.validateFloatingPoint("softmaxDerivative", x);
        SDVariable result = this.f().softmaxDerivative(x, wrt, dimension);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable softplus(SDVariable x) {
        return this.softplus(null, x);
    }

    public SDVariable softplus(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("softplus", x);
        SDVariable result = this.f().softplus(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable softsign(SDVariable x) {
        return this.softsign(null, x);
    }

    public SDVariable softsign(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("softsign", x);
        SDVariable result = this.f().softsign(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable softsignDerivative(SDVariable x) {
        return this.softsignDerivative(null, x);
    }

    public SDVariable softsignDerivative(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("softsignDerivative", x);
        SDVariable result = this.f().softsignDerivative(x);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable swish(SDVariable x) {
        return this.swish(null, x);
    }

    public SDVariable swish(String name, SDVariable x) {
        SDValidation.validateFloatingPoint("swish", x);
        SDVariable ret = this.f().swish(x);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable tanh(String name, SDVariable x) {
        return this.sd.math().tanh(name, x);
    }

    public SDVariable tanh(SDVariable x) {
        return this.sd.math().tanh(x);
    }

    public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int ... dimensions) {
        return this.layerNorm(null, input, gain, bias, dimensions);
    }

    public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, int ... dimensions) {
        SDValidation.validateFloatingPoint("layerNorm", "input", input);
        SDValidation.validateFloatingPoint("layerNorm", "gain", gain);
        SDValidation.validateFloatingPoint("layerNorm", "bias", bias);
        SDVariable result = this.f().layerNorm(input, gain, bias, dimensions);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable layerNorm(SDVariable input, SDVariable gain, int ... dimensions) {
        return this.layerNorm((String)null, input, gain, dimensions);
    }

    public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int ... dimensions) {
        SDValidation.validateFloatingPoint("layerNorm", "input", input);
        SDValidation.validateFloatingPoint("layerNorm", "gain", gain);
        SDVariable result = this.f().layerNorm(input, gain, dimensions);
        return this.updateVariableNameAndReference(result, name);
    }

    public SDVariable pad(SDVariable input, int[][] padding, double constant) {
        return this.pad(input, this.sd.constant(Nd4j.createFromArray(padding)), constant);
    }

    public SDVariable pad(SDVariable input, SDVariable padding, double constant) {
        return this.pad(null, input, padding, Pad.Mode.CONSTANT, constant);
    }

    public SDVariable pad(String outputName, SDVariable input, SDVariable padding, Pad.Mode mode, double constant) {
        SDVariable out = this.f().pad(input, padding, mode, constant);
        return this.updateVariableNameAndReference(out, outputName);
    }

    public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) {
        return this.dotProductAttention(null, queries, keys, values, mask, scaled);
    }

    public SDVariable dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) {
        SDVariable result = this.f().dotProductAttention(queries, keys, values, mask, scaled);
        return this.updateVariableNameAndReference(result, name);
    }

    public List<SDVariable> dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights) {
        return this.dotProductAttention(null, queries, keys, values, mask, scaled, withWeights);
    }

    public List<SDVariable> dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights) {
        List<SDVariable> result = this.f().dotProductAttention(queries, keys, values, mask, scaled, withWeights);
        if (withWeights) {
            return Collections.singletonList(this.updateVariableNameAndReference(result.get(0), name));
        }
        return Arrays.asList(this.updateVariableNameAndReference(result.get(0), name), this.updateVariableNameAndReference(result.get(1), name + ":weights"));
    }

    public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled) {
        return this.multiHeadDotProductAttention(null, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled);
    }

    public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled) {
        SDVariable result = this.f().multiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled);
        return this.updateVariableNameAndReference(result, name);
    }

    public List<SDVariable> multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights) {
        return this.multiHeadDotProductAttention(null, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights);
    }

    public List<SDVariable> multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights) {
        List<SDVariable> result = this.f().multiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights);
        if (withWeights) {
            return Collections.singletonList(this.updateVariableNameAndReference(result.get(0), name));
        }
        return Arrays.asList(this.updateVariableNameAndReference(result.get(0), name), this.updateVariableNameAndReference(result.get(1), name + ":weights"));
    }
}

