View Javadoc
1   package net.sumaris.server.http.graphql;
2   
3   /*-
4    * #%L
5    * SUMARiS:: Server
6    * %%
7    * Copyright (C) 2018 SUMARiS Consortium
8    * %%
9    * This program is free software: you can redistribute it and/or modify
10   * it under the terms of the GNU General Public License as
11   * published by the Free Software Foundation, either version 3 of the
12   * License, or (at your option) any later version.
13   * 
14   * This program is distributed in the hope that it will be useful,
15   * but WITHOUT ANY WARRANTY; without even the implied warranty of
16   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17   * GNU General Public License for more details.
18   * 
19   * You should have received a copy of the GNU General Public
20   * License along with this program.  If not, see
21   * <http://www.gnu.org/licenses/gpl-3.0.html>.
22   * #L%
23   */
24  
25  import com.fasterxml.jackson.databind.ObjectMapper;
26  import com.google.common.collect.ImmutableMap;
27  import graphql.ExecutionInput;
28  import graphql.ExecutionResult;
29  import graphql.GraphQL;
30  import net.sumaris.core.exception.SumarisTechnicalException;
31  import net.sumaris.server.exception.ErrorCodes;
32  import net.sumaris.server.http.security.AuthService;
33  import net.sumaris.server.http.security.AuthUser;
34  import net.sumaris.server.vo.security.AuthDataVO;
35  import org.apache.commons.collections4.CollectionUtils;
36  import org.apache.commons.collections4.MapUtils;
37  import org.apache.commons.lang3.StringUtils;
38  import org.nuiton.i18n.I18n;
39  import org.reactivestreams.Publisher;
40  import org.reactivestreams.Subscriber;
41  import org.reactivestreams.Subscription;
42  import org.slf4j.Logger;
43  import org.slf4j.LoggerFactory;
44  import org.springframework.beans.factory.annotation.Autowired;
45  import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
46  import org.springframework.security.core.Authentication;
47  import org.springframework.security.core.AuthenticationException;
48  import org.springframework.security.core.context.SecurityContextHolder;
49  import org.springframework.web.socket.CloseStatus;
50  import org.springframework.web.socket.TextMessage;
51  import org.springframework.web.socket.WebSocketSession;
52  import org.springframework.web.socket.handler.TextWebSocketHandler;
53  
54  import java.io.IOException;
55  import java.util.List;
56  import java.util.Map;
57  import java.util.Objects;
58  import java.util.Optional;
59  import java.util.concurrent.CopyOnWriteArrayList;
60  import java.util.concurrent.atomic.AtomicReference;
61  
62  public class SubscriptionWebSocketHandler extends TextWebSocketHandler {
63  
64      private static final Logger log = LoggerFactory.getLogger(SubscriptionWebSocketHandler.class);
65  
66      private final AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
67  
68      private final boolean debug;
69  
70      private List<WebSocketSession> sessions = new CopyOnWriteArrayList();
71  
72      @Autowired
73      private GraphQL graphQL;
74  
75      @Autowired
76      private ObjectMapper objectMapper;
77  
78      @Autowired
79      private AuthService authService;
80  
81      @Autowired
82      public SubscriptionWebSocketHandler() {
83          this.debug = log.isDebugEnabled();
84      }
85  
86      @Override
87      public void afterConnectionEstablished(WebSocketSession session) throws Exception {
88          // keep all sessions (for broadcast)
89          sessions.add(session);
90      }
91  
92      @Override
93      public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
94          sessions.remove(session);
95          if (subscriptionRef.get() != null) subscriptionRef.get().cancel();
96      }
97  
98      @Override
99      protected void handleTextMessage(WebSocketSession session, TextMessage message) {
100 
101         Map<String, Object> request;
102         try {
103             request = objectMapper.readValue(message.asBytes(), Map.class);
104             if (debug) log.debug(I18n.t("sumaris.server.subscription.getRequest", request));
105         }
106         catch(IOException e) {
107             log.error(I18n.t("sumaris.server.error.subscription.badRequest", e.getMessage()));
108             return;
109         }
110 
111         String type = Objects.toString(request.get("type"), "start");
112         if ("connection_init".equals(type)) {
113             handleInitConnection(session, request);
114         }
115         else if ("stop".equals(type)) {
116             if (subscriptionRef.get() != null) subscriptionRef.get().cancel();
117         }
118         else if ("start".equals(type)) {
119             handleStartConnection(session, request);
120         }
121     }
122 
123     @Override
124     public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
125         session.close(CloseStatus.SERVER_ERROR);
126     }
127 
128     /* -- protected methods -- */
129 
130     protected void handleInitConnection(WebSocketSession session, Map<String, Object> request) {
131         Map<String, Object> payload = (Map<String, Object>) request.get("payload");
132         String authToken = MapUtils.getString(payload, "authToken");
133 
134         // Has token: try to authenticate
135         if (StringUtils.isNotBlank(authToken)) {
136 
137             // try to authenticate
138             try {
139                 Optional<AuthUser> authUser = authService.authenticate(authToken);
140                 // If success
141                 if (authUser.isPresent()) {
142                     UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(authUser.get().getUsername(), authToken, authUser.get().getAuthorities());
143                     SecurityContextHolder.getContext().setAuthentication(authentication);
144                     return; // OK
145                 }
146             }
147             catch(AuthenticationException e) {
148                 log.warn("Unable to authenticate websocket session, using token: " + e.getMessage());
149                 // Continue
150             }
151         }
152 
153         // Not auth: send a new challenge
154         try {
155             session.sendMessage(new TextMessage(objectMapper.writeValueAsString(
156                     ImmutableMap.of(
157                             "type", "error",
158                             "payload", getUnauthorizedErrorWithChallenge()
159                     ))));
160         } catch (IOException e) {
161             throw new SumarisTechnicalException(e);
162         }
163     }
164 
165 
166     protected void handleStartConnection(WebSocketSession session, Map<String, Object> request) {
167 
168         Map<String, Object> payload = (Map<String, Object>)request.get("payload");
169         final Object opId = request.get("id");
170 
171         // Check authenticated
172         if (!isAuthenticated()) {
173             try {
174                 session.close(CloseStatus.SERVICE_RESTARTED);
175             }
176             catch(IOException e) {
177                 // continue
178             }
179             return;
180         }
181 
182         String query = Objects.toString(payload.get("query"));
183         ExecutionResult executionResult = graphQL.execute(ExecutionInput.newExecutionInput()
184                 .query(query)
185                 .operationName((String) payload.get("operationName"))
186                 .variables(GraphQLHelper.getVariables(payload, objectMapper))
187                 .build());
188 
189         // If error: send error then disconnect
190         if (CollectionUtils.isNotEmpty(executionResult.getErrors())) {
191             sendResponse(session,
192                          ImmutableMap.of(
193                                 "id", opId,
194                                 "type", "error",
195                                 "payload", GraphQLHelper.processExecutionResult(executionResult))
196                 );
197             return;
198         }
199 
200         Publisher<ExecutionResult> stream = executionResult.getData();
201 
202         stream.subscribe(new Subscriber<ExecutionResult>() {
203             @Override
204             public void onSubscribe(Subscription subscription) {
205                 subscriptionRef.set(subscription);
206                 if (subscriptionRef.get() != null) subscriptionRef.get().request(1);
207             }
208 
209             @Override
210             public void onNext(ExecutionResult result) {
211                 sendResponse(session, ImmutableMap.of(
212                                     "id", opId,
213                                     "type", "data",
214                                     "payload", GraphQLHelper.processExecutionResult(result))
215                 );
216 
217                 if (subscriptionRef.get() != null) subscriptionRef.get().request(1);
218             }
219 
220             @Override
221             public void onError(Throwable throwable) {
222                 log.warn("GraphQL subscription error", throwable);
223                 sendResponse(session,
224                              ImmutableMap.of(
225                                     "id", opId,
226                                     "type", "error",
227                                     "payload", GraphQLHelper.processError(throwable))
228                 );
229             }
230 
231             @Override
232             public void onComplete() {
233                 try {
234                     session.close();
235                 } catch (IOException e) {
236                     log.error(e.getMessage(), e);
237                 }
238             }
239 
240         });
241     }
242 
243     protected boolean isAuthenticated() {
244         Authentication auth = SecurityContextHolder.getContext().getAuthentication();
245         return (auth != null && auth.isAuthenticated());
246     }
247 
248     protected void sendResponse(WebSocketSession session, Object value)  {
249         try {
250             session.sendMessage(new TextMessage(objectMapper.writeValueAsString(value)));
251         } catch (IOException e) {
252             log.error(e.getMessage(), e);
253         }
254     }
255 
256     protected Map<String, Object> getUnauthorizedErrorWithChallenge() {
257         AuthDataVO challenge = authService.createNewChallenge();
258         return ImmutableMap.of("message", getUnauthorizedErrorString(),
259                                "challenge", challenge);
260     }
261 
262     protected String getUnauthorizedErrorString() {
263         return GraphQLHelper.toJsonErrorString(ErrorCodes.UNAUTHORIZED, "Authentication required");
264     }
265 }