/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.tls;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;
import org.bouncycastle.tls.ClientHello;
import org.bouncycastle.tls.DTLSHandshakeRetransmit;
import org.bouncycastle.tls.DTLSReassembler;
import org.bouncycastle.tls.DTLSRecordLayer;
import org.bouncycastle.tls.DTLSRequest;
import org.bouncycastle.tls.DatagramSender;
import org.bouncycastle.tls.DeferredHash;
import org.bouncycastle.tls.ProtocolVersion;
import org.bouncycastle.tls.Timeout;
import org.bouncycastle.tls.TlsContext;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.TlsHandshakeHash;
import org.bouncycastle.tls.TlsTimeoutException;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.util.Integers;

class DTLSReliableHandshake {
    private static final int MAX_RECEIVE_AHEAD = 16;
    private static final int MESSAGE_HEADER_LENGTH = 12;
    static final int INITIAL_RESEND_MILLIS = 1000;
    private static final int MAX_RESEND_MILLIS = 60000;
    private DTLSRecordLayer recordLayer;
    private Timeout handshakeTimeout;
    private TlsHandshakeHash handshakeHash;
    private Hashtable currentInboundFlight = new Hashtable();
    private Hashtable previousInboundFlight = null;
    private Vector outboundFlight = new Vector();
    private int resendMillis = -1;
    private Timeout resendTimeout = null;
    private int next_send_seq = 0;
    private int next_receive_seq = 0;

    static DTLSRequest readClientRequest(byte[] data, int dataOff, int dataLen, OutputStream dtlsOutput) throws IOException {
        byte[] message = DTLSRecordLayer.receiveClientHelloRecord(data, dataOff, dataLen);
        if (null == message || message.length < 12) {
            return null;
        }
        long recordSeq = TlsUtils.readUint48(data, dataOff + 5);
        short msgType = TlsUtils.readUint8(message, 0);
        if (1 != msgType) {
            return null;
        }
        int length = TlsUtils.readUint24(message, 1);
        if (message.length != 12 + length) {
            return null;
        }
        int fragmentOffset = TlsUtils.readUint24(message, 6);
        if (0 != fragmentOffset) {
            return null;
        }
        int fragmentLength = TlsUtils.readUint24(message, 9);
        if (length != fragmentLength) {
            return null;
        }
        ClientHello clientHello = ClientHello.parse(new ByteArrayInputStream(message, 12, length), dtlsOutput);
        return new DTLSRequest(recordSeq, message, clientHello);
    }

    static void sendHelloVerifyRequest(DatagramSender sender, long recordSeq, byte[] cookie) throws IOException {
        TlsUtils.checkUint8(cookie.length);
        int length = 3 + cookie.length;
        byte[] message = new byte[12 + length];
        TlsUtils.writeUint8((short)3, message, 0);
        TlsUtils.writeUint24(length, message, 1);
        TlsUtils.writeUint24(length, message, 9);
        TlsUtils.writeVersion(ProtocolVersion.DTLSv10, message, 12);
        TlsUtils.writeOpaque8(cookie, message, 14);
        DTLSRecordLayer.sendHelloVerifyRequestRecord(sender, recordSeq, message);
    }

    DTLSReliableHandshake(TlsContext context, DTLSRecordLayer transport, int timeoutMillis, DTLSRequest request) {
        this.recordLayer = transport;
        this.handshakeHash = new DeferredHash(context);
        this.handshakeTimeout = Timeout.forWaitMillis(timeoutMillis);
        if (null != request) {
            this.resendMillis = 1000;
            this.resendTimeout = new Timeout(this.resendMillis);
            long recordSeq = request.getRecordSeq();
            int messageSeq = request.getMessageSeq();
            byte[] message = request.getMessage();
            this.recordLayer.resetAfterHelloVerifyRequestServer(recordSeq);
            DTLSReassembler reassembler = new DTLSReassembler(1, message.length - 12);
            this.currentInboundFlight.put(Integers.valueOf((int)messageSeq), reassembler);
            this.next_send_seq = 1;
            this.next_receive_seq = messageSeq + 1;
            this.handshakeHash.update(message, 0, message.length);
        }
    }

    void resetAfterHelloVerifyRequestClient() {
        this.currentInboundFlight = new Hashtable();
        this.previousInboundFlight = null;
        this.outboundFlight = new Vector();
        this.resendMillis = -1;
        this.resendTimeout = null;
        this.next_receive_seq = 1;
        this.handshakeHash.reset();
    }

    TlsHandshakeHash getHandshakeHash() {
        return this.handshakeHash;
    }

    void prepareToFinish() {
        this.handshakeHash.stopTracking();
    }

