package xuliang.module.common.Utils;


import lombok.extern.slf4j.Slf4j;
import ognl.DefaultClassResolver;
import ognl.MemberAccess;
import ognl.Ognl;
import ognl.OgnlException;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Member;
import java.lang.reflect.ReflectPermission;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLDecoder;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * @Author: xuliang 徐良
 * @Date: 2021/4/8 11:40
 * @Description: 解析xml文件中的sql语句为可执行语句
 * 目前只支持简单的if判断sql拼接
 * @Version v1.0
 */
public class SqlParseUtils
{
    private static final OgnlMemberAccess MEMBER_ACCESS = new OgnlMemberAccess();
    private static final DefaultClassResolver CLASS_RESOLVER = new DefaultClassResolver();
    private static ConcurrentHashMap<String, String> sqlMap = new ConcurrentHashMap<>();
    private static final Map<String, Object> expressionCache = new ConcurrentHashMap<>();
    private static final String STARTTOKEN = "#{";
    private static final String ENDTOKEN = "}";

    public static final Set<String> includeJars = new HashSet<>();

    public static void addScanJarPrefix(String prefix) {
        includeJars.add(prefix.trim());
    }

    static
    {
		// 配置需扫描的jar包路径,读者自行实现
        String needScanMapperJarName = "";
        
            String[] needScanMapperJarNames = needScanMapperJarName.split(";");
            for(String jarName : needScanMapperJarNames)
            {
                addScanJarPrefix(jarName);
            }
        
    }
    /**
     * 扫描mapper文件主入口方法
     * 在系统启动时调用
     * 调用示例:SqlParseUtils.loadMapperFile(this.getClass().getClassLoader());
     * <p>
     * idea开发构建maven web项目时,target/classes目录下不生成xml文件
     * 解决方法:
     * <build>
     *  <resources>
     *      <resource>
     *          <directory>${basedir}/src/main/java</directory>
     *          <includes>
     *              <include>**\/*.xml</include>
     *          </includes>
     *      </resource>
     *  </resources>
     * </build>
     * 原文链接:https://blog.csdn.net/u011213044/article/details/103341929/
     * classLoader:HotSwapClassLoader
     *
     * @param classLoader
     */
    public static void loadMapperFile(ClassLoader classLoader)
    {
        try
        {
            // 加载classpath
            Set<String> classPaths = new HashSet<>();
            Set<String> jarPaths = new HashSet<>();
            findClassPathsAndJars(classPaths, jarPaths, classLoader);
			// 配置需扫描的jar包路径,读者自行实现
            String needScanMapperDir = "";
            if(ObjectUtils.isEmpty(needScanMapperDir))
            {
                // 默认值
                needScanMapperDir = "*/mapper"; // 在idea开发环境中即存放在resources目录下
            }
            // todo 暂时作简单的路径匹配处理
            List<String> needScanMapperDirList = Arrays.asList(needScanMapperDir.replace("*/","").split(";"));
            System.out.println("needScanMapperDirs:{}", needScanMapperDirList);
            // 打包成jar包运行时需扫描jar包路径解析mapper文件
            for (String jarPath : jarPaths)
            {
                if (!isIncludeJar(jarPath))
                {
                    continue;
                }
                
                    System.out.println("scan jarpath : " + jarPath);
                
                scanMapperFileFromJarPath(jarPath, needScanMapperDirList);
            }
            for (String classPath : classPaths)
            {
               
                    System.out.println("scan classpath : " + classPath);
                
                scanMapperFileFromClassPath(classPath, needScanMapperDirList);
            }
        }
        catch (Exception ex)
        {
            log.error("load mapper file fail:" + ex.toString(), ex);
        }
    }

    private static boolean isIncludeJar(String path) {

        String jarName = new File(path).getName().toLowerCase();

        for (String include : includeJars)
        {
            if (jarName.startsWith(include))
            {
                return true;
            }
        }
        return false;
    }

