浏览代码

fix: 处理自定义请求头

Walker 1 年之前
父节点
当前提交
4d57ac358f
共有 1 个文件被更改,包括 23 次插入10 次删除
  1. 23 10
      src/main/java/com/ywt/gateway/filter/AuthGatewayFilterFactory.java

+ 23 - 10
src/main/java/com/ywt/gateway/filter/AuthGatewayFilterFactory.java

@@ -35,6 +35,7 @@ import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * 处理鉴权
+ *
  * @author Walker
  * Created on 2023/11/10
  */
@@ -76,6 +77,7 @@ public class AuthGatewayFilterFactory extends AbstractGatewayFilterFactory<HostL
             String requestUrl = request.getPath().toString();
             String methodName = request.getMethodValue();
             HttpHeaders headers = new HttpHeaders();
+            HttpHeaders customHeaders = new HttpHeaders();
             headers.addAll(request.getHeaders());
             String host = "";
             if (request.getLocalAddress() != null) {
@@ -103,8 +105,9 @@ public class AuthGatewayFilterFactory extends AbstractGatewayFilterFactory<HostL
                     return chain.filter(exchange);
                 } else {
                     AuthInfo authInfo = bizCfg.getAuths().stream().filter(i -> auth.equals(i.getName())).findFirst().orElse(null);
-                    if (authInfo == null) throw new HttpMsgException(HttpStatus.BAD_GATEWAY, HttpStatus.BAD_GATEWAY.value(),
-                            "No Auth");
+                    if (authInfo == null)
+                        throw new HttpMsgException(HttpStatus.BAD_GATEWAY, HttpStatus.BAD_GATEWAY.value(),
+                                "No Auth");
                     String authType = authInfo.getType();
                     String name = authInfo.getName();
                     MultiValueMap<String, HttpCookie> cookieMap = request.getCookies();
@@ -123,14 +126,19 @@ public class AuthGatewayFilterFactory extends AbstractGatewayFilterFactory<HostL
                                 JWTPayload payload = jwt.getPayload();
                                 String authdata = (String) payload.getClaim(KEY_AUTHDATA);
                                 if (authdata != null && !authdata.isEmpty()) {
-                                    headers.add(KEY_AUTH_DATA, authdata);
+                                    customHeaders.add(KEY_AUTH_DATA, authdata);
                                 }
                                 if (authInfo.getParams() != null) {
-                                    headers.add(KEY_AUTH_PARAM, JSONUtil.toJsonStr(authInfo.getParams()));
+                                    customHeaders.add(KEY_AUTH_PARAM, JSONUtil.toJsonStr(authInfo.getParams()));
                                 }
-
+                                headers.addAll(customHeaders);
                                 proxyPass(config, authInfo, headers, response);
-                                return chain.filter(exchange);
+
+                                ServerHttpRequest.Builder modifiedReqBuilder = request.mutate();
+                                customHeaders.forEach((s, strings) -> modifiedReqBuilder.header(s, (strings != null && strings.size() > 0) ? strings.get(0) : ""));
+                                return chain.filter(exchange.mutate()
+                                        .request(modifiedReqBuilder.build())
+                                        .build());
                             } else {
                                 boolean authResponse401 = config.isAuthResponse401();
                                 // 处理授权跳转
@@ -202,7 +210,7 @@ public class AuthGatewayFilterFactory extends AbstractGatewayFilterFactory<HostL
                             }
                             String authdata = (String) payload.getClaim(KEY_AUTHDATA);
                             if (authdata != null && !authdata.isEmpty()) {
-                                headers.add(KEY_AUTH_DATA, authdata);
+                                customHeaders.add(KEY_AUTH_DATA, authdata);
                             }
                             Map<String, String> authParamMap = new HashMap<>();
                             if (authInfo.getParams() != null) {
@@ -211,13 +219,18 @@ public class AuthGatewayFilterFactory extends AbstractGatewayFilterFactory<HostL
                             if (appInfo.getParams() != null) {
                                 authParamMap.putAll(appInfo.getParams());
                             }
-                            headers.add(KEY_AUTH_PARAM, JSONUtil.toJsonStr(authParamMap));
+                            customHeaders.add(KEY_AUTH_PARAM, JSONUtil.toJsonStr(authParamMap));
 
                             // 下发当前授权的 appid 至后端
-                            headers.add(KEY_AUTH_APPID, appInfo.getAppid());
+                            customHeaders.add(KEY_AUTH_APPID, appInfo.getAppid());
+                            headers.addAll(customHeaders);
 
                             proxyPass(config, authInfo, headers, response);
-                            return chain.filter(exchange);
+                            ServerHttpRequest.Builder modifiedReqBuilder = request.mutate();
+                            customHeaders.forEach((s, strings) -> modifiedReqBuilder.header(s, (strings != null && strings.size() > 0) ? strings.get(0) : ""));
+                            return chain.filter(exchange.mutate()
+                                    .request(modifiedReqBuilder.build())
+                                    .build());
                         default:
                             throw new HttpMsgException(HttpStatus.BAD_GATEWAY, HttpStatus.BAD_GATEWAY.value(), "No Auth");
                     }