001/** 002 * Copyright (C) 2006-2020 Talend Inc. - www.talend.com 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.talend.sdk.component.server.configuration; 017 018import static java.util.Collections.emptyList; 019import static java.util.Collections.emptyMap; 020import static java.util.Collections.singletonList; 021import static java.util.Locale.ENGLISH; 022 023import java.io.IOException; 024import java.io.InputStream; 025import java.io.OutputStream; 026import java.nio.ByteBuffer; 027import java.nio.charset.StandardCharsets; 028import java.util.Collection; 029import java.util.Comparator; 030import java.util.HashMap; 031import java.util.List; 032import java.util.Map; 033import java.util.Set; 034import java.util.Spliterator; 035import java.util.Spliterators; 036import java.util.TreeMap; 037import java.util.logging.Logger; 038import java.util.stream.Stream; 039import java.util.stream.StreamSupport; 040 041import javax.enterprise.context.Dependent; 042import javax.enterprise.inject.Instance; 043import javax.inject.Inject; 044import javax.servlet.ServletConfig; 045import javax.servlet.ServletContext; 046import javax.servlet.ServletContextEvent; 047import javax.servlet.ServletContextListener; 048import javax.servlet.ServletException; 049import javax.servlet.annotation.WebListener; 050import javax.servlet.http.HttpServletRequest; 051import javax.servlet.http.HttpServletResponse; 052import javax.websocket.CloseReason; 053import javax.websocket.DeploymentException; 054import javax.websocket.Endpoint; 055import javax.websocket.EndpointConfig; 056import javax.websocket.Session; 057import javax.websocket.server.ServerContainer; 058import javax.websocket.server.ServerEndpointConfig; 059import javax.ws.rs.ApplicationPath; 060import javax.ws.rs.core.Application; 061import javax.ws.rs.core.HttpHeaders; 062import javax.xml.namespace.QName; 063 064import org.apache.cxf.Bus; 065import org.apache.cxf.common.logging.LogUtils; 066import org.apache.cxf.continuations.Continuation; 067import org.apache.cxf.continuations.ContinuationCallback; 068import org.apache.cxf.continuations.ContinuationProvider; 069import org.apache.cxf.endpoint.ServerRegistry; 070import org.apache.cxf.jaxrs.JAXRSServiceFactoryBean; 071import org.apache.cxf.message.ExchangeImpl; 072import org.apache.cxf.message.Message; 073import org.apache.cxf.message.MessageImpl; 074import org.apache.cxf.service.model.EndpointInfo; 075import org.apache.cxf.transport.AbstractDestination; 076import org.apache.cxf.transport.Conduit; 077import org.apache.cxf.transport.MessageObserver; 078import org.apache.cxf.transport.http.AbstractHTTPDestination; 079import org.apache.cxf.transport.http.ContinuationProviderFactory; 080import org.apache.cxf.transport.http.DestinationRegistry; 081import org.apache.cxf.transport.http.HTTPSession; 082import org.apache.cxf.transport.servlet.ServletController; 083import org.apache.cxf.transport.servlet.ServletDestination; 084import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet; 085import org.apache.cxf.transports.http.configuration.HTTPServerPolicy; 086import org.apache.cxf.ws.addressing.EndpointReferenceType; 087import org.talend.sdk.component.server.front.cxf.CxfExtractor; 088import org.talend.sdk.component.server.front.memory.InMemoryRequest; 089import org.talend.sdk.component.server.front.memory.InMemoryResponse; 090import org.talend.sdk.component.server.front.memory.MemoryInputStream; 091import org.talend.sdk.component.server.front.memory.SimpleServletConfig; 092 093import lombok.Data; 094import lombok.EqualsAndHashCode; 095import lombok.extern.slf4j.Slf4j; 096 097// ensure any JAX-RS command can use websockets 098@Slf4j 099@Dependent 100@WebListener 101public class WebSocketBroadcastSetup implements ServletContextListener { 102 103 private static final String EOM = "^@"; 104 105 @Inject 106 private Bus bus; 107 108 @Inject 109 private CxfExtractor cxf; 110 111 @Inject 112 private Instance<Application> applications; 113 114 @Override 115 public void contextInitialized(final ServletContextEvent sce) { 116 final ServerContainer container = 117 ServerContainer.class.cast(sce.getServletContext().getAttribute(ServerContainer.class.getName())); 118 119 final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class 120 .cast(bus 121 .getExtension(ServerRegistry.class) 122 .getServers() 123 .iterator() 124 .next() 125 .getEndpoint() 126 .get(JAXRSServiceFactoryBean.class.getName())); 127 128 final String appBase = StreamSupport 129 .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false) 130 .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class)) 131 .map(a -> a.getClass().getAnnotation(ApplicationPath.class)) 132 .map(ApplicationPath::value) 133 .findFirst() 134 .map(s -> !s.startsWith("/") ? "/" + s : s) 135 .orElse("/api/v1"); 136 final String version = appBase.replaceFirst("/api", ""); 137 138 final DestinationRegistry registry = cxf.getRegistry(); 139 final ServletContext servletContext = sce.getServletContext(); 140 141 final WebSocketRegistry webSocketRegistry = new WebSocketRegistry(registry); 142 final ServletController controller = new ServletController(webSocketRegistry, 143 new SimpleServletConfig(servletContext, "Talend Component Kit Websocket Transport"), 144 new ServiceListGeneratorServlet(registry, bus)); 145 webSocketRegistry.controller = controller; 146 147 Stream 148 .concat(factory 149 .getClassResourceInfo() 150 .stream() 151 .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream()) 152 .filter(cri -> cri.getAnnotatedMethod().getDeclaringClass().getName().startsWith("org.talend.")) 153 .map(ori -> { 154 final String uri = ori.getClassResourceInfo().getURITemplate().getValue() 155 + ori.getURITemplate().getValue(); 156 return ServerEndpointConfig.Builder 157 .create(Endpoint.class, 158 "/websocket" + version + "/" 159 + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri) 160 .configurator(new ServerEndpointConfig.Configurator() { 161 162 @Override 163 public <T> T getEndpointInstance(final Class<T> clazz) 164 throws InstantiationException { 165 final Map<String, List<String>> headers = new HashMap<>(); 166 if (!ori.getProduceTypes().isEmpty()) { 167 headers 168 .put(HttpHeaders.CONTENT_TYPE, singletonList( 169 ori.getProduceTypes().iterator().next().toString())); 170 } 171 if (!ori.getConsumeTypes().isEmpty()) { 172 headers 173 .put(HttpHeaders.ACCEPT, singletonList( 174 ori.getConsumeTypes().iterator().next().toString())); 175 } 176 return (T) new JAXRSEndpoint(appBase, controller, servletContext, 177 ori.getHttpMethod(), uri, headers); 178 } 179 }) 180 .build(); 181 }), 182 Stream 183 .of(ServerEndpointConfig.Builder 184 .create(Endpoint.class, "/websocket" + version + "/bus") 185 .configurator(new ServerEndpointConfig.Configurator() { 186 187 @Override 188 public <T> T getEndpointInstance(final Class<T> clazz) 189 throws InstantiationException { 190 191 return (T) new JAXRSEndpoint(appBase, controller, servletContext, "GET", 192 "/", emptyMap()); 193 } 194 }) 195 .build())) 196 .sorted(Comparator.comparing(ServerEndpointConfig::getPath)) 197 .peek(e -> log.info("Deploying WebSocket(path={})", e.getPath())) 198 .forEach(config -> { 199 try { 200 container.addEndpoint(config); 201 } catch (final DeploymentException e) { 202 throw new IllegalStateException(e); 203 } 204 }); 205 } 206 207 @Data 208 @EqualsAndHashCode(callSuper = false) 209 private static class JAXRSEndpoint extends Endpoint { 210 211 private final String appBase; 212 213 private final ServletController controller; 214 215 private final ServletContext context; 216 217 private final String defaultMethod; 218 219 private final String defaultUri; 220 221 private final Map<String, List<String>> baseHeaders; 222 223 @Override 224 public void onOpen(final Session session, final EndpointConfig endpointConfig) { 225 log.debug("Opened session {}", session.getId()); 226 session.addMessageHandler(InputStream.class, message -> { 227 final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); 228 headers.putAll(baseHeaders); 229 230 final StringBuilder buffer = new StringBuilder(128); 231 try { // read headers from the message 232 if (!"SEND".equalsIgnoreCase(readLine(buffer, message))) { 233 throw new IllegalArgumentException("not a message"); 234 } 235 236 String line; 237 int del; 238 while ((line = readLine(buffer, message)) != null) { 239 final boolean done = line.endsWith(EOM); 240 if (done) { 241 line = line.substring(0, line.length() - EOM.length()); 242 } 243 if (!line.isEmpty()) { 244 del = line.indexOf(':'); 245 if (del < 0) { 246 headers.put(line.trim(), emptyList()); 247 } else { 248 headers 249 .put(line.substring(0, del).trim(), 250 singletonList(line.substring(del + 1).trim())); 251 } 252 } 253 if (done) { 254 break; 255 } 256 } 257 } catch (final IOException ioe) { 258 throw new IllegalStateException(ioe); 259 } 260 261 final List<String> uris = headers.get("destination"); 262 final String uri; 263 if (uris == null || uris.isEmpty()) { 264 uri = defaultUri; 265 } else { 266 uri = uris.iterator().next(); 267 } 268 269 final List<String> methods = headers.get("destinationMethod"); 270 final String method; 271 if (methods == null || methods.isEmpty()) { 272 method = defaultMethod; 273 } else { 274 method = methods.iterator().next(); 275 } 276 277 final String queryString; 278 final String path; 279 final int query = uri.indexOf('?'); 280 if (query > 0) { 281 queryString = uri.substring(query + 1); 282 path = uri.substring(0, query); 283 } else { 284 queryString = null; 285 path = uri; 286 } 287 288 try { 289 final InMemoryRequest request = new InMemoryRequest(method.toUpperCase(ENGLISH), headers, path, 290 appBase + path, appBase, queryString, 8080, context, new WebSocketInputStream(message), 291 session::getUserPrincipal, controller); 292 final InMemoryResponse response = new InMemoryResponse(session::isOpen, () -> { 293 if (session.getBasicRemote().getBatchingAllowed()) { 294 try { 295 session.getBasicRemote().flushBatch(); 296 } catch (final IOException e) { 297 throw new IllegalStateException(e); 298 } 299 } 300 }, bytes -> { 301 try { 302 session.getBasicRemote().sendBinary(ByteBuffer.wrap(bytes)); 303 } catch (final IOException e) { 304 throw new IllegalStateException(e); 305 } 306 }, (status, responseHeaders) -> { 307 final StringBuilder top = new StringBuilder("MESSAGE\r\n"); 308 top.append("status: ").append(status).append("\r\n"); 309 responseHeaders 310 .forEach((k, 311 v) -> top.append(k).append(": ").append(String.join(",", v)).append("\r\n")); 312 top.append("\r\n");// empty line, means the next bytes are the payload 313 return top.toString(); 314 }) { 315 316 @Override 317 protected void onClose(final OutputStream stream) throws IOException { 318 stream.write(EOM.getBytes(StandardCharsets.UTF_8)); 319 } 320 }; 321 request.setResponse(response); 322 controller.invoke(request, response); 323 } catch (final ServletException e) { 324 throw new IllegalArgumentException(e); 325 } 326 }); 327 } 328 329 @Override 330 public void onClose(final Session session, final CloseReason closeReason) { 331 log.debug("Closed session {}", session.getId()); 332 } 333 334 @Override 335 public void onError(final Session session, final Throwable throwable) { 336 log.warn("Error for session {}", session.getId(), throwable); 337 } 338 339 private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException { 340 int c; 341 while ((c = in.read()) != -1) { 342 if (c == '\n') { 343 break; 344 } else if (c != '\r') { 345 buffer.append((char) c); 346 } 347 } 348 349 if (buffer.length() == 0) { 350 return null; 351 } 352 final String string = buffer.toString(); 353 buffer.setLength(0); 354 return string; 355 } 356 } 357 358 private static class WebSocketInputStream extends MemoryInputStream { 359 360 private int previous = Integer.MAX_VALUE; 361 362 private WebSocketInputStream(final InputStream delegate) { 363 super(delegate); 364 } 365 366 @Override 367 public int read() throws IOException { 368 if (finished) { 369 return -1; 370 } 371 if (previous != Integer.MAX_VALUE) { 372 previous = Integer.MAX_VALUE; 373 return previous; 374 } 375 final int read = delegate.read(); 376 if (read == '^') { 377 previous = delegate.read(); 378 if (previous == '@') { 379 finished = true; 380 return -1; 381 } 382 } 383 if (read < 0) { 384 finished = true; 385 } 386 return read; 387 } 388 } 389 390 private static class WebSocketRegistry implements DestinationRegistry { 391 392 private final DestinationRegistry delegate; 393 394 private ServletController controller; 395 396 private WebSocketRegistry(final DestinationRegistry registry) { 397 this.delegate = registry; 398 } 399 400 @Override 401 public void addDestination(final AbstractHTTPDestination destination) { 402 throw new UnsupportedOperationException(); 403 } 404 405 @Override 406 public void removeDestination(final String path) { 407 throw new UnsupportedOperationException(); 408 } 409 410 @Override 411 public AbstractHTTPDestination getDestinationForPath(final String path) { 412 return wrap(delegate.getDestinationForPath(path)); 413 } 414 415 @Override 416 public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) { 417 return wrap(delegate.getDestinationForPath(path, tryDecoding)); 418 } 419 420 @Override 421 public AbstractHTTPDestination checkRestfulRequest(final String address) { 422 return wrap(delegate.checkRestfulRequest(address)); 423 } 424 425 @Override 426 public Collection<AbstractHTTPDestination> getDestinations() { 427 return delegate.getDestinations(); 428 } 429 430 @Override 431 public AbstractDestination[] getSortedDestinations() { 432 return delegate.getSortedDestinations(); 433 } 434 435 @Override 436 public Set<String> getDestinationsPaths() { 437 return delegate.getDestinationsPaths(); 438 } 439 440 private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) { 441 try { 442 return destination == null ? null : new WebSocketDestination(destination, this); 443 } catch (final IOException e) { 444 throw new IllegalStateException(e); 445 } 446 } 447 } 448 449 private static class WebSocketDestination extends AbstractHTTPDestination { 450 451 static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class); 452 453 private final AbstractHTTPDestination delegate; 454 455 private WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry) 456 throws IOException { 457 super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false); 458 this.delegate = delegate; 459 this.cproviderFactory = new WebSocketContinuationFactory(registry); 460 } 461 462 @Override 463 public EndpointReferenceType getAddress() { 464 return delegate.getAddress(); 465 } 466 467 @Override 468 public Conduit getBackChannel(final Message inMessage) throws IOException { 469 return delegate.getBackChannel(inMessage); 470 } 471 472 @Override 473 public EndpointInfo getEndpointInfo() { 474 return delegate.getEndpointInfo(); 475 } 476 477 @Override 478 public void shutdown() { 479 throw new UnsupportedOperationException(); 480 } 481 482 @Override 483 public void setMessageObserver(final MessageObserver observer) { 484 throw new UnsupportedOperationException(); 485 } 486 487 @Override 488 public MessageObserver getMessageObserver() { 489 return delegate.getMessageObserver(); 490 } 491 492 @Override 493 protected Logger getLogger() { 494 return LOG; 495 } 496 497 @Override 498 public Bus getBus() { 499 return delegate.getBus(); 500 } 501 502 @Override 503 public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req, 504 final HttpServletResponse resp) throws IOException { 505 // eager create the message to ensure we set our continuation for @Suspended 506 Message inMessage = retrieveFromContinuation(req); 507 if (inMessage == null) { 508 inMessage = new MessageImpl(); 509 510 final ExchangeImpl exchange = new ExchangeImpl(); 511 exchange.setInMessage(inMessage); 512 setupMessage(inMessage, config, context, req, resp); 513 514 exchange.setSession(new HTTPSession(req)); 515 MessageImpl.class.cast(inMessage).setDestination(this); 516 } 517 518 delegate.invoke(config, context, req, resp); 519 } 520 521 @Override 522 public void finalizeConfig() { 523 delegate.finalizeConfig(); 524 } 525 526 @Override 527 public String getBeanName() { 528 return delegate.getBeanName(); 529 } 530 531 @Override 532 public EndpointReferenceType getAddressWithId(final String id) { 533 return delegate.getAddressWithId(id); 534 } 535 536 @Override 537 public String getId(final Map<String, Object> context) { 538 return delegate.getId(context); 539 } 540 541 @Override 542 public String getContextMatchStrategy() { 543 return delegate.getContextMatchStrategy(); 544 } 545 546 @Override 547 public boolean isFixedParameterOrder() { 548 return delegate.isFixedParameterOrder(); 549 } 550 551 @Override 552 public boolean isMultiplexWithAddress() { 553 return delegate.isMultiplexWithAddress(); 554 } 555 556 @Override 557 public HTTPServerPolicy getServer() { 558 return delegate.getServer(); 559 } 560 561 @Override 562 public void assertMessage(final Message message) { 563 delegate.assertMessage(message); 564 } 565 566 @Override 567 public boolean canAssert(final QName type) { 568 return delegate.canAssert(type); 569 } 570 571 @Override 572 public String getPath() { 573 return delegate.getPath(); 574 } 575 } 576 577 private static class WebSocketContinuationFactory implements ContinuationProviderFactory { 578 579 private static final String KEY = WebSocketContinuationFactory.class.getName(); 580 581 private final WebSocketRegistry registry; 582 583 private WebSocketContinuationFactory(final WebSocketRegistry registry) { 584 this.registry = registry; 585 } 586 587 @Override 588 public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req, 589 final HttpServletResponse resp) { 590 return new WebSocketContinuation(inMessage, req, resp, registry); 591 } 592 593 @Override 594 public Message retrieveFromContinuation(final HttpServletRequest req) { 595 return Message.class.cast(req.getAttribute(KEY)); 596 } 597 } 598 599 private static class WebSocketContinuation implements ContinuationProvider, Continuation { 600 601 private final Message message; 602 603 private final HttpServletRequest request; 604 605 private final HttpServletResponse response; 606 607 private final WebSocketRegistry registry; 608 609 private final ContinuationCallback callback; 610 611 private Object object; 612 613 private boolean resumed; 614 615 private boolean pending; 616 617 private boolean isNew; 618 619 private WebSocketContinuation(final Message message, final HttpServletRequest request, 620 final HttpServletResponse response, final WebSocketRegistry registry) { 621 this.message = message; 622 this.request = request; 623 this.response = response; 624 this.registry = registry; 625 this.request 626 .setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE, 627 message.getExchange().getInMessage()); 628 this.callback = message.getExchange().get(ContinuationCallback.class); 629 } 630 631 @Override 632 public Continuation getContinuation() { 633 return this; 634 } 635 636 @Override 637 public void complete() { 638 message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE); 639 if (callback != null) { 640 final Exception ex = message.getExchange().get(Exception.class); 641 if (ex == null) { 642 callback.onComplete(); 643 } else { 644 callback.onError(ex); 645 } 646 } 647 try { 648 response.getWriter().close(); 649 } catch (final IOException e) { 650 throw new IllegalStateException(e); 651 } 652 } 653 654 @Override 655 public boolean suspend(final long timeout) { 656 isNew = false; 657 resumed = false; 658 pending = true; 659 message.getExchange().getInMessage().getInterceptorChain().suspend(); 660 return true; 661 } 662 663 @Override 664 public void resume() { 665 resumed = true; 666 try { 667 registry.controller.invoke(request, response); 668 } catch (final ServletException e) { 669 throw new IllegalStateException(e); 670 } 671 } 672 673 @Override 674 public void reset() { 675 pending = false; 676 resumed = false; 677 isNew = false; 678 object = null; 679 } 680 681 @Override 682 public boolean isNew() { 683 return isNew; 684 } 685 686 @Override 687 public boolean isPending() { 688 return pending; 689 } 690 691 @Override 692 public boolean isResumed() { 693 return resumed; 694 } 695 696 @Override 697 public boolean isTimeout() { 698 return false; 699 } 700 701 @Override 702 public Object getObject() { 703 return object; 704 } 705 706 @Override 707 public void setObject(final Object o) { 708 object = o; 709 } 710 711 @Override 712 public boolean isReadyForWrite() { 713 return true; 714 } 715 } 716}