/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.performance;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.performance.primitives.AveragingTransactionsHolder;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PerformanceTracker {
    private static final Logger log = LoggerFactory.getLogger(PerformanceTracker.class);
    private static final PerformanceTracker INSTANCE = new PerformanceTracker();
    private Map<Integer, AveragingTransactionsHolder> bandwidth = new HashMap<Integer, AveragingTransactionsHolder>();
    private Map<Integer, AveragingTransactionsHolder> operations = new HashMap<Integer, AveragingTransactionsHolder>();

    private PerformanceTracker() {
        int nd = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int e = 0; e < nd; ++e) {
            this.bandwidth.put(e, new AveragingTransactionsHolder());
            this.operations.put(e, new AveragingTransactionsHolder());
        }
    }

    public static PerformanceTracker getInstance() {
        return INSTANCE;
    }

    public long addMemoryTransaction(int deviceId, long timeSpentNanos, long numberOfBytes) {
        return this.addMemoryTransaction(deviceId, timeSpentNanos, numberOfBytes, MemcpyDirection.HOST_TO_HOST);
    }

    public long addMemoryTransaction(int deviceId, long timeSpentNanos, long numberOfBytes, @NonNull MemcpyDirection direction) {
        if (direction == null) {
            throw new NullPointerException("direction is marked @NonNull but is null");
        }
        long bw = (long)((double)numberOfBytes / ((double)timeSpentNanos / 1000.0));
        if (bw > 0L) {
            this.bandwidth.get(deviceId).addValue(direction, bw);
        }
        return bw;
    }

    public void clear() {
        for (Integer k : this.bandwidth.keySet()) {
            this.bandwidth.get(k).clear();
        }
    }

    public long helperStartTransaction() {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.BANDWIDTH) {
            return System.nanoTime();
        }
        return 0L;
    }

    public void helperRegisterTransaction(int deviceId, long timeSpentNanos, long numberOfBytes, @NonNull MemcpyDirection direction) {
        if (direction == null) {
            throw new NullPointerException("direction is marked @NonNull but is null");
        }
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.BANDWIDTH) {
            this.addMemoryTransaction(deviceId, System.nanoTime() - timeSpentNanos, numberOfBytes, direction);
        }
    }

    public Map<Integer, Map<MemcpyDirection, Long>> getCurrentBandwidth() {
        HashMap<Integer, Map<MemcpyDirection, Long>> result = new HashMap<Integer, Map<MemcpyDirection, Long>>();
        Set<Integer> keys = this.bandwidth.keySet();
        for (Integer d : keys) {
            result.put(d, new HashMap());
            for (MemcpyDirection m : MemcpyDirection.values()) {
                result.get(d).put(m, this.bandwidth.get(d).getAverageValue(m));
            }
        }
        return result;
    }
}

