MyBatis分页功能实现

时间:2022-09-20 20:54:19

MyBatis分页功能实现

用了几天MyBatis,它不仅没有足够的亮点让我去使用,而且我觉得它恶心,但是有时为了团队的共同合作,还是要接受的.下面是从一个项目取出来的一个MyBatis分页功能.

1.注册拦截器

    <plugins>
        <plugin interceptor="org.exam.page.PageInterceptor">
            <property name="dataBaseType" value="mysql" />
        </plugin>
    </plugins>

其中相关的几个类:

package org.exam.page;
import java.io.Serializable;
import java.util.List;
public class Page<T> implements Serializable {
    private int currentPage;//当前页
    private int pageSize;
    private int totalPage;
    private int totalRecord;
    private int offset;
    private T paraObject;
    private List<T> resultList;
    public int getOffset() {
        return offset;
    }
    public void setOffset(int offset) {
        this.offset = offset;
    }
    public int getCurrentPage() {
        return currentPage;
    }
    public void setCurrentPage(int currentPage) {
        this.currentPage = currentPage;
    }
    public void setCurrentPage(int currentPage, int pageSize) {
        currentPage = currentPage <= 0 ? 1 : currentPage;
        this.currentPage = currentPage;
        this.pageSize = pageSize <= 0 ? 10 : pageSize;
        int offset = (currentPage - 1) * this.pageSize;
        this.offset = offset;
    }
    public int getPageSize() {
        return pageSize;
    }
    public void setPageSize(int pageSize) {
        this.pageSize = pageSize;
    }
    public int getTotalRecord() {
        return totalRecord;
    }
    public void setTotalRecord(int totalRecord) {
        this.totalRecord = totalRecord;
        int totalPage = totalRecord % getPageSize() == 0 ? totalRecord / getPageSize() : totalRecord / pageSize + 1;
        this.setTotalPage(totalPage);
    }
    public int getTotalPage() {
        return totalPage;
    }
    public void setTotalPage(int totalPage) {
        this.totalPage = totalPage;
    }
    public T getParaObject() {
        return paraObject;
    }
    public void setParaObject(T paraObject) {
        this.paraObject = paraObject;
    }
    public List<T> getResultList() {
        return resultList;
    }
    public void setResultList(List<T> resultList) {
        this.resultList = resultList;
    }
}
package org.exam.page;

