/*
 * Decompiled with CFR 0.152.
 */
package org.apache.plc4x.java.opcua.context;

import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import org.apache.plc4x.java.opcua.context.Conversation;
import org.apache.plc4x.java.opcua.protocol.chunk.Chunk;
import org.apache.plc4x.java.opcua.protocol.chunk.PayloadConverter;
import org.apache.plc4x.java.opcua.readwrite.ChunkType;
import org.apache.plc4x.java.opcua.readwrite.MessagePDU;
import org.apache.plc4x.java.opcua.security.SecurityPolicy;
import org.apache.plc4x.java.spi.generation.ByteOrder;
import org.apache.plc4x.java.spi.generation.SerializationException;
import org.apache.plc4x.java.spi.generation.WithWriterArgs;
import org.apache.plc4x.java.spi.generation.WriteBuffer;
import org.apache.plc4x.java.spi.generation.WriteBufferByteBased;

abstract class BaseEncryptionHandler {
    protected static final int SECURE_MESSAGE_HEADER_SIZE = 12;
    protected static final int SEQUENCE_HEADER_SIZE = 8;
    protected final Conversation conversation;
    protected final SecurityPolicy securityPolicy;

    public BaseEncryptionHandler(Conversation conversation, SecurityPolicy securityPolicy) {
        this.conversation = conversation;
        this.securityPolicy = securityPolicy;
    }

