001/**
002 * Copyright (C) 2006-2020 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.server.front.memory;
017
018import static java.util.Optional.ofNullable;
019
020import java.io.ByteArrayOutputStream;
021import java.io.IOException;
022import java.io.OutputStreamWriter;
023import java.nio.charset.StandardCharsets;
024import java.util.ArrayList;
025import java.util.Collection;
026import java.util.List;
027
028import javax.enterprise.inject.spi.CDI;
029import javax.servlet.AsyncContext;
030import javax.servlet.AsyncEvent;
031import javax.servlet.AsyncListener;
032import javax.servlet.ServletContext;
033import javax.servlet.ServletException;
034import javax.servlet.ServletRequest;
035import javax.servlet.ServletResponse;
036import javax.servlet.http.HttpServletRequest;
037import javax.servlet.http.HttpServletResponse;
038
039import org.apache.cxf.transport.servlet.ServletController;
040
041import lombok.RequiredArgsConstructor;
042import lombok.extern.slf4j.Slf4j;
043
044@Slf4j
045@RequiredArgsConstructor
046public class AsyncContextImpl implements AsyncContext {
047
048    private final ServletRequest request;
049
050    private final InMemoryResponse response;
051
052    private final boolean originalRequestAndResponse;
053
054    private final Collection<AsyncListener> listeners = new ArrayList<>();
055
056    private final ServletController controller;
057
058    private long timeout;
059
060    AsyncContext start() {
061        final AsyncEvent event = new AsyncEvent(this, request, response);
062        executeOnListeners(l -> l.onStartAsync(event), listeners::clear);
063        return this;
064    }
065
066    public void onError(final Throwable throwable) {
067        final AsyncEvent event = new AsyncEvent(this, request, response, throwable);
068        executeOnListeners(l -> l.onError(event), null);
069        if (!response.isCommitted() && HttpServletResponse.class.isInstance(response)) {
070            final HttpServletResponse http = HttpServletResponse.class.cast(response);
071            http.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
072        }
073        complete();
074    }
075
076    private void executeOnListeners(final UnsafeConsumer<AsyncListener> fn, final Runnable afterCopy) {
077        final List<AsyncListener> listenersCopy;
078        synchronized (listeners) {
079            listenersCopy = new ArrayList<>(listeners.size());
080            ofNullable(afterCopy).ifPresent(Runnable::run);
081        }
082        listenersCopy.forEach(listener -> {
083            try {
084                fn.accept(listener);
085            } catch (final Throwable t) {
086                log.warn("callback failed for listener of type [" + listener.getClass().getName() + "]", t);
087            }
088        });
089    }
090
091    @Override
092    public ServletRequest getRequest() {
093        return request;
094    }
095
096    @Override
097    public ServletResponse getResponse() {
098        return response;
099    }
100
101    @Override
102    public boolean hasOriginalRequestAndResponse() {
103        return originalRequestAndResponse;
104    }
105
106    @Override
107    public void dispatch() {
108        final ServletRequest servletRequest = getRequest();
109        if (!HttpServletRequest.class.isInstance(servletRequest)) {
110            throw new IllegalStateException("Not a http request: " + servletRequest);
111        }
112
113        final HttpServletRequest sr = HttpServletRequest.class.cast(servletRequest);
114
115        String path = sr.getRequestURI();
116        final String cpath = sr.getContextPath();
117        if (cpath.length() > 1) {
118            path = path.substring(cpath.length());
119        }
120        dispatch(urlDecode(path));
121    }
122
123    @Override
124    public void dispatch(final String path) {
125        dispatch(request.getServletContext(), path);
126    }
127
128    @Override
129    public void dispatch(final ServletContext context, final String path) {
130        final ServletRequest servletRequest = getRequest();
131        if (!HttpServletRequest.class.isInstance(servletRequest)) {
132            throw new IllegalStateException("Not a http request: " + servletRequest);
133        }
134
135        final HttpServletRequest request = HttpServletRequest.class.cast(servletRequest);
136        if (request.getAttribute(ASYNC_REQUEST_URI) == null) {
137            request.setAttribute(ASYNC_REQUEST_URI, request.getRequestURI());
138            request.setAttribute(ASYNC_CONTEXT_PATH, request.getContextPath());
139            request.setAttribute(ASYNC_SERVLET_PATH, request.getServletPath());
140            request.setAttribute(ASYNC_PATH_INFO, request.getPathInfo());
141            request.setAttribute(ASYNC_QUERY_STRING, request.getQueryString());
142        }
143
144        try {
145            controller.invoke(request, response);
146        } catch (final ServletException ioe) {
147            onError(ioe);
148        }
149    }
150
151    @Override
152    public void complete() {
153        final AsyncEvent event = new AsyncEvent(this, request, response);
154        executeOnListeners(l -> l.onComplete(event), null);
155        try {
156            response.getOutputStream().close();
157        } catch (final IOException e) {
158            throw new IllegalStateException(e);
159        }
160    }
161
162    @Override
163    public void start(final Runnable run) {
164        run.run();
165    }
166
167    @Override
168    public void addListener(final AsyncListener listener) {
169        listeners.add(new AsyncListenerWrapper(listener, request, response));
170    }
171
172    @Override
173    public void addListener(final AsyncListener listener, final ServletRequest request,
174            final ServletResponse response) {
175        listeners.add(new AsyncListenerWrapper(listener, request, response));
176    }
177
178    @Override
179    public <T extends AsyncListener> T createListener(final Class<T> clazz) {
180        return CDI.current().select(clazz).get();
181    }
182
183    @Override
184    public void setTimeout(final long timeout) {
185        this.timeout = timeout;
186    }
187
188    @Override
189    public long getTimeout() {
190        return timeout;
191    }
192
193    // taken from tomcat
194    private static String urlDecode(final String str) {
195        if (str == null) {
196            return null;
197        }
198
199        if (str.indexOf('%') == -1) {
200            // No %nn sequences, so return string unchanged
201            return str;
202        }
203
204        final ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length() * 2);
205        final OutputStreamWriter osw = new OutputStreamWriter(baos, StandardCharsets.UTF_8);
206        final char[] sourceChars = str.toCharArray();
207        final int len = sourceChars.length;
208        int ix = 0;
209
210        try {
211            while (ix < len) {
212                char c = sourceChars[ix++];
213                if (c == '%') {
214                    osw.flush();
215                    if (ix + 2 > len) {
216                        throw new IllegalArgumentException("Missing digit: " + str);
217                    }
218                    char c1 = sourceChars[ix++];
219                    char c2 = sourceChars[ix++];
220                    if (isHexDigit(c1) && isHexDigit(c2)) {
221                        baos.write(x2c(c1, c2));
222                    } else {
223                        throw new IllegalArgumentException("Missing digit: " + str);
224                    }
225                } else {
226                    osw.append(c);
227                }
228            }
229            osw.flush();
230
231            return baos.toString(StandardCharsets.UTF_8.name());
232        } catch (final IOException ioe) {
233            throw new IllegalArgumentException(ioe);
234        }
235    }
236
237    private static boolean isHexDigit(final int c) {
238        return ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'));
239    }
240
241    private static int x2c(final char b1, final char b2) {
242        int digit = (b1 >= 'A') ? ((b1 & 0xDF) - 'A') + 10 : (b1 - '0');
243        digit *= 16;
244        digit += (b2 >= 'A') ? ((b2 & 0xDF) - 'A') + 10 : (b2 - '0');
245        return digit;
246    }
247
248    private static class AsyncListenerWrapper implements AsyncListener {
249
250        private final AsyncListener delegate;
251
252        private final ServletRequest request;
253
254        private final ServletResponse response;
255
256        private AsyncListenerWrapper(final AsyncListener delegate, final ServletRequest request,
257                final ServletResponse response) {
258            this.delegate = delegate;
259            this.request = request;
260            this.response = response;
261        }
262
263        @Override
264        public void onComplete(final AsyncEvent event) throws IOException {
265            delegate.onComplete(wrap(event));
266        }
267
268        @Override
269        public void onTimeout(final AsyncEvent event) throws IOException {
270            delegate.onTimeout(wrap(event));
271        }
272
273        @Override
274        public void onError(final AsyncEvent event) throws IOException {
275            delegate.onError(wrap(event));
276        }
277
278        @Override
279        public void onStartAsync(final AsyncEvent event) throws IOException {
280            delegate.onStartAsync(wrap(event));
281        }
282
283        private AsyncEvent wrap(final AsyncEvent event) {
284            if (request != null && response != null) {
285                return new AsyncEvent(event.getAsyncContext(), request, response, event.getThrowable());
286            }
287            return event;
288        }
289    }
290
291    @FunctionalInterface
292    private interface UnsafeConsumer<T> {
293
294        void accept(T t) throws IOException;
295    }
296}