import java.lang.reflect.Field;
public class ReflectUtil {
    /** * 利用反射获取指定对象的指定属性 * @param obj 目标对象 * @param fieldName 目标属性 * @return 目标属性的值 */
    public static Object getFieldValue(Object obj, String fieldName) {
        Object result = null;
        Field field = ReflectUtil.getField(obj, fieldName);
        if (field != null) {
            field.setAccessible(true);
            try {
                result = field.get(obj);
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    /** * 利用反射获取指定对象里面的指定属性 * @param obj 目标对象 * @param fieldName 目标属性 * @return 目标字段 */
    public static Field getField(Object obj, String fieldName) {
        Field field = null;
        for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
            try {
                field = clazz.getDeclaredField(fieldName);
                break;
            } catch (NoSuchFieldException e) {
                //这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
            }
        }
        return field;
    }

    /** * 利用反射设置指定对象的指定属性为指定的值 * @param obj 目标对象 * @param fieldName 目标属性 * @param fieldValue 目标值 */
    public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
        Field field = ReflectUtil.getField(obj, fieldName);
        if (field != null) {
            field.setAccessible(true);
            try {
                field.set(obj, fieldValue);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}
package org.exam.page;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;
import org.apache.ibatis.executor.parameter.DefaultParameterHandler;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
@Intercepts({@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class})})
public class PageInterceptor implements Interceptor {
    private String dataBaseType;//数据库类型 不同数据库不同的分页
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
        StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
        BoundSql boundSql = delegate.getBoundSql();
        Object obj = boundSql.getParameterObject();
        if (obj instanceof Page<?>) {//只对传入参数为Page类型来进行处理
            Page<?> page = (Page<?>) obj;
            MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
            Connection connection = (Connection) invocation.getArgs()[0];
            String sql = boundSql.getSql();
            this.setTotalRecord(page, mappedStatement, connection);
            String pageSql = this.getPageSql(page, sql);
            ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
        }
        return invocation.proceed();
    }
    @Override
    public Object plugin(Object obj) {
        if (obj instanceof StatementHandler) {
            return Plugin.wrap(obj, this);
        }
        return obj;
    }
    @Override
    public void setProperties(Properties properties) {
        this.dataBaseType = properties.getProperty("dataBaseType");
    }
    private String getPageSql(Page<?> page, String sql) {
        StringBuffer sqlBuffer = new StringBuffer(sql);
        if ("mysql".equalsIgnoreCase(this.dataBaseType)) {
            return getMysqlPageSql(page, sqlBuffer);
        } else if ("oracle".equalsIgnoreCase(this.dataBaseType)) {
            return getOraclePageSql(page, sqlBuffer);
        }
        return sqlBuffer.toString();
    }
    /** * 获取Mysql数据库的分页查询语句 * @param page 分页对象 * @param sqlBuffer 包含原sql语句的StringBuffer对象 * @return Mysql数据库分页语句 */
    private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {
        //计算第一条记录的位置,Mysql中记录的位置是从0开始的。 指定了页数按优先按页数查询 若没指定按起始位置查询
        int offset = page.getCurrentPage() != 0 ? (page.getCurrentPage() - 1) * page.getPageSize() : page.getOffset();
        sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
        return sqlBuffer.toString();
    }

    /** * 获取Oracle数据库的分页查询语句 * @param page 分页对象 * @param sqlBuffer 包含原sql语句的StringBuffer对象 * @return Oracle数据库的分页查询语句 */
    private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
        //计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
        int offset = (page.getCurrentPage() - 1) * page.getPageSize() + 1;
        StringBuffer pageSql = new StringBuffer();
        pageSql.append(" SELECT * FROM (");
        pageSql.append(" SELECT A.*,ROWNUM RN FROM (");
        pageSql.append(sqlBuffer).append(" ) A ");
        pageSql.append(") B WHERE B.RN<" + (offset + page.getPageSize()) + " AND B.RN >=" + offset + "");
        return pageSql.toString();
    }

    /** * 给当前的参数对象page设置总记录数 * @param page Mapper映射语句对应的参数对象 * @param mappedStatement Mapper映射语句 * @param connection 当前的数据库连接 */
    private void setTotalRecord(Page<?> page, MappedStatement mappedStatement, Connection connection) {
        BoundSql boundSql = mappedStatement.getBoundSql(page);
        String sql = boundSql.getSql();
        String countSql = this.getCountSql(sql);
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
        //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
        //通过connection建立一个countSql对应的PreparedStatement对象。
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(countSql);
            parameterHandler.setParameters(pstmt);
            rs = pstmt.executeQuery();
            if (rs != null && rs.next()) {
                page.setTotalRecord(rs.getInt(1));//总记录数
            }
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                if (pstmt != null) {
                    pstmt.close();
                }
                if (rs != null) {
                    rs.close();
                }
            } catch (Exception e2) {
                e2.printStackTrace();
            }
        }
    }
    /** * 根据原Sql语句获取对应的查询总记录数的Sql语句 * @param sql * @return */
    private String getCountSql(String sql) {
        return "SELECT COUNT(1) FROM(" + sql + ") temp";
    }

}

2.接口声明分页方法

public interface UserRepository {
    ArrayList<User> getByPage(Page<User> page);
}

3.在Mapping.xml处理映射,其中传入参数为自定义的org.exam.page.Page

    <select id="getByPage" parameterType="org.exam.page.Page" resultType="org.exam.domain.User">
        select * from user
    </select>

4.调用例子

    @Test
    public void testGetByPage() {
        Page<User> page=new Page<User>();
        page.setCurrentPage(2);
        page.setPageSize(10);
        ArrayList<User> users=userRepository.getByPage(page);
        System.out.println(users);
    }

本文源码:http://download.csdn.net/detail/xiejx618/8946941