来公司一个月了,时间虽然不长,但是我感觉自己还是学到了不少的东西,至少在学校从没有过这样的感觉。并且我也从现在开始写博客了,希望博友能多来浏览,提出批评指正。先来看看我的第一篇博客。
事情起源于修改项目的UI(这个过程就不多说了)。我的指导人让我把项目的生成模版修改一下,适应我们的新样式,可是模版是freemarker写的,我完全都不会,没办法,只能自己去摸索了,费了九牛二虎之力终于修改好了。闲来无事,我就想这玩意好像还挺有用的,要不我也自己弄一个,于是就开始了编写我自己的代码生成器了。
好了,“废话”不多说,直接上代码。
自定义表字段结构和表结构
public class MyField {
private String field_name;
private String sql_type;
private String java_type;
private String field_comment;
private boolean is_primary;
}
public class MyTable { //表名 private String tableName; //普通字段集合,一个字段包括了:字段名,SQL类型,Java类型,注释,是否主键标志 private List<MyField> common_fields; //主键字段集合 private List<MyField> key_fields; //字段中是否有日期型的,如果有,那么在生成实体类时,就需要导java.util.Date包了 public boolean date_flag = false; //字段中是否有numeric和decimal类型的,如果有,那么,在生成实体类时就需要导java.math.BigDecimal包了 public boolean math_flag = false; //生成实体类,mapper接口等等时的基名 public String package_name_base; }
工具类(里面有一部分是我最开始想纯粹用java代码来生成想要的文件,后来发现这种方法完全是硬编码,所以放弃)
public class DataBaseUtil { private static String jdbc_url ; private static String jdbc_driver ; private static String jdbc_user ; private static String jdbc_password ; private static Integer jdbc_poolsize; private static List<Connection> connectionPool; /** * 静态代码块,用于读取db.properties,初始化和数据库连接相关的信息 */ static{ Properties properties = new Properties(); InputStream input = null; String path = "db.properties"; input = Thread.currentThread().getContextClassLoader().getResourceAsStream(path); try { properties.load(input); jdbc_driver = properties.getProperty("jdbc_driver"); jdbc_url = properties.getProperty("jdbc_url"); jdbc_user = properties.getProperty("jdbc_user"); jdbc_password = properties.getProperty("jdbc_password"); jdbc_poolsize = Integer.parseInt(properties.getProperty("jdbc_poolsize")); connectionPool = new ArrayList<Connection>(jdbc_poolsize); Class.forName(jdbc_driver); int i; for(i = 0 ; i < jdbc_poolsize ; i++){ Connection connection = DriverManager.getConnection(jdbc_url, jdbc_user, jdbc_password); connectionPool.add(connection); } } catch (IOException e) { System.out.println("加载属性文件失败,请检查db.properties文件是否在src下面"); System.exit(-1); } catch (ClassNotFoundException e) { StringBuffer buffer = new StringBuffer(); buffer.append("加载数据库驱动出错,请检查:\r\n"); buffer.append("(1)db.properties中的jdbc_driver是否写错?\r\n"); buffer.append("(2)数据库驱动JAR包是否已导入?"); System.out.println(buffer.toString()); System.exit(-1); } catch (SQLException e) { StringBuffer buffer = new StringBuffer(); buffer.append("Oh,My God! 数据库连接失败,请检查:\r\n"); buffer.append("(1)数据库服务是否已开启?\r\n"); buffer.append("(2)db.properties中的jdbc_url是否写错?\r\n"); buffer.append("(3)db.properties中的jdbc_user是否写错?\r\n"); buffer.append("(4)db.properties中的jdbc_password是否写错?\r\n"); System.out.println(buffer.toString()); System.exit(-1); } finally{ try { if(input != null) input.close(); } catch (IOException e) { e.printStackTrace(); } } } /** * @return 获取数据库连接,可能返回null */ public static Connection getConnection(){ if(connectionPool != null){ if(connectionPool.size() > 0){ return connectionPool.remove(connectionPool.size() - 1); }else{ Connection connection = null; try { Class.forName(jdbc_driver); connection = DriverManager.getConnection(jdbc_url, jdbc_user, jdbc_password); return connection; } catch (ClassNotFoundException e) { e.printStackTrace(); return connection; } catch (SQLException e) { e.printStackTrace(); return connection; } } }else return null; } /** * @return 返回某个数据库下所有的表结构 * @throws Exception */ public static List<MyTable> getAllTables(){ List<MyTable> tables = null; Connection connection = getConnection(); if(connection != null){ PreparedStatement statement = null; ResultSet connectionRS = null , queryRS = null; tables = new ArrayList<MyTable>(); try { DatabaseMetaData metaData = connection.getMetaData(); connectionRS = metaData.getTables(null, null, null, new String[]{"TABLE"}); while(connectionRS.next()){ String tableName = connectionRS.getString(3); String querySQL = "show full columns from " + tableName; statement = connection.prepareStatement(querySQL); queryRS = statement.executeQuery(); MyTable myTable = new MyTable(); List<MyField> key_fields = new ArrayList<MyField>(); List<MyField> common_fields = new ArrayList<MyField>(); //统计主键和普通字段 while(queryRS.next()){ String field_name = queryRS.getString("Field"); //注意:得到的Type可能是varchar(n),也可能是datetime String sql_type = queryRS.getString("Type"); String key = queryRS.getString("Key"); String commont = queryRS.getString("Commont"); String java_type = getJavaTypeFromSQLType(sql_type); //设置字段相关信息 MyField field = new MyField(); field.setField_name(field_name); field.setSql_type(sql_type); field.setJava_type(java_type); field.setField_comment(commont); if(key != null && key.equalsIgnoreCase("pri")){ field.setIs_primary(true); key_fields.add(field); }else{ field.setIs_primary(false); common_fields.add(field); } //是否需要导java.math.BigDecimal包 if(java_type.equalsIgnoreCase("bigdecimal")) myTable.setMath_flag(true); if(java_type.equalsIgnoreCase("date")||java_type.equalsIgnoreCase("time")||java_type.equalsIgnoreCase("timestamp")) myTable.setDate_flag(true); } //设置表相关的信息 myTable.setTableName(tableName); myTable.setKey_fields(key_fields); myTable.setCommon_fields(common_fields); tables.add(myTable); } return tables; } catch (SQLException e) { e.printStackTrace(); return tables; } finally{ try { if(queryRS != null) queryRS.close(); if(connectionRS != null) connectionRS.close(); if(statement != null) statement.close(); if(connection != null) connection.close(); } catch (SQLException e) { e.printStackTrace(); } } }else{ return tables; } } /** * * @param tableName * @return 返回表名称为 tableName 的表的结构 * @throws Exception */ public static MyTable getTable(String tableName){ Connection connection = getConnection(); String sql = "show full columns from " + tableName; PreparedStatement preparedStatement = null; ResultSet rs = null; MyTable myTable = null; if(connection != null){ try { preparedStatement = connection.prepareStatement(sql); rs = preparedStatement.executeQuery(); myTable = new MyTable(); List<MyField> commonFieldList = new ArrayList<MyField>(); List<MyField> keyFieldList = new ArrayList<MyField>(); //获取表的结构 while(rs.next()){ String field_name = rs.getString("Field"); //注意:得到的Type可能是varchar(n),也可能是datetime String sql_type = rs.getString("Type"); String key = rs.getString("Key"); String commont = rs.getString("Comment"); String java_type = getJavaTypeFromSQLType(sql_type); //设置字段相关信息 MyField field = new MyField(); field.setField_name(field_name); field.setSql_type(sql_type); field.setJava_type(java_type); field.setField_comment(commont); if(key != null && key.equalsIgnoreCase("pri")){ field.setIs_primary(true); keyFieldList.add(field); }else{ field.setIs_primary(false); commonFieldList.add(field); } //是否需要导java.math.BigDecimal包 if(java_type.equalsIgnoreCase("bigdecimal")) myTable.setMath_flag(true); if(java_type.equalsIgnoreCase("date")||java_type.equalsIgnoreCase("time")||java_type.equalsIgnoreCase("timestamp")) myTable.setDate_flag(true); } myTable.setTableName(tableName); myTable.setCommon_fields(commonFieldList); myTable.setKey_fields(keyFieldList); return myTable; } catch (SQLException e) { e.printStackTrace(); return myTable; } finally{ try { if(rs != null) rs.close(); if(preparedStatement != null) preparedStatement.close(); if(connection != null) connection.close(); } catch (SQLException e) { e.printStackTrace(); } } }else{ return myTable; } } /** * * @param sqlType * @return SQL类型得到java类型 */ private static String getJavaTypeFromSQLType(String sqlType){ String javaType = null; int index = sqlType.indexOf("("); if(index != -1) sqlType = sqlType.substring(0, index); if(sqlType.equalsIgnoreCase("VARCHAR")||sqlType.equalsIgnoreCase("CHAR")||sqlType.contains("TEXT")) javaType = "String"; else if(sqlType.equalsIgnoreCase("NUMERIC")||sqlType.equalsIgnoreCase("DECIMAL")) javaType = "BigDecimal"; else if(sqlType.equalsIgnoreCase("BIT")) javaType = "boolean"; else if(sqlType.equalsIgnoreCase("TINYINT")) javaType = "byte"; else if(sqlType.equalsIgnoreCase("SAMLLINT")) javaType = "short"; else if(sqlType.equalsIgnoreCase("INTEGER")||sqlType.equalsIgnoreCase("int")||sqlType.equalsIgnoreCase("mediumint")) javaType = "int"; else if(sqlType.equalsIgnoreCase("BIGINT")) javaType = "long"; else if(sqlType.equalsIgnoreCase("REAL")) javaType = "float"; else if(sqlType.equalsIgnoreCase("FLOAT")||sqlType.equalsIgnoreCase("double")) javaType = "double"; else if(sqlType.equalsIgnoreCase("binary")||sqlType.equalsIgnoreCase("varbinary")||sqlType.equalsIgnoreCase("longvarbinary")) javaType = "byte[]"; else if(sqlType.equalsIgnoreCase("date")) javaType = "Date"; else if(sqlType.equalsIgnoreCase("time")) javaType = "Time"; else if(sqlType.equalsIgnoreCase("datetime")||sqlType.equalsIgnoreCase("timestamp")) javaType = "Timestamp"; return javaType; } /** * * @param str * @return 把一个字符串首字母大写 */ public static String capitalFirstChar(String str){ if(str == null || str.trim().equals("")) return str; else{ char[] charArray = str.toCharArray(); if(charArray[0] >= 'a' && charArray[0] <= 'z'){ charArray[0] = (char) (charArray[0] - 32); return String.valueOf(charArray); } else return str; } } /** * * @param packageName * @param tableName * @return 生成实体类,返回一个包含实体类内容的StringBuffer,只需要把这个StringBuffer写到一个.java文件即可 */ public static StringBuffer generateEntity(String packageName, String tableName){ MyTable myTable = getTable(tableName); StringBuffer buffer = new StringBuffer(); buffer.append("package ").append(packageName).append(";\r\n"); if(myTable.math_flag) buffer.append("import java.math.*;\r\n"); if(myTable.date_flag) buffer.append("import java.util.Date;\r\n"); buffer.append("\r\n"); buffer.append("public class ").append(capitalFirstChar(tableName)).append(" {\r\n"); StringBuffer contentBuffer = generateSetAndGet(myTable); buffer.append(contentBuffer); buffer.append("\r\n}"); return buffer; } /** * * @param myTable * @return 生成实体类的主体内容,包括属性声明,以及get和set方法 */ private static StringBuffer generateSetAndGet(MyTable myTable){ StringBuffer buffer = new StringBuffer(); //思想:先生成属性声明,再生成相应的get和set方法 //1.生成字段属性声明 myTable.getCommon_fields().addAll(myTable.getCommon_fields()); for(MyField field : myTable.getCommon_fields()){ String field_name = field.getField_name(); String field_type = field.getJava_type(); buffer.append("\tprivate " + field_type + " " + field_name + ";\r\n"); } /*//2.生成普通字段属性声明 for(MyField field : myTable.getCommon_fields()){ String field_name = field.getField_name(); String field_type = field.getJava_type(); buffer.append("\tprivate " + field_type + " " + field_name + ";\r\n"); } //3.生成主键字段属性的getter和setter for(MyField field : myTable.getKey_fields()){ String field_name = field.getField_name(); String field_type = field.getField_type(); //生成set方法 buffer.append("\tpublic void set" + capitalFirstChar(field_name) + "(" + field_type + " " + field_name + "){\r\n"); buffer.append("\t\tthis.").append(field_name).append(" = ").append(field_name).append(";\r\n"); buffer.append("\t}").append("\r\n"); //生成get方法 buffer.append("\tpublic ").append(field_type).append(" get").append(capitalFirstChar(field_name)).append("(){\r\n"); buffer.append("\t\treturn this.").append(field_name).append(";\r\n"); buffer.append("\t}").append("\r\n"); }*/ //4.生成普通字段属性的getter和setter for(MyField field : myTable.getCommon_fields()){ String field_name = field.getField_name(); String field_type = field.getJava_type(); //生成set方法 buffer.append("\tpublic void set" + capitalFirstChar(field_name) + "(" + field_type + " " + field_name + "){\r\n"); buffer.append("\t\tthis.").append(field_name).append(" = ").append(field_name).append(";\r\n"); buffer.append("\t}").append("\r\n"); //生成get方法 buffer.append("\tpublic ").append(field_type).append(" get").append(capitalFirstChar(field_name)).append("(){\r\n"); buffer.append("\t\treturn this.").append(field_name).append(";\r\n"); buffer.append("\t}").append("\r\n"); } return buffer; } /** * * @param myTable * @return 主键作为参数时,返回的参数列表:如果主键有多个,那么参数列表就是一个实体类,否则,就是主键所对应的属性字段为参数 */ private static StringBuffer generatePrimaryKeyParam(MyTable myTable){ StringBuffer buffer = new StringBuffer(); List<MyField> key_fields = myTable.getKey_fields(); if(key_fields.size() >= 2){ String tableName = myTable.getTableName(); String entityName = capitalFirstChar(tableName); buffer.append(entityName).append(" ").append(tableName); }else{ MyField key_field = key_fields.get(0); String field_type = key_field.getJava_type(); String field_name = key_field.getField_name(); buffer.append(field_type).append(" ").append(field_name); } return buffer; } /** * * @param packageName mapper接口的包名 * @param tableName 生成哪张表的mapper接口 * @return 生成mapper接口,包含下列方法: * (1)插入一条记录 * (2)查询记录的总数 * (3)根据主键来删除一条记录、根据主键来更新一条记录、根据主键来查询记录 */ public static StringBuffer generateMapperInterface(String packageName, String tableName){ MyTable myTable = getTable(tableName); StringBuffer buffer = new StringBuffer(); buffer.append("package ").append(packageName).append(";\r\n"); //import 实体类 buffer.append("import 实体类;").append("\r\n\r\n"); buffer.append("public interface ").append(capitalFirstChar(tableName)).append("Mapper{\r\n"); buffer.append("\r\n"); //插入一条记录 buffer.append("\t"); buffer.append("public void insertRecord(").append(capitalFirstChar(tableName)).append(" ").append(tableName).append(");"); buffer.append("\r\n"); //查询记录的总数 buffer.append("\t"); buffer.append("public int queryRecordCount(").append(")"); buffer.append("\r\n"); //先判断是否存在主键,如果不存在主键,则不生成和主键相关的方法 List<MyField> key_fields = myTable.getKey_fields(); if(key_fields != null){ //参数列表 StringBuffer params = generatePrimaryKeyParam(myTable); //根据主键来删除一条记录 buffer.append("\t"); buffer.append("public void deleteByPrimaryKey("); buffer.append(params); buffer.append(");"); buffer.append("\r\n"); //根据主键来更新记录 buffer.append("\t"); buffer.append("public void updateByPrimaryKey("); buffer.append(params); buffer.append(");"); buffer.append("\r\n"); //根据主键来查询 buffer.append("\t"); buffer.append("public void selectByPrimaryKey("); buffer.append(params); buffer.append(");"); buffer.append("\r\n"); } buffer.append("\r\n}"); return buffer; } }
生成器类代码
public class MyGenerator {
public static void generate(String basePackageName , String tableName) throws IOException{
String projectPath = System.getProperty("user.dir");
String src_path = projectPath + File.separator + "src";
File file = new File(src_path);
if(basePackageName.indexOf(".") != -1){
String[] packageName_part = basePackageName.split(".");
for(String part : packageName_part){
file = new File(file,part);
if(!file.exists())
file.mkdirs();
}
}else{
file = new File(file , basePackageName);
if(!file.exists())
file.mkdirs();
}
String template_path = projectPath + File.separator + "config" + File.separator + "template" ;
File templateDirtory = new File(template_path);
Configuration configuration = new Configuration();
configuration.setDefaultEncoding("UTF-8");
configuration.setDirectoryForTemplateLoading(templateDirtory);
configuration.setObjectWrapper(new DefaultObjectWrapper());
File[] template_files = templateDirtory.listFiles();
for(File template_file : template_files){
//得到模版文件名xxx.ftl
String file_name = template_file.getName();
int index = file_name.indexOf(".");
String template_name = file_name.substring(0, index);
File target_file_dir = null , target_file = null;
String target_file_name = null;
//如果是po.ftl
if(template_name.equalsIgnoreCase("po")){
target_file_dir = new File(file , "po");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = DataBaseUtil.capitalFirstChar(tableName) + ".java";
}else if(template_name.equalsIgnoreCase("dao")){
//如果是dao.ftl,则生成mapper接口
target_file_dir = new File(file , "mapper");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = DataBaseUtil.capitalFirstChar(tableName) + "Mapper.java";
} else if(template_name.equalsIgnoreCase("mapper")){
//如果是mapper.ftl,则生成mapper.xml
target_file_dir = new File(file , "mapper");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = DataBaseUtil.capitalFirstChar(tableName) + "Mapper.xml";
} else if(template_name.equalsIgnoreCase("service")){
//如果是service.ftl,则生成service接口
target_file_dir = new File(file , "service");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = DataBaseUtil.capitalFirstChar(tableName) + "Service.java";
}else if(template_name.equalsIgnoreCase("serviceImpl")){
//如果是serviceImpl.ftl,则生成service接口实现类
target_file_dir = new File(file , "serviceImpl");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = DataBaseUtil.capitalFirstChar(tableName) + "ServiceImpl.java";
}else if(template_name.equalsIgnoreCase("controller")){
//如果是controller.ftl,则生成controller
target_file_dir = new File(file , "controller");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = DataBaseUtil.capitalFirstChar(tableName) + "Controller.java";
} else if(template_name.equalsIgnoreCase("mybatis")){
//如果是mybatis.ftl,则生成mybatis核心配置文件sqlMapConfig.xml
File config_source_folder = new File(projectPath);
config_source_folder = new File(config_source_folder , "config");
if(!config_source_folder.exists())
config_source_folder.mkdirs();
target_file_dir = new File(config_source_folder , "mybatis");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = "sqlMapConfig.xml";
} else if(template_name.equalsIgnoreCase("spring")){
//如果是spring.ftl,则生成spring核心配置文件spring.xml
File config_source_folder = new File(projectPath);
config_source_folder = new File(config_source_folder , "config");
if(!config_source_folder.exists())
config_source_folder.mkdirs();
target_file_dir = new File(config_source_folder , "spring");
if(!target_file_dir.exists())
target_file_dir.mkdirs();
target_file_name = "spring.xml";
}
target_file = new File(target_file_dir , target_file_name);
if(target_file.exists())
target_file.delete();
target_file.createNewFile();
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(target_file)));
Template template = configuration.getTemplate(file_name);
Map<String , Object> paramMap = new HashMap<String , Object>();
MyTable table = DataBaseUtil.getTable(tableName);
table.setPackage_name_base(basePackageName);
paramMap.put("table", table);
try {
template.process(paramMap, writer);
} catch (TemplateException e) {
e.printStackTrace();
} finally{
if(writer != null)
writer.close();
}
}
}
}
我把我写的模版文件放到config下面的template目录下(代码中也可以看出来),这里只贴出部分模版文件。
<#-- 生成实体类 -->
package ${table.package_name_base}.po;
<#if table.date_flag>
import java.util.Date;
</#if>
<#if table.math_flag>
import java.math.BigDecimal;
</#if>
public class ${table.tableName?cap_first} {
<#-- 生成字段属性 -->
<#list table.common_fields + table.key_fields as field>
private ${field.java_type} ${field.field_name};
</#list>
<#-- 生成字段get方法 -->
<#list table.common_fields + table.key_fields as field>
public ${field.java_type} get${field.field_name?cap_first}(){
return this.${field.field_name};
}
</#list>
<#-- 生成字段set方法 -->
<#list table.common_fields + table.key_fields as field>
public void set${field.field_name?cap_first}(${field.java_type} ${field.field_name}){
this.${field.field_name} = ${field.field_name};
}
</#list>
<#-- 生成mapper.xml映射文件,这里使用的mybatis版本是3.2.2 --> <?xml version="1.0" encoding="UTF-8" ?> <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper namespace="${table.package_name_base}.mapper.${table.tableName?cap_first}Mapper"> <#-- 查询表中记录的条数 --> <select id="queryRecordCount" resultType="java.lang.Integer"> select count(*) from ${table.tableName} </select> <#-- 生成插入记录的方法 --> <insert id="insertRecord" parameterType="${table.package_name_base}.po.${table.tableName?cap_first}" resultType="java.lang.Boolean"> insert into ${table.tableName} <#list (table.common_fields + table.key_fields) as field> <#if field_index == 0> (${field.field_name} <#elseif field_index == ((table.common_fields + table.key_fields)?size -1)> ,${field.field_name}) <#else> ,${field.field_name} </#if> </#list> values <#list (table.common_fields + table.key_fields) as field> <#if field_index ==0> (${"#"}{${field.field_name}}, <#elseif field_index == ((table.common_fields + table.key_fields)?size -1)> ,${"#"}{${field.field_name}}) <#else> ,${"#"}{${field.field_name}} </#if> </#list> </insert> <#if table.key_fields?size != 0> <#-- 主键只有一个时 --> <#if table.key_fields?size == 1> <update id="updateByPrimary" parameterType="${table.key_fields[0].java_type}" resultType="java.lang.Boolean"> update ${table.tableName} set <#list table.common_fields as field> <#if field_index == (table.common_fields?size -1)> ${field.field_name} = ${"#"}{${field.field_name}} <#else> ${field.field_name} = ${"#"}{${field.field_name}}, </#if> </#list> where ${table.key_fields[0].field_name} = ${"#"}{${table.key_fields[0].field_name}} </update> <#else> <update id="updateByPrimary" parameterType="${table.package_name_base}.po.${table.tableName?cap_first}" resultType="java.lang.Boolean"> update ${table.tableName} set <#list table.common_fields as field> <#if field_index = (table.common_fields?size -1)> ${field.field_name} = ${"#"}{${field.field_name}} <#else> ${field.field_name} = ${"#"}{${field.field_name}}, </#if> </#list> where <#list table.key_fields as field> <#if field_index = (table.key_fields?size -1)> ${field.field_name} = ${"#"}{${field.field_name}} <#else> ${field.field_name} = ${"#"}{${field.field_name}} and </#if> </#list> ${table.key_fields[0].field_name} = ${"#"}{${table.key_fields[0].field_name}} </update> </#if> </#if> </mapper>
执行代码生成器之前的项目结构
执行代码生成器,会看到效果(有错误是因为我没有导入相关的jar包)
代码写得有点简单,没有什么技术点,但是能够让我在以后的学习和工作中减少不必要的无休止的写po类,增删改查这些让人头疼的事。
如果过程中有什么错误或是不足,请指出,谢谢。