/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.redshift.plugin;

import com.amazon.redshift.CredentialsHolder;
import com.amazon.redshift.IPlugin;
import com.amazon.redshift.core.PGJDBCPropertyKey;
import com.amazon.redshift.httpclient.log.IamCustomLogFactory;
import com.amazon.redshift.ssl.NonValidatingFactory;
import com.amazon.support.ILogger;
import com.amazon.support.LogUtilities;
import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLRequest;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.util.StringUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.LogFactory;
import org.apache.http.client.RedirectStrategy;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.socket.LayeredConnectionSocketFactory;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.client.LaxRedirectStrategy;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

public abstract class SamlCredentialsProvider
implements IPlugin {
    protected static final String KEY_IDP_HOST = "idp_host";
    private static final String KEY_IDP_PORT = "idp_port";
    private static final String KEY_DURATION = "duration";
    private static final String KEY_PREFERRED_ROLE = "preferred_role";
    private static final String KEY_SSL_INSECURE = "ssl_insecure";
    protected String m_userName;
    protected String m_password;
    protected String m_idpHost;
    protected int m_idpPort = 443;
    protected int m_duration;
    protected String m_preferredRole;
    protected boolean m_sslInsecure;
    protected String m_dbUser;
    protected String m_dbGroups;
    protected String m_dbGroupsFilter;
    protected Boolean m_forceLowercase;
    protected Boolean m_autoCreate;
    protected String m_region;
    protected ILogger m_log;
    private static Map<String, CredentialsHolder> m_cache = new HashMap<String, CredentialsHolder>();
    private static final Class<?> CUSTOM_LOG_FACTORY_CLASS = IamCustomLogFactory.class;
    private static final String LOG_PROPERTIES_FILE_NAME = "log-factory.properties";
    private static final String LOG_PROPERTIES_FILE_PATH = "META-INF/services/org.apache.commons.logging.LogFactory";
    private static final ClassLoader CONTEXT_CLASS_LOADER = new ClassLoader(SamlCredentialsProvider.class.getClassLoader()){

        @Override
        public Class<?> loadClass(String string) throws ClassNotFoundException {
            Class<?> clazz = this.getParent().loadClass(string);
            if (LogFactory.class.isAssignableFrom(clazz)) {
                return CUSTOM_LOG_FACTORY_CLASS;
            }
            return clazz;
        }

        @Override
        public Enumeration<URL> getResources(String string) throws IOException {
            if ("commons-logging.properties".equals(string)) {
                return Collections.enumeration(Collections.emptyList());
            }
            return super.getResources(string);
        }

        @Override
        public URL getResource(String string) {
            if (SamlCredentialsProvider.LOG_PROPERTIES_FILE_PATH.equals(string)) {
                return SamlCredentialsProvider.class.getResource(SamlCredentialsProvider.LOG_PROPERTIES_FILE_NAME);
            }
            return super.getResource(string);
        }
    };

    protected abstract String getSamlAssertion() throws IOException;

    @Override
    public void addParameter(String string, String string2) {
        if (PGJDBCPropertyKey.USERNAME.equalsIgnoreCase(string) || PGJDBCPropertyKey.USERNAME_ALT.equalsIgnoreCase(string)) {
            this.m_userName = string2;
        } else if (PGJDBCPropertyKey.PASSWORD.equalsIgnoreCase(string) || PGJDBCPropertyKey.PASSWORD_ALT.equalsIgnoreCase(string)) {
            this.m_password = string2;
        } else if (KEY_IDP_HOST.equalsIgnoreCase(string)) {
            this.m_idpHost = string2;
        } else if (KEY_IDP_PORT.equalsIgnoreCase(string)) {
            this.m_idpPort = Integer.parseInt(string2);
        } else if (KEY_DURATION.equalsIgnoreCase(string)) {
            this.m_duration = Integer.parseInt(string2);
        } else if (KEY_PREFERRED_ROLE.equalsIgnoreCase(string)) {
            this.m_preferredRole = string2;
        } else if (KEY_SSL_INSECURE.equalsIgnoreCase(string)) {
            this.m_sslInsecure = Boolean.parseBoolean(string2);
        } else if (PGJDBCPropertyKey.DB_USER.equalsIgnoreCase(string)) {
            this.m_dbUser = string2;
        } else if (PGJDBCPropertyKey.DB_GROUPS.equalsIgnoreCase(string)) {
            this.m_dbGroups = string2;
        } else if (PGJDBCPropertyKey.DB_GROUPS_FILTER.equalsIgnoreCase(string) || PGJDBCPropertyKey.DB_GROUPS_FILTER_ALT.equalsIgnoreCase(string)) {
            this.m_dbGroupsFilter = string2;
        } else if (PGJDBCPropertyKey.FORCE_LOWERCASE.equalsIgnoreCase(string)) {
            this.m_forceLowercase = Boolean.valueOf(string2);
        } else if (PGJDBCPropertyKey.USER_AUTOCREATE.equalsIgnoreCase(string)) {
            this.m_autoCreate = Boolean.valueOf(string2);
        } else if (PGJDBCPropertyKey.AWS_REGION.equalsIgnoreCase(string)) {
            this.m_region = string2;
        }
    }

    @Override
    public void setILogger(ILogger iLogger) {
        this.m_log = iLogger;
    }

    public CredentialsHolder getCredentials() {
        LogUtilities.logDebug("\n{m_userName='" + this.m_userName + '\'' + ", m_password='" + (this.m_password == null ? "null" : "**") + '\'' + ", m_idpHost='" + this.m_idpHost + '\'' + ", m_idpPort=" + this.m_idpPort + ", m_duration=" + this.m_duration + ", m_preferredRole='" + this.m_preferredRole + '\'' + ", m_sslInsecure=" + this.m_sslInsecure + ", m_dbUser='" + this.m_dbUser + '\'' + ", m_dbGroups='" + this.m_dbGroups + '\'' + ", m_forceLowercase=" + this.m_forceLowercase + ", m_autoCreate=" + this.m_autoCreate + ", m_region='" + this.m_region + '\'' + '}', this.m_log);
        String string = this.getCacheKey();
        CredentialsHolder credentialsHolder = m_cache.get(string);
        if (credentialsHolder == null) {
            LogUtilities.logDebug("no credential", this.m_log);
            this.refresh();
        } else if (credentialsHolder.isExpired()) {
            LogUtilities.logDebug("credentials expired", this.m_log);
            this.refresh();
        }
        credentialsHolder = m_cache.get(string);
        if (!StringUtils.isNullOrEmpty((String)this.m_dbUser)) {
            credentialsHolder.getThisMetadata().setDbUser(this.m_dbUser);
        }
        if (credentialsHolder == null) {
            throw new SdkClientException("Unable to load AWS credentials from ADFS");
        }
        Date date = new Date();
        LogUtilities.logInfo(date + ": Using entry for SamlCredentialsProvider.getCredentials cache with expiration " + credentialsHolder.getExpiration(), this.m_log);
        return credentialsHolder;
    }

    public void refresh() {
        LogUtilities.logDebug("start refresh", this.m_log);
        Thread thread = Thread.currentThread();
        ClassLoader classLoader = thread.getContextClassLoader();
        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);
        try {
            String string;
            Object object;
            String string22;
            String string3;
            String string4;
            AWSStaticCredentialsProvider aWSStaticCredentialsProvider;
            Object object2;
            Object object3;
            Pattern pattern = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:saml-provider/\\S+");
            Pattern pattern2 = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:role/\\S+");
            String string5 = this.getSamlAssertion();
            LogUtilities.logTrace(String.format("\nSAML assertion:\n%s", string5), this.m_log);
            byte[] byArray = Base64.decodeBase64((String)string5);
            Document document = SamlCredentialsProvider.parse(byArray);
            LogUtilities.logTrace(String.format("SAML decoded:\n%s", new String(byArray, StandardCharsets.UTF_8)), this.m_log);
            XPath xPath = XPathFactory.newInstance().newXPath();
            String string6 = "//*[local-name()='Attribute'][@Name='https://aws.amazon.com/SAML/Attributes/Role']/*[local-name()='AttributeValue']/text()";
            NodeList nodeList = (NodeList)xPath.compile(string6).evaluate(document, XPathConstants.NODESET);
            HashMap<String, String> hashMap = new HashMap<String, String>();
            if (nodeList != null) {
                for (int i = 0; i < nodeList.getLength(); ++i) {
                    object3 = nodeList.item(i);
                    object2 = object3.getNodeValue();
                    LogUtilities.logDebug("SamlRoleAttribute >> " + (String)object2, this.m_log);
                    aWSStaticCredentialsProvider = ((String)object2).split(",");
                    if (((String[])aWSStaticCredentialsProvider).length < 2) continue;
                    string4 = null;
                    string3 = null;
                    for (String string22 : aWSStaticCredentialsProvider) {
                        object = pattern.matcher(string22);
                        if (((Matcher)object).find()) {
                            string4 = ((Matcher)object).group(0);
                            LogUtilities.logDebug("set provider >> " + string4, this.m_log);
                            continue;
                        }
                        Matcher matcher = pattern2.matcher(string22);
                        if (!matcher.find()) continue;
                        string3 = matcher.group(0);
                        LogUtilities.logDebug("set role >> " + string3, this.m_log);
                    }
                    if (StringUtils.isNullOrEmpty(string3) || StringUtils.isNullOrEmpty(string4)) continue;
                    hashMap.put(string3, string4);
                    LogUtilities.logDebug("added roles to list", this.m_log);
                }
            }
            if (hashMap.isEmpty()) {
                throw new SdkClientException("No role found in SamlAssertion: " + string5);
            }
            if (this.m_preferredRole != null) {
                LogUtilities.logDebug("set Role from Preferred Role", this.m_log);
                string = this.m_preferredRole;
                object3 = (String)hashMap.get(this.m_preferredRole);
                if (object3 == null) {
                    throw new SdkClientException("Preferred role not found in SamlAssertion: " + string5);
                }
            } else {
                LogUtilities.logDebug("get first Role from Attributes", this.m_log);
                object2 = hashMap.entrySet().iterator().next();
                string = (String)object2.getKey();
                object3 = (String)object2.getValue();
            }
            LogUtilities.logDebug(String.format("got [roleArn, principal] >> \"%s\" -- \"%s\"", string, object3), this.m_log);
            object2 = new AssumeRoleWithSAMLRequest();
            object2.setSAMLAssertion(string5);
            object2.setRoleArn(string);
            object2.setPrincipalArn((String)object3);
            if (this.m_duration > 0) {
                object2.setDurationSeconds(Integer.valueOf(this.m_duration));
            }
            aWSStaticCredentialsProvider = new AWSStaticCredentialsProvider((AWSCredentials)new AnonymousAWSCredentials());
            string4 = AWSSecurityTokenServiceClientBuilder.standard();
            string4.setRegion(this.m_region);
            string3 = (AWSSecurityTokenService)((AWSSecurityTokenServiceClientBuilder)string4.withCredentials((AWSCredentialsProvider)aWSStaticCredentialsProvider)).build();
            AWSStaticCredentialsProvider aWSStaticCredentialsProvider2 = string3.assumeRoleWithSAML((AssumeRoleWithSAMLRequest)object2);
            LogUtilities.logDebug("Role assumed:\n" + aWSStaticCredentialsProvider2, this.m_log);
            Credentials credentials = aWSStaticCredentialsProvider2.getCredentials();
            Date date = credentials.getExpiration();
            string22 = new BasicSessionCredentials(credentials.getAccessKeyId(), credentials.getSecretAccessKey(), credentials.getSessionToken());
            object = CredentialsHolder.newInstance((AWSCredentials)string22, date);
            ((CredentialsHolder)object).setMetadata(this.readMetadata(document));
            LogUtilities.logDebug("put credentials to cache. expire after: " + date, this.m_log);
            m_cache.put(this.getCacheKey(), (CredentialsHolder)object);
        }
        catch (IOException | ParserConfigurationException | XPathExpressionException | SAXException exception) {
            throw new SdkClientException("SAML error: " + exception.getMessage(), (Throwable)exception);
        }
        finally {
            thread.setContextClassLoader(classLoader);
        }
    }

    private String getCacheKey() {
        return this.m_userName + this.m_password + this.m_idpHost + this.m_idpPort + this.m_duration + this.m_preferredRole;
    }

    private CredentialsHolder.IamMetadata readMetadata(Document document) throws XPathExpressionException {
        CredentialsHolder.IamMetadata iamMetadata = new CredentialsHolder.IamMetadata();
        XPath xPath = XPathFactory.newInstance().newXPath();
        List<String> list = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, document, "https://redshift.amazon.com/SAML/Attributes/AllowDbUserOverride");
        if (!list.isEmpty()) {
            iamMetadata.setAllowDbUserOverride(Boolean.valueOf(list.get(0)));
        }
        if (!(list = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, document, "https://redshift.amazon.com/SAML/Attributes/DbUser")).isEmpty()) {
            iamMetadata.setSamlDbUser(list.get(0));
        } else {
            list = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, document, "https://aws.amazon.com/SAML/Attributes/RoleSessionName");
            if (!list.isEmpty()) {
                iamMetadata.setSamlDbUser(list.get(0));
            }
        }
        list = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, document, "https://redshift.amazon.com/SAML/Attributes/AutoCreate");
        if (!list.isEmpty()) {
            iamMetadata.setAutoCreate(Boolean.valueOf(list.get(0)));
        }
        if (!(list = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, document, "https://redshift.amazon.com/SAML/Attributes/DbGroups")).isEmpty() && !(list = this.filterOutGroups(list)).isEmpty()) {
            StringBuilder stringBuilder = new StringBuilder();
            for (String string : list) {
                if (stringBuilder.length() > 0) {
                    stringBuilder.append(',');
                }
                stringBuilder.append(string);
            }
            iamMetadata.setDbGroups(stringBuilder.toString());
        }
        if (!(list = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, document, "https://redshift.amazon.com/SAML/Attributes/ForceLowercase")).isEmpty()) {
            iamMetadata.setForceLowercase(Boolean.valueOf(list.get(0)));
        }
        LogUtilities.logDebug("got metadata from SAML:\n" + iamMetadata, this.m_log);
        return iamMetadata;
    }

    private List<String> filterOutGroups(List<String> list) {
        if (this.m_dbGroupsFilter != null) {
            Pattern pattern = Pattern.compile(this.m_dbGroupsFilter);
            ArrayList<String> arrayList = new ArrayList<String>();
            for (String string : list) {
                LogUtilities.logDebug(String.format("Check group %s with regexp %s", string, this.m_dbGroupsFilter), this.m_log);
                if (pattern.matcher(string).matches()) continue;
                LogUtilities.logDebug(String.format("Add %s to dbgroups", string), this.m_log);
                arrayList.add(string);
            }
            return arrayList;
        }
        return list;
    }

    private static Document parse(byte[] byArray) throws IOException, SAXException, ParserConfigurationException {
        DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
        documentBuilderFactory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
        documentBuilderFactory.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
        documentBuilderFactory.setFeature("http://xml.org/sax/features/external-general-entities", false);
        documentBuilderFactory.setXIncludeAware(false);
        documentBuilderFactory.setExpandEntityReferences(false);
        DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder();
        return documentBuilder.parse(new ByteArrayInputStream(byArray));
    }

    private static List<String> GetSAMLAttributeValues(XPath xPath, Document document, String string) throws XPathExpressionException {
        String string2 = String.format("//Attribute[@Name='%s']/AttributeValue/text()", string);
        NodeList nodeList = (NodeList)xPath.compile(string2).evaluate(document, XPathConstants.NODESET);
        if (null == nodeList || nodeList.getLength() == 0) {
            return Collections.emptyList();
        }
        ArrayList<String> arrayList = new ArrayList<String>(nodeList.getLength());
        for (int i = 0; i < nodeList.getLength(); ++i) {
            Node node = nodeList.item(i);
            arrayList.add(node.getNodeValue());
        }
        return arrayList;
    }

    protected CloseableHttpClient getHttpClient() throws GeneralSecurityException {
        RequestConfig requestConfig = RequestConfig.custom().setSocketTimeout(60000).setConnectTimeout(60000).setExpectContinueEnabled(false).setCookieSpec("standard").build();
        HttpClientBuilder httpClientBuilder = HttpClients.custom().setDefaultRequestConfig(requestConfig).setRedirectStrategy((RedirectStrategy)new LaxRedirectStrategy());
        if (this.m_sslInsecure) {
            SSLContext sSLContext = SSLContext.getInstance("TLSv1.2");
            TrustManager[] trustManagerArray = new TrustManager[]{new NonValidatingFactory()};
            sSLContext.init(null, trustManagerArray, null);
            SSLSocketFactory sSLSocketFactory = sSLContext.getSocketFactory();
            SSLConnectionSocketFactory sSLConnectionSocketFactory = new SSLConnectionSocketFactory(sSLSocketFactory, (HostnameVerifier)new NoopHostnameVerifier());
            httpClientBuilder.setSSLSocketFactory((LayeredConnectionSocketFactory)sSLConnectionSocketFactory);
        }
        return httpClientBuilder.build();
    }

    protected List<String> getInputTagsfromHTML(String string) {
        HashSet<String> hashSet = new HashSet<String>();
        ArrayList<String> arrayList = new ArrayList<String>();
        Pattern pattern = Pattern.compile("<input(.+?)/>", 32);
        Matcher matcher = pattern.matcher(string);
        while (matcher.find()) {
            String string2 = matcher.group(0);
            String string3 = this.getValueByKey(string2, "name").toLowerCase();
            if (string3.isEmpty() || !hashSet.add(string3)) continue;
            arrayList.add(string2);
        }
        return arrayList;
    }

    protected String getFormAction(String string) {
        Pattern pattern = Pattern.compile("<form.*?action=\"([^\"]+)\"");
        Matcher matcher = pattern.matcher(string);
        if (matcher.find()) {
            return this.escapeHtmlEntity(matcher.group(1));
        }
        return null;
    }

    protected String getValueByKey(String string, String string2) {
        Pattern pattern = Pattern.compile("(" + Pattern.quote(string2) + ")\\s*=\\s*\"(.*?)\"");
        Matcher matcher = pattern.matcher(string);
        if (matcher.find()) {
            return this.escapeHtmlEntity(matcher.group(2));
        }
        return "";
    }

    protected boolean isText(String string) {
        return "text".equals(this.getValueByKey(string, "type"));
    }

    protected boolean isPassword(String string) {
        return "password".equals(this.getValueByKey(string, "type"));
    }

    protected String escapeHtmlEntity(String string) {
        StringBuilder stringBuilder = new StringBuilder(string.length());
        int n = 0;
        int n2 = string.length();
        while (n < n2) {
            char c = string.charAt(n);
            if (c != '&') {
                stringBuilder.append(c);
                ++n;
                continue;
            }
            if (string.startsWith("&amp;", n)) {
                stringBuilder.append('&');
                n += 5;
                continue;
            }
            if (string.startsWith("&apos;", n)) {
                stringBuilder.append('\'');
                n += 6;
                continue;
            }
            if (string.startsWith("&quot;", n)) {
                stringBuilder.append('\"');
                n += 6;
                continue;
            }
            if (string.startsWith("&lt;", n)) {
                stringBuilder.append('<');
                n += 4;
                continue;
            }
            if (string.startsWith("&gt;", n)) {
                stringBuilder.append('>');
                n += 4;
                continue;
            }
            stringBuilder.append(c);
            ++n;
        }
        return stringBuilder.toString();
    }

    protected void checkRequiredParameters() throws IOException {
        if (StringUtils.isNullOrEmpty((String)this.m_userName)) {
            throw new IOException("Missing required property: " + (Object)((Object)PGJDBCPropertyKey.USERNAME_ALT));
        }
        if (StringUtils.isNullOrEmpty((String)this.m_password)) {
            throw new IOException("Missing required property: " + (Object)((Object)PGJDBCPropertyKey.PASSWORD_ALT));
        }
        if (StringUtils.isNullOrEmpty((String)this.m_idpHost)) {
            throw new IOException("Missing required property: idp_host");
        }
    }
}

