// from https://android.googlesource.com/platform/external/nanohttpd/+/42ff2a9/websocket/src/main/java/fi/iki/elonen/ import java.nio.charset.*; import java.nio.*; sbool webSocket_debug = true; sclass WebSocketException extends IOException { private WebSocketFrame.CloseCode code; private String reason; public WebSocketException(Exception cause) { this(WebSocketFrame.CloseCode.InternalServerError, cause.toString(), cause); } public WebSocketException(WebSocketFrame.CloseCode code, String reason) { this(code, reason, null); } public WebSocketException(WebSocketFrame.CloseCode code, String reason, Exception cause) { super(code + ": " + reason, cause); this.code = code; this.reason = reason; } public WebSocketFrame.CloseCode getCode() { return code; } public String getReason() { return reason; } } sclass WebSocketFrame { private OpCode opCode; private boolean fin; private byte[] maskingKey; private byte[] payload; private transient int _payloadLength; private transient String _payloadString; private WebSocketFrame(OpCode opCode, boolean fin) { setOpCode(opCode); setFin(fin); } public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload, byte[] maskingKey) { this(opCode, fin); setMaskingKey(maskingKey); setBinaryPayload(payload); } public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload) { this(opCode, fin, payload, null); } public WebSocketFrame(OpCode opCode, boolean fin, String payload, byte[] maskingKey) throws CharacterCodingException { this(opCode, fin); setMaskingKey(maskingKey); setTextPayload(payload); } public WebSocketFrame(OpCode opCode, boolean fin, String payload) throws CharacterCodingException { this(opCode, fin, payload, null); } public WebSocketFrame(WebSocketFrame clone) { setOpCode(clone.getOpCode()); setFin(clone.isFin()); setBinaryPayload(clone.getBinaryPayload()); setMaskingKey(clone.getMaskingKey()); } public WebSocketFrame(OpCode opCode, List fragments) throws WebSocketException { setOpCode(opCode); setFin(true); long _payloadLength = 0; for (WebSocketFrame inter : fragments) { _payloadLength += inter.getBinaryPayload().length; } if (webSocket_debug) print("Payload length with " + nFragments(fragments) + ": " + _payloadLength); if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) { throw new WebSocketException(WebSocketFrame.CloseCode.MessageTooBig, "Max frame length has been exceeded."); } this._payloadLength = (int) _payloadLength; byte[] payload = new byte[this._payloadLength]; int offset = 0; for (WebSocketFrame inter : fragments) { System.arraycopy(inter.getBinaryPayload(), 0, payload, offset, inter.getBinaryPayload().length); offset += inter.getBinaryPayload().length; } setBinaryPayload(payload); } // --------------------------------GETTERS--------------------------------- public OpCode getOpCode() { return opCode; } public void setOpCode(OpCode opcode) { this.opCode = opcode; } public boolean isFin() { return fin; } public void setFin(boolean fin) { this.fin = fin; } public boolean isMasked() { return maskingKey != null && maskingKey.length == 4; } public byte[] getMaskingKey() { return maskingKey; } public void setMaskingKey(byte[] maskingKey) { if (maskingKey != null && maskingKey.length != 4) { throw new IllegalArgumentException("MaskingKey " + Arrays.toString(maskingKey) + " hasn't length 4"); } this.maskingKey = maskingKey; } public void setUnmasked() { setMaskingKey(null); } public byte[] getBinaryPayload() { return payload; } public void setBinaryPayload(byte[] payload) { this.payload = payload; this._payloadLength = payload.length; this._payloadString = null; } public String getTextPayload() { if (_payloadString == null) { try { _payloadString = binary2Text(getBinaryPayload()); } catch (CharacterCodingException e) { throw new RuntimeException("Undetected CharacterCodingException", e); } } return _payloadString; } public void setTextPayload(String payload) throws CharacterCodingException { this.payload = text2Binary(payload); //this._payloadLength = payload.length(); // buggy! this._payloadLength = this.payload.length; if (webSocket_debug) print("payload length: " + _payloadLength + ", string length: " + l(payload)); // XXX this._payloadString = payload; } // --------------------------------SERIALIZATION--------------------------- public static WebSocketFrame read(InputStream in) throws IOException { byte head = (byte) checkedRead(in.read()); boolean fin = ((head & 0x80) != 0); OpCode opCode = OpCode.find((byte) (head & 0x0F)); if ((head & 0x70) != 0) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "The reserved bits (" + Integer.toBinaryString(head & 0x70) + ") must be 0."); } if (opCode == null) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Received frame with reserved/unknown opcode " + (head & 0x0F) + "."); } else if (opCode.isControlFrame() && !fin) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Fragmented control frame."); } WebSocketFrame frame = new WebSocketFrame(opCode, fin); frame.readPayloadInfo(in); frame.readPayload(in); if (frame.getOpCode() == WebSocketFrame.OpCode.Close) { return new WebSocketFrame.CloseFrame(frame); } else { return frame; } } private static int checkedRead(int read) throws IOException { if (read < 0) { throw new EOFException(); } //System.out.println(Integer.toBinaryString(read) + "/" + read + "/" + Integer.toHexString(read)); return read; } private void readPayloadInfo(InputStream in) throws IOException { byte b = (byte) checkedRead(in.read()); boolean masked = ((b & 0x80) != 0); _payloadLength = (byte) (0x7F & b); if (_payloadLength == 126) { // checkedRead must return int for this to work _payloadLength = (checkedRead(in.read()) << 8 | checkedRead(in.read())) & 0xFFFF; if (_payloadLength < 126) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Invalid data frame 2byte length. (not using minimal length encoding)"); } } else if (_payloadLength == 127) { long _payloadLength = ((long) checkedRead(in.read())) << 56 | ((long) checkedRead(in.read())) << 48 | ((long) checkedRead(in.read())) << 40 | ((long) checkedRead(in.read())) << 32 | checkedRead(in.read()) << 24 | checkedRead(in.read()) << 16 | checkedRead(in.read()) << 8 | checkedRead(in.read()); if (_payloadLength < 65536) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Invalid data frame 4byte length. (not using minimal length encoding)"); } if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) { throw new WebSocketException(WebSocketFrame.CloseCode.MessageTooBig, "Max frame length has been exceeded."); } this._payloadLength = (int) _payloadLength; } if (opCode.isControlFrame()) { if (_payloadLength > 125) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Control frame with payload length > 125 bytes."); } if (opCode == OpCode.Close && _payloadLength == 1) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Received close frame with payload len 1."); } } if (masked) { maskingKey = new byte[4]; int read = 0; while (read < maskingKey.length) { read += checkedRead(in.read(maskingKey, read, maskingKey.length - read)); } } } private void readPayload(InputStream in) throws IOException { payload = new byte[_payloadLength]; int read = 0; while (read < _payloadLength) { read += checkedRead(in.read(payload, read, _payloadLength - read)); } if (isMasked()) { for (int i = 0; i < payload.length; i++) { payload[i] ^= maskingKey[i % 4]; } } //Test for Unicode errors if (getOpCode() == WebSocketFrame.OpCode.Text) { _payloadString = binary2Text(getBinaryPayload()); } } public void write(OutputStream out) throws IOException { byte header = 0; if (fin) { header |= 0x80; } header |= opCode.getValue() & 0x0F; out.write(header); _payloadLength = getBinaryPayload().length; if (_payloadLength <= 125) { if (webSocket_debug) print("Sending short payload: " + _payloadLength); out.write(isMasked() ? 0x80 | (byte) _payloadLength : (byte) _payloadLength); } else if (_payloadLength <= 0xFFFF) { out.write(isMasked() ? 0xFE : 126); out.write(_payloadLength >>> 8); out.write(_payloadLength); } else { out.write(isMasked() ? 0xFF : 127); out.write(_payloadLength >>> 56 & 0); //integer only contains 31 bit out.write(_payloadLength >>> 48 & 0); out.write(_payloadLength >>> 40 & 0); out.write(_payloadLength >>> 32 & 0); out.write(_payloadLength >>> 24); out.write(_payloadLength >>> 16); out.write(_payloadLength >>> 8); out.write(_payloadLength); } if (isMasked()) { out.write(maskingKey); for (int i = 0; i < _payloadLength; i++) { out.write(getBinaryPayload()[i] ^ maskingKey[i % 4]); } } else { out.write(getBinaryPayload()); } out.flush(); } // --------------------------------ENCODING-------------------------------- public static final Charset TEXT_CHARSET = Charset.forName("UTF-8"); public static final CharsetDecoder TEXT_DECODER = TEXT_CHARSET.newDecoder(); public static final CharsetEncoder TEXT_ENCODER = TEXT_CHARSET.newEncoder(); public static String binary2Text(byte[] payload) throws CharacterCodingException { return TEXT_DECODER.decode(ByteBuffer.wrap(payload)).toString(); } public static String binary2Text(byte[] payload, int offset, int length) throws CharacterCodingException { return TEXT_DECODER.decode(ByteBuffer.wrap(payload, offset, length)).toString(); } public static byte[] text2Binary(String payload) throws CharacterCodingException { return TEXT_ENCODER.encode(CharBuffer.wrap(payload)).array(); } @Override public String toString() { final StringBuilder sb = new StringBuilder("WS["); sb.append(getOpCode()); sb.append(", ").append(isFin() ? "fin" : "inter"); sb.append(", ").append(isMasked() ? "masked" : "unmasked"); sb.append(", ").append(payloadToString()); sb.append(']'); return sb.toString(); } protected String payloadToString() { if (payload == null) return "null"; else { final StringBuilder sb = new StringBuilder(); sb.append('[').append(payload.length).append("b] "); if (getOpCode() == WebSocketFrame.OpCode.Text) { String text = getTextPayload(); if (text.length() > 100) sb.append(text.substring(0, 100)).append("..."); else sb.append(text); } else { sb.append("0x"); for (int i = 0; i < Math.min(payload.length, 50); ++i) sb.append(Integer.toHexString((int) payload[i] & 0xFF)); if (payload.length > 50) sb.append("..."); } return sb.toString(); } } // --------------------------------CONSTANTS------------------------------- public static enum OpCode { Continuation(0), Text(1), Binary(2), Close(8), Ping(9), Pong(10); private final byte code; private OpCode(int code) { this.code = (byte) code; } public byte getValue() { return code; } public boolean isControlFrame() { return this == Close || this == Ping || this == Pong; } public static OpCode find(byte value) { for (OpCode opcode : values()) { if (opcode.getValue() == value) { return opcode; } } return null; } } public static enum CloseCode { NormalClosure(1000), GoingAway(1001), ProtocolError(1002), UnsupportedData(1003), NoStatusRcvd(1005), AbnormalClosure(1006), InvalidFramePayloadData(1007), PolicyViolation(1008), MessageTooBig(1009), MandatoryExt(1010), InternalServerError(1011), TLSHandshake(1015); private final int code; private CloseCode(int code) { this.code = code; } public int getValue() { return code; } public static WebSocketFrame.CloseCode find(int value) { for (WebSocketFrame.CloseCode code : values()) { if (code.getValue() == value) { return code; } } return null; } } // ------------------------------------------------------------------------ public static class CloseFrame extends WebSocketFrame { private CloseCode _closeCode; private String _closeReason; private CloseFrame(WebSocketFrame wrap) throws CharacterCodingException { super(wrap); assert wrap.getOpCode() == OpCode.Close; if (wrap.getBinaryPayload().length >= 2) { _closeCode = CloseCode.find((wrap.getBinaryPayload()[0] & 0xFF) << 8 | (wrap.getBinaryPayload()[1] & 0xFF)); _closeReason = binary2Text(getBinaryPayload(), 2, getBinaryPayload().length - 2); } } public CloseFrame(CloseCode code, String closeReason) throws CharacterCodingException { super(OpCode.Close, true, generatePayload(code, closeReason)); } private static byte[] generatePayload(CloseCode code, String closeReason) throws CharacterCodingException { if (code != null) { byte[] reasonBytes = text2Binary(closeReason); byte[] payload = new byte[reasonBytes.length + 2]; payload[0] = (byte) ((code.getValue() >> 8) & 0xFF); payload[1] = (byte) ((code.getValue()) & 0xFF); System.arraycopy(reasonBytes, 0, payload, 2, reasonBytes.length); return payload; } else { return new byte[0]; } } protected String payloadToString() { return (_closeCode != null ? _closeCode : "UnknownCloseCode[" + _closeCode + "]") + (_closeReason != null && !_closeReason.isEmpty() ? ": " + _closeReason : ""); } public CloseCode getCloseCode() { return _closeCode; } public String getCloseReason() { return _closeReason; } } } sclass WebSocket implements AutoCloseable { protected final InputStream in; protected /*final*/ OutputStream out; protected WebSocketFrame.OpCode continuousOpCode = null; protected List continuousFrames = new LinkedList(); protected State state = State.UNCONNECTED; public static enum State { UNCONNECTED, CONNECTING, OPEN, CLOSING, CLOSED } protected final NanoHTTPD.IHTTPSession handshakeRequest; //protected final NanoHTTPD.Response handshakeResponse = new NanoHTTPD.Response(NanoHTTPD.Status.SWITCH_PROTOCOL, null, (InputStream) null) { protected final NanoHTTPD.Response handshakeResponse = new NanoHTTPD.Response(NanoHTTPD.Status.SWITCH_PROTOCOL, null, (InputStream) null, -1) { @Override protected void send(OutputStream out) { WebSocket.this.out = out; state = State.CONNECTING; super.send(out); state = State.OPEN; onOpen(); readWebsocket(); } }; public WebSocket(NanoHTTPD.IHTTPSession handshakeRequest) { this.handshakeRequest = handshakeRequest; this.in = handshakeRequest.getInputStream(); handshakeResponse.addHeader(NanoWebSocketServer.HEADER_UPGRADE, NanoWebSocketServer.HEADER_UPGRADE_VALUE); handshakeResponse.addHeader(NanoWebSocketServer.HEADER_CONNECTION, NanoWebSocketServer.HEADER_CONNECTION_VALUE); } // --------------------------------IO-------------------------------------- protected void readWebsocket() { try { while (state == State.OPEN) { handleWebsocketFrame(WebSocketFrame.read(in)); } } catch (CharacterCodingException e) { onException(e); doClose(WebSocketFrame.CloseCode.InvalidFramePayloadData, e.toString(), false); } catch (IOException e) { onException(e); if (e instanceof WebSocketException) { doClose(((WebSocketException) e).getCode(), ((WebSocketException) e).getReason(), false); } } finally { doClose(WebSocketFrame.CloseCode.InternalServerError, "Handler terminated without closing the connection.", false); } } protected void handleWebsocketFrame(WebSocketFrame frame) throws IOException { if (frame.getOpCode() == WebSocketFrame.OpCode.Close) { handleCloseFrame(frame); } else if (frame.getOpCode() == WebSocketFrame.OpCode.Ping) { sendFrame(new WebSocketFrame(WebSocketFrame.OpCode.Pong, true, frame.getBinaryPayload())); } else if (frame.getOpCode() == WebSocketFrame.OpCode.Pong) { onPong(frame); } else if (!frame.isFin() || frame.getOpCode() == WebSocketFrame.OpCode.Continuation) { handleFrameFragment(frame); } else if (continuousOpCode != null) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Continuous frame sequence not completed."); } else if (frame.getOpCode() == WebSocketFrame.OpCode.Text || frame.getOpCode() == WebSocketFrame.OpCode.Binary) { onMessage(frame); } else { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Non control or continuous frame expected."); } } protected void handleCloseFrame(WebSocketFrame frame) throws IOException { WebSocketFrame.CloseCode code = WebSocketFrame.CloseCode.NormalClosure; String reason = ""; if (frame instanceof WebSocketFrame.CloseFrame) { code = ((WebSocketFrame.CloseFrame) frame).getCloseCode(); reason = ((WebSocketFrame.CloseFrame) frame).getCloseReason(); } if (state == State.CLOSING) { //Answer for my requested close doClose(code, reason, false); } else { //Answer close request from other endpoint and close self State oldState = state; state = State.CLOSING; if (oldState == State.OPEN) try { sendFrame(new WebSocketFrame.CloseFrame(code, reason)); } catch {} doClose(code, reason, true); } } protected void handleFrameFragment(WebSocketFrame frame) throws IOException { if (frame.getOpCode() != WebSocketFrame.OpCode.Continuation) { //First if (continuousOpCode != null) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Previous continuous frame sequence not completed."); } continuousOpCode = frame.getOpCode(); continuousFrames.clear(); continuousFrames.add(frame); } else if (frame.isFin()) { //Last if (continuousOpCode == null) { throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Continuous frame sequence was not started."); } onMessage(new WebSocketFrame(continuousOpCode, continuousFrames)); continuousOpCode = null; continuousFrames.clear(); } else if (continuousOpCode == null) { //Unexpected throw new WebSocketException(WebSocketFrame.CloseCode.ProtocolError, "Continuous frame sequence was not started."); } else { //Intermediate continuousFrames.add(frame); } } public synchronized void sendFrame(WebSocketFrame frame) throws IOException { frame.write(out); } // --------------------------------Close----------------------------------- // deliberate closing by server, e.g. because of module reload public void close() { doClose(WebSocketFrame.CloseCode.GoingAway, "Internal closing", false); } protected void doClose(WebSocketFrame.CloseCode code, String reason, boolean initiatedByRemote) { if (state == State.CLOSED) { return; } if (in != null) { try { in.close(); } catch (IOException e) { e.printStackTrace(); } } if (out != null) { try { out.close(); } catch (IOException e) { e.printStackTrace(); } } state = State.CLOSED; onClose(code, reason, initiatedByRemote); } // --------------------------------Listener-------------------------------- protected void onPong(WebSocketFrame pongFrame) { print("WebSocket pong"); } swappable void onMessage(WebSocketFrame messageFrame) { print("WebSocket msg: " + messageFrame.getTextPayload()); } swappable void onOpen() {} protected void onClose(WebSocketFrame.CloseCode code, String reason, boolean initiatedByRemote) { print("WebSocket close"); onClose(); } swappable void onClose() {} protected void onException(IOException e) { printStackTrace(e); } // --------------------------------Public Facade--------------------------- public void ping(byte[] payload) throws IOException { sendFrame(new WebSocketFrame(WebSocketFrame.OpCode.Ping, true, payload)); } public void send(byte[] payload) throws IOException { sendFrame(new WebSocketFrame(WebSocketFrame.OpCode.Binary, true, payload)); } public void send(String payload) throws IOException { sendFrame(new WebSocketFrame(WebSocketFrame.OpCode.Text, true, payload)); } public void close(WebSocketFrame.CloseCode code, String reason) throws IOException { State oldState = state; state = State.CLOSING; if (oldState == State.OPEN) { sendFrame(new WebSocketFrame.CloseFrame(code, reason)); } else { doClose(code, reason, false); } } // --------------------------------Getters--------------------------------- public NanoHTTPD.IHTTPSession getHandshakeRequest() { return handshakeRequest; } public NanoHTTPD.Response getHandshakeResponse() { return handshakeResponse; } // convenience methods S getUri() { ret getHandshakeRequest().getUri(); } SS getParms() { ret getHandshakeRequest().getParms(); } } sinterface WebSocketFactory { WebSocket openWebSocket(NanoHTTPD.IHTTPSession handshake); } sclass NanoWebSocketServer extends NanoHTTPD implements WebSocketFactory { public static final String HEADER_UPGRADE = "upgrade"; public static final String HEADER_UPGRADE_VALUE = "websocket"; public static final String HEADER_CONNECTION = "connection"; public static final String HEADER_CONNECTION_VALUE = "Upgrade"; public static final String HEADER_WEBSOCKET_VERSION = "sec-websocket-version"; public static final String HEADER_WEBSOCKET_VERSION_VALUE = "13"; public static final String HEADER_WEBSOCKET_KEY = "sec-websocket-key"; public static final String HEADER_WEBSOCKET_ACCEPT = "sec-websocket-accept"; public static final String HEADER_WEBSOCKET_PROTOCOL = "sec-websocket-protocol"; public final static String WEBSOCKET_KEY_MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; WebSocketFactory webSocketFactory; public NanoWebSocketServer(int port) { super(port); webSocketFactory = null; } public NanoWebSocketServer(String hostname, int port) { super(hostname, port); webSocketFactory = null; } public NanoWebSocketServer(int port, WebSocketFactory webSocketFactory) { super(port); this.webSocketFactory = webSocketFactory; } public NanoWebSocketServer(String hostname, int port,WebSocketFactory webSocketFactory) { super(hostname, port); this.webSocketFactory = webSocketFactory; } @Override public Response serve(final IHTTPSession session) { Map headers = session.getHeaders(); if (isWebsocketRequested(session)) { if (!HEADER_UPGRADE_VALUE.equalsIgnoreCase(headers.get(HEADER_UPGRADE)) || !isWebSocketConnectionHeader(session.getHeaders())) { return newFixedLengthResponse(Status.BAD_REQUEST, NanoHTTPD.MIME_PLAINTEXT, "Invalid Websocket handshake"); } if (!HEADER_WEBSOCKET_VERSION_VALUE.equalsIgnoreCase(headers.get(HEADER_WEBSOCKET_VERSION))) { return newFixedLengthResponse(Status.BAD_REQUEST, NanoHTTPD.MIME_PLAINTEXT, "Invalid Websocket-Version " + headers.get(HEADER_WEBSOCKET_VERSION)); } if (!headers.containsKey(HEADER_WEBSOCKET_KEY)) { return newFixedLengthResponse(Status.BAD_REQUEST, NanoHTTPD.MIME_PLAINTEXT, "Missing Websocket-Key"); } WebSocket webSocket = openWebSocket(session); try { webSocket.getHandshakeResponse().addHeader(HEADER_WEBSOCKET_ACCEPT, makeAcceptKey(headers.get(HEADER_WEBSOCKET_KEY))); } catch (NoSuchAlgorithmException e) { return newFixedLengthResponse(Status.INTERNAL_ERROR, NanoHTTPD.MIME_PLAINTEXT, "The SHA-1 Algorithm required for websockets is not available on the server."); } if (headers.containsKey(HEADER_WEBSOCKET_PROTOCOL)) { webSocket.getHandshakeResponse().addHeader(HEADER_WEBSOCKET_PROTOCOL, headers.get(HEADER_WEBSOCKET_PROTOCOL).split(",")[0]); } return webSocket.getHandshakeResponse(); } else { return super.serve(session); } } public WebSocket openWebSocket(IHTTPSession handshake) { if (webSocketFactory == null) { throw new Error("You must either override this method or supply a WebSocketFactory in the constructor"); } return webSocketFactory.openWebSocket(handshake); } protected boolean isWebsocketRequested(IHTTPSession session) { Map headers = session.getHeaders(); String upgrade = headers.get(HEADER_UPGRADE); boolean isCorrectConnection = isWebSocketConnectionHeader(headers); boolean isUpgrade = HEADER_UPGRADE_VALUE.equalsIgnoreCase(upgrade); return (isUpgrade && isCorrectConnection); } private boolean isWebSocketConnectionHeader(Map headers) { String connection = headers.get(HEADER_CONNECTION); return (connection != null && connection.toLowerCase().contains(HEADER_CONNECTION_VALUE.toLowerCase())); } public static String makeAcceptKey(String key) throws NoSuchAlgorithmException { MessageDigest md = MessageDigest.getInstance("SHA-1"); String text = key + WEBSOCKET_KEY_MAGIC; md.update(text.getBytes(), 0, text.length()); byte[] sha1hash = md.digest(); return base64encode(sha1hash); } }