001/** 002 * Copyright (C) 2006-2020 Talend Inc. - www.talend.com 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.talend.sdk.component.server.front.memory; 017 018import static java.util.Optional.ofNullable; 019 020import java.io.ByteArrayOutputStream; 021import java.io.IOException; 022import java.io.OutputStreamWriter; 023import java.nio.charset.StandardCharsets; 024import java.util.ArrayList; 025import java.util.Collection; 026import java.util.List; 027 028import javax.enterprise.inject.spi.CDI; 029import javax.servlet.AsyncContext; 030import javax.servlet.AsyncEvent; 031import javax.servlet.AsyncListener; 032import javax.servlet.ServletContext; 033import javax.servlet.ServletException; 034import javax.servlet.ServletRequest; 035import javax.servlet.ServletResponse; 036import javax.servlet.http.HttpServletRequest; 037import javax.servlet.http.HttpServletResponse; 038 039import org.apache.cxf.transport.servlet.ServletController; 040 041import lombok.RequiredArgsConstructor; 042import lombok.extern.slf4j.Slf4j; 043 044@Slf4j 045@RequiredArgsConstructor 046public class AsyncContextImpl implements AsyncContext { 047 048 private final ServletRequest request; 049 050 private final InMemoryResponse response; 051 052 private final boolean originalRequestAndResponse; 053 054 private final Collection<AsyncListener> listeners = new ArrayList<>(); 055 056 private final ServletController controller; 057 058 private long timeout; 059 060 AsyncContext start() { 061 final AsyncEvent event = new AsyncEvent(this, request, response); 062 executeOnListeners(l -> l.onStartAsync(event), listeners::clear); 063 return this; 064 } 065 066 public void onError(final Throwable throwable) { 067 final AsyncEvent event = new AsyncEvent(this, request, response, throwable); 068 executeOnListeners(l -> l.onError(event), null); 069 if (!response.isCommitted() && HttpServletResponse.class.isInstance(response)) { 070 final HttpServletResponse http = HttpServletResponse.class.cast(response); 071 http.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); 072 } 073 complete(); 074 } 075 076 private void executeOnListeners(final UnsafeConsumer<AsyncListener> fn, final Runnable afterCopy) { 077 final List<AsyncListener> listenersCopy; 078 synchronized (listeners) { 079 listenersCopy = new ArrayList<>(listeners.size()); 080 ofNullable(afterCopy).ifPresent(Runnable::run); 081 } 082 listenersCopy.forEach(listener -> { 083 try { 084 fn.accept(listener); 085 } catch (final Throwable t) { 086 log.warn("callback failed for listener of type [" + listener.getClass().getName() + "]", t); 087 } 088 }); 089 } 090 091 @Override 092 public ServletRequest getRequest() { 093 return request; 094 } 095 096 @Override 097 public ServletResponse getResponse() { 098 return response; 099 } 100 101 @Override 102 public boolean hasOriginalRequestAndResponse() { 103 return originalRequestAndResponse; 104 } 105 106 @Override 107 public void dispatch() { 108 final ServletRequest servletRequest = getRequest(); 109 if (!HttpServletRequest.class.isInstance(servletRequest)) { 110 throw new IllegalStateException("Not a http request: " + servletRequest); 111 } 112 113 final HttpServletRequest sr = HttpServletRequest.class.cast(servletRequest); 114 115 String path = sr.getRequestURI(); 116 final String cpath = sr.getContextPath(); 117 if (cpath.length() > 1) { 118 path = path.substring(cpath.length()); 119 } 120 dispatch(urlDecode(path)); 121 } 122 123 @Override 124 public void dispatch(final String path) { 125 dispatch(request.getServletContext(), path); 126 } 127 128 @Override 129 public void dispatch(final ServletContext context, final String path) { 130 final ServletRequest servletRequest = getRequest(); 131 if (!HttpServletRequest.class.isInstance(servletRequest)) { 132 throw new IllegalStateException("Not a http request: " + servletRequest); 133 } 134 135 final HttpServletRequest request = HttpServletRequest.class.cast(servletRequest); 136 if (request.getAttribute(ASYNC_REQUEST_URI) == null) { 137 request.setAttribute(ASYNC_REQUEST_URI, request.getRequestURI()); 138 request.setAttribute(ASYNC_CONTEXT_PATH, request.getContextPath()); 139 request.setAttribute(ASYNC_SERVLET_PATH, request.getServletPath()); 140 request.setAttribute(ASYNC_PATH_INFO, request.getPathInfo()); 141 request.setAttribute(ASYNC_QUERY_STRING, request.getQueryString()); 142 } 143 144 try { 145 controller.invoke(request, response); 146 } catch (final ServletException ioe) { 147 onError(ioe); 148 } 149 } 150 151 @Override 152 public void complete() { 153 final AsyncEvent event = new AsyncEvent(this, request, response); 154 executeOnListeners(l -> l.onComplete(event), null); 155 try { 156 response.getOutputStream().close(); 157 } catch (final IOException e) { 158 throw new IllegalStateException(e); 159 } 160 } 161 162 @Override 163 public void start(final Runnable run) { 164 run.run(); 165 } 166 167 @Override 168 public void addListener(final AsyncListener listener) { 169 listeners.add(new AsyncListenerWrapper(listener, request, response)); 170 } 171 172 @Override 173 public void addListener(final AsyncListener listener, final ServletRequest request, 174 final ServletResponse response) { 175 listeners.add(new AsyncListenerWrapper(listener, request, response)); 176 } 177 178 @Override 179 public <T extends AsyncListener> T createListener(final Class<T> clazz) { 180 return CDI.current().select(clazz).get(); 181 } 182 183 @Override 184 public void setTimeout(final long timeout) { 185 this.timeout = timeout; 186 } 187 188 @Override 189 public long getTimeout() { 190 return timeout; 191 } 192 193 // taken from tomcat 194 private static String urlDecode(final String str) { 195 if (str == null) { 196 return null; 197 } 198 199 if (str.indexOf('%') == -1) { 200 // No %nn sequences, so return string unchanged 201 return str; 202 } 203 204 final ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length() * 2); 205 final OutputStreamWriter osw = new OutputStreamWriter(baos, StandardCharsets.UTF_8); 206 final char[] sourceChars = str.toCharArray(); 207 final int len = sourceChars.length; 208 int ix = 0; 209 210 try { 211 while (ix < len) { 212 char c = sourceChars[ix++]; 213 if (c == '%') { 214 osw.flush(); 215 if (ix + 2 > len) { 216 throw new IllegalArgumentException("Missing digit: " + str); 217 } 218 char c1 = sourceChars[ix++]; 219 char c2 = sourceChars[ix++]; 220 if (isHexDigit(c1) && isHexDigit(c2)) { 221 baos.write(x2c(c1, c2)); 222 } else { 223 throw new IllegalArgumentException("Missing digit: " + str); 224 } 225 } else { 226 osw.append(c); 227 } 228 } 229 osw.flush(); 230 231 return baos.toString(StandardCharsets.UTF_8.name()); 232 } catch (final IOException ioe) { 233 throw new IllegalArgumentException(ioe); 234 } 235 } 236 237 private static boolean isHexDigit(final int c) { 238 return ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')); 239 } 240 241 private static int x2c(final char b1, final char b2) { 242 int digit = (b1 >= 'A') ? ((b1 & 0xDF) - 'A') + 10 : (b1 - '0'); 243 digit *= 16; 244 digit += (b2 >= 'A') ? ((b2 & 0xDF) - 'A') + 10 : (b2 - '0'); 245 return digit; 246 } 247 248 private static class AsyncListenerWrapper implements AsyncListener { 249 250 private final AsyncListener delegate; 251 252 private final ServletRequest request; 253 254 private final ServletResponse response; 255 256 private AsyncListenerWrapper(final AsyncListener delegate, final ServletRequest request, 257 final ServletResponse response) { 258 this.delegate = delegate; 259 this.request = request; 260 this.response = response; 261 } 262 263 @Override 264 public void onComplete(final AsyncEvent event) throws IOException { 265 delegate.onComplete(wrap(event)); 266 } 267 268 @Override 269 public void onTimeout(final AsyncEvent event) throws IOException { 270 delegate.onTimeout(wrap(event)); 271 } 272 273 @Override 274 public void onError(final AsyncEvent event) throws IOException { 275 delegate.onError(wrap(event)); 276 } 277 278 @Override 279 public void onStartAsync(final AsyncEvent event) throws IOException { 280 delegate.onStartAsync(wrap(event)); 281 } 282 283 private AsyncEvent wrap(final AsyncEvent event) { 284 if (request != null && response != null) { 285 return new AsyncEvent(event.getAsyncContext(), request, response, event.getThrowable()); 286 } 287 return event; 288 } 289 } 290 291 @FunctionalInterface 292 private interface UnsafeConsumer<T> { 293 294 void accept(T t) throws IOException; 295 } 296}