    void sendMessage(short msg_type, byte[] body) throws IOException {
        TlsUtils.checkUint24(body.length);
        if (null != this.resendTimeout) {
            this.checkInboundFlight();
            this.resendMillis = -1;
            this.resendTimeout = null;
            this.outboundFlight.removeAllElements();
        }
        Message message = new Message(this.next_send_seq++, msg_type, body);
        this.outboundFlight.addElement(message);
        this.writeMessage(message);
        this.updateHandshakeMessagesDigest(message);
    }

    Message receiveMessage() throws IOException {
        Message message = this.implReceiveMessage();
        this.updateHandshakeMessagesDigest(message);
        return message;
    }

    byte[] receiveMessageBody(short msg_type) throws IOException {
        Message message = this.implReceiveMessage();
        if (message.getType() != msg_type) {
            throw new TlsFatalAlert(10);
        }
        this.updateHandshakeMessagesDigest(message);
        return message.getBody();
    }

    Message receiveMessageDelayedDigest(short msg_type) throws IOException {
        Message message = this.implReceiveMessage();
        if (message.getType() != msg_type) {
            throw new TlsFatalAlert(10);
        }
        return message;
    }

    void updateHandshakeMessagesDigest(Message message) throws IOException {
        short msg_type = message.getType();
        switch (msg_type) {
            case 0: 
            case 3: 
            case 24: {
                break;
            }
            default: {
                byte[] body = message.getBody();
                byte[] buf = new byte[12];
                TlsUtils.writeUint8(msg_type, buf, 0);
                TlsUtils.writeUint24(body.length, buf, 1);
                TlsUtils.writeUint16(message.getSeq(), buf, 4);
                TlsUtils.writeUint24(0, buf, 6);
                TlsUtils.writeUint24(body.length, buf, 9);
                this.handshakeHash.update(buf, 0, buf.length);
                this.handshakeHash.update(body, 0, body.length);
            }
        }
    }

    void finish() {
        DTLSHandshakeRetransmit retransmit = null;
        if (null != this.resendTimeout) {
            this.checkInboundFlight();
        } else {
            this.prepareInboundFlight(null);
            if (this.previousInboundFlight != null) {
                retransmit = new DTLSHandshakeRetransmit(){

                    public void receivedHandshakeRecord(int epoch, byte[] buf, int off, int len) throws IOException {
                        DTLSReliableHandshake.this.processRecord(0, epoch, buf, off, len);
                    }
                };
            }
        }
        this.recordLayer.handshakeSuccessful(retransmit);
    }

    static int backOff(int timeoutMillis) {
        return Math.min(timeoutMillis * 2, 60000);
    }

    private void checkInboundFlight() {
        Enumeration e = this.currentInboundFlight.keys();
        while (e.hasMoreElements()) {
            Integer key = (Integer)e.nextElement();
            if (key < this.next_receive_seq) continue;
        }
    }

    private Message getPendingMessage() throws IOException {
        byte[] body;
        DTLSReassembler next = (DTLSReassembler)this.currentInboundFlight.get(Integers.valueOf((int)this.next_receive_seq));
        if (next != null && (body = next.getBodyIfComplete()) != null) {
            this.previousInboundFlight = null;
            return new Message(this.next_receive_seq++, next.getMsgType(), body);
        }
        return null;
    }

