手写spring(简易版)

时间:2024-06-25 13:37:20

本文版权归 远方的风lyh和博客园共有,欢迎转载,但须保留此段声明,并给出原文链接,谢谢合作如有错误之处忘不吝批评指正!

理解Spring本质:

    相信之前在使用spring的时候大家都配置web.xml文件、会配置spring,(如下)配置其实就是一个Servlet,DispatcherServlet源码中,它(父类)重写了 HttpServlet接口,所有的请求将交给 DispatcherServlet来处理了    <servlet>

        <servlet-name>spring-mvc</servlet-name>
<servlet-class>org.springframework.web.servlet.DispatcherServlet</servlet-class>
<init-param>
<param-name>contextConfigLocation</param-name>
<param-value>WEB-INF/spring/spring-mvc.xml</param-value>
</init-param>
<load-on-startup>1</load-on-startup>
<async-supported>true</async-supported>
</servlet>
<servlet-mapping>
<servlet-name>spring-mvc</servlet-name>
<url-pattern>/</url-pattern>

</servlet-mapping>

手写spring(简易版)

配置

    web.xm: 配置一个servlet 并接收所有请求

<!DOCTYPE web-app PUBLIC
"-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN"
"http://java.sun.com/dtd/web-app_2_3.dtd" > <web-app>
<display-name>Archetype Created Web Application</display-name>
<servlet>
<servlet-name>MySpringMVC</servlet-name>
<servlet-class>cn.lyh.mySpring.MyDispatcherServlet</servlet-class>
<init-param>
<param-name>contextConfigLocation</param-name>
<param-value>context.properties</param-value>
</init-param>
<load-on-startup>1</load-on-startup>
</servlet> <servlet-mapping>
<servlet-name>MySpringMVC</servlet-name>
<url-pattern>/*</url-pattern>
</servlet-mapping>
</web-app>

    context.properties:

#包扫描
scan.package=cn.lyh.mySpringTest

注解类

手写spring(简易版)

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyAutowired {
String value() default "";
} @Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyController {
String value() default "";
} @Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestMapping {
String value() default "";
} @Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestParam {
/**
* 表示参数的别名,必填
* @return
*/
String value(); } @Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyResponseAdvice {
} @Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyResponseBody {
} @Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyService {
String value() default "";
}

MyDispacherServlet(核心实现):

    MyDispacherServlet实现了HttpServlet 并复写doGet、doPost、init 方法

·

package cn.lyh.mySpring;