    /**
     * 加载全项目classpath路径
     * @param classPaths
     * @param classLoader
     */
    private static void findClassPathsAndJars(Set<String> classPaths, Set<String> jarPaths, ClassLoader classLoader)
    {
        try
        {
            if (classLoader instanceof URLClassLoader)
            {
                URLClassLoader urlClassLoader = (URLClassLoader) classLoader;
                URL[] urLs = urlClassLoader.getURLs();
                for (URL url : urLs)
                {
                    String path = url.getPath();
                    path = URLDecoder.decode(path, "UTF-8");
                    // path : /d:/xxx
                    if (path.startsWith("/") && path.indexOf(":") == 2)
                    {
                        path = path.substring(1);
                    }
                    if (!path.toLowerCase().endsWith(".jar"))
                    {
                        classPaths.add(new File(path).getCanonicalPath());
                        continue;
                    }
                    jarPaths.add(path);
                }
            }
            ClassLoader parent = classLoader.getParent();
            if (parent != null)
            {
                findClassPathsAndJars(classPaths,jarPaths, parent);
            }
        }
        catch (Exception ex)
        {
            ex.printStackTrace();
        }
    }

    /**
     * 通过扫描指定的mapperDirectory 加载mapper文件
     * @param classPath
     * @param mapperDirectory
     */
    private static void scanMapperFileFromClassPath(String classPath, List<String> mapperDirectory)
    {
        File files[] = new File(classPath).listFiles();
        if (null == files || files.length == 0) return;
        for (File file : files)
        {
            // 如果该目录是指定的需要扫描的目录,则解析该目录下的mapper文件
            if (file.isDirectory() && !mapperDirectory.contains(file.getName()))
            {
                scanMapperFileFromClassPath(file.getAbsolutePath(), mapperDirectory);
            }
            // 不允许存在放xml文件的路径名与其他路径名相同的情况,否则扫描不到mapper文件 一般也不会重复
            else if (file.isDirectory() && mapperDirectory.contains(file.getName()))
            {
                parseMapperFile(file);
            }
        }
    }
    /**
     * 通过扫描指定的mapperDirectory 加载mapper文件
     * @param jarPath
     * @param mapperDirectory
     */
    private static void scanMapperFileFromJarPath(String jarPath, List<String> mapperDirectory)
    {
        JarFile jarFile = null;
        try
        {
            jarFile = new JarFile(jarPath);
            Enumeration<JarEntry> entries = jarFile.entries();
            while (entries.hasMoreElements())
            {
                JarEntry jarEntry = entries.nextElement();
                String entryName = jarEntry.getName();

                if (!jarEntry.isDirectory() && entryName.endsWith(".xml"))
                {
                    String xmlName = entryName.replace("/", ".");
                    InputStream inputStream = jarFile.getInputStream(jarEntry);

                    // 截取xml文件的上一级目录名称,判断是否是需要扫描的目录
                    String mapperName = entryName.substring(0, entryName.lastIndexOf("/"));
                    mapperName = mapperName.substring(mapperName.lastIndexOf("/") + 1);
                    if (mapperDirectory.contains(mapperName))
                    {
                        System.out.println("begin parse {}", xmlName);
                        parseMapperInputStream(inputStream);
                    }
                }
            }
        }
        catch (IOException e1)
        {
        }
        finally
        {
            if (jarFile != null) try
            {
                jarFile.close();
            }
            catch (IOException e)
            {
            }
        }
    }

    /**
     * 加载mapper文件中的sql语句
     * @param mapperDirectory
     */
    private static void parseMapperFile(File mapperDirectory)
    {
        try
        {
            File[] xmlFiles = mapperDirectory.listFiles();
            if(xmlFiles == null || xmlFiles.length==0)
            {
                log.warn("mapperDirectory["+ mapperDirectory.getAbsolutePath() +"] is empty!");
                return;
            }
            for(File xmlFile : xmlFiles)
            {
                if(xmlFile.isDirectory())
                {
                    log.warn("exist directory["+xmlFile.getName()+"] in mapperDirectory["+ mapperDirectory.getAbsolutePath() +"]");
                    continue;
                }
                if(!xmlFile.getName().endsWith(".xml"))
                {
                    log.warn("["+xmlFile.getName()+"] not an xml file in mapperDirectory["+ mapperDirectory.getAbsolutePath() +"]");
                    continue;
                }
                System.out.println("xml file name:{}", xmlFile.getAbsolutePath());
                Document doc = readXmlFile(xmlFile);
                Element rootElement = doc.getRootElement();
                String packageName = rootElement.attributeValue("namespace");
                List<Element> elements = rootElement.elements();
                for (Element element : elements)
                {
                    sqlMap.put(packageName + "." + element.getName(), element.getTextTrim());
                }
            }
        }
        catch(Exception ex)
        {
            log.error("parseMapperFile error:" + ex.toString(), ex);
        }
    }
    /**
     * 加载mapper文件中的sql语句
     * @param inputStream
     */
    private static void parseMapperInputStream(InputStream inputStream)
    {
        try
        {
            SAXReader xmlReader = new SAXReader(); //User.hbm.xml表示你要解析的xml文档
            Document doc = xmlReader.read(inputStream);
            Element rootElement = doc.getRootElement();
            String packageName = rootElement.attributeValue("namespace");
            List<Element> elements = rootElement.elements();
            for (Element element : elements)
            {
                sqlMap.put(packageName + "." + element.getName(), element.getTextTrim());
            }
        }
        catch(Exception ex)
        {
            log.error("parseMapperInputStream error:" + ex.toString(), ex);
        }
    }

