001/**
002 * Copyright (C) 2006-2020 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.server.configuration;
017
018import static java.util.Collections.emptyList;
019import static java.util.Collections.emptyMap;
020import static java.util.Collections.singletonList;
021import static java.util.Locale.ENGLISH;
022
023import java.io.IOException;
024import java.io.InputStream;
025import java.io.OutputStream;
026import java.nio.ByteBuffer;
027import java.nio.charset.StandardCharsets;
028import java.util.Collection;
029import java.util.Comparator;
030import java.util.HashMap;
031import java.util.List;
032import java.util.Map;
033import java.util.Set;
034import java.util.Spliterator;
035import java.util.Spliterators;
036import java.util.TreeMap;
037import java.util.logging.Logger;
038import java.util.stream.Stream;
039import java.util.stream.StreamSupport;
040
041import javax.enterprise.context.Dependent;
042import javax.enterprise.inject.Instance;
043import javax.inject.Inject;
044import javax.servlet.ServletConfig;
045import javax.servlet.ServletContext;
046import javax.servlet.ServletContextEvent;
047import javax.servlet.ServletContextListener;
048import javax.servlet.ServletException;
049import javax.servlet.annotation.WebListener;
050import javax.servlet.http.HttpServletRequest;
051import javax.servlet.http.HttpServletResponse;
052import javax.websocket.CloseReason;
053import javax.websocket.DeploymentException;
054import javax.websocket.Endpoint;
055import javax.websocket.EndpointConfig;
056import javax.websocket.Session;
057import javax.websocket.server.ServerContainer;
058import javax.websocket.server.ServerEndpointConfig;
059import javax.ws.rs.ApplicationPath;
060import javax.ws.rs.core.Application;
061import javax.ws.rs.core.HttpHeaders;
062import javax.xml.namespace.QName;
063
064import org.apache.cxf.Bus;
065import org.apache.cxf.common.logging.LogUtils;
066import org.apache.cxf.continuations.Continuation;
067import org.apache.cxf.continuations.ContinuationCallback;
068import org.apache.cxf.continuations.ContinuationProvider;
069import org.apache.cxf.endpoint.ServerRegistry;
070import org.apache.cxf.jaxrs.JAXRSServiceFactoryBean;
071import org.apache.cxf.message.ExchangeImpl;
072import org.apache.cxf.message.Message;
073import org.apache.cxf.message.MessageImpl;
074import org.apache.cxf.service.model.EndpointInfo;
075import org.apache.cxf.transport.AbstractDestination;
076import org.apache.cxf.transport.Conduit;
077import org.apache.cxf.transport.MessageObserver;
078import org.apache.cxf.transport.http.AbstractHTTPDestination;
079import org.apache.cxf.transport.http.ContinuationProviderFactory;
080import org.apache.cxf.transport.http.DestinationRegistry;
081import org.apache.cxf.transport.http.HTTPSession;
082import org.apache.cxf.transport.servlet.ServletController;
083import org.apache.cxf.transport.servlet.ServletDestination;
084import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet;
085import org.apache.cxf.transports.http.configuration.HTTPServerPolicy;
086import org.apache.cxf.ws.addressing.EndpointReferenceType;
087import org.talend.sdk.component.server.front.cxf.CxfExtractor;
088import org.talend.sdk.component.server.front.memory.InMemoryRequest;
089import org.talend.sdk.component.server.front.memory.InMemoryResponse;
090import org.talend.sdk.component.server.front.memory.MemoryInputStream;
091import org.talend.sdk.component.server.front.memory.SimpleServletConfig;
092
093import lombok.Data;
094import lombok.EqualsAndHashCode;
095import lombok.extern.slf4j.Slf4j;
096
097// ensure any JAX-RS command can use websockets
098@Slf4j
099@Dependent
100@WebListener
101public class WebSocketBroadcastSetup implements ServletContextListener {
102
103    private static final String EOM = "^@";
104
105    @Inject
106    private Bus bus;
107
108    @Inject
109    private CxfExtractor cxf;
110
111    @Inject
112    private Instance<Application> applications;
113
114    @Override
115    public void contextInitialized(final ServletContextEvent sce) {
116        final ServerContainer container =
117                ServerContainer.class.cast(sce.getServletContext().getAttribute(ServerContainer.class.getName()));
118
119        final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class
120                .cast(bus
121                        .getExtension(ServerRegistry.class)
122                        .getServers()
123                        .iterator()
124                        .next()
125                        .getEndpoint()
126                        .get(JAXRSServiceFactoryBean.class.getName()));
127
128        final String appBase = StreamSupport
129                .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false)
130                .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class))
131                .map(a -> a.getClass().getAnnotation(ApplicationPath.class))
132                .map(ApplicationPath::value)
133                .findFirst()
134                .map(s -> !s.startsWith("/") ? "/" + s : s)
135                .orElse("/api/v1");
136        final String version = appBase.replaceFirst("/api", "");
137
138        final DestinationRegistry registry = cxf.getRegistry();
139        final ServletContext servletContext = sce.getServletContext();
140
141        final WebSocketRegistry webSocketRegistry = new WebSocketRegistry(registry);
142        final ServletController controller = new ServletController(webSocketRegistry,
143                new SimpleServletConfig(servletContext, "Talend Component Kit Websocket Transport"),
144                new ServiceListGeneratorServlet(registry, bus));
145        webSocketRegistry.controller = controller;
146
147        Stream
148                .concat(factory
149                        .getClassResourceInfo()
150                        .stream()
151                        .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream())
152                        .filter(cri -> cri.getAnnotatedMethod().getDeclaringClass().getName().startsWith("org.talend."))
153                        .map(ori -> {
154                            final String uri = ori.getClassResourceInfo().getURITemplate().getValue()
155                                    + ori.getURITemplate().getValue();
156                            return ServerEndpointConfig.Builder
157                                    .create(Endpoint.class,
158                                            "/websocket" + version + "/"
159                                                    + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri)
160                                    .configurator(new ServerEndpointConfig.Configurator() {
161
162                                        @Override
163                                        public <T> T getEndpointInstance(final Class<T> clazz)
164                                                throws InstantiationException {
165                                            final Map<String, List<String>> headers = new HashMap<>();
166                                            if (!ori.getProduceTypes().isEmpty()) {
167                                                headers
168                                                        .put(HttpHeaders.CONTENT_TYPE, singletonList(
169                                                                ori.getProduceTypes().iterator().next().toString()));
170                                            }
171                                            if (!ori.getConsumeTypes().isEmpty()) {
172                                                headers
173                                                        .put(HttpHeaders.ACCEPT, singletonList(
174                                                                ori.getConsumeTypes().iterator().next().toString()));
175                                            }
176                                            return (T) new JAXRSEndpoint(appBase, controller, servletContext,
177                                                    ori.getHttpMethod(), uri, headers);
178                                        }
179                                    })
180                                    .build();
181                        }),
182                        Stream
183                                .of(ServerEndpointConfig.Builder
184                                        .create(Endpoint.class, "/websocket" + version + "/bus")
185                                        .configurator(new ServerEndpointConfig.Configurator() {
186
187                                            @Override
188                                            public <T> T getEndpointInstance(final Class<T> clazz)
189                                                    throws InstantiationException {
190
191                                                return (T) new JAXRSEndpoint(appBase, controller, servletContext, "GET",
192                                                        "/", emptyMap());
193                                            }
194                                        })
195                                        .build()))
196                .sorted(Comparator.comparing(ServerEndpointConfig::getPath))
197                .peek(e -> log.info("Deploying WebSocket(path={})", e.getPath()))
198                .forEach(config -> {
199                    try {
200                        container.addEndpoint(config);
201                    } catch (final DeploymentException e) {
202                        throw new IllegalStateException(e);
203                    }
204                });
205    }
206
207    @Data
208    @EqualsAndHashCode(callSuper = false)
209    private static class JAXRSEndpoint extends Endpoint {
210
211        private final String appBase;
212
213        private final ServletController controller;
214
215        private final ServletContext context;
216
217        private final String defaultMethod;
218
219        private final String defaultUri;
220
221        private final Map<String, List<String>> baseHeaders;
222
223        @Override
224        public void onOpen(final Session session, final EndpointConfig endpointConfig) {
225            log.debug("Opened session {}", session.getId());
226            session.addMessageHandler(InputStream.class, message -> {
227                final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
228                headers.putAll(baseHeaders);
229
230                final StringBuilder buffer = new StringBuilder(128);
231                try { // read headers from the message
232                    if (!"SEND".equalsIgnoreCase(readLine(buffer, message))) {
233                        throw new IllegalArgumentException("not a message");
234                    }
235
236                    String line;
237                    int del;
238                    while ((line = readLine(buffer, message)) != null) {
239                        final boolean done = line.endsWith(EOM);
240                        if (done) {
241                            line = line.substring(0, line.length() - EOM.length());
242                        }
243                        if (!line.isEmpty()) {
244                            del = line.indexOf(':');
245                            if (del < 0) {
246                                headers.put(line.trim(), emptyList());
247                            } else {
248                                headers
249                                        .put(line.substring(0, del).trim(),
250                                                singletonList(line.substring(del + 1).trim()));
251                            }
252                        }
253                        if (done) {
254                            break;
255                        }
256                    }
257                } catch (final IOException ioe) {
258                    throw new IllegalStateException(ioe);
259                }
260
261                final List<String> uris = headers.get("destination");
262                final String uri;
263                if (uris == null || uris.isEmpty()) {
264                    uri = defaultUri;
265                } else {
266                    uri = uris.iterator().next();
267                }
268
269                final List<String> methods = headers.get("destinationMethod");
270                final String method;
271                if (methods == null || methods.isEmpty()) {
272                    method = defaultMethod;
273                } else {
274                    method = methods.iterator().next();
275                }
276
277                final String queryString;
278                final String path;
279                final int query = uri.indexOf('?');
280                if (query > 0) {
281                    queryString = uri.substring(query + 1);
282                    path = uri.substring(0, query);
283                } else {
284                    queryString = null;
285                    path = uri;
286                }
287
288                try {
289                    final InMemoryRequest request = new InMemoryRequest(method.toUpperCase(ENGLISH), headers, path,
290                            appBase + path, appBase, queryString, 8080, context, new WebSocketInputStream(message),
291                            session::getUserPrincipal, controller);
292                    final InMemoryResponse response = new InMemoryResponse(session::isOpen, () -> {
293                        if (session.getBasicRemote().getBatchingAllowed()) {
294                            try {
295                                session.getBasicRemote().flushBatch();
296                            } catch (final IOException e) {
297                                throw new IllegalStateException(e);
298                            }
299                        }
300                    }, bytes -> {
301                        try {
302                            session.getBasicRemote().sendBinary(ByteBuffer.wrap(bytes));
303                        } catch (final IOException e) {
304                            throw new IllegalStateException(e);
305                        }
306                    }, (status, responseHeaders) -> {
307                        final StringBuilder top = new StringBuilder("MESSAGE\r\n");
308                        top.append("status: ").append(status).append("\r\n");
309                        responseHeaders
310                                .forEach((k,
311                                        v) -> top.append(k).append(": ").append(String.join(",", v)).append("\r\n"));
312                        top.append("\r\n");// empty line, means the next bytes are the payload
313                        return top.toString();
314                    }) {
315
316                        @Override
317                        protected void onClose(final OutputStream stream) throws IOException {
318                            stream.write(EOM.getBytes(StandardCharsets.UTF_8));
319                        }
320                    };
321                    request.setResponse(response);
322                    controller.invoke(request, response);
323                } catch (final ServletException e) {
324                    throw new IllegalArgumentException(e);
325                }
326            });
327        }
328
329        @Override
330        public void onClose(final Session session, final CloseReason closeReason) {
331            log.debug("Closed session {}", session.getId());
332        }
333
334        @Override
335        public void onError(final Session session, final Throwable throwable) {
336            log.warn("Error for session {}", session.getId(), throwable);
337        }
338
339        private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException {
340            int c;
341            while ((c = in.read()) != -1) {
342                if (c == '\n') {
343                    break;
344                } else if (c != '\r') {
345                    buffer.append((char) c);
346                }
347            }
348
349            if (buffer.length() == 0) {
350                return null;
351            }
352            final String string = buffer.toString();
353            buffer.setLength(0);
354            return string;
355        }
356    }
357
358    private static class WebSocketInputStream extends MemoryInputStream {
359
360        private int previous = Integer.MAX_VALUE;
361
362        private WebSocketInputStream(final InputStream delegate) {
363            super(delegate);
364        }
365
366        @Override
367        public int read() throws IOException {
368            if (finished) {
369                return -1;
370            }
371            if (previous != Integer.MAX_VALUE) {
372                previous = Integer.MAX_VALUE;
373                return previous;
374            }
375            final int read = delegate.read();
376            if (read == '^') {
377                previous = delegate.read();
378                if (previous == '@') {
379                    finished = true;
380                    return -1;
381                }
382            }
383            if (read < 0) {
384                finished = true;
385            }
386            return read;
387        }
388    }
389
390    private static class WebSocketRegistry implements DestinationRegistry {
391
392        private final DestinationRegistry delegate;
393
394        private ServletController controller;
395
396        private WebSocketRegistry(final DestinationRegistry registry) {
397            this.delegate = registry;
398        }
399
400        @Override
401        public void addDestination(final AbstractHTTPDestination destination) {
402            throw new UnsupportedOperationException();
403        }
404
405        @Override
406        public void removeDestination(final String path) {
407            throw new UnsupportedOperationException();
408        }
409
410        @Override
411        public AbstractHTTPDestination getDestinationForPath(final String path) {
412            return wrap(delegate.getDestinationForPath(path));
413        }
414
415        @Override
416        public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) {
417            return wrap(delegate.getDestinationForPath(path, tryDecoding));
418        }
419
420        @Override
421        public AbstractHTTPDestination checkRestfulRequest(final String address) {
422            return wrap(delegate.checkRestfulRequest(address));
423        }
424
425        @Override
426        public Collection<AbstractHTTPDestination> getDestinations() {
427            return delegate.getDestinations();
428        }
429
430        @Override
431        public AbstractDestination[] getSortedDestinations() {
432            return delegate.getSortedDestinations();
433        }
434
435        @Override
436        public Set<String> getDestinationsPaths() {
437            return delegate.getDestinationsPaths();
438        }
439
440        private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) {
441            try {
442                return destination == null ? null : new WebSocketDestination(destination, this);
443            } catch (final IOException e) {
444                throw new IllegalStateException(e);
445            }
446        }
447    }
448
449    private static class WebSocketDestination extends AbstractHTTPDestination {
450
451        static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class);
452
453        private final AbstractHTTPDestination delegate;
454
455        private WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry)
456                throws IOException {
457            super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false);
458            this.delegate = delegate;
459            this.cproviderFactory = new WebSocketContinuationFactory(registry);
460        }
461
462        @Override
463        public EndpointReferenceType getAddress() {
464            return delegate.getAddress();
465        }
466
467        @Override
468        public Conduit getBackChannel(final Message inMessage) throws IOException {
469            return delegate.getBackChannel(inMessage);
470        }
471
472        @Override
473        public EndpointInfo getEndpointInfo() {
474            return delegate.getEndpointInfo();
475        }
476
477        @Override
478        public void shutdown() {
479            throw new UnsupportedOperationException();
480        }
481
482        @Override
483        public void setMessageObserver(final MessageObserver observer) {
484            throw new UnsupportedOperationException();
485        }
486
487        @Override
488        public MessageObserver getMessageObserver() {
489            return delegate.getMessageObserver();
490        }
491
492        @Override
493        protected Logger getLogger() {
494            return LOG;
495        }
496
497        @Override
498        public Bus getBus() {
499            return delegate.getBus();
500        }
501
502        @Override
503        public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req,
504                final HttpServletResponse resp) throws IOException {
505            // eager create the message to ensure we set our continuation for @Suspended
506            Message inMessage = retrieveFromContinuation(req);
507            if (inMessage == null) {
508                inMessage = new MessageImpl();
509
510                final ExchangeImpl exchange = new ExchangeImpl();
511                exchange.setInMessage(inMessage);
512                setupMessage(inMessage, config, context, req, resp);
513
514                exchange.setSession(new HTTPSession(req));
515                MessageImpl.class.cast(inMessage).setDestination(this);
516            }
517
518            delegate.invoke(config, context, req, resp);
519        }
520
521        @Override
522        public void finalizeConfig() {
523            delegate.finalizeConfig();
524        }
525
526        @Override
527        public String getBeanName() {
528            return delegate.getBeanName();
529        }
530
531        @Override
532        public EndpointReferenceType getAddressWithId(final String id) {
533            return delegate.getAddressWithId(id);
534        }
535
536        @Override
537        public String getId(final Map<String, Object> context) {
538            return delegate.getId(context);
539        }
540
541        @Override
542        public String getContextMatchStrategy() {
543            return delegate.getContextMatchStrategy();
544        }
545
546        @Override
547        public boolean isFixedParameterOrder() {
548            return delegate.isFixedParameterOrder();
549        }
550
551        @Override
552        public boolean isMultiplexWithAddress() {
553            return delegate.isMultiplexWithAddress();
554        }
555
556        @Override
557        public HTTPServerPolicy getServer() {
558            return delegate.getServer();
559        }
560
561        @Override
562        public void assertMessage(final Message message) {
563            delegate.assertMessage(message);
564        }
565
566        @Override
567        public boolean canAssert(final QName type) {
568            return delegate.canAssert(type);
569        }
570
571        @Override
572        public String getPath() {
573            return delegate.getPath();
574        }
575    }
576
577    private static class WebSocketContinuationFactory implements ContinuationProviderFactory {
578
579        private static final String KEY = WebSocketContinuationFactory.class.getName();
580
581        private final WebSocketRegistry registry;
582
583        private WebSocketContinuationFactory(final WebSocketRegistry registry) {
584            this.registry = registry;
585        }
586
587        @Override
588        public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req,
589                final HttpServletResponse resp) {
590            return new WebSocketContinuation(inMessage, req, resp, registry);
591        }
592
593        @Override
594        public Message retrieveFromContinuation(final HttpServletRequest req) {
595            return Message.class.cast(req.getAttribute(KEY));
596        }
597    }
598
599    private static class WebSocketContinuation implements ContinuationProvider, Continuation {
600
601        private final Message message;
602
603        private final HttpServletRequest request;
604
605        private final HttpServletResponse response;
606
607        private final WebSocketRegistry registry;
608
609        private final ContinuationCallback callback;
610
611        private Object object;
612
613        private boolean resumed;
614
615        private boolean pending;
616
617        private boolean isNew;
618
619        private WebSocketContinuation(final Message message, final HttpServletRequest request,
620                final HttpServletResponse response, final WebSocketRegistry registry) {
621            this.message = message;
622            this.request = request;
623            this.response = response;
624            this.registry = registry;
625            this.request
626                    .setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE,
627                            message.getExchange().getInMessage());
628            this.callback = message.getExchange().get(ContinuationCallback.class);
629        }
630
631        @Override
632        public Continuation getContinuation() {
633            return this;
634        }
635
636        @Override
637        public void complete() {
638            message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE);
639            if (callback != null) {
640                final Exception ex = message.getExchange().get(Exception.class);
641                if (ex == null) {
642                    callback.onComplete();
643                } else {
644                    callback.onError(ex);
645                }
646            }
647            try {
648                response.getWriter().close();
649            } catch (final IOException e) {
650                throw new IllegalStateException(e);
651            }
652        }
653
654        @Override
655        public boolean suspend(final long timeout) {
656            isNew = false;
657            resumed = false;
658            pending = true;
659            message.getExchange().getInMessage().getInterceptorChain().suspend();
660            return true;
661        }
662
663        @Override
664        public void resume() {
665            resumed = true;
666            try {
667                registry.controller.invoke(request, response);
668            } catch (final ServletException e) {
669                throw new IllegalStateException(e);
670            }
671        }
672
673        @Override
674        public void reset() {
675            pending = false;
676            resumed = false;
677            isNew = false;
678            object = null;
679        }
680
681        @Override
682        public boolean isNew() {
683            return isNew;
684        }
685
686        @Override
687        public boolean isPending() {
688            return pending;
689        }
690
691        @Override
692        public boolean isResumed() {
693            return resumed;
694        }
695
696        @Override
697        public boolean isTimeout() {
698            return false;
699        }
700
701        @Override
702        public Object getObject() {
703            return object;
704        }
705
706        @Override
707        public void setObject(final Object o) {
708            object = o;
709        }
710
711        @Override
712        public boolean isReadyForWrite() {
713            return true;
714        }
715    }
716}