/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.security.scram.internals;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import javax.security.sasl.SaslException;
import org.apache.kafka.common.security.scram.internals.ScramFormatter;
import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.security.scram.internals.ScramMessages;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class ScramMessagesTest {
    private static final String[] VALID_EXTENSIONS = new String[]{"ext=val1", "anotherext=name1=value1 name2=another test value \"'!$[]()", "first=val1,second=name1 = value ,third=123"};
    private static final String[] INVALID_EXTENSIONS = new String[]{"ext1=value", "ext", "ext=value1,value2", "ext=,", "ext =value"};
    private static final String[] VALID_RESERVED = new String[]{"m=reserved-value", "m=name1=value1 name2=another test value \"'!$[]()"};
    private static final String[] INVALID_RESERVED = new String[]{"m", "m=name,value", "m=,"};
    private ScramFormatter formatter;

    @Before
    public void setUp() throws Exception {
        this.formatter = new ScramFormatter(ScramMechanism.SCRAM_SHA_256);
    }

    @Test
    public void validClientFirstMessage() throws SaslException {
        String nonce = this.formatter.secureRandomString();
        ScramMessages.ClientFirstMessage m = new ScramMessages.ClientFirstMessage("someuser", nonce, Collections.emptyMap());
        this.checkClientFirstMessage(m, "someuser", nonce, "");
        String str = String.format("n,,n=testuser,r=%s", nonce);
        m = this.createScramMessage(ScramMessages.ClientFirstMessage.class, str);
        this.checkClientFirstMessage(m, "testuser", nonce, "");
        m = new ScramMessages.ClientFirstMessage(m.toBytes());
        this.checkClientFirstMessage(m, "testuser", nonce, "");
        str = String.format("n,,n=test=2Cuser,r=%s", nonce);
        m = this.createScramMessage(ScramMessages.ClientFirstMessage.class, str);
        this.checkClientFirstMessage(m, "test=2Cuser", nonce, "");
        Assert.assertEquals((Object)"test,user", (Object)this.formatter.username(m.saslName()));
        str = String.format("n,,n=test=3Duser,r=%s", nonce);
        m = this.createScramMessage(ScramMessages.ClientFirstMessage.class, str);
        this.checkClientFirstMessage(m, "test=3Duser", nonce, "");
        Assert.assertEquals((Object)"test=user", (Object)this.formatter.username(m.saslName()));
        str = String.format("n,a=testauthzid,n=testuser,r=%s", nonce);
        this.checkClientFirstMessage(this.createScramMessage(ScramMessages.ClientFirstMessage.class, str), "testuser", nonce, "testauthzid");
        for (String reserved : VALID_RESERVED) {
            str = String.format("n,,%s,n=testuser,r=%s", reserved, nonce);
            this.checkClientFirstMessage(this.createScramMessage(ScramMessages.ClientFirstMessage.class, str), "testuser", nonce, "");
        }
        for (String extension : VALID_EXTENSIONS) {
            str = String.format("n,,n=testuser,r=%s,%s", nonce, extension);
            this.checkClientFirstMessage(this.createScramMessage(ScramMessages.ClientFirstMessage.class, str), "testuser", nonce, "");
        }
        str = String.format("n,,n=testuser,r=%s,%s", nonce, "tokenauth=true");
        m = this.createScramMessage(ScramMessages.ClientFirstMessage.class, str);
        Assert.assertTrue((String)"Token authentication not set from extensions", (boolean)m.extensions().tokenAuthenticated());
    }

    @Test
    public void invalidClientFirstMessage() throws SaslException {
        String nonce = this.formatter.secureRandomString();
        String invalid = String.format("n,x=something,n=testuser,r=%s", nonce);
        this.checkInvalidScramMessage(ScramMessages.ClientFirstMessage.class, invalid);
        for (String reserved : INVALID_RESERVED) {
            invalid = String.format("n,,%s,n=testuser,r=%s", reserved, nonce);
            this.checkInvalidScramMessage(ScramMessages.ClientFirstMessage.class, invalid);
        }
        for (String extension : INVALID_EXTENSIONS) {
            invalid = String.format("n,,n=testuser,r=%s,%s", nonce, extension);
            this.checkInvalidScramMessage(ScramMessages.ClientFirstMessage.class, invalid);
        }
    }

    @Test
    public void validServerFirstMessage() throws SaslException {
        String clientNonce = this.formatter.secureRandomString();
        String serverNonce = this.formatter.secureRandomString();
        String nonce = clientNonce + serverNonce;
        String salt = this.randomBytesAsString();
        ScramMessages.ServerFirstMessage m = new ScramMessages.ServerFirstMessage(clientNonce, serverNonce, this.toBytes(salt), 8192);
        this.checkServerFirstMessage(m, nonce, salt, 8192);
        String str = String.format("r=%s,s=%s,i=4096", nonce, salt);
        m = this.createScramMessage(ScramMessages.ServerFirstMessage.class, str);
        this.checkServerFirstMessage(m, nonce, salt, 4096);
        m = new ScramMessages.ServerFirstMessage(m.toBytes());
        this.checkServerFirstMessage(m, nonce, salt, 4096);
        for (String reserved : VALID_RESERVED) {
            str = String.format("%s,r=%s,s=%s,i=4096", reserved, nonce, salt);
            this.checkServerFirstMessage(this.createScramMessage(ScramMessages.ServerFirstMessage.class, str), nonce, salt, 4096);
        }
        for (String extension : VALID_EXTENSIONS) {
            str = String.format("r=%s,s=%s,i=4096,%s", nonce, salt, extension);
            this.checkServerFirstMessage(this.createScramMessage(ScramMessages.ServerFirstMessage.class, str), nonce, salt, 4096);
        }
    }

    @Test
    public void invalidServerFirstMessage() throws SaslException {
        String nonce = this.formatter.secureRandomString();
        String salt = this.randomBytesAsString();
        String invalid = String.format("r=%s,s=%s,i=0", nonce, salt);
        this.checkInvalidScramMessage(ScramMessages.ServerFirstMessage.class, invalid);
        invalid = String.format("r=%s,s=%s,i=4096", nonce, "=123");
        this.checkInvalidScramMessage(ScramMessages.ServerFirstMessage.class, invalid);
        invalid = String.format("r=%s,invalid,s=%s,i=4096", nonce, salt);
        this.checkInvalidScramMessage(ScramMessages.ServerFirstMessage.class, invalid);
        for (String reserved : INVALID_RESERVED) {
            invalid = String.format("%s,r=%s,s=%s,i=4096", reserved, nonce, salt);
            this.checkInvalidScramMessage(ScramMessages.ServerFirstMessage.class, invalid);
        }
        for (String extension : INVALID_EXTENSIONS) {
            invalid = String.format("r=%s,s=%s,i=4096,%s", nonce, salt, extension);
            this.checkInvalidScramMessage(ScramMessages.ServerFirstMessage.class, invalid);
        }
    }

    @Test
    public void validClientFinalMessage() throws SaslException {
        String nonce = this.formatter.secureRandomString();
        String channelBinding = this.randomBytesAsString();
        String proof = this.randomBytesAsString();
        ScramMessages.ClientFinalMessage m = new ScramMessages.ClientFinalMessage(this.toBytes(channelBinding), nonce);
        Assert.assertNull((String)"Invalid proof", (Object)m.proof());
        m.proof(this.toBytes(proof));
        this.checkClientFinalMessage(m, channelBinding, nonce, proof);
        String str = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof);
        m = this.createScramMessage(ScramMessages.ClientFinalMessage.class, str);
        this.checkClientFinalMessage(m, channelBinding, nonce, proof);
        m = new ScramMessages.ClientFinalMessage(m.toBytes());
        this.checkClientFinalMessage(m, channelBinding, nonce, proof);
        for (String extension : VALID_EXTENSIONS) {
            str = String.format("c=%s,r=%s,%s,p=%s", channelBinding, nonce, extension, proof);
            this.checkClientFinalMessage(this.createScramMessage(ScramMessages.ClientFinalMessage.class, str), channelBinding, nonce, proof);
        }
    }

    @Test
    public void invalidClientFinalMessage() throws SaslException {
        String nonce = this.formatter.secureRandomString();
        String channelBinding = this.randomBytesAsString();
        String proof = this.randomBytesAsString();
        String invalid = String.format("c=ab,r=%s,p=%s", nonce, proof);
        this.checkInvalidScramMessage(ScramMessages.ClientFirstMessage.class, invalid);
        invalid = String.format("c=%s,r=%s,p=123", channelBinding, nonce);
        this.checkInvalidScramMessage(ScramMessages.ClientFirstMessage.class, invalid);
        for (String extension : INVALID_EXTENSIONS) {
            invalid = String.format("c=%s,r=%s,%s,p=%s", channelBinding, nonce, extension, proof);
            this.checkInvalidScramMessage(ScramMessages.ClientFinalMessage.class, invalid);
        }
    }

    @Test
    public void validServerFinalMessage() throws SaslException {
        String serverSignature = this.randomBytesAsString();
        ScramMessages.ServerFinalMessage m = new ScramMessages.ServerFinalMessage("unknown-user", null);
        this.checkServerFinalMessage(m, "unknown-user", null);
        m = new ScramMessages.ServerFinalMessage(null, this.toBytes(serverSignature));
        this.checkServerFinalMessage(m, null, serverSignature);
        String str = String.format("v=%s", serverSignature);
        m = this.createScramMessage(ScramMessages.ServerFinalMessage.class, str);
        this.checkServerFinalMessage(m, null, serverSignature);
        m = new ScramMessages.ServerFinalMessage(m.toBytes());
        this.checkServerFinalMessage(m, null, serverSignature);
        str = "e=other-error";
        m = this.createScramMessage(ScramMessages.ServerFinalMessage.class, str);
        this.checkServerFinalMessage(m, "other-error", null);
        m = new ScramMessages.ServerFinalMessage(m.toBytes());
        this.checkServerFinalMessage(m, "other-error", null);
        for (String extension : VALID_EXTENSIONS) {
            str = String.format("v=%s,%s", serverSignature, extension);
            this.checkServerFinalMessage(this.createScramMessage(ScramMessages.ServerFinalMessage.class, str), null, serverSignature);
        }
    }

    @Test
    public void invalidServerFinalMessage() throws SaslException {
        String serverSignature = this.randomBytesAsString();
        String invalid = "e=error1,error2";
        this.checkInvalidScramMessage(ScramMessages.ServerFinalMessage.class, invalid);
        invalid = String.format("v=1=23", new Object[0]);
        this.checkInvalidScramMessage(ScramMessages.ServerFinalMessage.class, invalid);
        for (String extension : INVALID_EXTENSIONS) {
            invalid = String.format("v=%s,%s", serverSignature, extension);
            this.checkInvalidScramMessage(ScramMessages.ServerFinalMessage.class, invalid);
            invalid = String.format("e=unknown-user,%s", extension);
            this.checkInvalidScramMessage(ScramMessages.ServerFinalMessage.class, invalid);
        }
    }

    private String randomBytesAsString() {
        return Base64.getEncoder().encodeToString(this.formatter.secureRandomBytes());
    }

    private byte[] toBytes(String base64Str) {
        return Base64.getDecoder().decode(base64Str);
    }

    private void checkClientFirstMessage(ScramMessages.ClientFirstMessage message, String saslName, String nonce, String authzid) {
        Assert.assertEquals((Object)saslName, (Object)message.saslName());
        Assert.assertEquals((Object)nonce, (Object)message.nonce());
        Assert.assertEquals((Object)authzid, (Object)message.authorizationId());
    }

    private void checkServerFirstMessage(ScramMessages.ServerFirstMessage message, String nonce, String salt, int iterations) {
        Assert.assertEquals((Object)nonce, (Object)message.nonce());
        Assert.assertArrayEquals((byte[])Base64.getDecoder().decode(salt), (byte[])message.salt());
        Assert.assertEquals((long)iterations, (long)message.iterations());
    }

    private void checkClientFinalMessage(ScramMessages.ClientFinalMessage message, String channelBinding, String nonce, String proof) {
        Assert.assertArrayEquals((byte[])Base64.getDecoder().decode(channelBinding), (byte[])message.channelBinding());
        Assert.assertEquals((Object)nonce, (Object)message.nonce());
        Assert.assertArrayEquals((byte[])Base64.getDecoder().decode(proof), (byte[])message.proof());
    }

    private void checkServerFinalMessage(ScramMessages.ServerFinalMessage message, String error, String serverSignature) {
        Assert.assertEquals((Object)error, (Object)message.error());
        if (serverSignature == null) {
            Assert.assertNull((String)"Unexpected server signature", (Object)message.serverSignature());
        } else {
            Assert.assertArrayEquals((byte[])Base64.getDecoder().decode(serverSignature), (byte[])message.serverSignature());
        }
    }

    private <T extends ScramMessages.AbstractScramMessage> T createScramMessage(Class<T> clazz, String message) throws SaslException {
        byte[] bytes = message.getBytes(StandardCharsets.UTF_8);
        if (clazz == ScramMessages.ClientFirstMessage.class) {
            return (T)new ScramMessages.ClientFirstMessage(bytes);
        }
        if (clazz == ScramMessages.ServerFirstMessage.class) {
            return (T)new ScramMessages.ServerFirstMessage(bytes);
        }
        if (clazz == ScramMessages.ClientFinalMessage.class) {
            return (T)new ScramMessages.ClientFinalMessage(bytes);
        }
        if (clazz == ScramMessages.ServerFinalMessage.class) {
            return (T)new ScramMessages.ServerFinalMessage(bytes);
        }
        throw new IllegalArgumentException("Unknown message type: " + clazz);
    }

    private <T extends ScramMessages.AbstractScramMessage> void checkInvalidScramMessage(Class<T> clazz, String message) {
        try {
            this.createScramMessage(clazz, message);
            Assert.fail((String)("Exception not throws for invalid message of type " + clazz + " : " + message));
        }
        catch (SaslException saslException) {
            // empty catch block
        }
    }
}

