栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

在mybatis执行SQL语句之前进行拦击处理实例

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

在mybatis执行SQL语句之前进行拦击处理实例

比较适用于在分页时候进行拦截。对分页的SQL语句通过封装处理,处理成不同的分页sql。

实用性比较强。

import java.sql.Connection; 
import java.sql.PreparedStatement; 
import java.sql.ResultSet; 
import java.sql.SQLException; 
import java.util.List; 
import java.util.Properties; 
 
import org.apache.ibatis.executor.parameter.ParameterHandler; 
import org.apache.ibatis.executor.statement.RoutingStatementHandler; 
import org.apache.ibatis.executor.statement.StatementHandler; 
import org.apache.ibatis.mapping.BoundSql; 
import org.apache.ibatis.mapping.MappedStatement; 
import org.apache.ibatis.mapping.ParameterMapping; 
import org.apache.ibatis.plugin.Interceptor; 
import org.apache.ibatis.plugin.Intercepts; 
import org.apache.ibatis.plugin.Invocation; 
import org.apache.ibatis.plugin.Plugin; 
import org.apache.ibatis.plugin.Signature; 
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; 
 
import com.yidao.utils.Page; 
import com.yidao.utils.ReflectHelper; 
 
  
@Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})}) 
public class PageInterceptor implements Interceptor { 
  private String dialect = ""; //数据库方言  
  private String pageSqlId = ""; //mapper.xml中需要拦截的ID(正则匹配)  
    
  public Object intercept(Invocation invocation) throws Throwable { 
    //对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类baseStatementHandler,  
    //baseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,  
    //SimpleStatementHandler是用于处理Statement的,PreparedStatementHandler是处理PreparedStatement的,而CallableStatementHandler是  
    //处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,而在RoutingStatementHandler里面拥有一个  
    //StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的baseStatementHandler,即SimpleStatementHandler、  
    //PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。  
    //我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候  
    //是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。 
    if(invocation.getTarget() instanceof RoutingStatementHandler){  
      RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();  
      StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");  
      BoundSql boundSql = delegate.getBoundSql(); 
      Object obj = boundSql.getParameterObject(); 
      if (obj instanceof Page) {  
 Page page = (Page) obj;  
 //通过反射获取delegate父类baseStatementHandler的mappedStatement属性  
 MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate, "mappedStatement");  
 //拦截到的prepare方法参数是一个Connection对象  
 Connection connection = (Connection)invocation.getArgs()[0];  
 //获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句  
 String sql = boundSql.getSql();  
 //给当前的page参数对象设置总记录数  
 this.setTotalRecord(page,  
     mappedStatement, connection);  
 //获取分页Sql语句  
 String pageSql = this.getPageSql(page, sql);  
 //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句  
 ReflectHelper.setFieldValue(boundSql, "sql", pageSql);  
      }  
    }  
    return invocation.proceed();  
  } 
   
    
  private void setTotalRecord(Page page,  
      MappedStatement mappedStatement, Connection connection) {  
    //获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。  
    //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。  
    BoundSql boundSql = mappedStatement.getBoundSql(page);  
    //获取到我们自己写在Mapper映射语句中对应的Sql语句  
    String sql = boundSql.getSql();  
    //通过查询Sql语句获取到对应的计算总记录数的sql语句  
    String countSql = this.getCountSql(sql);  
    //通过BoundSql获取对应的参数映射  
    List parameterMappings = boundSql.getParameterMappings();  
    //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。  
    BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);  
    //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象  
    ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);  
    //通过connection建立一个countSql对应的PreparedStatement对象。  
    PreparedStatement pstmt = null;  
    ResultSet rs = null;  
    try {  
      pstmt = connection.prepareStatement(countSql);  
      //通过parameterHandler给PreparedStatement对象设置参数  
      parameterHandler.setParameters(pstmt);  
      //之后就是执行获取总记录数的Sql语句和获取结果了。  
      rs = pstmt.executeQuery();  
      if (rs.next()) {  
int totalRecord = rs.getInt(1);  
//给当前的参数page对象设置总记录数  
page.setTotalRecord(totalRecord);  
      }  
    } catch (SQLException e) {  
      e.printStackTrace();  
    } finally {  
      try {  
if (rs != null)  
  rs.close();  
 if (pstmt != null)  
  pstmt.close();  
      } catch (SQLException e) {  
e.printStackTrace();  
      }  
    }  
  }  
   
    
  private String getCountSql(String sql) {  
    int index = sql.indexOf("from");  
    return "select count(*) " + sql.substring(index);  
  }  
   
    
  private String getPageSql(Page page, String sql) {  
    StringBuffer sqlBuffer = new StringBuffer(sql);  
    if ("mysql".equalsIgnoreCase(dialect)) {  
      return getMysqlPageSql(page, sqlBuffer);  
    } else if ("oracle".equalsIgnoreCase(dialect)) {  
      return getOraclePageSql(page, sqlBuffer);  
    }  
    return sqlBuffer.toString();  
  }  
   
    
  private String getMysqlPageSql(Page page, StringBuffer sqlBuffer) {  
   //计算第一条记录的位置,Mysql中记录的位置是从0开始的。  
//   System.out.println("page:"+page.getPage()+"-------"+page.getRows()); 
   int offset = (page.getPage() - 1) * page.getRows();  
   sqlBuffer.append(" limit ").append(offset).append(",").append(page.getRows());  
   return sqlBuffer.toString();  
  }  
   
    
  private String getOraclePageSql(Page page, StringBuffer sqlBuffer) {  
   //计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的  
   int offset = (page.getPage() - 1) * page.getRows() + 1;  
   sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getRows());  
   sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);  
   //上面的Sql语句拼接之后大概是这个样子:  
   //select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16  
   return sqlBuffer.toString();  
  }  
   
    
       
  public Object plugin(Object arg0) {  
    // TODO Auto-generated method stub  
    if (arg0 instanceof StatementHandler) {  
      return Plugin.wrap(arg0, this);  
    } else {  
      return arg0;  
    }  
  }  
  
    
  public void setProperties(Properties p) { 
     
  } 
 
  public String getDialect() { 
    return dialect; 
  } 
 
  public void setDialect(String dialect) { 
    this.dialect = dialect; 
  } 
 
  public String getPageSqlId() { 
    return pageSqlId; 
  } 
 
  public void setPageSqlId(String pageSqlId) { 
    this.pageSqlId = pageSqlId; 
  } 
   
} 

