001/**
002 * Copyright (C) 2006-2025 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.components.vault.client;
017
018import static java.util.Optional.of;
019import static java.util.Optional.ofNullable;
020import static java.util.concurrent.CompletableFuture.completedFuture;
021import static java.util.concurrent.TimeUnit.MILLISECONDS;
022import static java.util.concurrent.TimeUnit.MINUTES;
023import static java.util.function.Function.identity;
024import static java.util.stream.Collectors.toList;
025import static java.util.stream.Collectors.toMap;
026import static javax.ws.rs.client.Entity.entity;
027import static javax.ws.rs.core.MediaType.APPLICATION_JSON_TYPE;
028
029import java.nio.charset.StandardCharsets;
030import java.time.Clock;
031import java.util.Base64;
032import java.util.Collection;
033import java.util.HashSet;
034import java.util.Iterator;
035import java.util.List;
036import java.util.Map;
037import java.util.Map.Entry;
038import java.util.Objects;
039import java.util.Optional;
040import java.util.concurrent.CompletableFuture;
041import java.util.concurrent.CompletionStage;
042import java.util.concurrent.Executor;
043import java.util.concurrent.Executors;
044import java.util.concurrent.ScheduledExecutorService;
045import java.util.concurrent.ThreadFactory;
046import java.util.concurrent.TimeUnit;
047import java.util.concurrent.atomic.AtomicInteger;
048import java.util.concurrent.atomic.AtomicReference;
049import java.util.function.Function;
050import java.util.function.Predicate;
051import java.util.function.Supplier;
052import java.util.regex.Pattern;
053
054import javax.annotation.PostConstruct;
055import javax.annotation.PreDestroy;
056import javax.cache.Cache;
057import javax.enterprise.context.ApplicationScoped;
058import javax.enterprise.context.Initialized;
059import javax.enterprise.event.Observes;
060import javax.inject.Inject;
061import javax.json.bind.annotation.JsonbProperty;
062import javax.servlet.ServletContext;
063import javax.ws.rs.WebApplicationException;
064import javax.ws.rs.client.WebTarget;
065import javax.ws.rs.core.Response;
066import javax.ws.rs.core.Response.Status;
067
068import org.eclipse.microprofile.config.inject.ConfigProperty;
069import org.talend.sdk.components.vault.configuration.Documentation;
070import org.talend.sdk.components.vault.server.error.ErrorPayload;
071import org.talend.sdk.components.vault.server.error.ErrorPayload.ErrorDictionary;
072
073import lombok.AllArgsConstructor;
074import lombok.Data;
075import lombok.NoArgsConstructor;
076import lombok.RequiredArgsConstructor;
077import lombok.SneakyThrows;
078import lombok.extern.slf4j.Slf4j;
079
080@Slf4j
081@Data
082@ApplicationScoped
083public class VaultClient {
084
085    @Inject
086    private VaultClientSetup setup;
087
088    @Inject
089    @VaultHttp
090    private WebTarget vault;
091
092    @Inject
093    @Documentation("The vault path to retrieve a token.")
094    @ConfigProperty(name = "talend.vault.cache.vault.auth.endpoint", defaultValue = "v1/auth/engines/login")
095    private String authEndpoint;
096
097    @Inject
098    @Documentation("The vault path to decrypt values. You can use the variable `{x-talend-tenant-id}` to replace by `x-talend-tenant-id` header value.")
099    @ConfigProperty(name = "talend.vault.cache.vault.decrypt.endpoint",
100            defaultValue = "v1/tenants-keyrings/decrypt/{x-talend-tenant-id}")
101    private String decryptEndpoint;
102
103    @Inject
104    @Documentation("The vault token to use to log in (will make roleId and secretId ignored). `-` means it is ignored.")
105    @ConfigProperty(name = "talend.vault.cache.vault.auth.token", defaultValue = "-")
106    private Supplier<String> token;
107
108    @Inject
109    @Documentation("The vault role identifier to use to log in (if token is not set). `-` means it is ignored.")
110    @ConfigProperty(name = "talend.vault.cache.vault.auth.roleId", defaultValue = "-")
111    private Supplier<String> role;
112
113    @Inject
114    @Documentation("The vault secret identifier to use to log in (if token is not set). `-` means it is ignored.")
115    @ConfigProperty(name = "talend.vault.cache.vault.auth.secretId", defaultValue = "-")
116    private Supplier<String> secret;
117
118    @Inject
119    @Documentation("How often (in ms) to refresh the vault token.")
120    @ConfigProperty(name = "talend.vault.cache.service.auth.refreshDelayMargin", defaultValue = "600000")
121    private Long refreshDelayMargin;
122
123    @Inject
124    @Documentation("How long (in ms) to wait before retrying a recoverable error or refresh the vault token in case of an authentication failure.")
125    @ConfigProperty(name = "talend.vault.cache.service.auth.refreshDelayOnFailure", defaultValue = "1000")
126    private Long refreshDelayOnFailure;
127
128    @Inject
129    @Documentation("How many times do we retry a recoverable operation in case of a failure.")
130    @ConfigProperty(name = "talend.vault.cache.service.auth.numberOfRetryOnFailure", defaultValue = "3")
131    private Integer numberOfRetryOnFailure;
132
133    @Inject
134    @Documentation("Status code sent when vault can't decipher some values.")
135    @ConfigProperty(name = "talend.vault.cache.service.auth.cantDecipherStatusCode", defaultValue = "422")
136    private Integer cantDecipherStatusCode;
137
138    @Inject
139    @Documentation("The regex to whitelist ciphered keys, others will be passthrough in the output without going to vault.")
140    @ConfigProperty(name = "talend.vault.cache.service.decipher.skip.regex", defaultValue = "vault\\:v[0-9]+\\:.*")
141    private String passthroughRegex;
142
143    @Inject
144    private Cache<String, DecryptedValue> cache;
145
146    @Inject
147    private Clock clock;
148
149    private final AtomicReference<Authentication> authToken = new AtomicReference<>();
150
151    private ScheduledExecutorService scheduledExecutorService;
152
153    private Pattern compiledPassthroughRegex;
154
155    private final Predicate<Throwable> shouldRetry = cause -> {
156        if (WebApplicationException.class.isInstance(cause)) {
157            final WebApplicationException wae = WebApplicationException.class.cast(cause);
158            final int status = wae.getResponse().getStatus();
159            if (Status.NOT_FOUND.getStatusCode() == status || status >= 500) {
160                return false;
161            }
162        }
163        return true;
164    };
165
166    @PostConstruct
167    private void init() {
168        compiledPassthroughRegex = Pattern.compile(passthroughRegex);
169    }
170
171    @PreDestroy
172    private void destroy() {
173        scheduledExecutorService.shutdownNow(); // we don't care anymore about these tasks
174        try {
175            scheduledExecutorService.awaitTermination(1L, MINUTES); // wait too much but enough for our goal
176        } catch (final InterruptedException e) {
177            Thread.currentThread().interrupt();
178        }
179    }
180
181    public void init(@Observes @Initialized(ApplicationScoped.class) final ServletContext init) {
182        scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(new ThreadFactory() {
183
184            private final ThreadGroup group = ofNullable(System.getSecurityManager())
185                    .map(SecurityManager::getThreadGroup)
186                    .orElseGet(() -> Thread.currentThread().getThreadGroup());
187
188            @Override
189            public Thread newThread(final Runnable r) {
190                final Thread t = new Thread(group, r, "talend-vault-service-refresh", 0L);
191                if (t.isDaemon()) {
192                    t.setDaemon(false);
193                }
194                if (t.getPriority() != Thread.NORM_PRIORITY) {
195                    t.setPriority(Thread.NORM_PRIORITY);
196                }
197                return t;
198            }
199        });
200    }
201
202    @SneakyThrows
203    public Map<String, String> decrypt(final Map<String, String> values) {
204        return decrypt(values, null);
205    }
206
207    @SneakyThrows
208    public Map<String, String> decrypt(final Map<String, String> values, final String tenantId) {
209        if ("no-vault".equals(setup.getVaultUrl())) {
210            return values;
211        }
212        final List<String> cipheredKeys = values
213                .entrySet()
214                .stream()
215                .filter(entry -> compiledPassthroughRegex.matcher(entry.getValue()).matches())
216                .map(cyphered -> cyphered.getKey())
217                .collect(toList());
218        if (cipheredKeys.isEmpty()) {
219            return values;
220        }
221        Supplier<CompletableFuture<Map<String, String>>> attempt = () -> prepareRequest(values, cipheredKeys, tenantId);
222        return withRetries(attempt, shouldRetry).get();
223    }
224
225    private CompletableFuture<Map<String, String>> prepareRequest(final Map<String, String> values,
226            final List<String> cipheredKeys, final String tenantId) {
227        return get(cipheredKeys.stream().map(values::get).collect(toList()), clock.millis(), tenantId)
228                .thenApply(decrypted -> values
229                        .entrySet()
230                        .stream()
231                        .collect(toMap(Entry::getKey,
232                                e -> of(cipheredKeys.indexOf(e.getKey()))
233                                        .filter(idx -> idx >= 0)
234                                        .map(decrypted::get)
235                                        .map(DecryptedValue::getValue)
236                                        .orElseGet(() -> values.get(e.getKey())))));
237    }
238
239    private CompletableFuture<List<DecryptedValue>> get(final Collection<String> values, final long currentTime,
240            final String tenantId) {
241        final AtomicInteger index = new AtomicInteger();
242        final Collection<EntryWithIndex<String>> clearValues = values
243                .stream()
244                .map(it -> new EntryWithIndex<>(index.getAndIncrement(), it))
245                .filter(it -> it.entry != null && !compiledPassthroughRegex.matcher(it.entry).matches())
246                .collect(toList());
247        if (clearValues.isEmpty()) {
248            return doDecipher(values, currentTime, tenantId).toCompletableFuture();
249        }
250        if (clearValues.size() == values.size()) {
251            final long now = clock.millis();
252            return completedFuture(values.stream().map(it -> new DecryptedValue(it, now)).collect(toList()));
253        }
254        return doDecipher(values, currentTime, tenantId).thenApply(deciphered -> {
255            final long now = clock.millis();
256            clearValues.forEach(entry -> deciphered.add(entry.index, new DecryptedValue(entry.entry, now)));
257            return deciphered;
258        }).toCompletableFuture();
259    }
260
261    private CompletionStage<List<DecryptedValue>> doDecipher(final Collection<String> values, final long currentTime,
262            final String tenantId) {
263        final Map<String, Optional<DecryptedValue>> alreadyCached =
264                new HashSet<>(values).stream().collect(toMap(identity(), it -> ofNullable(cache.get(it))));
265        final Collection<String> missing = alreadyCached
266                .entrySet()
267                .stream()
268                .filter(it -> !it.getValue().isPresent())
269                .map(Map.Entry::getKey)
270                .collect(toList());
271        if (missing.isEmpty()) { // no remote call, yeah
272            return completedFuture(values.stream().map(alreadyCached::get).map(Optional::get).collect(toList()));
273        }
274        // do request
275        return getOrRequestAuth()
276                // prepare decrypt request to vault
277                .thenCompose(auth -> ofNullable(auth.getAuth()).map(Auth::getClientToken).map(clientToken -> {
278                    WebTarget path = vault.path(decryptEndpoint);
279                    if (decryptEndpoint.contains("x-talend-tenant-id")) {
280                        path = path
281                                .resolveTemplate("x-talend-tenant-id",
282                                        ofNullable(tenantId)
283                                                .orElseThrow(() -> new WebApplicationException(Response
284                                                        .status(Status.NOT_FOUND)
285                                                        .entity(new ErrorPayload(ErrorDictionary.BAD_FORMAT,
286                                                                "No header x-talend-tenant-id"))
287                                                        .build())));
288                    }
289                    return path
290                            .request(APPLICATION_JSON_TYPE)
291                            .header("X-Vault-Token", clientToken)
292                            .rx()
293                            .post(entity(new DecryptRequest(
294                                    missing.stream().map(it -> new DecryptInput(it, null, null)).collect(toList())),
295                                    APPLICATION_JSON_TYPE), DecryptResponse.class)
296                            .toCompletableFuture()
297                            // fetch decrypted values
298                            .thenApply(decrypted -> {
299                                final Collection<DecryptResult> results = decrypted.getData().getBatchResults();
300                                if (results.isEmpty()) {
301                                    throwError(cantDecipherStatusCode, "Decrypted values are empty");
302                                }
303                                final List<String> errors = results
304                                        .stream()
305                                        .map(DecryptResult::getError)
306                                        .filter(Objects::nonNull)
307                                        .collect(toList());
308                                if (!errors.isEmpty()) {
309                                    throwError(cantDecipherStatusCode, "Can't decipher properties: " + errors);
310                                }
311                                final Iterator<String> keyIterator = missing.iterator();
312                                final Map<String, DecryptedValue> decryptedResults = results
313                                        .stream()
314                                        .map(it -> new String(Base64.getDecoder().decode(it.getPlaintext()),
315                                                StandardCharsets.UTF_8))
316                                        .collect(toMap(it -> keyIterator.next(),
317                                                it -> new DecryptedValue(it, currentTime)));
318                                cache.putAll(decryptedResults);
319                                //
320                                return values
321                                        .stream()
322                                        .map(it -> decryptedResults
323                                                .getOrDefault(it, alreadyCached.get(it).orElse(null)))
324                                        .collect(toList());
325                            })
326                            // oops, smtg went wrong
327                            .exceptionally(e -> {
328                                final Throwable cause = e.getCause();
329                                String message = "";
330                                int status = cantDecipherStatusCode;
331                                if (WebApplicationException.class.isInstance(cause)) {
332                                    final WebApplicationException wae = WebApplicationException.class.cast(cause);
333                                    final Response response = wae.getResponse();
334                                    if (response != null) {
335                                        if (ErrorPayload.class.isInstance(response.getEntity())) { // internal error
336                                            throw wae;
337                                        } else {
338                                            try {
339                                                message = response.readEntity(String.class);
340                                            } catch (final Exception ignored) {
341                                                // no-op
342                                            }
343                                        }
344                                        status = response.getStatus();
345                                        if (status == Status.NOT_FOUND.getStatusCode() && message.isEmpty()) {
346                                            message = "Decryption failed: Endpoint not found, check your setup.";
347                                        }
348                                    }
349                                }
350                                if (message.isEmpty()) {
351                                    message = String.format("Decryption failed: %s", cause.getMessage());
352                                }
353                                log.error("{} ({}).", message, status);
354                                throw new WebApplicationException(message,
355                                        Response
356                                                .status(status)
357                                                .entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, message))
358                                                .build());
359                            });
360                })
361                        .orElseThrow(() -> new WebApplicationException(Response
362                                .status(Response.Status.FORBIDDEN)
363                                .entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, "getOrRequestAuth failed"))
364                                .build())));
365    }
366
367    private CompletionStage<Authentication> getOrRequestAuth() {
368        return of(token.get()).filter(this::isReloadableConfigSet).map(value -> {
369            final Auth authInfo = new Auth();
370            authInfo.setClientToken(value);
371            authInfo.setLeaseDuration(Long.MAX_VALUE);
372            authInfo.setRenewable(false);
373            return completedFuture(new Authentication(authInfo, Long.MAX_VALUE));
374        }).orElseGet(() -> {
375            final String role = of(this.role.get()).filter(this::isReloadableConfigSet).orElse(null);
376            final String secret = of(this.secret.get()).filter(this::isReloadableConfigSet).orElse(null);
377            return ofNullable(authToken.get())
378                    .filter(auth -> (auth.getExpiresAt() - clock.millis()) <= refreshDelayMargin) // is expired
379                    .map(CompletableFuture::completedFuture)
380                    .orElseGet(() -> doAuth(role, secret).toCompletableFuture());
381        });
382    }
383
384    private CompletionStage<Authentication> doAuth(final String role, final String secret) {
385        log.info("Authenticating to vault");
386        return vault
387                .path(authEndpoint)
388                .request(APPLICATION_JSON_TYPE)
389                .rx()
390                .post(entity(new AuthRequest(role, secret), APPLICATION_JSON_TYPE), AuthResponse.class)
391                //
392                .thenApply(token -> {
393                    log.debug("Authenticated to vault");
394                    if (token.getAuth() == null || token.getAuth().getClientToken() == null) {
395                        throwError(500, "Vault didn't return a token");
396                    } else {
397                        log.info("Authenticated to vault");
398                    }
399                    final long validityMargin = TimeUnit.SECONDS.toMillis(token.getAuth().getLeaseDuration());
400                    final long nextRefresh = clock.millis() + validityMargin - refreshDelayMargin;
401                    final Authentication authentication = new Authentication(token.getAuth(), nextRefresh);
402                    authToken.set(authentication);
403                    if (!scheduledExecutorService.isShutdown() && token.getAuth().isRenewable()) {
404                        scheduledExecutorService.schedule(() -> doAuth(role, secret), nextRefresh, MILLISECONDS);
405                    }
406                    return authentication;
407                })
408                //
409                .exceptionally(e -> {
410                    final Throwable cause = e.getCause();
411                    if (WebApplicationException.class.isInstance(cause)) {
412                        final WebApplicationException wae = WebApplicationException.class.cast(cause);
413                        final Response response = wae.getResponse();
414                        String message = "";
415                        if (ErrorPayload.class.isInstance(wae.getResponse().getEntity())) {
416                            throw wae; // already logged and setup broken so just rethrow
417                        } else {
418                            try {
419                                message = response.readEntity(String.class);
420                            } catch (final Exception ignored) {
421                                // no-op
422                            }
423                            if (message.isEmpty()) {
424                                message = cause.getMessage();
425                            }
426                            throwError(response.getStatus(), message);
427                        }
428                    }
429                    throwError(cause);
430                    return null;
431                });
432    }
433
434    private <T> CompletableFuture<T> withRetries(final Supplier<CompletableFuture<T>> attempt,
435            final Predicate<Throwable> shouldRetry) {
436        Executor scheduler = r -> scheduledExecutorService.schedule(r, refreshDelayOnFailure, TimeUnit.MILLISECONDS);
437        CompletableFuture<T> firstAttempt = attempt.get();
438        return flatten(firstAttempt
439                .thenApply(CompletableFuture::completedFuture)
440                .exceptionally(throwable -> retryFuture(attempt, 1, throwable, shouldRetry, scheduler)));
441    }
442
443    private <T> CompletableFuture<T> retryFuture(final Supplier<CompletableFuture<T>> attempter,
444            final int attemptsSoFar, final Throwable throwable, final Predicate<Throwable> shouldRetry,
445            final Executor scheduler) {
446        int nextAttempt = attemptsSoFar + 1;
447        log
448                .info("[retryFuture] Retry failed operation ({}/{}). Reason: {}.", attemptsSoFar,
449                        numberOfRetryOnFailure, throwable.getMessage());
450        if (nextAttempt > numberOfRetryOnFailure || !shouldRetry.test(throwable.getCause())) {
451            log.info("[retryFuture] Stop retry failed operation (condition triggered).");
452            throwError(throwable.getCause());
453        }
454        return flatten(flatten(CompletableFuture.supplyAsync(attempter, scheduler))
455                .thenApply(CompletableFuture::completedFuture)
456                .exceptionally(
457                        nextThrowable -> retryFuture(attempter, nextAttempt, nextThrowable, shouldRetry, scheduler)));
458    }
459
460    private <T> CompletableFuture<T> flatten(final CompletableFuture<CompletableFuture<T>> completableCompletable) {
461        return completableCompletable.thenCompose(Function.identity());
462    }
463
464    private void throwError(final int status, final String message) {
465        throw new WebApplicationException(message,
466                Response.status(status).entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, message)).build());
467    }
468
469    private void throwError(final Throwable cause) {
470        String message = "";
471        int status = cantDecipherStatusCode;
472        if (WebApplicationException.class.isInstance(cause)) {
473            final WebApplicationException wae = WebApplicationException.class.cast(cause);
474            final Response response = wae.getResponse();
475            status = response.getStatus();
476            if (response != null) {
477                if (ErrorPayload.class.isInstance(response.getEntity())) { // internal error
478                    throw wae;
479                } else {
480                    try {
481                        message = response.readEntity(String.class);
482                    } catch (final Exception ignored) {
483                        // no-op
484                    }
485                }
486            }
487        }
488        if (message.isEmpty()) {
489            message = cause.getMessage();
490        }
491        throw new WebApplicationException(message,
492                Response.status(status).entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, message)).build());
493    }
494
495    // workaround while geronimo-config does not support generics of generics
496    // (1.2.1 in org.apache.geronimo.config.cdi.ConfigInjectionBean.create)
497    private boolean isReloadableConfigSet(final String value) {
498        return !"-".equals(value);
499    }
500
501    @Data
502    @NoArgsConstructor
503    @AllArgsConstructor
504    public static class DecryptInput {
505
506        private String ciphertext;
507
508        private String context; // only when derivation is activated
509
510        private String nonce; // only when convergent encryption is activated
511    }
512
513    @Data
514    @NoArgsConstructor
515    @AllArgsConstructor
516    public static class DecryptRequest {
517
518        @JsonbProperty("batch_input")
519        private Collection<DecryptInput> batchInput;
520    }
521
522    @Data
523    public static class DecryptResult {
524
525        private String plaintext;
526
527        private String context;
528
529        private String error;
530    }
531
532    @Data
533    public static class DecryptResponse {
534
535        private DecryptData data;
536    }
537
538    @Data
539    public static class DecryptData {
540
541        @JsonbProperty("batch_results")
542        private Collection<DecryptResult> batchResults;
543    }
544
545    @Data
546    @NoArgsConstructor
547    @AllArgsConstructor
548    public static class AuthRequest {
549
550        @JsonbProperty("role_id")
551        private String roleId;
552
553        @JsonbProperty("secret_id")
554        private String secretId;
555    }
556
557    @Data
558    public static class Auth {
559
560        private boolean renewable;
561
562        @JsonbProperty("lease_duration")
563        private long leaseDuration;
564
565        @JsonbProperty("client_token")
566        private String clientToken;
567    }
568
569    @Data
570    public static class AuthResponse {
571
572        private Auth auth;
573    }
574
575    @Data
576    private static class Authentication {
577
578        private final Auth auth;
579
580        private final long expiresAt;
581    }
582
583    @RequiredArgsConstructor
584    private static class EntryWithIndex<T> {
585
586        private final int index;
587
588        private final T entry;
589    }
590}