    public final List<MessagePDU> encodeMessage(Chunk chunk, MessagePDU message, Supplier<Integer> sequenceSupplier) {
        try {
            ByteBuffer messageBuffer = ByteBuffer.wrap(PayloadConverter.toStream(message));
            int sequenceStart = 12 + chunk.getSecurityHeaderSize();
            byte[] messageHeader = new byte[12];
            messageBuffer.get(messageHeader);
            byte[] securityHeader = new byte[chunk.getSecurityHeaderSize()];
            messageBuffer.get(securityHeader);
            byte[] sequenceHeader = new byte[8];
            messageBuffer.get(sequenceHeader);
            ByteBuffer bodyBuffer = messageBuffer.slice();
            ArrayList<MessagePDU> messages = new ArrayList<MessagePDU>();
            boolean first = true;
            while (bodyBuffer.hasRemaining()) {
                int bodySize = Math.min(bodyBuffer.remaining(), chunk.getMaxBodySize());
                int paddingSize = 0;
                if (chunk.isEncrypted()) {
                    int plainTextSize = 8 + bodySize + chunk.getPaddingOverhead() + chunk.getSignatureSize();
                    int gap = plainTextSize % chunk.getPlainTextBlockSize();
                    paddingSize = gap > 0 ? chunk.getPlainTextBlockSize() - gap : 0;
                }
                int plainTextContentSize = 8 + bodySize + chunk.getSignatureSize() + paddingSize + chunk.getPaddingOverhead();
                if (chunk.isEncrypted()) assert (plainTextContentSize % chunk.getPlainTextBlockSize() == 0);
                int chunkSize = 12 + chunk.getSecurityHeaderSize() + plainTextContentSize / chunk.getPlainTextBlockSize() * chunk.getCipherTextBlockSize();
                WriteBufferByteBased chunkBuffer = new WriteBufferByteBased(chunkSize, ByteOrder.LITTLE_ENDIAN);
                chunkBuffer.writeByteArray("messageHeader", messageHeader, new WithWriterArgs[0]);
                chunkBuffer.writeByteArray("securityHeader", securityHeader, new WithWriterArgs[0]);
                chunkBuffer.writeByteArray("sequenceHeader", sequenceHeader, new WithWriterArgs[0]);
                this.updateFrameSize(chunkBuffer, chunkSize);
                ChunkType chunkType = bodyBuffer.remaining() - bodySize > 0 ? ChunkType.CONTINUE : ChunkType.FINAL;
                this.updateFrame(first, chunkBuffer, chunk, chunkType, sequenceSupplier);
                first = false;
                byte[] chunkContents = new byte[bodySize];
                bodyBuffer.get(chunkContents);
                chunkBuffer.writeByteArray("payload", chunkContents, new WithWriterArgs[0]);
                if (chunk.isEncrypted()) {
                    int index = 0;
                    int limit = paddingSize + chunk.getPaddingOverhead();
                    while (index < limit) {
                        chunkBuffer.writeByte("padding", (byte)paddingSize, new WithWriterArgs[0]);
                        ++index;
                    }
                    if (chunk.getPaddingOverhead() > 1) {
                        chunkBuffer.setPos(bodySize + paddingSize + chunk.getPaddingOverhead());
                        chunkBuffer.writeByte("paddingMSB", (byte)(paddingSize >> 8 & 0xFF), new WithWriterArgs[0]);
                    }
                }
                if (chunk.isSigned()) {
                    byte[] signatureData = this.sign(chunkBuffer.getBytes(0, chunkBuffer.getPos()));
                    chunkBuffer.writeByteArray("signature", signatureData, new WithWriterArgs[0]);
                }
                if (chunk.isEncrypted()) {
                    this.encrypt(chunkBuffer, chunk.getSecurityHeaderSize(), chunk.getPlainTextBlockSize(), chunk.getCipherTextBlockSize(), plainTextContentSize / chunk.getPlainTextBlockSize());
                }
                MessagePDU chunkedMessage = PayloadConverter.pduFromStream(chunkBuffer.getBytes(), message.getResponse());
                messages.add(chunkedMessage);
            }
            return messages;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public final MessagePDU decodeMessage(Chunk chunk, MessagePDU message) {
        try {
            int payloadStart;
            int payloadEnd;
            int expectedPaddingSize;
            if (!chunk.isEncrypted() && !chunk.isSigned()) {
                return message;
            }
            int messageLength = message.getLengthInBytes();
            WriteBufferByteBased chunkBuffer = new WriteBufferByteBased(messageLength, ByteOrder.LITTLE_ENDIAN);
            message.serialize((WriteBuffer)chunkBuffer);
            int bodySize = messageLength - chunk.getSecurityHeaderSize() - 12;
            if (chunk.isEncrypted()) {
                bodySize = this.decrypt(chunkBuffer, chunk, messageLength);
            }
            if (chunk.isSigned()) {
                this.verify(chunkBuffer, chunk, messageLength);
            }
            int encryptionOverhead = this.getEncryptionOverhead(chunk, messageLength);
            short paddingSize = this.getPaddingSize(chunkBuffer, chunk, messageLength);
            if (paddingSize != (expectedPaddingSize = messageLength - (payloadEnd = (payloadStart = 12 + chunk.getSecurityHeaderSize()) + bodySize - paddingSize - chunk.getSignatureSize() - chunk.getPaddingOverhead()) - chunk.getSignatureSize() - encryptionOverhead - chunk.getPaddingOverhead())) {
                throw new IllegalArgumentException("Malformed data detected - expected padding size do not match");
            }
            if (chunk.isEncrypted()) {
                byte[] paddingBytes = chunkBuffer.getBytes(payloadEnd, payloadEnd + expectedPaddingSize);
                byte paddingByte = (byte)(paddingSize & 0xFF);
                int index = 0;
                while (index < paddingBytes.length) {
                    if (paddingBytes[index] != paddingByte) {
                        throw new IllegalArgumentException("Malformed padding byte at index " + index);
                    }
                    ++index;
                }
            }
            int overhead = paddingSize + chunk.getSignatureSize() + chunk.getPaddingOverhead() + encryptionOverhead;
            this.updateFrameSize(chunkBuffer, messageLength - overhead);
            return PayloadConverter.pduFromStream(chunkBuffer.getBytes(), message.getResponse());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void updateFrame(boolean first, WriteBufferByteBased messageBuffer, Chunk chunk, ChunkType chunkType, Supplier<Integer> sequenceSupplier) throws SerializationException {
        int payloadStart = 12 + chunk.getSecurityHeaderSize();
        if (chunkType != ChunkType.FINAL) {
            messageBuffer.setPos(3);
            messageBuffer.writeString("chunkType", 8, chunkType.getValue(), new WithWriterArgs[0]);
        }
        if (!first) {
            messageBuffer.setPos(payloadStart);
            messageBuffer.writeUnsignedLong("sequenceId", 32, (long)sequenceSupplier.get().intValue(), new WithWriterArgs[0]);
        }
        messageBuffer.setPos(payloadStart + 8);
    }

    private void updateFrameSize(WriteBufferByteBased messageBuffer, long frameSize) throws SerializationException {
        int position = messageBuffer.getPos();
        try {
            messageBuffer.setPos(4);
            messageBuffer.writeUnsignedLong("totalLength", 32, frameSize, new WithWriterArgs[0]);
        }
        finally {
            messageBuffer.setPos(position);
        }
    }

    private int getEncryptionOverhead(Chunk chunk, int messageLength) {
        if (!chunk.isEncrypted()) {
            return 0;
        }
        int bodyStart = 12 + chunk.getSecurityHeaderSize();
        int bodySize = messageLength - bodyStart;
        int blockCount = bodySize / chunk.getCipherTextBlockSize();
        return chunk.getCipherTextBlockSize() * blockCount - chunk.getPlainTextBlockSize() * blockCount;
    }

    private short getPaddingSize(WriteBufferByteBased chunkBuffer, Chunk chunk, int messageLength) {
        if (!chunk.isEncrypted()) {
            return 0;
        }
        int bodyStart = 12 + chunk.getSecurityHeaderSize();
        int bodySize = messageLength - bodyStart;
        int blockCount = bodySize / chunk.getCipherTextBlockSize();
        int encryptionOverhead = chunk.getCipherTextBlockSize() * blockCount - chunk.getPlainTextBlockSize() * blockCount;
        int paddingEnd = messageLength - chunk.getSignatureSize() - encryptionOverhead - chunk.getPaddingOverhead();
        byte[] padding = chunkBuffer.getBytes(paddingEnd, paddingEnd + chunk.getPaddingOverhead());
        if (padding.length > 2) {
            return (short)((padding[1] & 0xFF) << 8 | padding[0] & 0xFF);
        }
        return padding[0];
    }

    protected abstract void verify(WriteBufferByteBased var1, Chunk var2, int var3) throws Exception;

    protected abstract int decrypt(WriteBufferByteBased var1, Chunk var2, int var3) throws Exception;

    protected abstract void encrypt(WriteBufferByteBased var1, int var2, int var3, int var4, int var5) throws Exception;

    protected abstract byte[] sign(byte[] var1) throws GeneralSecurityException;
}

