SpringBoot 基于注解实现接口的代理Bean注入
在springboot加载时需自己手动将接口的代理bean注入到spring容器中,这样在service层注入该接口类型即可,
1.在SpringBoot启动类上添加EnableProxyBeanScan注解
EnableProxyBeanScan为自定义注解,通过Import注解扫描被ProxyBean注解的类或者被ProxyBean修饰的注解注解的类("注解继承")
ProxyBeanDefinitionRegistrar实现ImportBeanDefinitionRegistrar 通过ProxyInterfaceBeanBeanDefinitionScanner 来进行bean的加载
ProxyFactoryBean为bean的工厂类,提供代理bean
ProxyHandler为代理业务逻辑接口,提供三个参数: Class(被代理的类),Method(被代理的方法),Object[] 入参参数
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Import(EnableProxyBeanScan.ProxyBeanDefinitionRegistrar.class)
public @interface EnableProxyBeanScan {
String[] basePackages() default {};
class ProxyBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar {
@Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
ProxyInterfaceBeanBeanDefinitionScanner scanner = new ProxyInterfaceBeanBeanDefinitionScanner(registry);
scanner.scan(getBasePackages(importingClassMetadata));
}
private String[] getBasePackages(AnnotationMetadata importingClassMetadata){
Map<String, Object> attributes = importingClassMetadata.getAnnotationAttributes(EnableProxyBeanScan.class.getCanonicalName());
Set<String> basePackages = new HashSet();
String[] basePackagesArr = (String[])((String[])attributes.get("basePackages"));
for(String item: basePackagesArr){
if(StringUtils.hasText(item))
basePackages.add(item);
}
if (basePackages.isEmpty()) {
basePackages.add(ClassUtils.getPackageName(importingClassMetadata.getClassName()));
}
return basePackages.toArray(new String[basePackages.size()]);
}
}
}
public class ProxyInterfaceBeanBeanDefinitionScanner extends ClassPathBeanDefinitionScanner { public ProxyInterfaceBeanBeanDefinitionScanner(BeanDefinitionRegistry registry) { //registry是Spring的Bean注册中心 // false表示不使用ClassPathBeanDefinitionScanner默认的TypeFilter // 默认的TypeFilter只会扫描带有@Service,@Controller,@Repository,@Component注解的类 super(registry,false); } @Override protected Set<BeanDefinitionHolder> doScan(String... basePackages) { addIncludeFilter(new AnnotationTypeFilter(ProxyBean.class)); Set<BeanDefinitionHolder> beanDefinitionHolders = super.doScan(basePackages); if (beanDefinitionHolders.isEmpty()){ System.err.println("No Interface Found!"); }else{ //创建代理对象 createBeanDefinition(beanDefinitionHolders); } return beanDefinitionHolders; } @Override protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) { AnnotationMetadata metadata = beanDefinition.getMetadata(); return metadata.isInterface() || metadata.isAbstract(); } /** * 为扫描到的接口创建代理对象 * @param beanDefinitionHolders */ private void createBeanDefinition(Set<BeanDefinitionHolder> beanDefinitionHolders) { for (BeanDefinitionHolder beanDefinitionHolder : beanDefinitionHolders) { GenericBeanDefinition beanDefinition = ((GenericBeanDefinition) beanDefinitionHolder.getBeanDefinition()); //将bean的真实类型改变为FactoryBean beanDefinition.getConstructorArgumentValues(). addGenericArgumentValue(beanDefinition.getBeanClassName()); beanDefinition.setBeanClass(ProxyFactoryBean.class); beanDefinition.setAutowireMode(GenericBeanDefinition.AUTOWIRE_BY_TYPE); } } }
@Target({ElementType.TYPE,ElementType.ANNOTATION_TYPE}) @Retention(RetentionPolicy.RUNTIME) public @interface ProxyBean { Class<? extends ProxyHandler> value(); }
public interface ProxyHandler{ Object execute(Class<?> proxyType,Object proxy, Method proxyMethod, Object[] args); }
public class ProxyFactoryBean<T> implements FactoryBean { private static final Map<Class<? extends ProxyHandler>,ProxyHandler> ProxyHandlers = new ConcurrentHashMap<>(); private Class<T> interfaceClass; private Class<? extends ProxyHandler> proxyHandlerType; public ProxyFactoryBean(Class<T> interfaceClass) { this.interfaceClass = interfaceClass; this.proxyHandlerType = AnnotationUtils.findAnnotation(interfaceClass, ProxyBean.class).value(); if(!ProxyFactoryBean.ProxyHandlers.containsKey(proxyHandlerType)) { ProxyHandler proxyHandler = ClassUtils.newInstance(proxyHandlerType); SpringBean.inject(proxyHandler); ProxyFactoryBean.ProxyHandlers.put(proxyHandlerType, proxyHandler); } } @Override public T getObject() throws Exception { final ProxyHandler proxyHandler = ProxyFactoryBean.ProxyHandlers.get(proxyHandlerType); return (T) Proxy.newProxyInstance( interfaceClass.getClassLoader(), new Class[]{interfaceClass}, (proxy,method,args) -> proxyHandler.execute(interfaceClass,proxy,method,args) ); } @Override public Class<T> getObjectType() { return interfaceClass; } }
简单的例子:
类似spring-feign的接口发送Http请求
1.先定义一个注解HttpClient,和HttpClientProxyHandler
@Target({ElementType.TYPE,ElementType.ANNOTATION_TYPE}) @Retention(RetentionPolicy.RUNTIME) @ProxyBean(HttpClient.HttpClientProxyHandler.class) public @interface HttpClient { @Target({ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @interface Request{ String url(); RequestMethod method() default RequestMethod.POST; } /**简单定义下进行测试,实际实现肯定要比这个复杂*/ class HttpClientProxyHandler implements ProxyHandler { /**这个类虽然没有被Spring管理,不过通过这个注解可以实现SpringBean的注入和使用, * 见ProxyFactoryBean构造方法的代码 * SpringBean.inject(proxyHandler); */ @Autowired private RestTemplate template; @Override public Object execute(Class<?> proxyType,Object proxy, Method proxyMethod, Object[] args) { return template.postForObject( proxyMethod.getAnnotation(Request.class).url() ,args[0] ,proxyMethod.getReturnType() ); } } }
2.被代理的接口
@HttpClient public interface LoginService { @HttpClient.Request(url="ddd") String getUserAge(ExamineReqDto username); }
3.测试,
测试这里没有细致的测,RestTemplate这里是成功拿到了,不影响后续的使用
最后,附Bean注入的代码:
@Component
public class SpringBean implements ApplicationContextAware {
private static final Logger log = LoggerFactory.getLogger(SpringBean.class);
private static ApplicationContext applicationContext;
private SpringBean(){}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
SpringBean.applicationContext = applicationContext;
}
public static <T> T getSpringBean(Class<T> clazz){
return SpringBean.applicationContext.getBean(clazz);
}
@SuppressWarnings("unchecked")
public static <T> T getSpringBean(String beanName){
return (T) SpringBean.applicationContext.getBean(beanName);
}
public static void inject(Object object){
if(object == null)
return;
Class clazz = object.getClass();
while (clazz != Object.class) {
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
Autowired annotation = field.getAnnotation(Autowired.class);
if (annotation != null) {
Reflector.setFieldValue(object,field,SpringBean.getSpringBean(field.getType()));
}
Resource resource = field.getAnnotation(Resource.class);
if (resource != null) {
Reflector.setFieldValue(object,field,SpringBean.getSpringBean(field.getName()));
}
}
clazz = clazz.getSuperclass();
}
}
}
补全Http请求代理接口
@Target({ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) @ProxyCustomizer(HttpClientProxyHandler.class) public @interface HttpClient { }
import com.sinosoft.demo.componment.proxy.core.ProxyHandler;
import javafx.util.Builder;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.ClassUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.client.RestTemplate;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.*;
import java.util.*;
import static org.springframework.objenesis.instantiator.util.ClassUtils.newInstance;
public class HttpClientProxyHandler implements ProxyHandler {
@Override
public Object execute(Class<?> proxyType, Object proxy ,Method proxyMethod, Object[] args) {
if(proxyMethod.isDefault()){ //不对default 的方法进行代理
try {
Constructor<MethodHandles.Lookup> constructor = MethodHandles.Lookup.class
.getDeclaredConstructor(Class.class, int.class);
constructor.setAccessible(true);
Class<?> declaringClass = proxyMethod.getDeclaringClass();
int allModes = MethodHandles.Lookup.PUBLIC | MethodHandles.Lookup.PRIVATE | MethodHandles.Lookup.PROTECTED | MethodHandles.Lookup.PACKAGE;
return constructor.newInstance(declaringClass, allModes)
.unreflectSpecial(proxyMethod, declaringClass)
.bindTo(proxy)
.invokeWithArguments(args);
}catch(Exception ex) {
throw new RuntimeException(ex);
} catch (Throwable throwable) {
throwable.printStackTrace();
}
}
RequestMapping requestMapping = AnnotationUtils.getAnnotation(proxyMethod, RequestMapping.class);
String url = getRequestUrl(requestMapping);
Object invokeParam = handleRequestObject(proxyMethod,args);
Class<?> returnType = getReturnType(proxyMethod);
MultiValueMap<String, String> httpAttributes = getHttpAttributes(proxyType, proxyMethod);
RestTemplate template = createRestTemplate(httpAttributes,proxyType,proxyMethod);
HttpEntity entity = new HttpEntity(invokeParam,headers(httpAttributes));
ResponseEntity<?> responseEntity = template.exchange(url,getHttpMethod(requestMapping),entity,returnType,args);
return handleReturnObject(proxyMethod,invokeParam,responseEntity);
}
private String getRequestUrl(RequestMapping requestMapping){
Map<String, Object> annotationAttributes = AnnotationUtils.getAnnotationAttributes(requestMapping);
String[] path = (String[]) annotationAttributes.get("path");
if(path.length>0)
return path[0];
throw new UnsupportedOperationException("url not be null!");
}
private Class<?> getReturnType(Method proxyMethod){
Class<?> returnType = proxyMethod.getReturnType();
if(ClassUtils.isAssignable(CallBack.class,returnType)){
Type[] interfaces = returnType.getGenericInterfaces();
for (int i = 0; i < interfaces.length; i++) {
if (interfaces[i] instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) interfaces[i];
if (parameterizedType.getRawType() == CallBack.class) {
return (Class<?>) parameterizedType.getActualTypeArguments()[0];
}
}
}
}
if(ClassUtils.isAssignable(ResponseEntity.class,returnType)){
Type genericReturnType = proxyMethod.getGenericReturnType();
if (genericReturnType instanceof ParameterizedType) {
Type[] actualTypeArguments = ((ParameterizedType) genericReturnType).getActualTypeArguments();
return (Class<?>) actualTypeArguments[0];
}
}
return returnType;
}
private Object handleRequestObject(Method proxyMethod,Object[] args){
Object invokeParam;
if(args.length==1){
if(args[0] instanceof Builder)
invokeParam = ((Builder)args[0]).build();
else
invokeParam = args[0];
}else{
Map<String,Object> paramsMap = new LinkedHashMap<>();
Parameter[] parameters = proxyMethod.getParameters();
for(int i=0;i<parameters.length;i++){
RequestParam annotation = AnnotationUtils.findAnnotation(parameters[i], RequestParam.class);
if(annotation!=null)
paramsMap.put((String)AnnotationUtils.getDefaultValue(annotation),args[i]);
else
paramsMap.put(parameters[i].getName(),args[i]);
}
invokeParam = paramsMap;
}
return invokeParam;
}
private Object handleReturnObject(Method proxyMethod,Object invokeParam,ResponseEntity<?> responseEntity){
Object result = null;
if(ClassUtils.isAssignable(CallBack.class,proxyMethod.getReturnType())){
CallBack callBack = (CallBack) newInstance(proxyMethod.getReturnType());
callBack.call(responseEntity,invokeParam);
result = callBack;
}else if(ClassUtils.isAssignable(ResponseEntity.class,proxyMethod.getReturnType())){
result = responseEntity;
}else{
result = responseEntity.getBody();
}
return result;
}
private HttpMethod getHttpMethod(RequestMapping requestMapping){
RequestMethod[] requestMethod = requestMapping.method();
if(requestMethod.length>0)
return HttpMethod.valueOf(requestMethod[0].name());
return HttpMethod.POST;
}
private MultiValueMap<String,String> getHttpAttributes(Class<?> proxyType, Method proxyMethod){
MultiValueMap<String,String> attributes= new LinkedMultiValueMap<>();
HttpAttribute[] proxyTypeAttributeAnnotations = proxyType.getAnnotationsByType(HttpAttribute.class);
for(HttpAttribute attribute : proxyTypeAttributeAnnotations)
attributes.add(attribute.name(),attribute.value());
HttpAttribute[] proxyMethodAttributeAnnotations = proxyMethod.getAnnotationsByType(HttpAttribute.class);
for(HttpAttribute attribute : proxyMethodAttributeAnnotations)
attributes.add(attribute.name(),attribute.value());
return attributes;
}
private HttpHeaders headers(MultiValueMap<String,String> multiValueMap){
if(multiValueMap.containsKey(HttpAttribute.CONNECTION_TIMEOUT))
multiValueMap.remove(HttpAttribute.CONNECTION_TIMEOUT);
if(multiValueMap.containsKey(HttpAttribute.SOCKET_TIMEOUT))
multiValueMap.remove(HttpAttribute.SOCKET_TIMEOUT);
return new HttpHeaders(multiValueMap);
}
private RestTemplate createRestTemplate(MultiValueMap<String,String> attributes,Class<?> proxyType, Method proxyMethod){
SimpleClientHttpRequestFactory factory = new SimpleClientHttpRequestFactory();
if(attributes.containsKey(HttpAttribute.CONNECTION_TIMEOUT))
factory.setConnectTimeout(Integer.valueOf(attributes.getFirst(HttpAttribute.CONNECTION_TIMEOUT)));
if(attributes.containsKey(HttpAttribute.SOCKET_TIMEOUT))
factory.setReadTimeout(Integer.valueOf(attributes.getFirst(HttpAttribute.SOCKET_TIMEOUT)));
RestTemplate template = new RestTemplate();
template.setRequestFactory(new BufferingClientHttpRequestFactory(factory));
HashSet<Class<? extends ClientHttpRequestInterceptor>> interceptors = new LinkedHashSet<>();
HttpRequestInterceptors proxyTypeAnnotation = proxyType.getAnnotation(HttpRequestInterceptors.class);
HttpRequestInterceptors proxyMethodAnnotation = proxyMethod.getAnnotation(HttpRequestInterceptors.class);
if(proxyTypeAnnotation!=null) {
for (Class<? extends ClientHttpRequestInterceptor> interceptor : proxyTypeAnnotation.include())
interceptors.add(interceptor);
for (Class<? extends ClientHttpRequestInterceptor> interceptor : proxyTypeAnnotation.unInclude())
interceptors.remove(interceptor);
}
if(proxyMethodAnnotation!=null) {
for (Class<? extends ClientHttpRequestInterceptor> interceptor : proxyMethodAnnotation.include())
interceptors.add(interceptor);
for (Class<? extends ClientHttpRequestInterceptor> interceptor : proxyMethodAnnotation.unInclude())
interceptors.remove(interceptor);
}
List<ClientHttpRequestInterceptor> interceptorsList = template.getInterceptors();
for( Class<? extends ClientHttpRequestInterceptor> interceptor :interceptors)
interceptorsList.add(newInstance(interceptor));
return template;
}
}
public interface CallBack<T,R> { void call(ResponseEntity<T> entity, R requestParam); }
@Target({ElementType.TYPE,ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Repeatable(HttpAttribute.HttpAttributes.class) public @interface HttpAttribute { String name(); String value(); String SOCKET_TIMEOUT = "http.socket.timeout"; String CONNECTION_TIMEOUT = "http.connection.timeout"; @Target({ElementType.TYPE,ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @interface HttpAttributes { HttpAttribute[] value(); } }
@Target({ElementType.TYPE,ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Inherited public @interface HttpRequestInterceptors { Class<? extends ClientHttpRequestInterceptor>[] include() default {}; Class<? extends ClientHttpRequestInterceptor>[] unInclude() default {}; }
CallBack(补充)---"集合"类结果数据
CallBack类型返回值作为对响应报文实体的进一步封装,Http接口为我们有时可能希望返回集合类型的数据结果(List<Pojo>),该代理类没有直接提供针对集合类型返回值的封装,
比如保险系统针对被保人进行风控校验(第三方接口来校验,自己解析结果,"集合">0表示有风控风险,进行提示或者其他后续处理)
这里通过实现callBack来变相实现对象的迭代(foreach)操作,
public class InsuredRiskWarnInfo implements CallBack<ResponseDto,RequestDto>,Iterable<InsuredRiskWarnInfo>{ private List<InsuredRiskWarnInfo> data; @Getter private String name; //姓名 @Getter private String code; //错误码 @Getter private String message;//错误信息 public void merge(InsuredRiskWarnInfo info){ if(this.data == null) this.data = info.data; else this.data.addAll(info.data); } @Override public void call(ResponseEntity<ResponseDto> entity, RequestDto requestParam) { data = new ArrayList<>(); // if(entity.getStatusCode() != HttpStatus.OK) // throw new RuntimeException(entity.getStatusCode().getReasonPhrase()); /**这里是具体的对响应对象的数据封装逻辑,假设一些数据*/ Map<String,String> map = new HashMap<>(); map.put("1","错误1"); map.put("2","错误2"); map.put("3","错误3"); for(Map.Entry<String,String> item : map.entrySet()){ InsuredRiskWarnInfo info = new InsuredRiskWarnInfo(); info.name = "某某人"; //可以从请求对象requestParam中取 info.code = item.getKey(); info.message = item.getValue(); info.data = this.data; this.data.add(info); } } @Override public Iterator<InsuredRiskWarnInfo> iterator() { return data.iterator(); }
调用示例 :
/**模拟请求后*/ InsuredRiskWarnInfo insuredRiskWarnInfo = new InsuredRiskWarnInfo(); insuredRiskWarnInfo.call(null,null); String format = "%s风控失败%s:%s"; for (InsuredRiskWarnInfo info : insuredRiskWarnInfo){ System.err.println(String.format(format,info.getName(),info.getCode(),info.getMessage())); } /** 某某人风控失败1:错误1 某某人风控失败2:错误2 某某人风控失败3:错误3 */