    private Message implReceiveMessage() throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        if (null == this.resendTimeout) {
            this.resendMillis = 1000;
            this.resendTimeout = new Timeout(this.resendMillis, currentTimeMillis);
            this.prepareInboundFlight(new Hashtable());
        }
        byte[] buf = null;
        while (true) {
            int received;
            if (this.recordLayer.isClosed()) {
                throw new TlsFatalAlert(90);
            }
            Message pending = this.getPendingMessage();
            if (pending != null) {
                return pending;
            }
            if (Timeout.hasExpired(this.handshakeTimeout, currentTimeMillis)) {
                throw new TlsTimeoutException("Handshake timed out");
            }
            int waitMillis = Timeout.getWaitMillis(this.handshakeTimeout, currentTimeMillis);
            if ((waitMillis = Timeout.constrainWaitMillis(waitMillis, this.resendTimeout, currentTimeMillis)) < 1) {
                waitMillis = 1;
            }
            int receiveLimit = this.recordLayer.getReceiveLimit();
            if (buf == null || buf.length < receiveLimit) {
                buf = new byte[receiveLimit];
            }
            if ((received = this.recordLayer.receive(buf, 0, receiveLimit, waitMillis)) < 0) {
                this.resendOutboundFlight();
            } else {
                this.processRecord(16, this.recordLayer.getReadEpoch(), buf, 0, received);
            }
            currentTimeMillis = System.currentTimeMillis();
        }
    }

    private void prepareInboundFlight(Hashtable nextFlight) {
        DTLSReliableHandshake.resetAll(this.currentInboundFlight);
        this.previousInboundFlight = this.currentInboundFlight;
        this.currentInboundFlight = nextFlight;
    }

    private void processRecord(int windowSize, int epoch, byte[] buf, int off, int len) throws IOException {
        int fragment_length;
        int message_length;
        boolean checkPreviousFlight = false;
        while (len >= 12 && len >= (message_length = (fragment_length = TlsUtils.readUint24(buf, off + 9)) + 12)) {
            int expectedEpoch;
            int length = TlsUtils.readUint24(buf, off + 1);
            int fragment_offset = TlsUtils.readUint24(buf, off + 6);
            if (fragment_offset + fragment_length > length) break;
            short msg_type = TlsUtils.readUint8(buf, off + 0);
            int n = expectedEpoch = msg_type == 20 ? 1 : 0;
            if (epoch != expectedEpoch) break;
            int message_seq = TlsUtils.readUint16(buf, off + 4);
            if (message_seq < this.next_receive_seq + windowSize) {
                DTLSReassembler reassembler;
                if (message_seq >= this.next_receive_seq) {
                    reassembler = (DTLSReassembler)this.currentInboundFlight.get(Integers.valueOf((int)message_seq));
                    if (reassembler == null) {
                        reassembler = new DTLSReassembler(msg_type, length);
                        this.currentInboundFlight.put(Integers.valueOf((int)message_seq), reassembler);
                    }
                    reassembler.contributeFragment(msg_type, length, buf, off + 12, fragment_offset, fragment_length);
                } else if (this.previousInboundFlight != null && (reassembler = (DTLSReassembler)this.previousInboundFlight.get(Integers.valueOf((int)message_seq))) != null) {
                    reassembler.contributeFragment(msg_type, length, buf, off + 12, fragment_offset, fragment_length);
                    checkPreviousFlight = true;
                }
            }
            off += message_length;
            len -= message_length;
        }
        if (checkPreviousFlight && DTLSReliableHandshake.checkAll(this.previousInboundFlight)) {
            this.resendOutboundFlight();
            DTLSReliableHandshake.resetAll(this.previousInboundFlight);
        }
    }

    private void resendOutboundFlight() throws IOException {
        this.recordLayer.resetWriteEpoch();
        for (int i = 0; i < this.outboundFlight.size(); ++i) {
            this.writeMessage((Message)this.outboundFlight.elementAt(i));
        }
        this.resendMillis = DTLSReliableHandshake.backOff(this.resendMillis);
        this.resendTimeout = new Timeout(this.resendMillis);
    }

    private void writeMessage(Message message) throws IOException {
        int fragment_length;
        int sendLimit = this.recordLayer.getSendLimit();
        int fragmentLimit = sendLimit - 12;
        if (fragmentLimit < 1) {
            throw new TlsFatalAlert(80);
        }
        int length = message.getBody().length;
        int fragment_offset = 0;
        do {
            fragment_length = Math.min(length - fragment_offset, fragmentLimit);
            this.writeHandshakeFragment(message, fragment_offset, fragment_length);
        } while ((fragment_offset += fragment_length) < length);
    }

    private void writeHandshakeFragment(Message message, int fragment_offset, int fragment_length) throws IOException {
        RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length);
        TlsUtils.writeUint8(message.getType(), (OutputStream)fragment);
        TlsUtils.writeUint24(message.getBody().length, fragment);
        TlsUtils.writeUint16(message.getSeq(), fragment);
        TlsUtils.writeUint24(fragment_offset, fragment);
        TlsUtils.writeUint24(fragment_length, fragment);
        fragment.write(message.getBody(), fragment_offset, fragment_length);
        fragment.sendToRecordLayer(this.recordLayer);
    }

    private static boolean checkAll(Hashtable inboundFlight) {
        Enumeration e = inboundFlight.elements();
        while (e.hasMoreElements()) {
            if (((DTLSReassembler)e.nextElement()).getBodyIfComplete() != null) continue;
            return false;
        }
        return true;
    }

    private static void resetAll(Hashtable inboundFlight) {
        Enumeration e = inboundFlight.elements();
        while (e.hasMoreElements()) {
            ((DTLSReassembler)e.nextElement()).reset();
        }
    }

    static class Message {
        private final int message_seq;
        private final short msg_type;
        private final byte[] body;

        private Message(int message_seq, short msg_type, byte[] body) {
            this.message_seq = message_seq;
            this.msg_type = msg_type;
            this.body = body;
        }

        public int getSeq() {
            return this.message_seq;
        }

        public short getType() {
            return this.msg_type;
        }

        public byte[] getBody() {
            return this.body;
        }
    }

    static class RecordLayerBuffer
    extends ByteArrayOutputStream {
        RecordLayerBuffer(int size) {
            super(size);
        }

        void sendToRecordLayer(DTLSRecordLayer recordLayer) throws IOException {
            recordLayer.send(this.buf, 0, this.count);
            this.buf = null;
        }
    }
}

