|
@@ -0,0 +1,279 @@
|
|
|
+package com.ywt.gateway.filter;
|
|
|
+
|
|
|
+import cn.hutool.crypto.digest.DigestUtil;
|
|
|
+import cn.hutool.json.JSONUtil;
|
|
|
+import cn.hutool.jwt.JWT;
|
|
|
+import cn.hutool.jwt.JWTPayload;
|
|
|
+import cn.hutool.jwt.JWTUtil;
|
|
|
+import cn.hutool.jwt.signers.JWTSignerUtil;
|
|
|
+import com.ywt.gateway.configuration.BizCfg;
|
|
|
+import com.ywt.gateway.decorator.RecorderServerHttpRequestDecorator;
|
|
|
+import com.ywt.gateway.model.*;
|
|
|
+import com.ywt.gateway.utils.BizUtil;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.cloud.gateway.filter.GatewayFilter;
|
|
|
+import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
|
|
|
+import org.springframework.core.io.buffer.DataBuffer;
|
|
|
+import org.springframework.http.*;
|
|
|
+import org.springframework.http.server.reactive.ServerHttpRequest;
|
|
|
+import org.springframework.http.server.reactive.ServerHttpResponse;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+import org.springframework.util.MultiValueMap;
|
|
|
+import reactor.core.publisher.Flux;
|
|
|
+
|
|
|
+import java.io.UnsupportedEncodingException;
|
|
|
+import java.net.URLEncoder;
|
|
|
+import java.nio.CharBuffer;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+import java.util.Date;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.Optional;
|
|
|
+import java.util.concurrent.atomic.AtomicReference;
|
|
|
+
|
|
|
+/**
|
|
|
+ * 处理鉴权
|
|
|
+ * @author Walker
|
|
|
+ * Created on 2023/11/10
|
|
|
+ */
|
|
|
+@Component
|
|
|
+public class AuthGatewayFilterFactory extends AbstractGatewayFilterFactory<HostLocationInfo> {
|
|
|
+ private final Logger logger = LoggerFactory.getLogger(AuthGatewayFilterFactory.class);
|
|
|
+ public static final String AUTH_TYPE_WEB = "web";
|
|
|
+ public static final String AUTH_TYPE_WECHATMP = "wechatmp";
|
|
|
+ public static final String AUTH_TYPE_API = "api";
|
|
|
+ public static final String KEY_AUTHDATA = "authdata";
|
|
|
+ public static final String KEY_AUTH_DATA = "auth-data";
|
|
|
+ public static final String KEY_CLEAN_DATA = "clean-data";
|
|
|
+ public static final String KEY_AUTH_PARAM = "auth-param";
|
|
|
+ public static final String KEY_BEARER = "Bearer ";
|
|
|
+ public static final String KEY_APPID = "appid";
|
|
|
+ public static final String KEY_EXP = "exp";
|
|
|
+ public static final String KEY_CHECKSUM = "checksum";
|
|
|
+ public static final String KEY_AUTH_APPID = "auth-appid";
|
|
|
+ public static final String KEY_AUTHORIZATION = "Authorization";
|
|
|
+ public static final String KEY_IAT = "iat";
|
|
|
+ public static final String POST = "post";
|
|
|
+ public static final String YES = "yes";
|
|
|
+ public static final int HTTP_STATUS_CODE_601 = 601;
|
|
|
+ public static final int HTTP_STATUS_CODE_602 = 602;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private BizCfg bizCfg;
|
|
|
+
|
|
|
+ public AuthGatewayFilterFactory() {
|
|
|
+ super(HostLocationInfo.class);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public GatewayFilter apply(HostLocationInfo config) {
|
|
|
+ return (exchange, chain) -> {
|
|
|
+
|
|
|
+ ServerHttpResponse response = exchange.getResponse();
|
|
|
+ ServerHttpRequest request = exchange.getRequest();
|
|
|
+ String requestUrl = request.getPath().toString();
|
|
|
+ String methodName = request.getMethodValue();
|
|
|
+ HttpHeaders headers = request.getHeaders();
|
|
|
+ String host = "";
|
|
|
+ if (request.getLocalAddress() != null) {
|
|
|
+ host = request.getLocalAddress().getHostString();
|
|
|
+ } else {
|
|
|
+ logger.warn("Could not get local address!");
|
|
|
+ }
|
|
|
+ // 获取 request body
|
|
|
+ AtomicReference<String> requestBody = new AtomicReference<>("");
|
|
|
+ // 复用现成的 decorator
|
|
|
+ RecorderServerHttpRequestDecorator requestDecorator = new RecorderServerHttpRequestDecorator(request);
|
|
|
+ Flux<DataBuffer> body = requestDecorator.getBody();
|
|
|
+ body.subscribe(buffer -> {
|
|
|
+ CharBuffer charBuffer = StandardCharsets.UTF_8.decode(buffer.asByteBuffer());
|
|
|
+ requestBody.set(charBuffer.toString());
|
|
|
+ });
|
|
|
+ String bodyStr = requestBody.get();
|
|
|
+ try {
|
|
|
+ String auth = config.getAuth();
|
|
|
+ String protocol = config.getProtocol();
|
|
|
+ if (auth == null || auth.isEmpty()) {
|
|
|
+ // 不需要授权
|
|
|
+
|
|
|
+ proxyPass(config, null, headers, response);
|
|
|
+ return chain.filter(exchange);
|
|
|
+ } else {
|
|
|
+ AuthInfo authInfo = bizCfg.getAuths().stream().filter(i -> auth.equals(i.getName())).findFirst().orElse(null);
|
|
|
+ if (authInfo == null) throw new HttpMsgException(HttpStatus.BAD_GATEWAY, HttpStatus.BAD_GATEWAY.value(),
|
|
|
+ "No Auth");
|
|
|
+ String authType = authInfo.getType();
|
|
|
+ String name = authInfo.getName();
|
|
|
+ MultiValueMap<String, HttpCookie> cookieMap = request.getCookies();
|
|
|
+ switch (authType) {
|
|
|
+ case AUTH_TYPE_WEB:
|
|
|
+ case AUTH_TYPE_WECHATMP:
|
|
|
+ String tokenName = String.format("t%d", BizUtil.getCRC32Checksum(name.getBytes()));
|
|
|
+ String cookieName = authInfo.getCookieName();
|
|
|
+ if (cookieName != null && !cookieName.isEmpty()) {
|
|
|
+ tokenName = cookieName;
|
|
|
+ }
|
|
|
+ HttpCookie cookie = cookieMap.getFirst(tokenName);
|
|
|
+ if (cookie != null) {
|
|
|
+ String tokenStr = cookie.getValue();
|
|
|
+ JWT jwt = JWTUtil.parseToken(tokenStr);
|
|
|
+ JWTPayload payload = jwt.getPayload();
|
|
|
+ String authdata = (String) payload.getClaim(KEY_AUTHDATA);
|
|
|
+ if (authdata != null && !authdata.isEmpty()) {
|
|
|
+ headers.add(KEY_AUTH_DATA, authdata);
|
|
|
+ }
|
|
|
+ if (authInfo.getParams() != null) {
|
|
|
+ headers.add(KEY_AUTH_PARAM, JSONUtil.toJsonStr(authInfo.getParams()));
|
|
|
+ }
|
|
|
+
|
|
|
+ proxyPass(config, authInfo, headers, response);
|
|
|
+ return chain.filter(exchange);
|
|
|
+ } else {
|
|
|
+ boolean authResponse401 = config.isAuthResponse401();
|
|
|
+ // 处理授权跳转
|
|
|
+ if (AUTH_TYPE_WECHATMP.equals(authType)) {
|
|
|
+ if (authResponse401) {
|
|
|
+ throw new HttpMsgException(HttpStatus.UNAUTHORIZED, HttpStatus.UNAUTHORIZED.value(), "");
|
|
|
+ } else {
|
|
|
+ String returnUrl = String.format("%s://%s%s?r=%s", protocol, host, authInfo.getUrl(),
|
|
|
+ URLEncoder.encode(String.format("%s://%s%s", protocol, host, requestUrl),
|
|
|
+ StandardCharsets.UTF_8.name()));
|
|
|
+ String weRdtUrl = String.format("https://open.weixin.qq.com/connect/oauth2/authorize?appid=%s&redirect_uri=%s&response_type=code&scope=%s&state=STATE#wechat_redirect",
|
|
|
+ authInfo.getWeappid(), URLEncoder.encode(returnUrl, StandardCharsets.UTF_8.name()),
|
|
|
+ authInfo.getScope());
|
|
|
+ response.setStatusCode(HttpStatus.FOUND);
|
|
|
+ response.getHeaders().set(HttpHeaders.LOCATION, weRdtUrl);
|
|
|
+ return response.setComplete();
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if (authResponse401) {
|
|
|
+ throw new HttpMsgException(HttpStatus.UNAUTHORIZED, HttpStatus.UNAUTHORIZED.value(), "");
|
|
|
+ } else {
|
|
|
+ String rUrl = String.format("%s://%s%s", protocol, host, requestUrl);
|
|
|
+ String url = String.format("%s?r=%s", authInfo.getUrl(), URLEncoder.encode(rUrl,
|
|
|
+ StandardCharsets.UTF_8.name()));
|
|
|
+ response.setStatusCode(HttpStatus.FOUND);
|
|
|
+ response.getHeaders().set(HttpHeaders.LOCATION, url);
|
|
|
+ return response.setComplete();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case AUTH_TYPE_API:
|
|
|
+ if (POST.equalsIgnoreCase(methodName))
|
|
|
+ throw new HttpMsgException(HttpStatus.METHOD_NOT_ALLOWED, HttpStatus.METHOD_NOT_ALLOWED.value(),
|
|
|
+ HttpStatus.METHOD_NOT_ALLOWED.getReasonPhrase());
|
|
|
+ String authStr = Optional.ofNullable(headers.getFirst(KEY_AUTHORIZATION)).orElse("");
|
|
|
+ if (!authStr.startsWith(KEY_BEARER))
|
|
|
+ throw new HttpMsgException(HttpStatus.UNAUTHORIZED, HttpStatus.UNAUTHORIZED.value(),
|
|
|
+ "Auth Fail");
|
|
|
+
|
|
|
+ // JWT 验证
|
|
|
+ JWT jwt = JWTUtil.parseToken(authStr.replace(KEY_BEARER, ""));
|
|
|
+ JWTPayload payload = jwt.getPayload();
|
|
|
+ if (payload == null)
|
|
|
+ throw new HttpMsgException(HttpStatus.INTERNAL_SERVER_ERROR, HttpStatus.INTERNAL_SERVER_ERROR.value(),
|
|
|
+ "Parse Payload Error");
|
|
|
+ String appId = (String) payload.getClaim(KEY_APPID);
|
|
|
+ if (appId == null) {
|
|
|
+ throw new HttpMsgException(HttpStatus.INTERNAL_SERVER_ERROR, HttpStatus.INTERNAL_SERVER_ERROR.value(),
|
|
|
+ "Payload 必需包含 appid");
|
|
|
+ }
|
|
|
+ AppInfo appInfo = bizCfg.getApps().stream().filter(i -> appId.equals(i.getAppid())).findFirst().orElse(null);
|
|
|
+ if (appInfo == null)
|
|
|
+ throw new HttpMsgException(HttpStatus.INTERNAL_SERVER_ERROR, HttpStatus.INTERNAL_SERVER_ERROR.value(),
|
|
|
+ "不正确的 appid");
|
|
|
+ String appSecret = appInfo.getAppsecret();
|
|
|
+ Long exp = (Long) payload.getClaim(KEY_EXP);
|
|
|
+ // 判断 token 是否过期
|
|
|
+ if (exp != null && exp > 0 && (new Date()).getTime() > exp) {
|
|
|
+ throw new HttpMsgException(HttpStatus.INTERNAL_SERVER_ERROR, HTTP_STATUS_CODE_602,
|
|
|
+ "Token expired");
|
|
|
+ }
|
|
|
+ //判断是否需要验证checksum
|
|
|
+ if (appInfo.isChecksum()) {
|
|
|
+ String checksum = (String) payload.getClaim(KEY_CHECKSUM);
|
|
|
+ if (!DigestUtil.md5Hex(bodyStr).equals(checksum)) {
|
|
|
+ throw new HttpMsgException(HttpStatus.INTERNAL_SERVER_ERROR, HTTP_STATUS_CODE_601,
|
|
|
+ "Checksum Error");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ String authdata = (String) payload.getClaim(KEY_AUTHDATA);
|
|
|
+ if (authdata != null && !authdata.isEmpty()) {
|
|
|
+ headers.add(KEY_AUTH_DATA, authdata);
|
|
|
+ }
|
|
|
+ Map<String, String> authParamMap = new HashMap<>();
|
|
|
+ authParamMap.putAll(authInfo.getParams());
|
|
|
+ authParamMap.putAll(appInfo.getParams());
|
|
|
+ headers.add(KEY_AUTH_PARAM, JSONUtil.toJsonStr(authParamMap));
|
|
|
+
|
|
|
+ // 下发当前授权的 appid 至后端
|
|
|
+ headers.add(KEY_AUTH_APPID, appInfo.getAppid());
|
|
|
+
|
|
|
+ proxyPass(config, authInfo, headers, response);
|
|
|
+ return chain.filter(exchange);
|
|
|
+ default:
|
|
|
+ throw new HttpMsgException(HttpStatus.BAD_GATEWAY, HttpStatus.BAD_GATEWAY.value(), "No Auth");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (HttpMsgException e) {
|
|
|
+ response.setStatusCode(e.getHttpStatus());
|
|
|
+ response.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);
|
|
|
+ DataBuffer dataBuffer = response.bufferFactory().wrap(JSONUtil.toJsonStr(new BaseResponse(e.getStatusCode(), e.getMessage())).getBytes());
|
|
|
+ return response.writeWith(Flux.just(dataBuffer));
|
|
|
+ } catch (UnsupportedEncodingException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ private void proxyPass(HostLocationInfo locationInfo, AuthInfo authInfo, HttpHeaders headers, ServerHttpResponse response) throws HttpMsgException {
|
|
|
+ // 原网关的“选择后端服务”、“解析协议”部分代码不需要重编码实现,直接使用 Spring Cloud Gateway 的 uri 实现转发
|
|
|
+// if (locationInfo.getServers() == null || locationInfo.getServers().isEmpty())
|
|
|
+// throw new HttpMsgException(HttpStatus.BAD_GATEWAY, HttpStatus.BAD_GATEWAY.value(),
|
|
|
+// "No Server");
|
|
|
+ // “重定向”部分代码不需要重编码实现,直接使用 Filter 的 RedirectTo 实现
|
|
|
+
|
|
|
+ // 设置授权
|
|
|
+ if (authInfo == null) {
|
|
|
+ String refreshAuth = Optional.ofNullable(locationInfo.getRefreshAuth()).orElse("");
|
|
|
+ authInfo = bizCfg.getAuths().stream().filter(i -> refreshAuth.equals(i.getName())).findFirst().orElse(null);
|
|
|
+ }
|
|
|
+ if (authInfo != null && !AUTH_TYPE_API.equals(authInfo.getType())) {
|
|
|
+ String authDataStr = Optional.ofNullable(headers.getFirst(KEY_AUTH_DATA)).orElse("");
|
|
|
+ String cleanAuthStr = Optional.ofNullable(headers.getFirst(KEY_CLEAN_DATA)).orElse("");
|
|
|
+ String cookieName = Optional.ofNullable(authInfo.getCookieName()).orElse("");
|
|
|
+ String appSecret = Optional.ofNullable(authInfo.getJwtSecret()).orElse("");
|
|
|
+ String name = Optional.ofNullable(authInfo.getName()).orElse("");
|
|
|
+ // 签发JWT授权
|
|
|
+ if (!authDataStr.isEmpty()) {
|
|
|
+ Map<String, Object> payload = new HashMap<>();
|
|
|
+ payload.put(KEY_AUTHDATA, authDataStr);
|
|
|
+ payload.put(KEY_IAT, new Date().getTime());
|
|
|
+ String token = JWTUtil.createToken(payload, JWTSignerUtil.hs256(appSecret.getBytes()));
|
|
|
+ String tokenName = String.format("t%d", BizUtil.getCRC32Checksum(name.getBytes()));
|
|
|
+ if (!cookieName.isEmpty()) tokenName = cookieName;
|
|
|
+ ResponseCookie cookie = ResponseCookie.from(tokenName, token)
|
|
|
+ .httpOnly(false)
|
|
|
+ .path("/")
|
|
|
+ .maxAge(authInfo.getMaxAge())
|
|
|
+ .domain(authInfo.getCookieDomain())
|
|
|
+ .build();
|
|
|
+ response.addCookie(cookie);
|
|
|
+ }
|
|
|
+ // 清除授权
|
|
|
+ if (YES.equals(cleanAuthStr)) {
|
|
|
+ String tokenName = String.format("t%d", BizUtil.getCRC32Checksum(name.getBytes()));
|
|
|
+ if (!cookieName.isEmpty()) tokenName = cookieName;
|
|
|
+ ResponseCookie cookie = ResponseCookie.from(tokenName, "")
|
|
|
+ .httpOnly(false)
|
|
|
+ .path("/")
|
|
|
+ .maxAge(-1)
|
|
|
+ .domain(authInfo.getCookieDomain())
|
|
|
+ .build();
|
|
|
+ response.addCookie(cookie);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|