xml配置:

 
   
     
     
     
   
   
   
   
      
      
     
    

Page类

package com.yidao.utils;  
 
public class Page { 
   
  private Integer rows; 
   
  private Integer page = 1; 
   
  private Integer totalRecord; 
 
  public Integer getRows() { 
    return rows; 
  } 
 
  public void setRows(Integer rows) { 
    this.rows = rows; 
  } 
 
  public Integer getPage() { 
    return page; 
  } 
 
  public void setPage(Integer page) { 
    this.page = page; 
  } 
 
  public Integer getTotalRecord() { 
    return totalRecord; 
  } 
 
  public void setTotalRecord(Integer totalRecord) { 
    this.totalRecord = totalRecord; 
  } 
   
} 

ReflectHelper类

package com.yidao.utils; 
 
import java.lang.reflect.Field; 
 
import org.apache.commons.lang3.reflect.FieldUtils; 
 
public class ReflectHelper { 
   
  public static Object getFieldValue(Object obj , String fieldName ){ 
     
    if(obj == null){ 
      return null ; 
    } 
     
    Field targetField = getTargetField(obj.getClass(), fieldName); 
     
    try { 
      return FieldUtils.readField(targetField, obj, true ) ; 
    } catch (IllegalAccessException e) { 
      e.printStackTrace(); 
    }  
    return null ; 
  } 
   
  public static Field getTargetField(Class targetClass, String fieldName) { 
    Field field = null; 
 
    try { 
      if (targetClass == null) { 
 return field; 
      } 
 
      if (Object.class.equals(targetClass)) { 
 return field; 
      } 
 
      field = FieldUtils.getDeclaredField(targetClass, fieldName, true); 
      if (field == null) { 
 field = getTargetField(targetClass.getSuperclass(), fieldName); 
      } 
    } catch (Exception e) { 
    } 
 
    return field; 
  } 
   
  public static void setFieldValue(Object obj , String fieldName , Object value ){ 
    if(null == obj){return;} 
    Field targetField = getTargetField(obj.getClass(), fieldName);  
    try { 
FieldUtils.writeField(targetField, obj, value) ; 
    } catch (IllegalAccessException e) { 
      e.printStackTrace(); 
    }  
  }  
}

  以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持考高分网。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/146581.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号