package edu.mines.jtk.bench;

import edu.mines.jtk.util.Array;
import edu.mines.jtk.util.Check;
import edu.mines.jtk.util.Stopwatch;
import ij.macro.MacroConstants;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

/* loaded from: input_file:lib/stitching/edu_mines_jtk.jar:edu/mines/jtk/bench/MtMatMulBench.class */
public class MtMatMulBench {
    public static final int NTHREAD = 8;

    public static void main(String[] strArr) {
        double d;
        double d2;
        double d3;
        double d4;
        float[][] randfloat = Array.randfloat(MacroConstants.COS, MacroConstants.ABS);
        float[][] randfloat2 = Array.randfloat(MacroConstants.ABS, MacroConstants.COS);
        float[][] zerofloat = Array.zerofloat(MacroConstants.ABS, MacroConstants.ABS);
        float[][] zerofloat2 = Array.zerofloat(MacroConstants.ABS, MacroConstants.ABS);
        float[][] zerofloat3 = Array.zerofloat(MacroConstants.ABS, MacroConstants.ABS);
        float[][] zerofloat4 = Array.zerofloat(MacroConstants.ABS, MacroConstants.ABS);
        Stopwatch stopwatch = new Stopwatch();
        double d5 = 2.0E-6d * MacroConstants.ABS * MacroConstants.ABS * MacroConstants.COS;
        System.out.println("Methods:");
        System.out.println("mul1 = single-threaded");
        System.out.println("mul2 = multi-threaded (equal chunks)");
        System.out.println("mul3 = multi-threaded (atomic-integer)");
        System.out.println("mul4 = multi-threaded (thread-pool)");
        System.out.println("number of threads = 8");
        for (int i = 0; i < 5; i++) {
            System.out.println();
            stopwatch.restart();
            double d6 = 0.0d;
            while (true) {
                d = d6;
                if (stopwatch.time() >= 5.0d) {
                    break;
                }
                mul1(randfloat, randfloat2, zerofloat);
                d6 = d + 1.0d;
            }
            stopwatch.stop();
            System.out.println("mul1: rate=" + ((int) ((d * d5) / stopwatch.time())) + " mflops");
            stopwatch.restart();
            double d7 = 0.0d;
            while (true) {
                d2 = d7;
                if (stopwatch.time() >= 5.0d) {
                    break;
                }
                mul2(randfloat, randfloat2, zerofloat2);
                d7 = d2 + 1.0d;
            }
            stopwatch.stop();
            System.out.println("mul2: rate=" + ((int) ((d2 * d5) / stopwatch.time())) + " mflops");
            stopwatch.restart();
            double d8 = 0.0d;
            while (true) {
                d3 = d8;
                if (stopwatch.time() >= 5.0d) {
                    break;
                }
                mul3(randfloat, randfloat2, zerofloat3);
                d8 = d3 + 1.0d;
            }
            stopwatch.stop();
            System.out.println("mul3: rate=" + ((int) ((d3 * d5) / stopwatch.time())) + " mflops");
            stopwatch.restart();
            double d9 = 0.0d;
            while (true) {
                d4 = d9;
                if (stopwatch.time() < 5.0d) {
                    mul4(randfloat, randfloat2, zerofloat4);
                    d9 = d4 + 1.0d;
                }
            }
            stopwatch.stop();
            System.out.println("mul4: rate=" + ((int) ((d4 * d5) / stopwatch.time())) + " mflops");
            assertEquals(zerofloat, zerofloat2);
            assertEquals(zerofloat, zerofloat3);
            assertEquals(zerofloat, zerofloat4);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void computeColumn(int i, float[][] fArr, float[][] fArr2, float[][] fArr3) {
        int length = fArr3.length;
        int length2 = fArr2.length;
        float[] fArr4 = new float[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            fArr4[i2] = fArr2[i2][i];
        }
        for (int i3 = 0; i3 < length; i3++) {
            float[] fArr5 = fArr[i3];
            float f = 0.0f;
            int i4 = length2 % 4;
            for (int i5 = 0; i5 < i4; i5++) {
                f += fArr5[i5] * fArr4[i5];
            }
            for (int i6 = i4; i6 < length2; i6 += 4) {
                f = f + (fArr5[i6] * fArr4[i6]) + (fArr5[i6 + 1] * fArr4[i6 + 1]) + (fArr5[i6 + 2] * fArr4[i6 + 2]) + (fArr5[i6 + 3] * fArr4[i6 + 3]);
            }
            fArr3[i3][i] = f;
        }
    }

    private static void mul1(float[][] fArr, float[][] fArr2, float[][] fArr3) {
        checkDimensions(fArr, fArr2, fArr3);
        int length = fArr3[0].length;
        for (int i = 0; i < length; i++) {
            computeColumn(i, fArr, fArr2, fArr3);
        }
    }

    private static void mul2(final float[][] fArr, final float[][] fArr2, final float[][] fArr3) {
        checkDimensions(fArr, fArr2, fArr3);
        int length = fArr3[0].length;
        int i = 1 + (length / 8);
        Thread[] threadArr = new Thread[8];
        for (int i2 = 0; i2 < 8; i2++) {
            final int i3 = i2 * i;
            final int min = Math.min(i3 + i, length);
            threadArr[i2] = new Thread(new Runnable() { // from class: edu.mines.jtk.bench.MtMatMulBench.1
                @Override // java.lang.Runnable
                public void run() {
                    for (int i4 = i3; i4 < min; i4++) {
                        MtMatMulBench.computeColumn(i4, fArr, fArr2, fArr3);
                    }
                }
            });
        }
        startAndJoin(threadArr);
    }

    private static void mul3(final float[][] fArr, final float[][] fArr2, final float[][] fArr3) {
        checkDimensions(fArr, fArr2, fArr3);
        final int length = fArr3[0].length;
        final AtomicInteger atomicInteger = new AtomicInteger();
        Thread[] threadArr = new Thread[8];
        for (int i = 0; i < threadArr.length; i++) {
            threadArr[i] = new Thread(new Runnable() { // from class: edu.mines.jtk.bench.MtMatMulBench.2
                @Override // java.lang.Runnable
                public void run() {
                    int andIncrement = atomicInteger.getAndIncrement();
                    while (true) {
                        int i2 = andIncrement;
                        if (i2 >= length) {
                            return;
                        }
                        MtMatMulBench.computeColumn(i2, fArr, fArr2, fArr3);
                        andIncrement = atomicInteger.getAndIncrement();
                    }
                }
            });
        }
        startAndJoin(threadArr);
    }

    private static void mul4(final float[][] fArr, final float[][] fArr2, final float[][] fArr3) {
        checkDimensions(fArr, fArr2, fArr3);
        int length = fArr3[0].length;
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(8);
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(newFixedThreadPool);
        for (int i = 0; i < length; i++) {
            final int i2 = i;
            executorCompletionService.submit(new Runnable() { // from class: edu.mines.jtk.bench.MtMatMulBench.3
                @Override // java.lang.Runnable
                public void run() {
                    MtMatMulBench.computeColumn(i2, fArr, fArr2, fArr3);
                }
            }, null);
        }
        for (int i3 = 0; i3 < length; i3++) {
            try {
                executorCompletionService.take();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        newFixedThreadPool.shutdown();
    }

    private static void startAndJoin(Thread[] threadArr) {
        for (Thread thread : threadArr) {
            thread.start();
        }
        for (Thread thread2 : threadArr) {
            try {
                thread2.join();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private static void assertEquals(float[][] fArr, float[][] fArr2) {
        Check.state(fArr.length == fArr2.length, "same dimensions");
        Check.state(fArr[0].length == fArr2[0].length, "same dimensions");
        int length = fArr[0].length;
        int length2 = fArr.length;
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                Check.state(fArr[i][i2] == fArr2[i][i2], "same elements");
            }
        }
    }

    private static void checkDimensions(float[][] fArr, float[][] fArr2, float[][] fArr3) {
        int length = fArr.length;
        int length2 = fArr[0].length;
        int length3 = fArr2.length;
        int length4 = fArr2[0].length;
        int length5 = fArr3.length;
        int length6 = fArr3[0].length;
        Check.argument(length2 == length3, "number of columns in A = number of rows in B");
        Check.argument(length == length5, "number of rows in A = number of rows in C");
        Check.argument(length4 == length6, "number of columns in B = number of columns in C");
    }
}