import cn.lyh.mySpring.Handler.ResponseBodyHandler;
import cn.lyh.mySpring.annotation.*;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import org.apache.log4j.Logger; import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URL;
import java.util.*; /***
*dispatcherServlet
* @author lyh
*/
public class MyDispatcherServlet extends HttpServlet {
/***配置***/
private Properties contextConfig = new Properties();
/***扫描的类名列表****/
private List<String> classNames = new ArrayList<>();
/***ioc容器 存放实例****/
private Map<String, Object> ioc = new HashMap<>();
/***url映射****/
private Map<String, Method> handlerMapping = new HashMap<>();
private static Logger logger = Logger.getLogger(MyDispatcherServlet.class);
/***返回处理器****/
private ResponseBodyHandler responseBodyHandler; @Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req, resp);
} @Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doDispatcherServlet(req, resp);
} /****
* 加载启动
* @param config
* @throws ServletException
*/
@Override
public void init(ServletConfig config) throws ServletException {
String contextConfigLocation = config.getInitParameter("contextConfigLocation");
try {
initMyDispatcherServlet(contextConfigLocation);
} catch (Exception e) {
e.printStackTrace();
throw new ServletException(e.getMessage());
}
} /***
* url请求映射到具体方法
* @param request
* @param response
*/
private void doDispatcherServlet(HttpServletRequest request, HttpServletResponse response) {
invoke(request, response);
} private void invoke(HttpServletRequest request, HttpServletResponse response) {
String queryUrl = request.getRequestURI();
queryUrl = queryUrl.replaceAll("/+", "/");
Method method = handlerMapping.get(queryUrl);
if (null == method) {
PrintWriter pw = null;
try {
response.setStatus(404);
logger.debug("request fail(404): " + request.getRequestURI());
pw = response.getWriter();
pw.print("404 not find -> " + request.getRequestURI());
pw.flush();
} catch (IOException e) {
e.printStackTrace();
} finally {
pw.close();
}
} else {
//todo method parameters need to deal
Object[] paramValues = getMethodParamAndValue(request, response, method);
try {
String controllerClassName = toFirstWordLower(method.getDeclaringClass().getSimpleName());
Object object = method.invoke(ioc.get(controllerClassName), paramValues);
if (object != null) {
if (method.isAnnotationPresent(MyResponseBody.class)) {
response.setHeader("content-type", "application/json;charset=UTF-8");
if (null == responseBodyHandler) {
object = JSONObject.toJSONString(object, SerializerFeature.WriteMapNullValue);
} else {
object = responseBodyHandler.equals(object);
}
}
response.getWriter().print(object);
logger.debug("request-> " + request.getRequestURI() + ", response success ->" + response.getStatus());
}
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} }
} /****
* @MyRequestParam
* 参数解析 复制
* @注意: 参数解析暂不完整 int float long double boolean string
* 实体接收暂不支持
* @param request
* @param response
* @param method
* @return
*/
private Object[] getMethodParamAndValue(HttpServletRequest request, HttpServletResponse response, Method method) {
Parameter[] parameters = method.getParameters();
Object[] paramValues = new Object[parameters.length];
for (int i = 0; i < parameters.length; i++) { if (ServletRequest.class.isAssignableFrom(parameters[i].getType())) {
paramValues[i] = request;
} else if (ServletResponse.class.isAssignableFrom(parameters[i].getType())) {
paramValues[i] = response;
} else {
String bindingValue = parameters[i].getName();
if (parameters[i].isAnnotationPresent(MyRequestParam.class)) {
bindingValue = parameters[i].getAnnotation(MyRequestParam.class).value();
}
String paramValue = request.getParameter(bindingValue);
paramValues[i] = paramValue;
if (paramValue != null) {
if (Integer.class.isAssignableFrom(parameters[i].getType())) {
paramValues[i] = Integer.parseInt(paramValue);
} else if (Float.class.isAssignableFrom(parameters[i].getType())) {
paramValues[i] = Float.parseFloat(paramValue);
} else if (Double.class.isAssignableFrom(parameters[i].getType())) {
paramValues[i] = Double.parseDouble(paramValue);
} else if (Long.class.isAssignableFrom(parameters[i].getType())) {
paramValues[i] = Long.parseLong(paramValue);
} else if (Boolean.class.isAssignableFrom(parameters[i].getType())) {
paramValues[i] = Boolean.parseBoolean(paramValue);
}
}
}
}
return paramValues;
} /****
* 初始化
* @param contextConfigLocation
* @throws Exception
*/
private void initMyDispatcherServlet(String contextConfigLocation) throws Exception {
logger.info("-----------------------------mySpring init start-----------------------------------------");
logger.debug("doLoadConfig:" + contextConfigLocation);
//加载配置
doLoadConfig(contextConfigLocation);
//扫描 包扫描
logger.debug("scan:" + contextConfig.getProperty("scan.package"));
doScanner(contextConfig.getProperty("scan.package"));
//创建实体类、ioc
doInstance();
//注入 di
doAutowired();
//url 映射
initHandlerMapping(); } /***
* 注入
*/
private void doAutowired() {
if (ioc.isEmpty()) {
return;
}
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
Object object = entry.getValue();
Field[] fields = object.getClass().getDeclaredFields();
for (Field filed : fields) {
if (filed.isAnnotationPresent(MyAutowired.class)) {
MyAutowired myAutowired = filed.getAnnotation(MyAutowired.class);
String key = filed.getType().getName();
String val = myAutowired.value();
if (val != null && "".equals(val.trim())) {
key = val.trim();
}
filed.setAccessible(true);
try {
filed.set(object, ioc.get(key));
} catch (IllegalAccessException e) {
e.printStackTrace();
}
} else {
continue;
}
}
}
} /***
* 初始化HandlerMapper
*/
private void initHandlerMapping() {
if (ioc.isEmpty()) {
return;
}
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
Object object = entry.getValue();
Class<?> clazz = object.getClass();
if (clazz.isAnnotationPresent(MyController.class)) {
Method[] methods = clazz.getDeclaredMethods();
MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
String crlRequstMapping = requestMapping.value() == null ? "" : requestMapping.value();
for (Method method : methods) {
if (method.isAnnotationPresent(MyRequestMapping.class)) {
String url = ("/" + crlRequstMapping + "/" + method.getAnnotation(MyRequestMapping.class).value()).replaceAll("/+", "/");
// check request url must only
if (handlerMapping.containsKey(url)) {
logger.error("mapping request url:" + url + "is already exist! request url must only");
new Exception("mapping:" + url + "is already exist!");
}
handlerMapping.put(url, method);
logger.debug("mapping: " + url);
} else {
continue;
}
}
} }
} /***
* 加载配置文件
* @param contextConfigLocation
* @throws Exception
*/
private void doLoadConfig(String contextConfigLocation) throws Exception {
InputStream is = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
if (is == null) {
logger.error("config:" + contextConfigLocation + " not exist");
throw new Exception("config:" + contextConfigLocation + " not exist");
} else {
try {
contextConfig.load(is);
} catch (IOException e) {
e.printStackTrace();
} finally {
//关流
if (null != is) {
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
} /****
* 包扫描
* @param packageName
* @throws Exception
*/
private void doScanner(String packageName) throws Exception {
if (packageName == null || packageName.length() == 0) {
throw new Exception("init scan is empty");
} URL url = this.getClass().getClassLoader().getResource("/" + packageName.replaceAll("\\.", "/"));
if (null != url) {
File dir = new File(url.getFile());
for (File file : dir.listFiles()) {
if (file.isDirectory()) {
//递归读取包
doScanner(packageName + "." + file.getName());
} else {
String className = packageName + "." + file.getName().replace(".class", "");
logger.debug("scan class find:" + className);
classNames.add(className);
}
}
} } /****
* ioc实例化
*/
private void doInstance() {
if (classNames.isEmpty()) {
return;
}
for (String className : classNames) {
try {
// @MyController instance
Class<?> clazz = Class.forName(className);
if (clazz.isAnnotationPresent(MyController.class)) {
logger.debug("MyController instance: " + clazz.getName());
ioc.put(toFirstWordLower(clazz.getSimpleName()), clazz.newInstance());
} else if (clazz.isAnnotationPresent(MyService.class)) {
//todo @MyService instance
// 1 以自己本类或者用户自定义别名为key
Object newInstance = clazz.newInstance();
String key = toFirstWordLower(clazz.getSimpleName());
logger.debug("MyService instance: " + clazz.getName());
MyService service = clazz.getAnnotation(MyService.class);
String value = service.value().trim();
if (!"".equals(value)) {
key = value;
}
if (!ioc.containsKey(key)) {
ioc.put(key, newInstance);
} else {
logger.error("MyService instance: " + service.value() + " is exist");
throw new Exception("MyService instance: " + service.value() + " is exist");
}
//2 以所继承的接口为 key
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> interClazz : interfaces) {
ioc.put(interClazz.getName(), clazz.newInstance());
} } else if (clazz.isAnnotationPresent(MyResponseAdvice.class)) {
if (clazz.isAssignableFrom(ResponseBodyHandler.class)) {
if (null != responseBodyHandler) {
continue;
}
responseBodyHandler = (ResponseBodyHandler) clazz.newInstance();
} else {
logger.error("class+'" + clazz.getName() + "' must implement ResponseBodyHandler");
throw new Exception("class+'" + clazz.getName() + "' must implement ResponseBodyHandler");
}
} else {
continue;
} } catch (Exception e) {
e.printStackTrace();
continue;
}
}
} /**
* 把字符串的首字母小写
*
* @param name
* @return
*/
private String toFirstWordLower(String name) {
char[] charArray = name.toCharArray();
charArray[0] += 32;
return String.valueOf(charArray);
} }

TestController:

import cn.lyh.mySpring.annotation.*;
import cn.lyh.mySpringTest.domain.User;
import cn.lyh.mySpringTest.service.TestService; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*; @MyController
@MyRequestMapping("/test")
public class TestController {
@MyAutowired
private TestService testService; @MyRequestMapping("test1")
public String test1(@MyRequestParam("name") String name,
@MyRequestParam("sex") Integer sex,
HttpServletRequest request,
HttpServletResponse response) throws IOException { return "name=" + name + "sex=" + sex;
} @MyRequestMapping("test2")
public void test2() { } @MyRequestMapping("test3")
@MyResponseBody
public Map<String, Object> test3(@MyRequestParam("name") String name,
@MyRequestParam("sex") Integer sex,
HttpServletRequest request,
HttpServletResponse response) throws IOException {
Map<String, Object> result = new HashMap<>();
result.put("name", name);
result.put("sex", name); return result;
} @MyRequestMapping("test4")
@MyResponseBody
public User test4(@MyRequestParam("name") String name,
@MyRequestParam("sex") Integer sex,
HttpServletRequest request,
HttpServletResponse response) throws IOException {
User user = new User();
user.setName(name);
user.setId(sex); return user;
} @MyRequestMapping("test5")
@MyResponseBody
public List test5(@MyRequestParam("name") String name,
@MyRequestParam("sex") Integer sex,
HttpServletRequest request,
HttpServletResponse response) throws IOException {
List list = new ArrayList();
User user = new User();
user.setName(name);
user.setId(sex);
list.add(user); return list;
} @MyRequestMapping("test6")
@MyResponseBody
public List test5(HttpServletRequest request,
HttpServletResponse response) throws IOException {
List list = new ArrayList();
User user = new User();
user.setName(null);
user.setId(1);
list.add(user); return list;
} }

pom文件依赖:

 <dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>4.0.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.6.6</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.2</version>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.47</version>
<scope>compile</scope>
</dependency>

最后附上源码地址(MySpring Moudle)