模拟mybatis解析mapper文件中的sql语句为可执行语句
package xuliang.module.common.Utils;import lombok.extern.slf4j.Slf4j;import ognl.DefaultClassResolver;import ognl.MemberAccess;import ognl.Ognl;import ognl.OgnlException;import org.dom4j.Document;impo
·
package xuliang.module.common.Utils;
import lombok.extern.slf4j.Slf4j;
import ognl.DefaultClassResolver;
import ognl.MemberAccess;
import ognl.Ognl;
import ognl.OgnlException;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Member;
import java.lang.reflect.ReflectPermission;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLDecoder;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
/**
* @Author: xuliang 徐良
* @Date: 2021/4/8 11:40
* @Description: 解析xml文件中的sql语句为可执行语句
* 目前只支持简单的if判断sql拼接
* @Version v1.0
*/
public class SqlParseUtils
{
private static final OgnlMemberAccess MEMBER_ACCESS = new OgnlMemberAccess();
private static final DefaultClassResolver CLASS_RESOLVER = new DefaultClassResolver();
private static ConcurrentHashMap<String, String> sqlMap = new ConcurrentHashMap<>();
private static final Map<String, Object> expressionCache = new ConcurrentHashMap<>();
private static final String STARTTOKEN = "#{";
private static final String ENDTOKEN = "}";
public static final Set<String> includeJars = new HashSet<>();
public static void addScanJarPrefix(String prefix) {
includeJars.add(prefix.trim());
}
static
{
// 配置需扫描的jar包路径,读者自行实现
String needScanMapperJarName = "";
String[] needScanMapperJarNames = needScanMapperJarName.split(";");
for(String jarName : needScanMapperJarNames)
{
addScanJarPrefix(jarName);
}
}
/**
* 扫描mapper文件主入口方法
* 在系统启动时调用
* 调用示例:SqlParseUtils.loadMapperFile(this.getClass().getClassLoader());
* <p>
* idea开发构建maven web项目时,target/classes目录下不生成xml文件
* 解决方法:
* <build>
* <resources>
* <resource>
* <directory>${basedir}/src/main/java</directory>
* <includes>
* <include>**\/*.xml</include>
* </includes>
* </resource>
* </resources>
* </build>
* 原文链接:https://blog.csdn.net/u011213044/article/details/103341929/
* classLoader:HotSwapClassLoader
*
* @param classLoader
*/
public static void loadMapperFile(ClassLoader classLoader)
{
try
{
// 加载classpath
Set<String> classPaths = new HashSet<>();
Set<String> jarPaths = new HashSet<>();
findClassPathsAndJars(classPaths, jarPaths, classLoader);
// 配置需扫描的jar包路径,读者自行实现
String needScanMapperDir = "";
if(ObjectUtils.isEmpty(needScanMapperDir))
{
// 默认值
needScanMapperDir = "*/mapper"; // 在idea开发环境中即存放在resources目录下
}
// todo 暂时作简单的路径匹配处理
List<String> needScanMapperDirList = Arrays.asList(needScanMapperDir.replace("*/","").split(";"));
System.out.println("needScanMapperDirs:{}", needScanMapperDirList);
// 打包成jar包运行时需扫描jar包路径解析mapper文件
for (String jarPath : jarPaths)
{
if (!isIncludeJar(jarPath))
{
continue;
}
System.out.println("scan jarpath : " + jarPath);
scanMapperFileFromJarPath(jarPath, needScanMapperDirList);
}
for (String classPath : classPaths)
{
System.out.println("scan classpath : " + classPath);
scanMapperFileFromClassPath(classPath, needScanMapperDirList);
}
}
catch (Exception ex)
{
log.error("load mapper file fail:" + ex.toString(), ex);
}
}
private static boolean isIncludeJar(String path) {
String jarName = new File(path).getName().toLowerCase();
for (String include : includeJars)
{
if (jarName.startsWith(include))
{
return true;
}
}
return false;
}
/**
* 加载全项目classpath路径
* @param classPaths
* @param classLoader
*/
private static void findClassPathsAndJars(Set<String> classPaths, Set<String> jarPaths, ClassLoader classLoader)
{
try
{
if (classLoader instanceof URLClassLoader)
{
URLClassLoader urlClassLoader = (URLClassLoader) classLoader;
URL[] urLs = urlClassLoader.getURLs();
for (URL url : urLs)
{
String path = url.getPath();
path = URLDecoder.decode(path, "UTF-8");
// path : /d:/xxx
if (path.startsWith("/") && path.indexOf(":") == 2)
{
path = path.substring(1);
}
if (!path.toLowerCase().endsWith(".jar"))
{
classPaths.add(new File(path).getCanonicalPath());
continue;
}
jarPaths.add(path);
}
}
ClassLoader parent = classLoader.getParent();
if (parent != null)
{
findClassPathsAndJars(classPaths,jarPaths, parent);
}
}
catch (Exception ex)
{
ex.printStackTrace();
}
}
/**
* 通过扫描指定的mapperDirectory 加载mapper文件
* @param classPath
* @param mapperDirectory
*/
private static void scanMapperFileFromClassPath(String classPath, List<String> mapperDirectory)
{
File files[] = new File(classPath).listFiles();
if (null == files || files.length == 0) return;
for (File file : files)
{
// 如果该目录是指定的需要扫描的目录,则解析该目录下的mapper文件
if (file.isDirectory() && !mapperDirectory.contains(file.getName()))
{
scanMapperFileFromClassPath(file.getAbsolutePath(), mapperDirectory);
}
// 不允许存在放xml文件的路径名与其他路径名相同的情况,否则扫描不到mapper文件 一般也不会重复
else if (file.isDirectory() && mapperDirectory.contains(file.getName()))
{
parseMapperFile(file);
}
}
}
/**
* 通过扫描指定的mapperDirectory 加载mapper文件
* @param jarPath
* @param mapperDirectory
*/
private static void scanMapperFileFromJarPath(String jarPath, List<String> mapperDirectory)
{
JarFile jarFile = null;
try
{
jarFile = new JarFile(jarPath);
Enumeration<JarEntry> entries = jarFile.entries();
while (entries.hasMoreElements())
{
JarEntry jarEntry = entries.nextElement();
String entryName = jarEntry.getName();
if (!jarEntry.isDirectory() && entryName.endsWith(".xml"))
{
String xmlName = entryName.replace("/", ".");
InputStream inputStream = jarFile.getInputStream(jarEntry);
// 截取xml文件的上一级目录名称,判断是否是需要扫描的目录
String mapperName = entryName.substring(0, entryName.lastIndexOf("/"));
mapperName = mapperName.substring(mapperName.lastIndexOf("/") + 1);
if (mapperDirectory.contains(mapperName))
{
System.out.println("begin parse {}", xmlName);
parseMapperInputStream(inputStream);
}
}
}
}
catch (IOException e1)
{
}
finally
{
if (jarFile != null) try
{
jarFile.close();
}
catch (IOException e)
{
}
}
}
/**
* 加载mapper文件中的sql语句
* @param mapperDirectory
*/
private static void parseMapperFile(File mapperDirectory)
{
try
{
File[] xmlFiles = mapperDirectory.listFiles();
if(xmlFiles == null || xmlFiles.length==0)
{
log.warn("mapperDirectory["+ mapperDirectory.getAbsolutePath() +"] is empty!");
return;
}
for(File xmlFile : xmlFiles)
{
if(xmlFile.isDirectory())
{
log.warn("exist directory["+xmlFile.getName()+"] in mapperDirectory["+ mapperDirectory.getAbsolutePath() +"]");
continue;
}
if(!xmlFile.getName().endsWith(".xml"))
{
log.warn("["+xmlFile.getName()+"] not an xml file in mapperDirectory["+ mapperDirectory.getAbsolutePath() +"]");
continue;
}
System.out.println("xml file name:{}", xmlFile.getAbsolutePath());
Document doc = readXmlFile(xmlFile);
Element rootElement = doc.getRootElement();
String packageName = rootElement.attributeValue("namespace");
List<Element> elements = rootElement.elements();
for (Element element : elements)
{
sqlMap.put(packageName + "." + element.getName(), element.getTextTrim());
}
}
}
catch(Exception ex)
{
log.error("parseMapperFile error:" + ex.toString(), ex);
}
}
/**
* 加载mapper文件中的sql语句
* @param inputStream
*/
private static void parseMapperInputStream(InputStream inputStream)
{
try
{
SAXReader xmlReader = new SAXReader(); //User.hbm.xml表示你要解析的xml文档
Document doc = xmlReader.read(inputStream);
Element rootElement = doc.getRootElement();
String packageName = rootElement.attributeValue("namespace");
List<Element> elements = rootElement.elements();
for (Element element : elements)
{
sqlMap.put(packageName + "." + element.getName(), element.getTextTrim());
}
}
catch(Exception ex)
{
log.error("parseMapperInputStream error:" + ex.toString(), ex);
}
}
/**
* 解析sql语句
* 在这里将sql语句及所需要的入参整理好
* mapperName sql语句映射Id
* hashMap 包含sql操作所需要的参数
* parameters sql调用时参数准备
*/
public static String prepareSql(String mapperName, HashMap<String, Object> hashMap, List<Object> parameterList)
{
String originalSql = sqlMap.get(mapperName);
if(ObjectUtils.isEmpty(originalSql))
{
return null;
}
// 分两步
// 1、判断语句提取及解析
originalSql = parseIfNode(originalSql, hashMap);
// System.out.println("After Convert If Sql:{}", originalSql);
// 2、参数替换 替换#{} 为? 将#{}里边的key所对应的值放入 parameters(首先先存到list里)
originalSql = replaceParameter(originalSql, hashMap, parameterList);
// System.out.println("After Convert Parameter Sql:{}", originalSql);
return originalSql;
}
/**
* 参数替换 替换#{} 为? 将#{}里边的key所对应的值放入 parameters
* @param originalSql
* @param parameterList
* @return
*/
private static String replaceParameter(String originalSql, HashMap<String, Object> hashMap, List<Object> parameterList)
{
if(originalSql.contains("#{"))
{
int startIndex = originalSql.indexOf("#{");
int endIndex = originalSql.indexOf("}");
String beforeStr = originalSql.substring(0, startIndex);
// 参数
String containsifStr = originalSql.substring(startIndex + 2, endIndex).trim();
// 准备参数取值
parameterList.add(hashMap.get(containsifStr));
// 该"}"后边的sql串
String afterStr = originalSql.substring(endIndex +1);
originalSql = beforeStr + "?" + afterStr;
originalSql = replaceParameter(originalSql, hashMap, parameterList);
}
if(originalSql.contains("$["))
{
int startIndex = originalSql.indexOf("$[");
int endIndex = originalSql.indexOf("]");
String beforeStr = originalSql.substring(0, startIndex);
// 参数
String containsifStr = originalSql.substring(startIndex + 2, endIndex).trim();
// 准备参数取值
String parameter = (String)hashMap.get(containsifStr);
// 该"]"后边的sql串
String afterStr = originalSql.substring(endIndex +1);
originalSql = beforeStr + parameter + afterStr;
originalSql = replaceParameter(originalSql, hashMap, parameterList);
}
return originalSql;
}
/**
* 处理sql中的if语句
* @param originalSql
* @return
*/
private static String parseIfNode(String originalSql, HashMap<String, Object> hashMap)
{
try
{
if(originalSql.contains("<if"))
{
int startIndex = originalSql.indexOf("<if");
int endIndex = originalSql.indexOf("</if>");
String beforeStr = originalSql.substring(0, startIndex);
// System.out.println("before:" + beforeStr);
String containsifStr = originalSql.substring(startIndex, endIndex +5);
// System.out.println("containsifStr:" + containsifStr);
// 该if节点后边的sql串
String afterStr = originalSql.substring(endIndex +5);
// System.out.println("after:" + afterStr);
// 截取if里边的条件
String conditionStr = containsifStr.substring(containsifStr.indexOf("<"), containsifStr.indexOf(">") +1);
// condition示例:(username != null and pwd == '2345') or pwd=='7788'
conditionStr = conditionStr.substring(conditionStr.indexOf("test=\"")+6 , conditionStr.lastIndexOf("\""));
// 替换转义字符
conditionStr = replaceEscapeCharacter(conditionStr);
Map context = Ognl.createDefaultContext(hashMap, MEMBER_ACCESS, CLASS_RESOLVER, null);
// 截取if包含的sql语句
String sqlStr = containsifStr.substring(containsifStr.indexOf(">")+1, containsifStr.indexOf("</if>"));
if((boolean)Ognl.getValue(parseExpression(conditionStr), context, hashMap))
{
originalSql = beforeStr + sqlStr + afterStr;
}
else
{
originalSql = beforeStr + afterStr;
}
originalSql = parseIfNode(originalSql, hashMap);
}
return originalSql;
}
catch (Exception ex)
{
log.error("parseIfNode error:" + ex.toString(), ex);
}
return null;
}
/**
* 替换转义字符
* @param conditionStr
* @return
*/
private static String replaceEscapeCharacter(String conditionStr)
{
conditionStr = conditionStr.replace("<","<") // 小于
.replace(">",">") // 大于
.replace("&","&") // 与
.replace("'","'") // 单引号
.replace(""","\""); // 双引号
return conditionStr;
}
/**
* 将xml文件流转换为dom对象
*
* @param xmlFile
* @return
* @throws DocumentException
*/
public static Document readXmlFile(File xmlFile) throws Exception
{
// 读取并解析XML文档
// SAXReader就是一个管道,用一个流的方式,把xml文件读出来
SAXReader xmlReader = new SAXReader(); //User.hbm.xml表示你要解析的xml文档
return xmlReader.read(xmlFile);
}
/**
* 解析if表达式并放入缓存
* @param expression
* @return
* @throws OgnlException
*/
private static Object parseExpression(String expression) throws OgnlException
{
Object node = expressionCache.get(expression);
if (node == null)
{
synchronized (expression.intern())
{
node = expressionCache.get(expression);
if (node == null)
{
Object nodeTmp = Ognl.parseExpression(expression);
expressionCache.put(expression, nodeTmp);
node = nodeTmp;
}
}
}
return node;
}
public static void main(String[] args) throws OgnlException
{
HashMap<String, Object> hashMap = new HashMap<>();
hashMap.put("MODELTYPE", "station");
Map context = Ognl.createDefaultContext(hashMap, MEMBER_ACCESS, CLASS_RESOLVER, null);
System.out.println(Ognl.getValue(parseExpression("MODELTYPE == 'station'"), context, hashMap));
}
/**
* 上边创建OgnlContext
* 设置成员的访问权限,需要自己实现
*/
private static class OgnlMemberAccess implements MemberAccess
{
private final boolean canControlMemberAccessible;
public OgnlMemberAccess()
{
this.canControlMemberAccessible = canControlMemberAccessible();
}
/**
* Checks whether can control member accessible.
*
* @return If can control member accessible, it return {@literal true}
* @since 3.5.0
*/
private boolean canControlMemberAccessible()
{
try
{
SecurityManager securityManager = System.getSecurityManager();
if (null != securityManager)
{
securityManager.checkPermission(new ReflectPermission("suppressAccessChecks"));
}
}
catch (SecurityException e)
{
return false;
}
return true;
}
@Override
public Object setup(Map context, Object target, Member member, String propertyName)
{
Object result = null;
if (isAccessible(context, target, member, propertyName))
{
AccessibleObject accessible = (AccessibleObject) member;
if (!accessible.isAccessible())
{
result = Boolean.FALSE;
accessible.setAccessible(true);
}
}
return result;
}
@Override
public void restore(Map context, Object target, Member member, String propertyName, Object state)
{
// Flipping accessible flag is not thread safe. See #1648
}
@Override
public boolean isAccessible(Map context, Object target, Member member, String propertyName)
{
return canControlMemberAccessible;
}
}
}
依赖:
<dependency> <groupId>dom4j</groupId> <artifactId>dom4j</artifactId> <version>1.6.1</version> </dependency> <dependency> <groupId>ognl</groupId> <artifactId>ognl</artifactId> <version>3.2.19</version> </dependency>

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)