Rog.java
package space.sunqian.common.base.random;
import space.sunqian.annotations.Nonnull;
import space.sunqian.annotations.Nullable;
import space.sunqian.common.Check;
import space.sunqian.common.Fs;
import space.sunqian.common.base.exception.UnreachablePointException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.LongSupplier;
import java.util.function.Supplier;
/**
* The Random Object Generator, to produce random objects. The usage is simple:
* <pre>{@code
* Rog<String> rog = Rog.newBuilder()
* .weight(10, "a")
* .weight(10, "b")
* .build();
* String randomString = rog.next();
* }</pre>
*
* @author sunqian
*/
public interface Rog<T> {
/**
* Returns a new builder for generating a new {@link Rog} instances.
*
* @param <T> the type of the generated object
* @return a new builder for generating a new {@link Rog} instances
*/
static <T> @Nonnull Builder<T> newBuilder() {
return new Builder<>();
}
/**
* Returns next random object.
*
* @return next random object
*/
T next();
/**
* Builder for generating a new {@link Rog} instance.
* <p>
* This Builder generates {@link Rog} by setting {@code rng} (random number generator) and weights. For example:
* <pre>{@code
* Rog<String> rog = Rog.newBuilder()
* .weight(20, "a") // 20% to generate "a"
* .weight(80, () -> "b") // 80% to generate "b"
* .rng(Rng.newRng())
* .build();
* }</pre>
* <p>
* The probability of hitting the object or supplier associated with each weight is given by:
* {@code weight / sum(weights)}. The {@code rng} provides the random long value, which is used to calculate the
* hitting probability. If the {@code rng} is not set, {@link Rng#newRng()} is used.
* <p>
* Note that the sum of weights can not overflow the {@link Long#MAX_VALUE}.
*
* @param <T> the type of the random objects
* @author sunqian
*/
class Builder<T> {
private final @Nonnull List<@Nonnull Weight<T>> weights = new ArrayList<>();
private @Nullable LongSupplier rng;
/**
* Add a weight and its corresponding object.
*
* @param weight the weight, cannot be negative
* @param obj the object corresponding to the weight
* @param <T1> the type of the object
* @return this builder
*/
public <T1> @Nonnull Builder<T1> weight(long weight, T obj) throws IllegalArgumentException {
return weight(weight, () -> obj);
}
/**
* Add a weight and its corresponding supplier.
*
* @param weight the weight, cannot be negative
* @param supplier the supplier corresponding to the weight
* @param <T1> the type of the generated object
* @return this builder
*/
public <T1> @Nonnull Builder<T1> weight(
long weight, @Nonnull Supplier<T> supplier
) throws IllegalArgumentException {
Check.checkArgument(weight >= 0, "weight must be non-negative");
weights.add(new Weight<>(weight, supplier));
return Fs.as(this);
}
/**
* Set the random number generator. If this is not set, {@link Rng#newRng()} will be used.
*
* @param rng the random number generator
* @param <T1> the type of the generated object
* @return this builder
*/
public <T1> @Nonnull Builder<T1> rng(@Nonnull LongSupplier rng) {
this.rng = rng;
return Fs.as(this);
}
/**
* Builds and returns a new {@link Rog} instance with the added weights, objects and suppliers.
*
* @param <T1> the type of the generated object
* @return a new {@link Rog} instance with the added weights, objects and suppliers
*/
public <T1> @Nonnull Rog<T1> build() {
return Fs.as(
new RogImpl<>(rng == null ? Rng.newRng() : rng, weights)
);
}
private static final class RogImpl<T> implements Rog<T> {
private final @Nonnull LongSupplier rng;
private final @Nonnull WeightNode<T>[] nodes;
private final long totalWeight;
@SuppressWarnings("unchecked")
RogImpl(
@Nonnull LongSupplier rng,
@Nonnull List<@Nonnull Weight<T>> weights
) {
this.rng = rng;
List<WeightNode<T>> nodes = new ArrayList<>(weights.size());
long totalScore = 0;
for (Weight<T> weight : weights) {
long from = totalScore;
totalScore += weight.weight;
long to = totalScore;
nodes.add(new WeightNode<>(weight.supplier, from, to));
}
this.nodes = nodes.toArray(new WeightNode[0]);
this.totalWeight = totalScore;
}
@Override
public T next() {
long score = Math.abs(rng.getAsLong()) % totalWeight;
WeightNode<T> weight = getWeight(score);
return weight.supplier.get();
}
private @Nonnull WeightNode<T> getWeight(long score) {
int index = binarySearch(score);
if (index < 0) {
throw new UnreachablePointException("Weight not found by score: " + score + ".");
}
return nodes[index];
}
private int binarySearch(long score) {
int left = 0;
int right = nodes.length - 1;
while (left <= right) {
// int mid = (left + right) / 2;
int mid = (left + right) >>> 1;
WeightNode<T> weight = nodes[mid];
long compare = compare(score, weight);
if (compare == 0) {
return mid;
}
if (compare > 0) {
left = mid + 1;
} else {
right = mid - 1;
}
}
return -1;
}
private long compare(long next, WeightNode<T> node) {
if (next < node.from) {
return -1;
}
if (next >= node.to) {
return 1;
}
return 0;
}
}
private static final class Weight<T> {
private final long weight;
private final @Nonnull Supplier<T> supplier;
private Weight(long weight, @Nonnull Supplier<T> supplier) {
this.weight = weight;
this.supplier = supplier;
}
}
private static final class WeightNode<T> {
private final @Nonnull Supplier<T> supplier;
private final long from;
private final long to;
private WeightNode(@Nonnull Supplier<T> supplier, long from, long to) {
this.supplier = supplier;
this.from = from;
this.to = to;
}
}
}
}