1 package net.sumaris.server.http.graphql;
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25 import com.fasterxml.jackson.databind.ObjectMapper;
26 import com.google.common.collect.ImmutableMap;
27 import graphql.ExecutionInput;
28 import graphql.ExecutionResult;
29 import graphql.GraphQL;
30 import net.sumaris.core.exception.SumarisTechnicalException;
31 import net.sumaris.server.exception.ErrorCodes;
32 import net.sumaris.server.http.security.AuthService;
33 import net.sumaris.server.http.security.AuthUser;
34 import net.sumaris.server.vo.security.AuthDataVO;
35 import org.apache.commons.collections4.CollectionUtils;
36 import org.apache.commons.collections4.MapUtils;
37 import org.apache.commons.lang3.StringUtils;
38 import org.nuiton.i18n.I18n;
39 import org.reactivestreams.Publisher;
40 import org.reactivestreams.Subscriber;
41 import org.reactivestreams.Subscription;
42 import org.slf4j.Logger;
43 import org.slf4j.LoggerFactory;
44 import org.springframework.beans.factory.annotation.Autowired;
45 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
46 import org.springframework.security.core.Authentication;
47 import org.springframework.security.core.AuthenticationException;
48 import org.springframework.security.core.context.SecurityContextHolder;
49 import org.springframework.web.socket.CloseStatus;
50 import org.springframework.web.socket.TextMessage;
51 import org.springframework.web.socket.WebSocketSession;
52 import org.springframework.web.socket.handler.TextWebSocketHandler;
53
54 import java.io.IOException;
55 import java.util.List;
56 import java.util.Map;
57 import java.util.Objects;
58 import java.util.Optional;
59 import java.util.concurrent.CopyOnWriteArrayList;
60 import java.util.concurrent.atomic.AtomicReference;
61
62 public class SubscriptionWebSocketHandler extends TextWebSocketHandler {
63
64 private static final Logger log = LoggerFactory.getLogger(SubscriptionWebSocketHandler.class);
65
66 private final AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
67
68 private final boolean debug;
69
70 private List<WebSocketSession> sessions = new CopyOnWriteArrayList();
71
72 @Autowired
73 private GraphQL graphQL;
74
75 @Autowired
76 private ObjectMapper objectMapper;
77
78 @Autowired
79 private AuthService authService;
80
81 @Autowired
82 public SubscriptionWebSocketHandler() {
83 this.debug = log.isDebugEnabled();
84 }
85
86 @Override
87 public void afterConnectionEstablished(WebSocketSession session) throws Exception {
88
89 sessions.add(session);
90 }
91
92 @Override
93 public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
94 sessions.remove(session);
95 if (subscriptionRef.get() != null) subscriptionRef.get().cancel();
96 }
97
98 @Override
99 protected void handleTextMessage(WebSocketSession session, TextMessage message) {
100
101 Map<String, Object> request;
102 try {
103 request = objectMapper.readValue(message.asBytes(), Map.class);
104 if (debug) log.debug(I18n.t("sumaris.server.subscription.getRequest", request));
105 }
106 catch(IOException e) {
107 log.error(I18n.t("sumaris.server.error.subscription.badRequest", e.getMessage()));
108 return;
109 }
110
111 String type = Objects.toString(request.get("type"), "start");
112 if ("connection_init".equals(type)) {
113 handleInitConnection(session, request);
114 }
115 else if ("stop".equals(type)) {
116 if (subscriptionRef.get() != null) subscriptionRef.get().cancel();
117 }
118 else if ("start".equals(type)) {
119 handleStartConnection(session, request);
120 }
121 }
122
123 @Override
124 public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
125 session.close(CloseStatus.SERVER_ERROR);
126 }
127
128
129
130 protected void handleInitConnection(WebSocketSession session, Map<String, Object> request) {
131 Map<String, Object> payload = (Map<String, Object>) request.get("payload");
132 String authToken = MapUtils.getString(payload, "authToken");
133
134
135 if (StringUtils.isNotBlank(authToken)) {
136
137
138 try {
139 Optional<AuthUser> authUser = authService.authenticate(authToken);
140
141 if (authUser.isPresent()) {
142 UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(authUser.get().getUsername(), authToken, authUser.get().getAuthorities());
143 SecurityContextHolder.getContext().setAuthentication(authentication);
144 return;
145 }
146 }
147 catch(AuthenticationException e) {
148 log.warn("Unable to authenticate websocket session, using token: " + e.getMessage());
149
150 }
151 }
152
153
154 try {
155 session.sendMessage(new TextMessage(objectMapper.writeValueAsString(
156 ImmutableMap.of(
157 "type", "error",
158 "payload", getUnauthorizedErrorWithChallenge()
159 ))));
160 } catch (IOException e) {
161 throw new SumarisTechnicalException(e);
162 }
163 }
164
165
166 protected void handleStartConnection(WebSocketSession session, Map<String, Object> request) {
167
168 Map<String, Object> payload = (Map<String, Object>)request.get("payload");
169 final Object opId = request.get("id");
170
171
172 if (!isAuthenticated()) {
173 try {
174 session.close(CloseStatus.SERVICE_RESTARTED);
175 }
176 catch(IOException e) {
177
178 }
179 return;
180 }
181
182 String query = Objects.toString(payload.get("query"));
183 ExecutionResult executionResult = graphQL.execute(ExecutionInput.newExecutionInput()
184 .query(query)
185 .operationName((String) payload.get("operationName"))
186 .variables(GraphQLHelper.getVariables(payload, objectMapper))
187 .build());
188
189
190 if (CollectionUtils.isNotEmpty(executionResult.getErrors())) {
191 sendResponse(session,
192 ImmutableMap.of(
193 "id", opId,
194 "type", "error",
195 "payload", GraphQLHelper.processExecutionResult(executionResult))
196 );
197 return;
198 }
199
200 Publisher<ExecutionResult> stream = executionResult.getData();
201
202 stream.subscribe(new Subscriber<ExecutionResult>() {
203 @Override
204 public void onSubscribe(Subscription subscription) {
205 subscriptionRef.set(subscription);
206 if (subscriptionRef.get() != null) subscriptionRef.get().request(1);
207 }
208
209 @Override
210 public void onNext(ExecutionResult result) {
211 sendResponse(session, ImmutableMap.of(
212 "id", opId,
213 "type", "data",
214 "payload", GraphQLHelper.processExecutionResult(result))
215 );
216
217 if (subscriptionRef.get() != null) subscriptionRef.get().request(1);
218 }
219
220 @Override
221 public void onError(Throwable throwable) {
222 log.warn("GraphQL subscription error", throwable);
223 sendResponse(session,
224 ImmutableMap.of(
225 "id", opId,
226 "type", "error",
227 "payload", GraphQLHelper.processError(throwable))
228 );
229 }
230
231 @Override
232 public void onComplete() {
233 try {
234 session.close();
235 } catch (IOException e) {
236 log.error(e.getMessage(), e);
237 }
238 }
239
240 });
241 }
242
243 protected boolean isAuthenticated() {
244 Authentication auth = SecurityContextHolder.getContext().getAuthentication();
245 return (auth != null && auth.isAuthenticated());
246 }
247
248 protected void sendResponse(WebSocketSession session, Object value) {
249 try {
250 session.sendMessage(new TextMessage(objectMapper.writeValueAsString(value)));
251 } catch (IOException e) {
252 log.error(e.getMessage(), e);
253 }
254 }
255
256 protected Map<String, Object> getUnauthorizedErrorWithChallenge() {
257 AuthDataVO challenge = authService.createNewChallenge();
258 return ImmutableMap.of("message", getUnauthorizedErrorString(),
259 "challenge", challenge);
260 }
261
262 protected String getUnauthorizedErrorString() {
263 return GraphQLHelper.toJsonErrorString(ErrorCodes.UNAUTHORIZED, "Authentication required");
264 }
265 }