package simdbitset;

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

public final class TreiberStackTest {
  @Test
  public void concurrentPush() throws InterruptedException {
    var threadCount = 4;
    var pushesPerThread = 100_000;

    var stack = new TreiberStack<Integer>();
    var threads = new Thread[threadCount];
    for (var t = 0; t < threadCount; t++) {
      var base = t * pushesPerThread;
      threads[t] = new Thread(() -> {
        for (var i = 0; i < pushesPerThread; i++) {
          stack.push(base + i);
        }
      });
    }
    for (var thread : threads) { thread.start(); }
    for (var thread : threads) { thread.join(); }

    // Count how many elements are in the stack
    var count = 0;
    while (stack.pop() != null) { count++; }

    var expected = threadCount * pushesPerThread;
    assertEquals(expected, count);
  }

  @Test
  public void concurrentPop() throws InterruptedException {

    var threadCount = 4;
    var pushesPerThread = 100_000;

    var stack = new TreiberStack<Integer>();
    var threads = new Thread[threadCount];
    var totalElements = threadCount * pushesPerThread;
    for (var i = 0; i < totalElements; i++) {
      stack.push(i);
    }

    var popCounts = new int[threadCount];
    for (var t = 0; t < threadCount; t++) {
      var index = t;
      threads[t] = new Thread(() -> {
        var local = 0;
        while (stack.pop() != null) { local++; }
        popCounts[index] = local;
      });
    }
    for (var thread : threads) { thread.start(); }
    for (var thread : threads) { thread.join(); }

    // Every element should have been popped exactly once
    var totalPopped = 0;
    for (var count : popCounts) { totalPopped += count; }

    assertEquals(totalElements, totalPopped);
  }

  @Test
  public void pushAllIsAtomic() throws InterruptedException {
    record Item(long batchId, int index) {}

    var iterationsPerThread = 100_000;
    var producerCount = 4;

    var stack = new TreiberStack<Item>();

    var threads = new Thread[producerCount];
    for (var t = 0; t < producerCount; t++) {
      var startId = ((long) t) * iterationsPerThread;
      threads[t] = new Thread(() -> {
        for (var i = 0; i < iterationsPerThread; i++) {
          var id = startId + i;
          stack.pushAll(List.of(
              new Item(id, 0),
              new Item(id, 1),
              new Item(id, 2)
          ));
        }
      });
    }
    for (var thread : threads) { thread.start(); }
    for (var thread : threads) { thread.join(); }

    var items = new ArrayList<Item>();
    Item item;
    while ((item = stack.pop()) != null) {
      items.add(item);
    }

    for (var i = 0; i < items.size(); i += 3) {
      var first  = items.get(i);
      var second = items.get(i + 1);
      var third  = items.get(i + 2);

      assertAll(
          "Batch " + first.batchId() + " was not pushed atomically",
          () -> assertEquals(first.batchId(), second.batchId()),
          () -> assertEquals(first.batchId(), third.batchId()),
          () -> assertEquals(2, first.index()),
          () -> assertEquals(1, second.index()),
          () -> assertEquals(0, third.index())
      );
    }
  }
}