ServerSocketChannel实现多Selector高并发server

时间:2023-03-10 01:52:45
ServerSocketChannel实现多Selector高并发server

参考hbase RpcServer,编写了一个简洁版多Selector server,对nio怎么用,Selector如何选择事件会有更深入的认识。

client端发送消息:内容长度 + 内容,200线程同时发送

server端接收消息:解析内容长度和内容,返回2MB测试数据给客户端

Server端:一个accept selector,多个read selector,一个write selector

package com.ai.nio;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.LinkedBlockingQueue; /**
* Created by wangkai8 on 17/1/5.
*/
public class Server { public static final Log LOG = LogFactory.getLog(Server.class); private BlockingQueue<Call> queue = new LinkedBlockingQueue<Call>(); private Queue<Call> responseCalls = new ConcurrentLinkedQueue<Call>(); volatile boolean running = true; private Responder responder = null; private static int NIO_BUFFER_LIMIT = 64 * 1024; private int handler = 10; class Listener extends Thread { Selector selector;
Reader[] readers;
int robin;
int readNum; Listener(int port) throws IOException {
ServerSocketChannel serverChannel = ServerSocketChannel.open();
serverChannel.configureBlocking(false);
serverChannel.socket().bind(new InetSocketAddress(port), 150);
selector = Selector.open();
serverChannel.register(selector, SelectionKey.OP_ACCEPT);
readNum = 10;
readers = new Reader[readNum];
for(int i = 0; i < readNum; i++) {
readers[i] = new Reader(i);
readers[i].start();
}
} public void run() {
while(running) {
try {
selector.select();
Iterator<SelectionKey> it = selector.selectedKeys().iterator();
while(it.hasNext()) {
SelectionKey key = it.next();
it.remove();
if(key.isValid()) {
if(key.isAcceptable()) {
doAccept(key);
}
}
}
} catch (IOException e) {
LOG.error("", e);
}
}
} public void doAccept(SelectionKey selectionKey) throws IOException {
ServerSocketChannel serverSocketChannel = (ServerSocketChannel) selectionKey.channel();
SocketChannel socketChannel;
while((socketChannel = serverSocketChannel.accept()) != null) {
try {
socketChannel.configureBlocking(false);
socketChannel.socket().setTcpNoDelay(true);
socketChannel.socket().setKeepAlive(true);
} catch (IOException e) {
socketChannel.close();
throw e;
}
Reader reader = getReader();
try {
reader.startAdd();
SelectionKey readKey = reader.registerChannel(socketChannel);
Connection c = new Connection(socketChannel);
readKey.attach(c);
} finally {
reader.finishAdd();
}
}
} public Reader getReader() {
if(robin == Integer.MAX_VALUE) {
robin = 0;
}
return readers[(robin ++) % readNum];
}
} class Reader extends Thread { Selector readSelector;
boolean adding; Reader(int i) throws IOException {
setName("Reader-" + i);
this.readSelector = Selector.open();
LOG.info("Starting Reader-" + i + "...");
} @Override
public void run() {
while(running) {
try {
readSelector.select();
while(adding) {
synchronized(this) {
this.wait(1000);
}
} Iterator<SelectionKey> it = readSelector.selectedKeys().iterator();
while(it.hasNext()) {
SelectionKey key = it.next();
it.remove();
if(key.isValid()) {
if(key.isReadable()) {
doRead(key);
}
}
}
} catch (IOException e) {
LOG.error("", e);
} catch (InterruptedException e) {
LOG.error("", e);
}
}
} public void doRead(SelectionKey selectionKey) {
Connection c = (Connection) selectionKey.attachment();
if(c == null) {
return;
} int n;
try {
n = c.readAndProcess();
} catch (IOException e) {
LOG.error("", e);
n = -1;
} catch (Exception e) {
LOG.error("", e);
n = -1;
}
if(n == -1) {
c.close();
}
} public SelectionKey registerChannel(SocketChannel channel) throws IOException {
return channel.register(readSelector, SelectionKey.OP_READ);
} public void startAdd() {
adding = true;
readSelector.wakeup();
} public synchronized void finishAdd() {
adding = false;
this.notify();
}
} class Connection {
private SocketChannel channel;
private ByteBuffer dataBufferLength;
private ByteBuffer dataBuffer;
private boolean skipHeader; public Connection(SocketChannel channel) {
this.channel = channel;
this.dataBufferLength = ByteBuffer.allocate(4);
} public int readAndProcess() throws IOException {
int count;
if(!skipHeader) {
count = channelRead(channel, dataBufferLength);
if (count < 0 || dataBufferLength.remaining() > 0) {
return count;
}
} skipHeader = true; if(dataBuffer == null) {
dataBufferLength.flip();
int dataLength = dataBufferLength.getInt();
dataBuffer = ByteBuffer.allocate(dataLength);
} count = channelRead(channel, dataBuffer); if(count >= 0 && dataBuffer.remaining() == 0) {
process();
} return count;
} /**
* process the dataBuffer
*/
public void process() {
dataBuffer.flip();
byte[] data = dataBuffer.array();
Call call = new Call(this, data, responder);
try {
queue.put(call);
} catch (InterruptedException e) {
LOG.error("", e);
} } public void close() {
if(channel != null) {
try {
channel.close();
} catch (IOException e) {
}
}
}
} class Responder extends Thread { Selector writeSelector; public Responder() throws IOException {
writeSelector = Selector.open();
} public void run() {
while(running) {
try {
registWriters();
int n = writeSelector.select(1000);
if(n == 0) {
continue;
}
Iterator<SelectionKey> it = writeSelector.selectedKeys().iterator();
while(it.hasNext()) {
SelectionKey key = it.next();
it.remove();
if(key.isValid() && key.isWritable()) {
doAsyncWrite(key);
}
}
} catch (IOException e) {
LOG.error("", e);
}
}
} public void registWriters() throws IOException {
Iterator<Call> it = responseCalls.iterator();
while(it.hasNext()) {
Call call = it.next();
it.remove();
SelectionKey key = call.conn.channel.keyFor(writeSelector);
try {
if (key == null) {
try {
call.conn.channel.register(writeSelector, SelectionKey.OP_WRITE, call);
} catch (ClosedChannelException e) {
//the client went away
if (LOG.isTraceEnabled())
LOG.trace("the client went away", e);
}
} else {
key.interestOps(SelectionKey.OP_WRITE);
}
} catch (CancelledKeyException e) {
if (LOG.isTraceEnabled())
LOG.trace("the client went away", e);
}
}
} public void registerForWrite(Call call) throws IOException {
responseCalls.add(call);
writeSelector.wakeup();
} private void doAsyncWrite(SelectionKey key) throws IOException {
Call call = (Call) key.attachment();
if(call.conn.channel != key.channel()) {
throw new IOException("bad channel");
}
int numBytes = channelWrite(call.conn.channel, call.response);
if(numBytes < 0 || call.response.remaining() == 0) {
try {
key.interestOps(0);
} catch (CancelledKeyException e) {
LOG.warn("Exception while changing ops : " + e);
}
}
} private void doResponse(Call call) throws IOException {
//if data not fully send, then register the channel for async writer
if(!processResponse(call)) {
registerForWrite(call);
}
} private boolean processResponse(Call call) throws IOException {
boolean error = true;
try {
int numBytes = channelWrite(call.conn.channel, call.response);
if (numBytes < 0) {
throw new IOException("error socket write");
}
error = false;
} finally {
if(error) {
call.conn.close();
}
}
if(!call.response.hasRemaining()) {
call.done = true;
return true;
}
return false;
}
} class Handler extends Thread { public Handler(int i) {
setName("handler-" + i);
LOG.info("Starting Handler-" + i + "...");
} public void run() {
while(running) {
try {
Call call = queue.take();
process(call);
} catch (InterruptedException e) {
LOG.error("", e);
} catch (IOException e) {
LOG.error("", e);
}
}
} public void process(Call call) throws IOException {
byte[] request = call.request;
String message = new String(request);
LOG.info("received mseesage: " + message); //each channel write 2MB data for test
int dataLength = 2 * 1024 * 1024;
ByteBuffer buffer = ByteBuffer.allocate(4 + dataLength); buffer.putInt(dataLength);
writeDataForTest(buffer);
buffer.flip(); call.response = buffer;
responder.doResponse(call);
}
} public void writeDataForTest(ByteBuffer buffer) {
int n = buffer.limit() - 4;
for(int i = 0; i < n; i++) {
buffer.put((byte)0);
}
} class Call {
Connection conn;
byte[] request;
Responder responder;
ByteBuffer response;
boolean done;
public Call(Connection conn, byte[] request, Responder responder) {
this.conn = conn;
this.request = request;
this.responder = responder;
}
} public int channelRead(ReadableByteChannel channel, ByteBuffer buffer) throws IOException {
return buffer.remaining() <= NIO_BUFFER_LIMIT ? channel.read(buffer) : channleIO(channel, null, buffer);
} public int channelWrite(WritableByteChannel channel, ByteBuffer buffer) throws IOException {
return buffer.remaining() <= NIO_BUFFER_LIMIT ? channel.write(buffer) : channleIO(null, channel, buffer);
} public int channleIO(ReadableByteChannel readCh, WritableByteChannel writeCh, ByteBuffer buffer) throws IOException {
int initRemaining = buffer.remaining();
int originalLimit = buffer.limit(); int ret = 0;
try {
while (buffer.remaining() > 0) {
int ioSize = Math.min(buffer.remaining(), NIO_BUFFER_LIMIT);
buffer.limit(buffer.position() + ioSize);
ret = readCh == null ? writeCh.write(buffer) : readCh.read(buffer);
if (ret < ioSize) {
break;
}
}
} finally {
buffer.limit(originalLimit);
} int byteRead = initRemaining - buffer.remaining();
return byteRead > 0 ? byteRead : ret;
} public void startHandler() {
for(int i = 0; i < handler; i++) {
new Handler(i).start();
}
} public void start() throws IOException {
new Listener(10000).start();
responder = new Responder();
responder.start();
startHandler();
LOG.info("server startup! ");
} public static void main(String[] args) throws IOException {
Server server = new Server();
server.start();
}
}

