TcpServerBuilder.java
package space.sunqian.common.net.tcp;
import space.sunqian.annotations.Nonnull;
import space.sunqian.annotations.Nullable;
import space.sunqian.common.Check;
import space.sunqian.common.Fs;
import space.sunqian.common.base.function.callable.VoidCallable;
import space.sunqian.common.collect.ListKit;
import space.sunqian.common.io.IOKit;
import space.sunqian.common.io.IOOperator;
import space.sunqian.common.io.communicate.AbstractChannelContext;
import space.sunqian.common.net.NetException;
import space.sunqian.common.net.NetServer;
import java.net.InetSocketAddress;
import java.net.SocketOption;
import java.net.StandardSocketOptions;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadFactory;
/**
* Builder for building new instances of {@link TcpServer} by {@link ServerSocketChannel} and {@link SocketChannel}.
* <p>
* The server built by this builder requires a main thread and at least one worker thread, the main thread is
* responsible for accepting new client, and the worker threads are responsible for handling connected client. A client
* is always handled by one worker thread, so there is no client thread safety issues in the {@link TcpServerHandler}.
*
* @author sunqian
*/
public class TcpServerBuilder {
private @Nonnull TcpServerHandler handler = TcpServerHandler.nullHandler();
private int workerThreadNum = 1;
private @Nullable ThreadFactory mainThreadFactory;
private @Nullable ThreadFactory workerThreadFactory;
private final @Nonnull Map<SocketOption<?>, Object> socketOptions = new LinkedHashMap<>();
private long selectTimeout = 0;
private int bufSize = IOKit.bufferSize();
/**
* Sets the handler to handle server events. The default handler is {@link TcpServerHandler#nullHandler()}.
*
* @param handler the handler to handle server events
* @return this builder
*/
public @Nonnull TcpServerBuilder handler(@Nonnull TcpServerHandler handler) {
this.handler = handler;
return this;
}
/**
* Sets the main thread factory to create main thread. The main thread is responsible for accepting new client, and
* then the worker thread will take over the already connected clients.
* <p>
* If the factory is not configured, the server will use {@link Thread#Thread(Runnable)}.
*
* @param mainThreadFactory the main thread factory
* @return this builder
*/
public @Nonnull TcpServerBuilder mainThreadFactory(@Nonnull ThreadFactory mainThreadFactory) {
this.mainThreadFactory = mainThreadFactory;
return this;
}
/**
* Sets the worker thread factory to create worker thread. The main thread is responsible for accepting new client,
* and then the worker thread will take over the already connected clients.
* <p>
* If the factory is not configured, the server will use {@link Thread#Thread(Runnable)}.
*
* @param workerThreadFactory the worker thread factory
* @return this builder
*/
public @Nonnull TcpServerBuilder workerThreadFactory(@Nonnull ThreadFactory workerThreadFactory) {
this.workerThreadFactory = workerThreadFactory;
return this;
}
/**
* Sets the number of worker thread. The default is {@code 1}.
*
* @param workThreadNum the number of worker thread, must {@code >= 1}
* @return this builder
* @throws IllegalArgumentException if the number is negative or {@code 0}
* @see #workerThreadFactory(ThreadFactory)
*/
public @Nonnull TcpServerBuilder workerThreadNum(int workThreadNum) throws IllegalArgumentException {
Check.checkArgument(workThreadNum >= 1, "workThreadNum must >= 1");
this.workerThreadNum = workThreadNum;
return this;
}
/**
* Sets the buffer size for advanced IO operations. Note this buffer size is not the kernel network buffer size, it
* is an I/O advanced operations buffer size.
*
* @param bufSize the buffer size for advanced IO operations
* @return this builder
* @throws IllegalArgumentException if the buffer size is negative or {@code 0}
*/
public @Nonnull TcpServerBuilder bufferSize(int bufSize) throws IllegalArgumentException {
Check.checkArgument(bufSize > 0, "bufSize must be positive");
this.bufSize = bufSize;
return this;
}
/**
* Sets a socket option. This method can be invoked multiple times to set different socket options.
*
* @param <T> the type of the socket option value
* @param name the socket option
* @param value the value of the socket option, a value of {@code null} may be a valid value for some socket
* options.
* @return this builder
* @throws NetException If an error occurs
* @see StandardSocketOptions
*/
public <T> @Nonnull TcpServerBuilder socketOption(@Nonnull SocketOption<T> name, T value) throws NetException {
socketOptions.put(name, value);
return this;
}
/**
* Sets the timeout for underlying {@link Selector#select(long)}, in milliseconds. This timeout must {@code >= 0},
* and will affect the triggering interval of {@link TcpServerHandler#channelLoop(TcpContext)}. If it is {@code 0},
* there may be a large interval or even never triggering.
*
* @param selectTimeout the timeout for underlying {@link Selector#select(long)}, in milliseconds, must
* {@code >= 0}
* @return this builder
* @throws IllegalArgumentException if the timeout is negative
*/
public @Nonnull TcpServerBuilder selectTimeout(long selectTimeout) throws IllegalArgumentException {
Check.checkArgument(selectTimeout >= 0, "selectTimeout must >= 0");
this.selectTimeout = selectTimeout;
return this;
}
/**
* Binds the server's socket to the automatically assigned address and configures the socket to listen for
* connections. And a new {@link TcpServer} instance is returned.
*
* @return a new {@link TcpServer} instance
* @throws NetException If an error occurs
*/
public @Nonnull TcpServer bind() throws NetException {
return bind(null);
}
/**
* Binds the server's socket to the specified local address and configures the socket to listen for connections. And
* a new {@link TcpServer} instance is returned.
*
* @param localAddress the local address the server is bound to, may be {@code null} to bind to the automatically
* assigned address
* @return a new {@link TcpServer} instance
* @throws NetException If an error occurs
*/
public @Nonnull TcpServer bind(@Nullable InetSocketAddress localAddress) throws NetException {
return bind(localAddress, 0);
}
/**
* Binds the server's socket to the specified local address and configures the socket to listen for connections. And
* a new {@link TcpServer} instance is returned.
* <p>
* The {@code backlog} is the maximum number of pending connections on the socket. If the {@code backlog} parameter
* has the value 0, or a negative value, then a default value is used.
*
* @param localAddress the local address the server is bound to, may be {@code null} to bind to the automatically
* assigned address
* @param backlog the maximum number of pending connections
* @return a new {@link TcpServer} instance
* @throws NetException If an error occurs
*/
public @Nonnull TcpServer bind(@Nullable InetSocketAddress localAddress, int backlog) throws NetException {
return Fs.uncheck(() -> new TcpServerImpl(
localAddress,
handler,
mainThreadFactory,
workerThreadFactory,
workerThreadNum,
socketOptions,
selectTimeout,
backlog,
bufSize
),
NetException::new
);
}
private static final class TcpServerImpl implements TcpServer, Runnable {
private final @Nonnull ServerSocketChannel server;
private final @Nonnull Selector mainSelector;
private final long selectTimeout;
private final @Nonnull Thread mainThread;
private final @Nonnull WorkerImpl @Nonnull [] workers;
private final @Nonnull TcpServerHandler handler;
private final @Nonnull InetSocketAddress localAddress;
private final int bufSize;
private volatile boolean closed = false;
@SuppressWarnings("resource")
private TcpServerImpl(
@Nullable InetSocketAddress localAddress,
@Nonnull TcpServerHandler handler,
@Nullable ThreadFactory mainthreadFactory,
@Nullable ThreadFactory workerthreadFactory,
int workThreadNum,
Map<SocketOption<?>, Object> socketOptions,
long selectTimeout,
int backlog,
int bufSize
) throws Exception {
this.server = ServerSocketChannel.open();
this.mainSelector = Selector.open();
this.handler = handler;
this.mainThread = newThread(mainthreadFactory, this);
this.workers = new WorkerImpl[workThreadNum];
this.selectTimeout = selectTimeout;
this.bufSize = bufSize;
server.configureBlocking(false);
socketOptions.forEach((name, value) ->
Fs.uncheck(() -> server.setOption(Fs.as(name), value), NetException::new));
server.register(mainSelector, SelectionKey.OP_ACCEPT);
for (int i = 0; i < workThreadNum; i++) {
WorkerImpl worker = new WorkerImpl();
workers[i] = worker;
worker.thread = newThread(workerthreadFactory, worker);
}
server.bind(localAddress, backlog);
this.localAddress = (InetSocketAddress) server.getLocalAddress();
mainThread.start();
}
private @Nonnull Thread newThread(@Nullable ThreadFactory factory, @Nonnull Runnable runnable) {
return factory == null ? new Thread(runnable) : factory.newThread(runnable);
}
@Override
public void await() throws NetException {
Fs.uncheck(mainThread::join);
for (WorkerImpl worker : workers) {
Fs.uncheck(worker.thread::join);
}
}
@Override
public synchronized void close() throws NetException {
if (closed) {
return;
}
Fs.uncheck(() -> {
server.close();
mainSelector.close();
mainSelector.wakeup();
mainThread.interrupt();
},
NetException::new
);
releaseWorkers();
closed = true;
}
@Override
public @Nonnull InetSocketAddress localAddress() throws NetException {
return localAddress;
}
@Override
public @Nonnull List<NetServer.@Nonnull Worker> workers() {
return ListKit.list(workers);
}
@Override
public boolean isClosed() {
return closed;
}
@Override
public void run() {
for (WorkerImpl worker : workers) {
worker.thread.start();
}
while (!mainThread.isInterrupted()) {
doWork(this::doMainWork, closed);
}
releaseWorkers();
Fs.uncheck(() -> {
server.close();
mainSelector.close();
}, NetException::new);
}
private void doMainWork() throws Exception {
mainSelector.select();
Set<SelectionKey> selectedKeys = mainSelector.selectedKeys();
Iterator<SelectionKey> keys = selectedKeys.iterator();
while (keys.hasNext()) {
SelectionKey key = keys.next();
keys.remove();
handleAccept(key, workers);
// key.cancel();
}
}
@SuppressWarnings("resource")
private void handleAccept(SelectionKey key, WorkerImpl[] workers) throws Exception {
ServerSocketChannel server = (ServerSocketChannel) key.channel();
SocketChannel client = server.accept();
int index = findWorker(workers);
workers[index].registerClient(client);
workers[index].selector.wakeup();
}
private int findWorker(WorkerImpl[] workers) {
int index = 0;
int minClientCount = Integer.MAX_VALUE;
for (int i = 0; i < workers.length; i++) {
int clientCount = workers[i].clientSet.size();
if (clientCount < minClientCount) {
minClientCount = clientCount;
index = i;
}
}
return index;
}
private void releaseWorkers() {
for (WorkerImpl worker : workers) {
worker.thread.interrupt();
}
for (WorkerImpl worker : workers) {
try {
worker.thread.join();
} catch (InterruptedException ignored) {
}
}
}
private void doWork(VoidCallable callable, boolean closed) {
if (closed) {
return;
}
try {
callable.call();
} catch (Exception e) {
handler.exceptionCaught(null, e);
}
}
private final class WorkerImpl implements Worker, Runnable {
private final @Nonnull Selector selector;
private final @Nonnull Set<ContextImpl> clientSet = new HashSet<>();
// the thread this worker starts on
private Thread thread;
private volatile @Nonnull AcceptedEvent acceptedEvent = new AcceptedEvent();
private WorkerImpl() {
this.selector = Fs.uncheck(Selector::open, NetException::new);
}
public void registerClient(SocketChannel client) {
AcceptedEvent newAc = new AcceptedEvent(client);
AcceptedEvent event = this.acceptedEvent;
while (true) {
AcceptedEvent next = event.next;
if (next == null) {
event.next = newAc;
break;
} else {
event = next;
}
}
}
//@SuppressWarnings({"InfiniteLoopStatement"})
@Override
public void run() {
Thread thread = Thread.currentThread();
while (!thread.isInterrupted()) {
doWork(this::doWorkerWork, closed);
}
releaseClients();
Fs.uncheck(selector::close, NetException::new);
}
private void doWorkerWork() throws Exception {
// register read event
handleOpen();
// read event
handleRead();
// loop event
handleLoop();
// remove closed client
handleClose();
}
private void handleOpen() throws Exception {
@Nonnull AcceptedEvent event = this.acceptedEvent;
while (true) {
SocketChannel channel = event.channel;
if (channel != null) {
ContextImpl context = new ContextImpl(channel, bufSize);
clientSet.add(context);
registerRead(context);
event.channel = null;
TcpKit.channelOpen(handler, context);
}
AcceptedEvent next = event.next;
if (next == null) {
this.acceptedEvent = event;
break;
} else {
event = next;
}
}
}
@SuppressWarnings("resource")
private void registerRead(ContextImpl context) throws Exception {
SocketChannel channel = context.channel();
channel.configureBlocking(false);
channel.register(selector, SelectionKey.OP_READ, context);
}
private void handleRead() throws Exception {
int keysNum = selector.select(selectTimeout);
if (keysNum == 0) {
return;
}
Set<SelectionKey> selectedKeys = selector.selectedKeys();
Iterator<SelectionKey> keys = selectedKeys.iterator();
while (keys.hasNext()) {
SelectionKey key = keys.next();
keys.remove();
TcpKit.channelRead(handler, (ContextImpl) key.attachment());
}
}
private void handleLoop() {
for (ContextImpl context : clientSet) {
TcpKit.channelLoop(handler, context);
}
}
@SuppressWarnings("resource")
private void handleClose() {
Iterator<ContextImpl> iterator = clientSet.iterator();
while (iterator.hasNext()) {
ContextImpl context = iterator.next();
if (!context.channel().isOpen()) {
context.close();
iterator.remove();
}
}
}
@Override
public int connectionNumber() {
return clientSet.size();
}
@Override
public @Nonnull Thread thread() {
return thread;
}
private void releaseClients() {
for (ContextImpl context : clientSet) {
context.close();
}
}
private final class ContextImpl extends AbstractChannelContext<SocketChannel> implements TcpContext {
// private final @Nonnull SocketChannel channel;
private final @Nonnull InetSocketAddress clientAddress;
private final @Nonnull InetSocketAddress serverAddress;
private final @Nonnull IOOperator ioOperator;
private volatile boolean closed = false;
private ContextImpl(@Nonnull SocketChannel channel, int bufSize) throws IllegalArgumentException {
super(channel);
this.clientAddress = (InetSocketAddress) Fs.uncheck(channel::getRemoteAddress, NetException::new);
this.serverAddress = (InetSocketAddress) Fs.uncheck(channel::getLocalAddress, NetException::new);
this.ioOperator = IOOperator.get(bufSize);
}
@Override
public @Nonnull InetSocketAddress clientAddress() {
return clientAddress;
}
@Override
public @Nonnull InetSocketAddress serverAddress() {
return serverAddress;
}
@Override
public synchronized void close() throws NetException {
if (closed) {
return;
}
Fs.uncheck(() -> {
SocketChannel channel = channel();
channel.close();
channel.keyFor(selector).cancel();
TcpKit.channelClose(handler, this);
}, NetException::new);
closed = true;
}
@Override
protected @Nonnull IOOperator ioOperator() {
return ioOperator;
}
}
}
private static final class AcceptedEvent {
private volatile @Nullable SocketChannel channel;
private volatile @Nullable AcceptedEvent next;
private AcceptedEvent(@Nonnull SocketChannel channel) {
this.channel = channel;
}
private AcceptedEvent() {
this.channel = null;
}
}
}
}