001/** 002 * Copyright (C) 2006-2025 Talend Inc. - www.talend.com 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.talend.sdk.components.vault.client; 017 018import static java.util.Optional.of; 019import static java.util.Optional.ofNullable; 020import static java.util.concurrent.CompletableFuture.completedFuture; 021import static java.util.concurrent.TimeUnit.MILLISECONDS; 022import static java.util.concurrent.TimeUnit.MINUTES; 023import static java.util.function.Function.identity; 024import static java.util.stream.Collectors.toList; 025import static java.util.stream.Collectors.toMap; 026import static javax.ws.rs.client.Entity.entity; 027import static javax.ws.rs.core.MediaType.APPLICATION_JSON_TYPE; 028 029import java.nio.charset.StandardCharsets; 030import java.time.Clock; 031import java.util.Base64; 032import java.util.Collection; 033import java.util.HashSet; 034import java.util.Iterator; 035import java.util.List; 036import java.util.Map; 037import java.util.Map.Entry; 038import java.util.Objects; 039import java.util.Optional; 040import java.util.concurrent.CompletableFuture; 041import java.util.concurrent.CompletionStage; 042import java.util.concurrent.Executor; 043import java.util.concurrent.Executors; 044import java.util.concurrent.ScheduledExecutorService; 045import java.util.concurrent.ThreadFactory; 046import java.util.concurrent.TimeUnit; 047import java.util.concurrent.atomic.AtomicInteger; 048import java.util.concurrent.atomic.AtomicReference; 049import java.util.function.Function; 050import java.util.function.Predicate; 051import java.util.function.Supplier; 052import java.util.regex.Pattern; 053 054import javax.annotation.PostConstruct; 055import javax.annotation.PreDestroy; 056import javax.cache.Cache; 057import javax.enterprise.context.ApplicationScoped; 058import javax.enterprise.context.Initialized; 059import javax.enterprise.event.Observes; 060import javax.inject.Inject; 061import javax.json.bind.annotation.JsonbProperty; 062import javax.servlet.ServletContext; 063import javax.ws.rs.WebApplicationException; 064import javax.ws.rs.client.WebTarget; 065import javax.ws.rs.core.Response; 066import javax.ws.rs.core.Response.Status; 067 068import org.eclipse.microprofile.config.inject.ConfigProperty; 069import org.talend.sdk.components.vault.configuration.Documentation; 070import org.talend.sdk.components.vault.server.error.ErrorPayload; 071import org.talend.sdk.components.vault.server.error.ErrorPayload.ErrorDictionary; 072 073import lombok.AllArgsConstructor; 074import lombok.Data; 075import lombok.NoArgsConstructor; 076import lombok.RequiredArgsConstructor; 077import lombok.SneakyThrows; 078import lombok.extern.slf4j.Slf4j; 079 080@Slf4j 081@Data 082@ApplicationScoped 083public class VaultClient { 084 085 @Inject 086 private VaultClientSetup setup; 087 088 @Inject 089 @VaultHttp 090 private WebTarget vault; 091 092 @Inject 093 @Documentation("The vault path to retrieve a token.") 094 @ConfigProperty(name = "talend.vault.cache.vault.auth.endpoint", defaultValue = "v1/auth/engines/login") 095 private String authEndpoint; 096 097 @Inject 098 @Documentation("The vault path to decrypt values. You can use the variable `{x-talend-tenant-id}` to replace by `x-talend-tenant-id` header value.") 099 @ConfigProperty(name = "talend.vault.cache.vault.decrypt.endpoint", 100 defaultValue = "v1/tenants-keyrings/decrypt/{x-talend-tenant-id}") 101 private String decryptEndpoint; 102 103 @Inject 104 @Documentation("The vault token to use to log in (will make roleId and secretId ignored). `-` means it is ignored.") 105 @ConfigProperty(name = "talend.vault.cache.vault.auth.token", defaultValue = "-") 106 private Supplier<String> token; 107 108 @Inject 109 @Documentation("The vault role identifier to use to log in (if token is not set). `-` means it is ignored.") 110 @ConfigProperty(name = "talend.vault.cache.vault.auth.roleId", defaultValue = "-") 111 private Supplier<String> role; 112 113 @Inject 114 @Documentation("The vault secret identifier to use to log in (if token is not set). `-` means it is ignored.") 115 @ConfigProperty(name = "talend.vault.cache.vault.auth.secretId", defaultValue = "-") 116 private Supplier<String> secret; 117 118 @Inject 119 @Documentation("How often (in ms) to refresh the vault token.") 120 @ConfigProperty(name = "talend.vault.cache.service.auth.refreshDelayMargin", defaultValue = "600000") 121 private Long refreshDelayMargin; 122 123 @Inject 124 @Documentation("How long (in ms) to wait before retrying a recoverable error or refresh the vault token in case of an authentication failure.") 125 @ConfigProperty(name = "talend.vault.cache.service.auth.refreshDelayOnFailure", defaultValue = "1000") 126 private Long refreshDelayOnFailure; 127 128 @Inject 129 @Documentation("How many times do we retry a recoverable operation in case of a failure.") 130 @ConfigProperty(name = "talend.vault.cache.service.auth.numberOfRetryOnFailure", defaultValue = "3") 131 private Integer numberOfRetryOnFailure; 132 133 @Inject 134 @Documentation("Status code sent when vault can't decipher some values.") 135 @ConfigProperty(name = "talend.vault.cache.service.auth.cantDecipherStatusCode", defaultValue = "422") 136 private Integer cantDecipherStatusCode; 137 138 @Inject 139 @Documentation("The regex to whitelist ciphered keys, others will be passthrough in the output without going to vault.") 140 @ConfigProperty(name = "talend.vault.cache.service.decipher.skip.regex", defaultValue = "vault\\:v[0-9]+\\:.*") 141 private String passthroughRegex; 142 143 @Inject 144 private Cache<String, DecryptedValue> cache; 145 146 @Inject 147 private Clock clock; 148 149 private final AtomicReference<Authentication> authToken = new AtomicReference<>(); 150 151 private ScheduledExecutorService scheduledExecutorService; 152 153 private Pattern compiledPassthroughRegex; 154 155 private final Predicate<Throwable> shouldRetry = cause -> { 156 if (WebApplicationException.class.isInstance(cause)) { 157 final WebApplicationException wae = WebApplicationException.class.cast(cause); 158 final int status = wae.getResponse().getStatus(); 159 if (Status.NOT_FOUND.getStatusCode() == status || status >= 500) { 160 return false; 161 } 162 } 163 return true; 164 }; 165 166 @PostConstruct 167 private void init() { 168 compiledPassthroughRegex = Pattern.compile(passthroughRegex); 169 } 170 171 @PreDestroy 172 private void destroy() { 173 scheduledExecutorService.shutdownNow(); // we don't care anymore about these tasks 174 try { 175 scheduledExecutorService.awaitTermination(1L, MINUTES); // wait too much but enough for our goal 176 } catch (final InterruptedException e) { 177 Thread.currentThread().interrupt(); 178 } 179 } 180 181 public void init(@Observes @Initialized(ApplicationScoped.class) final ServletContext init) { 182 scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(new ThreadFactory() { 183 184 private final ThreadGroup group = ofNullable(System.getSecurityManager()) 185 .map(SecurityManager::getThreadGroup) 186 .orElseGet(() -> Thread.currentThread().getThreadGroup()); 187 188 @Override 189 public Thread newThread(final Runnable r) { 190 final Thread t = new Thread(group, r, "talend-vault-service-refresh", 0L); 191 if (t.isDaemon()) { 192 t.setDaemon(false); 193 } 194 if (t.getPriority() != Thread.NORM_PRIORITY) { 195 t.setPriority(Thread.NORM_PRIORITY); 196 } 197 return t; 198 } 199 }); 200 } 201 202 @SneakyThrows 203 public Map<String, String> decrypt(final Map<String, String> values) { 204 return decrypt(values, null); 205 } 206 207 @SneakyThrows 208 public Map<String, String> decrypt(final Map<String, String> values, final String tenantId) { 209 if ("no-vault".equals(setup.getVaultUrl())) { 210 return values; 211 } 212 final List<String> cipheredKeys = values 213 .entrySet() 214 .stream() 215 .filter(entry -> compiledPassthroughRegex.matcher(entry.getValue()).matches()) 216 .map(cyphered -> cyphered.getKey()) 217 .collect(toList()); 218 if (cipheredKeys.isEmpty()) { 219 return values; 220 } 221 Supplier<CompletableFuture<Map<String, String>>> attempt = () -> prepareRequest(values, cipheredKeys, tenantId); 222 return withRetries(attempt, shouldRetry).get(); 223 } 224 225 private CompletableFuture<Map<String, String>> prepareRequest(final Map<String, String> values, 226 final List<String> cipheredKeys, final String tenantId) { 227 return get(cipheredKeys.stream().map(values::get).collect(toList()), clock.millis(), tenantId) 228 .thenApply(decrypted -> values 229 .entrySet() 230 .stream() 231 .collect(toMap(Entry::getKey, 232 e -> of(cipheredKeys.indexOf(e.getKey())) 233 .filter(idx -> idx >= 0) 234 .map(decrypted::get) 235 .map(DecryptedValue::getValue) 236 .orElseGet(() -> values.get(e.getKey()))))); 237 } 238 239 private CompletableFuture<List<DecryptedValue>> get(final Collection<String> values, final long currentTime, 240 final String tenantId) { 241 final AtomicInteger index = new AtomicInteger(); 242 final Collection<EntryWithIndex<String>> clearValues = values 243 .stream() 244 .map(it -> new EntryWithIndex<>(index.getAndIncrement(), it)) 245 .filter(it -> it.entry != null && !compiledPassthroughRegex.matcher(it.entry).matches()) 246 .collect(toList()); 247 if (clearValues.isEmpty()) { 248 return doDecipher(values, currentTime, tenantId).toCompletableFuture(); 249 } 250 if (clearValues.size() == values.size()) { 251 final long now = clock.millis(); 252 return completedFuture(values.stream().map(it -> new DecryptedValue(it, now)).collect(toList())); 253 } 254 return doDecipher(values, currentTime, tenantId).thenApply(deciphered -> { 255 final long now = clock.millis(); 256 clearValues.forEach(entry -> deciphered.add(entry.index, new DecryptedValue(entry.entry, now))); 257 return deciphered; 258 }).toCompletableFuture(); 259 } 260 261 private CompletionStage<List<DecryptedValue>> doDecipher(final Collection<String> values, final long currentTime, 262 final String tenantId) { 263 final Map<String, Optional<DecryptedValue>> alreadyCached = 264 new HashSet<>(values).stream().collect(toMap(identity(), it -> ofNullable(cache.get(it)))); 265 final Collection<String> missing = alreadyCached 266 .entrySet() 267 .stream() 268 .filter(it -> !it.getValue().isPresent()) 269 .map(Map.Entry::getKey) 270 .collect(toList()); 271 if (missing.isEmpty()) { // no remote call, yeah 272 return completedFuture(values.stream().map(alreadyCached::get).map(Optional::get).collect(toList())); 273 } 274 // do request 275 return getOrRequestAuth() 276 // prepare decrypt request to vault 277 .thenCompose(auth -> ofNullable(auth.getAuth()).map(Auth::getClientToken).map(clientToken -> { 278 WebTarget path = vault.path(decryptEndpoint); 279 if (decryptEndpoint.contains("x-talend-tenant-id")) { 280 path = path 281 .resolveTemplate("x-talend-tenant-id", 282 ofNullable(tenantId) 283 .orElseThrow(() -> new WebApplicationException(Response 284 .status(Status.NOT_FOUND) 285 .entity(new ErrorPayload(ErrorDictionary.BAD_FORMAT, 286 "No header x-talend-tenant-id")) 287 .build()))); 288 } 289 return path 290 .request(APPLICATION_JSON_TYPE) 291 .header("X-Vault-Token", clientToken) 292 .rx() 293 .post(entity(new DecryptRequest( 294 missing.stream().map(it -> new DecryptInput(it, null, null)).collect(toList())), 295 APPLICATION_JSON_TYPE), DecryptResponse.class) 296 .toCompletableFuture() 297 // fetch decrypted values 298 .thenApply(decrypted -> { 299 final Collection<DecryptResult> results = decrypted.getData().getBatchResults(); 300 if (results.isEmpty()) { 301 throwError(cantDecipherStatusCode, "Decrypted values are empty"); 302 } 303 final List<String> errors = results 304 .stream() 305 .map(DecryptResult::getError) 306 .filter(Objects::nonNull) 307 .collect(toList()); 308 if (!errors.isEmpty()) { 309 throwError(cantDecipherStatusCode, "Can't decipher properties: " + errors); 310 } 311 final Iterator<String> keyIterator = missing.iterator(); 312 final Map<String, DecryptedValue> decryptedResults = results 313 .stream() 314 .map(it -> new String(Base64.getDecoder().decode(it.getPlaintext()), 315 StandardCharsets.UTF_8)) 316 .collect(toMap(it -> keyIterator.next(), 317 it -> new DecryptedValue(it, currentTime))); 318 cache.putAll(decryptedResults); 319 // 320 return values 321 .stream() 322 .map(it -> decryptedResults 323 .getOrDefault(it, alreadyCached.get(it).orElse(null))) 324 .collect(toList()); 325 }) 326 // oops, smtg went wrong 327 .exceptionally(e -> { 328 final Throwable cause = e.getCause(); 329 String message = ""; 330 int status = cantDecipherStatusCode; 331 if (WebApplicationException.class.isInstance(cause)) { 332 final WebApplicationException wae = WebApplicationException.class.cast(cause); 333 final Response response = wae.getResponse(); 334 if (response != null) { 335 if (ErrorPayload.class.isInstance(response.getEntity())) { // internal error 336 throw wae; 337 } else { 338 try { 339 message = response.readEntity(String.class); 340 } catch (final Exception ignored) { 341 // no-op 342 } 343 } 344 status = response.getStatus(); 345 if (status == Status.NOT_FOUND.getStatusCode() && message.isEmpty()) { 346 message = "Decryption failed: Endpoint not found, check your setup."; 347 } 348 } 349 } 350 if (message.isEmpty()) { 351 message = String.format("Decryption failed: %s", cause.getMessage()); 352 } 353 log.error("{} ({}).", message, status); 354 throw new WebApplicationException(message, 355 Response 356 .status(status) 357 .entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, message)) 358 .build()); 359 }); 360 }) 361 .orElseThrow(() -> new WebApplicationException(Response 362 .status(Response.Status.FORBIDDEN) 363 .entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, "getOrRequestAuth failed")) 364 .build()))); 365 } 366 367 private CompletionStage<Authentication> getOrRequestAuth() { 368 return of(token.get()).filter(this::isReloadableConfigSet).map(value -> { 369 final Auth authInfo = new Auth(); 370 authInfo.setClientToken(value); 371 authInfo.setLeaseDuration(Long.MAX_VALUE); 372 authInfo.setRenewable(false); 373 return completedFuture(new Authentication(authInfo, Long.MAX_VALUE)); 374 }).orElseGet(() -> { 375 final String role = of(this.role.get()).filter(this::isReloadableConfigSet).orElse(null); 376 final String secret = of(this.secret.get()).filter(this::isReloadableConfigSet).orElse(null); 377 return ofNullable(authToken.get()) 378 .filter(auth -> (auth.getExpiresAt() - clock.millis()) <= refreshDelayMargin) // is expired 379 .map(CompletableFuture::completedFuture) 380 .orElseGet(() -> doAuth(role, secret).toCompletableFuture()); 381 }); 382 } 383 384 private CompletionStage<Authentication> doAuth(final String role, final String secret) { 385 log.info("Authenticating to vault"); 386 return vault 387 .path(authEndpoint) 388 .request(APPLICATION_JSON_TYPE) 389 .rx() 390 .post(entity(new AuthRequest(role, secret), APPLICATION_JSON_TYPE), AuthResponse.class) 391 // 392 .thenApply(token -> { 393 log.debug("Authenticated to vault"); 394 if (token.getAuth() == null || token.getAuth().getClientToken() == null) { 395 throwError(500, "Vault didn't return a token"); 396 } else { 397 log.info("Authenticated to vault"); 398 } 399 final long validityMargin = TimeUnit.SECONDS.toMillis(token.getAuth().getLeaseDuration()); 400 final long nextRefresh = clock.millis() + validityMargin - refreshDelayMargin; 401 final Authentication authentication = new Authentication(token.getAuth(), nextRefresh); 402 authToken.set(authentication); 403 if (!scheduledExecutorService.isShutdown() && token.getAuth().isRenewable()) { 404 scheduledExecutorService.schedule(() -> doAuth(role, secret), nextRefresh, MILLISECONDS); 405 } 406 return authentication; 407 }) 408 // 409 .exceptionally(e -> { 410 final Throwable cause = e.getCause(); 411 if (WebApplicationException.class.isInstance(cause)) { 412 final WebApplicationException wae = WebApplicationException.class.cast(cause); 413 final Response response = wae.getResponse(); 414 String message = ""; 415 if (ErrorPayload.class.isInstance(wae.getResponse().getEntity())) { 416 throw wae; // already logged and setup broken so just rethrow 417 } else { 418 try { 419 message = response.readEntity(String.class); 420 } catch (final Exception ignored) { 421 // no-op 422 } 423 if (message.isEmpty()) { 424 message = cause.getMessage(); 425 } 426 throwError(response.getStatus(), message); 427 } 428 } 429 throwError(cause); 430 return null; 431 }); 432 } 433 434 private <T> CompletableFuture<T> withRetries(final Supplier<CompletableFuture<T>> attempt, 435 final Predicate<Throwable> shouldRetry) { 436 Executor scheduler = r -> scheduledExecutorService.schedule(r, refreshDelayOnFailure, TimeUnit.MILLISECONDS); 437 CompletableFuture<T> firstAttempt = attempt.get(); 438 return flatten(firstAttempt 439 .thenApply(CompletableFuture::completedFuture) 440 .exceptionally(throwable -> retryFuture(attempt, 1, throwable, shouldRetry, scheduler))); 441 } 442 443 private <T> CompletableFuture<T> retryFuture(final Supplier<CompletableFuture<T>> attempter, 444 final int attemptsSoFar, final Throwable throwable, final Predicate<Throwable> shouldRetry, 445 final Executor scheduler) { 446 int nextAttempt = attemptsSoFar + 1; 447 log 448 .info("[retryFuture] Retry failed operation ({}/{}). Reason: {}.", attemptsSoFar, 449 numberOfRetryOnFailure, throwable.getMessage()); 450 if (nextAttempt > numberOfRetryOnFailure || !shouldRetry.test(throwable.getCause())) { 451 log.info("[retryFuture] Stop retry failed operation (condition triggered)."); 452 throwError(throwable.getCause()); 453 } 454 return flatten(flatten(CompletableFuture.supplyAsync(attempter, scheduler)) 455 .thenApply(CompletableFuture::completedFuture) 456 .exceptionally( 457 nextThrowable -> retryFuture(attempter, nextAttempt, nextThrowable, shouldRetry, scheduler))); 458 } 459 460 private <T> CompletableFuture<T> flatten(final CompletableFuture<CompletableFuture<T>> completableCompletable) { 461 return completableCompletable.thenCompose(Function.identity()); 462 } 463 464 private void throwError(final int status, final String message) { 465 throw new WebApplicationException(message, 466 Response.status(status).entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, message)).build()); 467 } 468 469 private void throwError(final Throwable cause) { 470 String message = ""; 471 int status = cantDecipherStatusCode; 472 if (WebApplicationException.class.isInstance(cause)) { 473 final WebApplicationException wae = WebApplicationException.class.cast(cause); 474 final Response response = wae.getResponse(); 475 status = response.getStatus(); 476 if (response != null) { 477 if (ErrorPayload.class.isInstance(response.getEntity())) { // internal error 478 throw wae; 479 } else { 480 try { 481 message = response.readEntity(String.class); 482 } catch (final Exception ignored) { 483 // no-op 484 } 485 } 486 } 487 } 488 if (message.isEmpty()) { 489 message = cause.getMessage(); 490 } 491 throw new WebApplicationException(message, 492 Response.status(status).entity(new ErrorPayload(ErrorDictionary.UNEXPECTED, message)).build()); 493 } 494 495 // workaround while geronimo-config does not support generics of generics 496 // (1.2.1 in org.apache.geronimo.config.cdi.ConfigInjectionBean.create) 497 private boolean isReloadableConfigSet(final String value) { 498 return !"-".equals(value); 499 } 500 501 @Data 502 @NoArgsConstructor 503 @AllArgsConstructor 504 public static class DecryptInput { 505 506 private String ciphertext; 507 508 private String context; // only when derivation is activated 509 510 private String nonce; // only when convergent encryption is activated 511 } 512 513 @Data 514 @NoArgsConstructor 515 @AllArgsConstructor 516 public static class DecryptRequest { 517 518 @JsonbProperty("batch_input") 519 private Collection<DecryptInput> batchInput; 520 } 521 522 @Data 523 public static class DecryptResult { 524 525 private String plaintext; 526 527 private String context; 528 529 private String error; 530 } 531 532 @Data 533 public static class DecryptResponse { 534 535 private DecryptData data; 536 } 537 538 @Data 539 public static class DecryptData { 540 541 @JsonbProperty("batch_results") 542 private Collection<DecryptResult> batchResults; 543 } 544 545 @Data 546 @NoArgsConstructor 547 @AllArgsConstructor 548 public static class AuthRequest { 549 550 @JsonbProperty("role_id") 551 private String roleId; 552 553 @JsonbProperty("secret_id") 554 private String secretId; 555 } 556 557 @Data 558 public static class Auth { 559 560 private boolean renewable; 561 562 @JsonbProperty("lease_duration") 563 private long leaseDuration; 564 565 @JsonbProperty("client_token") 566 private String clientToken; 567 } 568 569 @Data 570 public static class AuthResponse { 571 572 private Auth auth; 573 } 574 575 @Data 576 private static class Authentication { 577 578 private final Auth auth; 579 580 private final long expiresAt; 581 } 582 583 @RequiredArgsConstructor 584 private static class EntryWithIndex<T> { 585 586 private final int index; 587 588 private final T entry; 589 } 590}