Client端:

package com.ai.nio;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import javax.net.SocketFactory;
import java.io.*;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel; /**
* Created by wangkai8 on 17/1/6.
*/
public class Client { public static final Log LOG = LogFactory.getLog(Client.class); Socket socket;
OutputStream out;
InputStream in; public Client() throws IOException {
socket = SocketFactory.getDefault().createSocket();
socket.setTcpNoDelay(true);
socket.setKeepAlive(true);
InetSocketAddress server = new InetSocketAddress("localhost", 10000);
socket.connect(server, 10000);
out = socket.getOutputStream();
in = socket.getInputStream();
} public void send(String message) throws IOException {
byte[] data = message.getBytes();
DataOutputStream dos = new DataOutputStream(out);
dos.writeInt(data.length);
dos.write(data);
out.flush();
} public static void main(String[] args) throws IOException {
int n = 200;
for(int i = 0; i < n; i++) {
new Thread() {
Client client = new Client(); public void run() {
try {
client.send(getName() + "_xiaomiemie"); DataInputStream inputStream = new DataInputStream(client.in);
int dataLength = inputStream.readInt();
byte[] data = new byte[dataLength];
inputStream.readFully(data);
client.socket.close();
LOG.info("receive from server: dataLength=" + data.length);
} catch (IOException e) {
LOG.error("", e);
} catch (Exception e) {
LOG.error("", e);
}
}
}.start();
}
} }

转载请标注原文地址:http://www.cnblogs.com/yueweimian/p/6262211.html