/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.milo.opcua.stack.transport.server.uasc;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.Timeout;
import java.io.IOException;
import java.security.KeyPair;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.milo.opcua.stack.core.Stack;
import org.eclipse.milo.opcua.stack.core.UaException;
import org.eclipse.milo.opcua.stack.core.UaSerializationException;
import org.eclipse.milo.opcua.stack.core.channel.ChannelParameters;
import org.eclipse.milo.opcua.stack.core.channel.ChannelSecurity;
import org.eclipse.milo.opcua.stack.core.channel.ChunkDecoder;
import org.eclipse.milo.opcua.stack.core.channel.ChunkEncoder;
import org.eclipse.milo.opcua.stack.core.channel.EncodingLimits;
import org.eclipse.milo.opcua.stack.core.channel.ExceptionHandler;
import org.eclipse.milo.opcua.stack.core.channel.MessageAbortException;
import org.eclipse.milo.opcua.stack.core.channel.MessageDecodeException;
import org.eclipse.milo.opcua.stack.core.channel.MessageEncodeException;
import org.eclipse.milo.opcua.stack.core.channel.SecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.ServerSecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import org.eclipse.milo.opcua.stack.core.channel.headers.HeaderDecoder;
import org.eclipse.milo.opcua.stack.core.channel.messages.ErrorMessage;
import org.eclipse.milo.opcua.stack.core.channel.messages.MessageType;
import org.eclipse.milo.opcua.stack.core.encoding.binary.OpcUaBinaryDecoder;
import org.eclipse.milo.opcua.stack.core.encoding.binary.OpcUaBinaryEncoder;
import org.eclipse.milo.opcua.stack.core.security.CertificateGroup;
import org.eclipse.milo.opcua.stack.core.security.CertificateManager;
import org.eclipse.milo.opcua.stack.core.security.CertificateValidator;
import org.eclipse.milo.opcua.stack.core.security.SecurityPolicy;
import org.eclipse.milo.opcua.stack.core.transport.TransportProfile;
import org.eclipse.milo.opcua.stack.core.types.UaMessageType;
import org.eclipse.milo.opcua.stack.core.types.builtin.ByteString;
import org.eclipse.milo.opcua.stack.core.types.builtin.DateTime;
import org.eclipse.milo.opcua.stack.core.types.builtin.StatusCode;
import org.eclipse.milo.opcua.stack.core.types.builtin.unsigned.Unsigned;
import org.eclipse.milo.opcua.stack.core.types.enumerated.SecurityTokenRequestType;
import org.eclipse.milo.opcua.stack.core.types.structured.ChannelSecurityToken;
import org.eclipse.milo.opcua.stack.core.types.structured.EndpointDescription;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelRequest;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelResponse;
import org.eclipse.milo.opcua.stack.core.types.structured.ResponseHeader;
import org.eclipse.milo.opcua.stack.core.util.BufferUtil;
import org.eclipse.milo.opcua.stack.core.util.DigestUtil;
import org.eclipse.milo.opcua.stack.core.util.EndpointUtil;
import org.eclipse.milo.opcua.stack.core.util.NonceUtil;
import org.eclipse.milo.opcua.stack.transport.server.ServerApplicationContext;
import org.eclipse.milo.opcua.stack.transport.server.uasc.UascServerConfig;
import org.eclipse.milo.opcua.stack.transport.server.uasc.UascServerHelloHandler;
import org.eclipse.milo.opcua.stack.transport.server.uasc.UascServerSymmetricHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UascServerAsymmetricHandler
extends ByteToMessageDecoder
implements HeaderDecoder {
    static final AttributeKey<EndpointDescription> ENDPOINT_KEY = AttributeKey.valueOf((String)"endpoint");
    private final Logger logger = LoggerFactory.getLogger(((Object)((Object)this)).getClass());
    private ServerSecureChannel secureChannel;
    private Timeout secureChannelTimeout;
    private boolean symmetricHandlerAdded = false;
    private List<ByteBuf> chunkBuffers = new ArrayList<ByteBuf>();
    private final AtomicReference<AsymmetricSecurityHeader> headerRef = new AtomicReference();
    private final OpcUaBinaryEncoder binaryEncoder;
    private final OpcUaBinaryDecoder binaryDecoder;
    private final ChunkEncoder chunkEncoder;
    private final ChunkDecoder chunkDecoder;
    private final int maxChunkCount;
    private final int maxChunkSize;
    private final UascServerConfig config;
    private final ServerApplicationContext application;
    private final TransportProfile transportProfile;
    private final ChannelParameters channelParameters;

    UascServerAsymmetricHandler(UascServerConfig config, ServerApplicationContext application, TransportProfile transportProfile, ChannelParameters channelParameters) {
        this.config = config;
        this.application = application;
        this.transportProfile = transportProfile;
        this.channelParameters = channelParameters;
        this.binaryEncoder = new OpcUaBinaryEncoder(application.getEncodingContext());
        this.binaryDecoder = new OpcUaBinaryDecoder(application.getEncodingContext());
        this.chunkEncoder = new ChunkEncoder(channelParameters);
        this.chunkDecoder = new ChunkDecoder(channelParameters, application.getEncodingContext().getEncodingLimits());
        this.maxChunkCount = channelParameters.getLocalMaxChunkCount();
        this.maxChunkSize = channelParameters.getLocalReceiveBufferSize();
    }

    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        if (this.secureChannelTimeout != null) {
            this.secureChannelTimeout.cancel();
            this.secureChannelTimeout = null;
        }
        super.channelInactive(ctx);
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        this.chunkBuffers.forEach(ReferenceCountUtil::safeRelease);
        this.chunkBuffers.clear();
        if (cause instanceof IOException) {
            ctx.close();
            this.logger.debug("[remote={}] IOException caught; channel closed", (Object)ctx.channel().remoteAddress(), (Object)cause);
        } else {
            ErrorMessage errorMessage = ExceptionHandler.sendErrorMessage((ChannelHandlerContext)ctx, (Throwable)cause);
            if (cause instanceof UaException) {
                this.logger.debug("[remote={}] UaException caught; sent {}", new Object[]{ctx.channel().remoteAddress(), errorMessage, cause});
            } else {
                this.logger.error("[remote={}] Exception caught; sent {}", new Object[]{ctx.channel().remoteAddress(), errorMessage, cause});
            }
        }
    }

    protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
        if (buffer.readableBytes() >= 8) {
            int messageLength = this.getMessageLength(buffer, this.maxChunkSize);
            if (buffer.readableBytes() >= messageLength) {
                MessageType messageType = MessageType.fromMediumInt((int)buffer.getMediumLE(buffer.readerIndex()));
                switch (messageType) {
                    case OpenSecureChannel: {
                        this.onOpenSecureChannel(ctx, buffer.readSlice(messageLength));
                        break;
                    }
                    case CloseSecureChannel: {
                        this.logger.debug("Received CloseSecureChannelRequest");
                        buffer.skipBytes(messageLength);
                        if (this.secureChannelTimeout != null) {
                            this.secureChannelTimeout.cancel();
                            this.secureChannelTimeout = null;
                        }
                        ctx.close();
                        break;
                    }
                    default: {
                        throw new UaException(2155741184L, "unexpected MessageType: " + String.valueOf(messageType));
                    }
                }
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void onOpenSecureChannel(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
        buffer.skipBytes(3);
        char chunkType = (char)buffer.readByte();
        if (chunkType == 'A') {
            this.chunkBuffers.forEach(ReferenceCounted::release);
            this.chunkBuffers.clear();
            this.headerRef.set(null);
        } else {
            buffer.skipBytes(4);
            long secureChannelId = buffer.readUnsignedIntLE();
            AsymmetricSecurityHeader header = AsymmetricSecurityHeader.decode((ByteBuf)buffer, (EncodingLimits)this.application.getEncodingContext().getEncodingLimits());
            if (!this.headerRef.compareAndSet(null, header) && !header.equals((Object)this.headerRef.get())) {
                throw new UaException(2148728832L, "subsequent AsymmetricSecurityHeader did not match");
            }
            if (secureChannelId != 0L) {
                if (this.secureChannel == null) {
                    throw new UaException(2155806720L, "unknown secure channel id: " + secureChannelId);
                }
                if (secureChannelId != this.secureChannel.getChannelId()) {
                    throw new UaException(2155806720L, "unknown secure channel id: " + secureChannelId);
                }
            }
            if (this.secureChannel == null) {
                this.secureChannel = new ServerSecureChannel();
                this.secureChannel.setChannelId(this.application.getNextSecureChannelId().longValue());
                String securityPolicyUri = header.getSecurityPolicyUri();
                SecurityPolicy securityPolicy = SecurityPolicy.fromUri((String)securityPolicyUri);
                this.secureChannel.setSecurityPolicy(securityPolicy);
                if (securityPolicy != SecurityPolicy.None) {
                    CertificateManager certificateManager = this.application.getCertificateManager();
                    Optional localCertificateChain = certificateManager.getCertificateChain(header.getReceiverThumbprint());
                    Optional keyPair = certificateManager.getKeyPair(header.getReceiverThumbprint());
                    if (localCertificateChain.isPresent() && keyPair.isPresent()) {
                        this.secureChannel.setRemoteCertificate(header.getSenderCertificate().bytesOrEmpty());
                        CertificateGroup certificateGroup = (CertificateGroup)this.application.getCertificateManager().getCertificateGroup(header.getReceiverThumbprint()).orElseThrow(() -> new UaException(2148728832L, "no certificate group for provided thumbprint"));
                        CertificateValidator certificateValidator = certificateGroup.getCertificateValidator();
                        certificateValidator.validateCertificateChain(this.secureChannel.getRemoteCertificateChain(), null, null);
                        X509Certificate[] chain = (X509Certificate[])localCertificateChain.get();
                        this.secureChannel.setLocalCertificate(chain[0]);
                        this.secureChannel.setLocalCertificateChain(chain);
                        this.secureChannel.setKeyPair((KeyPair)keyPair.get());
                    } else {
                        throw new UaException(2148728832L, "no certificate for provided thumbprint");
                    }
                }
            }
            String endpointUrl = (String)ctx.channel().attr(UascServerHelloHandler.ENDPOINT_URL_KEY).get();
            if (this.application.getEndpointDescriptions().stream().noneMatch(e -> {
                boolean transportMatch = Objects.equals(e.getTransportProfileUri(), this.transportProfile.getUri());
                boolean pathMatch = Objects.equals(EndpointUtil.getPath((String)e.getEndpointUrl()), EndpointUtil.getPath((String)endpointUrl));
                boolean securityPolicyMatch = Objects.equals(e.getSecurityPolicyUri(), this.secureChannel.getSecurityPolicy().getUri());
                boolean thumbprintMatch = true;
                if (!header.getReceiverThumbprint().isNullOrEmpty()) {
                    thumbprintMatch = Arrays.equals(DigestUtil.sha1((byte[])e.getServerCertificate().bytesOrEmpty()), header.getReceiverThumbprint().bytesOrEmpty());
                }
                return transportMatch && pathMatch && thumbprintMatch && (securityPolicyMatch || this.secureChannel.getSecurityPolicy() == SecurityPolicy.None);
            })) {
                String message = String.format("no matching endpoint found: transportProfile=%s, endpointUrl=%s, securityPolicy=%s", this.transportProfile, endpointUrl, this.secureChannel.getSecurityPolicy());
                throw new UaException(2148728832L, message);
            }
            int chunkSize = buffer.readerIndex(0).readableBytes();
            if (chunkSize > this.maxChunkSize) {
                throw new UaException(0x80800000L, String.format("max chunk size exceeded (%s)", this.maxChunkSize));
            }
            this.chunkBuffers.add(buffer.retain());
            if (this.maxChunkCount > 0 && this.chunkBuffers.size() > this.maxChunkCount) {
                throw new UaException(0x80800000L, String.format("max chunk count exceeded (%s)", this.maxChunkCount));
            }
            if (chunkType == 'F') {
                long requestId;
                ByteBuf message;
                List<ByteBuf> buffersToDecode = this.chunkBuffers;
                this.chunkBuffers = new ArrayList<ByteBuf>();
                this.headerRef.set(null);
                try {
                    ChunkDecoder.DecodedMessage decodedMessage = this.chunkDecoder.decodeAsymmetric((SecureChannel)this.secureChannel, buffersToDecode);
                    message = decodedMessage.getMessage();
                    requestId = decodedMessage.getRequestId();
                }
                catch (MessageAbortException e2) {
                    this.logger.warn("Received message abort chunk; error={}, reason={}", (Object)e2.getStatusCode(), (Object)e2.getMessage());
                    return;
                }
                catch (MessageDecodeException e3) {
                    this.logger.error("Error decoding asymmetric message", (Throwable)e3);
                    ctx.executor().schedule(() -> ctx.close(), (long)new Random().nextInt(1000), TimeUnit.MILLISECONDS);
                    return;
                }
                try {
                    OpenSecureChannelRequest request = (OpenSecureChannelRequest)this.binaryDecoder.setBuffer(message).decodeMessage(null);
                    this.logger.debug("Received OpenSecureChannelRequest ({}, id={}).", (Object)request.getRequestType(), (Object)secureChannelId);
                    if (request.getRequestType() == SecurityTokenRequestType.Renew && secureChannelId == 0L) {
                        throw new UaException(2148728832L, "secure channel renewal for secureChannelId=0");
                    }
                    this.sendOpenSecureChannelResponse(ctx, requestId, header, request);
                }
                catch (Throwable t) {
                    this.logger.error("Error decoding OpenSecureChannelRequest", t);
                    ctx.close();
                }
                finally {
                    message.release();
                    buffersToDecode.clear();
                }
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void sendOpenSecureChannelResponse(ChannelHandlerContext ctx, long requestId, AsymmetricSecurityHeader header, OpenSecureChannelRequest request) {
        ByteBuf messageBuffer = BufferUtil.pooledBuffer();
        try {
            OpenSecureChannelResponse response = this.openSecureChannel(ctx, header, request);
            this.binaryEncoder.setBuffer(messageBuffer);
            this.binaryEncoder.encodeMessage(null, (UaMessageType)response);
            this.checkMessageSize(messageBuffer);
            ChunkEncoder.EncodedMessage encodedMessage = this.chunkEncoder.encodeAsymmetric((SecureChannel)this.secureChannel, requestId, messageBuffer, MessageType.OpenSecureChannel);
            if (!this.symmetricHandlerAdded) {
                UascServerSymmetricHandler symmetricHandler = new UascServerSymmetricHandler(this.config, this.application, this.transportProfile, this.channelParameters, this.chunkEncoder, this.chunkDecoder, this.secureChannel);
                ctx.pipeline().addBefore(ctx.name(), null, (ChannelHandler)symmetricHandler);
                this.symmetricHandlerAdded = true;
            }
            CompositeByteBuf chunkComposite = BufferUtil.compositeBuffer();
            for (ByteBuf chunk : encodedMessage.getMessageChunks()) {
                chunkComposite.addComponent(chunk);
                chunkComposite.writerIndex(chunkComposite.writerIndex() + chunk.readableBytes());
            }
            ctx.writeAndFlush((Object)chunkComposite, ctx.voidPromise());
            this.logger.debug("Sent OpenSecureChannelResponse.");
        }
        catch (MessageEncodeException e) {
            this.logger.error("Error encoding OpenSecureChannelResponse: {}", (Object)e.getMessage(), (Object)e);
            ctx.fireExceptionCaught((Throwable)e);
        }
        catch (UaSerializationException e) {
            this.logger.error("Error serializing OpenSecureChannelResponse: {}", (Object)e.getMessage(), (Object)e);
            ctx.fireExceptionCaught((Throwable)e);
        }
        catch (UaException e) {
            this.logger.error("Error installing security token: {}", (Object)e.getStatusCode(), (Object)e);
            ctx.close();
        }
        finally {
            messageBuffer.release();
        }
    }

    private OpenSecureChannelResponse openSecureChannel(ChannelHandlerContext ctx, AsymmetricSecurityHeader header, OpenSecureChannelRequest request) throws UaException {
        ChannelSecurity oldSecrets;
        SecurityTokenRequestType requestType = request.getRequestType();
        if (requestType == SecurityTokenRequestType.Issue) {
            this.secureChannel.setMessageSecurityMode(request.getSecurityMode());
            String endpointUrl = (String)ctx.channel().attr(UascServerHelloHandler.ENDPOINT_URL_KEY).get();
            EndpointDescription endpoint = this.application.getEndpointDescriptions().stream().filter(e -> {
                boolean transportMatch = Objects.equals(e.getTransportProfileUri(), this.transportProfile.getUri());
                boolean pathMatch = Objects.equals(EndpointUtil.getPath((String)e.getEndpointUrl()), EndpointUtil.getPath((String)endpointUrl));
                boolean securityPolicyMatch = Objects.equals(e.getSecurityPolicyUri(), this.secureChannel.getSecurityPolicy().getUri());
                boolean securityModeMatch = Objects.equals(e.getSecurityMode(), request.getSecurityMode());
                boolean thumbprintMatch = true;
                if (!header.getReceiverThumbprint().isNullOrEmpty()) {
                    thumbprintMatch = Arrays.equals(DigestUtil.sha1((byte[])e.getServerCertificate().bytesOrEmpty()), header.getReceiverThumbprint().bytesOrEmpty());
                }
                return transportMatch && pathMatch && thumbprintMatch && (securityPolicyMatch && securityModeMatch || this.secureChannel.getSecurityPolicy() == SecurityPolicy.None);
            }).findFirst().orElseThrow(() -> {
                String message = String.format("no matching endpoint found: transportProfile=%s, endpointUrl=%s, thumbprint=%s, securityPolicy=%s, securityMode=%s", this.transportProfile, endpointUrl, header.getReceiverThumbprint(), this.secureChannel.getSecurityPolicy(), request.getSecurityMode());
                return new UaException(2148728832L, message);
            });
            ctx.channel().attr(ENDPOINT_KEY).set((Object)endpoint);
        } else if (requestType == SecurityTokenRequestType.Renew && this.secureChannel.getMessageSecurityMode() != request.getSecurityMode()) {
            throw new UaException(2148728832L, "secure channel renewal requested a different MessageSecurityMode.");
        }
        long channelLifetime = request.getRequestedLifetime().longValue();
        channelLifetime = Math.min(channelLifetime, this.config.getMaximumSecureChannelLifetime().longValue());
        channelLifetime = Math.max(channelLifetime, this.config.getMinimumSecureChannelLifetime().longValue());
        ChannelSecurityToken newToken = new ChannelSecurityToken(Unsigned.uint((long)this.secureChannel.getChannelId()), Unsigned.uint((long)this.application.getNextSecureChannelTokenId()), DateTime.now(), Unsigned.uint((long)channelLifetime));
        ChannelSecurity.SecurityKeys newKeys = null;
        if (this.secureChannel.isSymmetricSigningEnabled()) {
            ByteString remoteNonce = request.getClientNonce();
            NonceUtil.validateNonce((ByteString)remoteNonce, (SecurityPolicy)this.secureChannel.getSecurityPolicy());
            ByteString localNonce = NonceUtil.generateNonce((SecurityPolicy)this.secureChannel.getSecurityPolicy());
            this.secureChannel.setLocalNonce(localNonce);
            this.secureChannel.setRemoteNonce(remoteNonce);
            newKeys = ChannelSecurity.generateKeyPair((SecureChannel)this.secureChannel, (ByteString)this.secureChannel.getRemoteNonce(), (ByteString)this.secureChannel.getLocalNonce());
        }
        ChannelSecurity.SecurityKeys oldKeys = (oldSecrets = this.secureChannel.getChannelSecurity()) != null ? oldSecrets.getCurrentKeys() : null;
        ChannelSecurityToken oldToken = oldSecrets != null ? oldSecrets.getCurrentToken() : null;
        ChannelSecurity newSecrets = new ChannelSecurity(newKeys, newToken, oldKeys, oldToken);
        this.secureChannel.setChannelSecurity(newSecrets);
        if (this.secureChannelTimeout == null || this.secureChannelTimeout.cancel()) {
            long lifetime = channelLifetime;
            this.secureChannelTimeout = Stack.sharedWheelTimer().newTimeout(timeout -> {
                this.logger.debug("SecureChannel renewal timed out after {}ms. id={}, channel={}", new Object[]{lifetime, this.secureChannel.getChannelId(), ctx.channel()});
                ctx.close();
            }, channelLifetime, TimeUnit.MILLISECONDS);
        }
        ResponseHeader responseHeader = new ResponseHeader(DateTime.now(), request.getRequestHeader().getRequestHandle(), StatusCode.GOOD, null, null, null);
        return new OpenSecureChannelResponse(responseHeader, Unsigned.uint((long)0L), newToken, this.secureChannel.getLocalNonce());
    }

    private void checkMessageSize(ByteBuf messageBuffer) throws UaSerializationException {
        int messageSize = messageBuffer.readableBytes();
        int remoteMaxMessageSize = this.channelParameters.getRemoteMaxMessageSize();
        if (remoteMaxMessageSize > 0 && messageSize > remoteMaxMessageSize) {
            throw new UaSerializationException(2159607808L, "response exceeds remote max message size: " + messageSize + " > " + remoteMaxMessageSize);
        }
    }
}