    /**
     * 解析sql语句
     * 在这里将sql语句及所需要的入参整理好
     * mapperName sql语句映射Id
     * hashMap 包含sql操作所需要的参数
     * parameters  sql调用时参数准备
     */
    public static String prepareSql(String mapperName, HashMap<String, Object> hashMap, List<Object> parameterList)
    {
        String originalSql = sqlMap.get(mapperName);
        if(ObjectUtils.isEmpty(originalSql))
        {
            return null;
        }
        // 分两步
        // 1、判断语句提取及解析
        originalSql = parseIfNode(originalSql, hashMap);
//        System.out.println("After Convert If Sql:{}", originalSql);
        // 2、参数替换  替换#{} 为?  将#{}里边的key所对应的值放入 parameters(首先先存到list里)
        originalSql = replaceParameter(originalSql, hashMap, parameterList);
//        System.out.println("After Convert Parameter Sql:{}", originalSql);
        return originalSql;
    }

    /**
     * 参数替换  替换#{} 为?  将#{}里边的key所对应的值放入 parameters
     * @param originalSql
     * @param parameterList
     * @return
     */
    private static String replaceParameter(String originalSql, HashMap<String, Object> hashMap, List<Object> parameterList)
    {
        if(originalSql.contains("#{"))
        {
            int startIndex = originalSql.indexOf("#{");
            int endIndex = originalSql.indexOf("}");
            String beforeStr = originalSql.substring(0, startIndex);
            // 参数
            String containsifStr = originalSql.substring(startIndex + 2, endIndex).trim();
            // 准备参数取值
            parameterList.add(hashMap.get(containsifStr));
            // 该"}"后边的sql串
            String afterStr = originalSql.substring(endIndex +1);
            originalSql = beforeStr + "?" + afterStr;
            originalSql = replaceParameter(originalSql, hashMap, parameterList);
        }
        if(originalSql.contains("$["))
        {
            int startIndex = originalSql.indexOf("$[");
            int endIndex = originalSql.indexOf("]");
            String beforeStr = originalSql.substring(0, startIndex);
            // 参数
            String containsifStr = originalSql.substring(startIndex + 2, endIndex).trim();
            // 准备参数取值
            String parameter = (String)hashMap.get(containsifStr);
            // 该"]"后边的sql串
            String afterStr = originalSql.substring(endIndex +1);
            originalSql = beforeStr + parameter + afterStr;
            originalSql = replaceParameter(originalSql, hashMap, parameterList);
        }
        return originalSql;
    }

    /**
     * 处理sql中的if语句
     * @param originalSql
     * @return
     */
    private static String parseIfNode(String originalSql, HashMap<String, Object> hashMap)
    {
        try
        {
            if(originalSql.contains("<if"))
            {
                int startIndex = originalSql.indexOf("<if");
                int endIndex = originalSql.indexOf("</if>");
                String beforeStr = originalSql.substring(0, startIndex);
//            System.out.println("before:" + beforeStr);
                String containsifStr = originalSql.substring(startIndex, endIndex +5);
//            System.out.println("containsifStr:" + containsifStr);
                // 该if节点后边的sql串
                String afterStr = originalSql.substring(endIndex +5);
//            System.out.println("after:" + afterStr);
                // 截取if里边的条件
                String conditionStr = containsifStr.substring(containsifStr.indexOf("<"), containsifStr.indexOf(">") +1);
                // condition示例:(username != null and pwd == '2345') or pwd=='7788'
                conditionStr = conditionStr.substring(conditionStr.indexOf("test=\"")+6 , conditionStr.lastIndexOf("\""));
                // 替换转义字符
                conditionStr = replaceEscapeCharacter(conditionStr);
                Map context = Ognl.createDefaultContext(hashMap, MEMBER_ACCESS, CLASS_RESOLVER, null);
                // 截取if包含的sql语句
                String sqlStr = containsifStr.substring(containsifStr.indexOf(">")+1, containsifStr.indexOf("</if>"));

                if((boolean)Ognl.getValue(parseExpression(conditionStr), context, hashMap))
                {
                    originalSql = beforeStr + sqlStr + afterStr;
                }
                else
                {
                    originalSql = beforeStr + afterStr;
                }
                originalSql = parseIfNode(originalSql, hashMap);
            }
            return originalSql;
        }
        catch (Exception ex)
        {
            log.error("parseIfNode error:" + ex.toString(), ex);
        }
        return null;
    }

    /**
     * 替换转义字符
     * @param conditionStr
     * @return
     */
    private static String replaceEscapeCharacter(String conditionStr)
    {
        conditionStr = conditionStr.replace("&lt;","<") // 小于
                .replace("&gt;",">") // 大于
                .replace("&amp;","&") // 与
                .replace("&apos;","'") // 单引号
                .replace("&quot;","\""); // 双引号
        return conditionStr;
    }
	
	/**
     * 将xml文件流转换为dom对象
     *
     * @param xmlFile
     * @return
     * @throws DocumentException
     */
    public static Document readXmlFile(File xmlFile) throws Exception
    {
        // 读取并解析XML文档
        // SAXReader就是一个管道,用一个流的方式,把xml文件读出来
        SAXReader xmlReader = new SAXReader(); //User.hbm.xml表示你要解析的xml文档
        return xmlReader.read(xmlFile);
    }

    /**
     * 解析if表达式并放入缓存
     * @param expression
     * @return
     * @throws OgnlException
     */
    private static Object parseExpression(String expression) throws OgnlException
    {
        Object node = expressionCache.get(expression);
        if (node == null)
        {
            synchronized (expression.intern())
            {
                node = expressionCache.get(expression);
                if (node == null)
                {
                    Object nodeTmp = Ognl.parseExpression(expression);
                    expressionCache.put(expression, nodeTmp);
                    node = nodeTmp;
                }
            }
        }
        return node;
    }

    public static void main(String[] args) throws OgnlException
    {
        HashMap<String, Object> hashMap = new HashMap<>();
        hashMap.put("MODELTYPE", "station");

        Map context = Ognl.createDefaultContext(hashMap, MEMBER_ACCESS, CLASS_RESOLVER, null);
        System.out.println(Ognl.getValue(parseExpression("MODELTYPE == 'station'"), context, hashMap));
    }

    /**
     * 上边创建OgnlContext
     * 设置成员的访问权限,需要自己实现
     */
    private static class OgnlMemberAccess implements MemberAccess
    {
        private final boolean canControlMemberAccessible;

        public OgnlMemberAccess()
        {
            this.canControlMemberAccessible = canControlMemberAccessible();
        }

        /**
         * Checks whether can control member accessible.
         *
         * @return If can control member accessible, it return {@literal true}
         * @since 3.5.0
         */
        private boolean canControlMemberAccessible()
        {
            try
            {
                SecurityManager securityManager = System.getSecurityManager();
                if (null != securityManager)
                {
                    securityManager.checkPermission(new ReflectPermission("suppressAccessChecks"));
                }
            }
            catch (SecurityException e)
            {
                return false;
            }
            return true;
        }

        @Override
        public Object setup(Map context, Object target, Member member, String propertyName)
        {
            Object result = null;
            if (isAccessible(context, target, member, propertyName))
            {
                AccessibleObject accessible = (AccessibleObject) member;
                if (!accessible.isAccessible())
                {
                    result = Boolean.FALSE;
                    accessible.setAccessible(true);
                }
            }
            return result;
        }

        @Override
        public void restore(Map context, Object target, Member member, String propertyName, Object state)
        {
            // Flipping accessible flag is not thread safe. See #1648
        }

        @Override
        public boolean isAccessible(Map context, Object target, Member member, String propertyName)
        {
            return canControlMemberAccessible;
        }
    }
}

依赖:

<dependency>
    <groupId>dom4j</groupId>
    <artifactId>dom4j</artifactId>
    <version>1.6.1</version>
</dependency>
<dependency>
    <groupId>ognl</groupId>
    <artifactId>ognl</artifactId>
    <version>3.2.19</version>
</dependency>
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