JDK序列化机制及源码解读三:ObjectInputStream对象输入流

这是Java序列化机制及源码解读系列的第三篇,主要学习JDK处理对象输入流的方法。有了ObjectOutputStream的分析,ObjectInputStream解读起来更得心应手,所以这里我们仅分析下ObjectInputStream的构造方法和readObject方法。

1、ObjectInputStream的构造方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public ObjectInputStream(InputStream in) throws IOException {
verifySubclass();
// 根据参数创建BlockDataInputStream实例,会创建一个输入流
bin = new BlockDataInputStream(in);
handles = new HandleTable(10);
vlist = new ValidationList();
serialFilter = ObjectInputFilter.Config.getSerialFilter();
// 自定义readObject判断
enableOverride = false;
// 读取流的头部信息,一般是0xaced
readStreamHeader();
// 设置块数据模式
bin.setBlockDataMode(true);
}

BlockDataInputStream跟 BlockDataOutputStream类似,不过是反过来,其作用是从输入流读取byte数据,read[Type]和read[Type]s将byte转为Java基元数据,并且会返回对应类型的对象。

2、readObject

readObject

两种情况读取对象,一种类重写了readObject,直接调用重写方法,否则调用readObject0反序列化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public final Object readObject()
throws IOException, ClassNotFoundException
{
// 重写了readObject,直接调用重写的readObject实现
if (enableOverride) {
return readObjectOverride();
}

// if nested read, passHandle contains handle of enclosing object
int outerHandle = passHandle;
try {
// 调用readObject0反序列化
Object obj = readObject0(false);
handles.markDependency(outerHandle, passHandle);
ClassNotFoundException ex = handles.lookupException(passHandle);
if (ex != null) {
throw ex;
}
if (depth == 0) {
vlist.doCallbacks();
}
return obj;
} finally {
passHandle = outerHandle;
if (closed && depth == 0) {
clear();
}
}
}
readObject0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
private Object readObject0(boolean unshared) throws IOException {
boolean oldMode = bin.getBlockDataMode();
if (oldMode) {
int remain = bin.currentBlockRemaining();
if (remain > 0) {
throw new OptionalDataException(remain);
} else if (defaultDataEnd) {
throw new OptionalDataException(true);
}
bin.setBlockDataMode(false);
}

//读取tc标识
byte tc;
while ((tc = bin.peekByte()) == TC_RESET) {
bin.readByte();
handleReset();
}

depth++; //递归调用标记
totalObjectRefs++;
try {
// 根据不同的TC标记读取内容
switch (tc) {
/*** 1、非对象非有效数据 对应调用即可 **/
case TC_NULL:
return readNull();

case TC_REFERENCE:
return readHandle(unshared);

case TC_CLASS:
return readClass(unshared);

case TC_CLASSDESC:
case TC_PROXYCLASSDESC:
return readClassDesc(unshared);

/*** 2、具体对象的反序列化调用checkResolve **/
case TC_STRING:
case TC_LONGSTRING:
return checkResolve(readString(unshared)); // 调用readString

case TC_ARRAY:
return checkResolve(readArray(unshared)); // 调用readArray

case TC_ENUM:
return checkResolve(readEnum(unshared)); // 调用readEnum

case TC_OBJECT:
return checkResolve(readOrdinaryObject(unshared)); //调用readOrdinaryObject

case TC_EXCEPTION:
IOException ex = readFatalException();
throw new WriteAbortedException("writing aborted", ex);

case TC_BLOCKDATA:
case TC_BLOCKDATALONG:
if (oldMode) {
bin.setBlockDataMode(true);
bin.peek(); // force header read
throw new OptionalDataException(
bin.currentBlockRemaining());
} else {
throw new StreamCorruptedException(
"unexpected block data");
}

case TC_ENDBLOCKDATA:
if (oldMode) {
throw new OptionalDataException(true);
} else {
throw new StreamCorruptedException(
"unexpected end of block data");
}

default:
throw new StreamCorruptedException(
String.format("invalid type code: %02X", tc));
}
} finally {
depth--;
bin.setBlockDataMode(oldMode);
}
}
read[String|Array|Enum]

读取实例值并返回对应对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//String的读取,并返回字符串对象
private String readString(boolean unshared) throws IOException {
String str;
byte tc = bin.readByte(); //tc获取标识
switch (tc) {
case TC_STRING:
str = bin.readUTF(); // TC_STRING标识,调用readUTF,先读取字符串长度,在读取字符串
break;

case TC_LONGSTRING:
str = bin.readLongUTF(); // TC_LONGSTRING标识,调用readLongUTF,先读取字符串长度,在读取字符串
break;

default:
throw new StreamCorruptedException(
String.format("invalid type code: %02X", tc));
}
passHandle = handles.assign(unshared ? unsharedMarker : str);
handles.finish(passHandle);
//返回字符串对象
return str;
}

// Array的读取,并返回array对象
private Object readArray(boolean unshared) throws IOException {
// 通过标识判断是否ARRAY类型
if (bin.readByte() != TC_ARRAY) {
throw new InternalError();
}

// readClassDesc
ObjectStreamClass desc = readClassDesc(false);
// 获取数组长度
int len = bin.readInt();

filterCheck(desc.forClass(), len);

Object array = null;
Class<?> cl, ccl = null;
if ((cl = desc.forClass()) != null) {
ccl = cl.getComponentType();
// 创建Array实例
array = Array.newInstance(ccl, len);
}

int arrayHandle = handles.assign(unshared ? unsharedMarker : array);
ClassNotFoundException resolveEx = desc.getResolveException();
if (resolveEx != null) {
handles.markException(arrayHandle, resolveEx);
}

if (ccl == null) {
for (int i = 0; i < len; i++) {
readObject0(false);
}
} else if (ccl.isPrimitive()) { //原生类型数组直接读取
if (ccl == Integer.TYPE) {
bin.readInts((int[]) array, 0, len);
} else if (ccl == Byte.TYPE) {
bin.readFully((byte[]) array, 0, len, true);
} else if (ccl == Long.TYPE) {
bin.readLongs((long[]) array, 0, len);
} else if (ccl == Float.TYPE) {
bin.readFloats((float[]) array, 0, len);
} else if (ccl == Double.TYPE) {
bin.readDoubles((double[]) array, 0, len);
} else if (ccl == Short.TYPE) {
bin.readShorts((short[]) array, 0, len);
} else if (ccl == Character.TYPE) {
bin.readChars((char[]) array, 0, len);
} else if (ccl == Boolean.TYPE) {
bin.readBooleans((boolean[]) array, 0, len);
} else {
throw new InternalError();
}
} else { //非原生类型数组调用readObject0读取
Object[] oa = (Object[]) array;
// 循环数组长度
for (int i = 0; i < len; i++) {
oa[i] = readObject0(false);
handles.markDependency(arrayHandle, passHandle);
}
}

handles.finish(arrayHandle);
passHandle = arrayHandle;
// 返回数组对象
return array;
}

// Enum的读取
private Enum<?> readEnum(boolean unshared) throws IOException {
if (bin.readByte() != TC_ENUM) {
throw new InternalError();
}

ObjectStreamClass desc = readClassDesc(false);
if (!desc.isEnum()) {
throw new InvalidClassException("non-enum class: " + desc);
}

int enumHandle = handles.assign(unshared ? unsharedMarker : null);
ClassNotFoundException resolveEx = desc.getResolveException();
if (resolveEx != null) {
handles.markException(enumHandle, resolveEx);
}

String name = readString(false);
Enum<?> result = null;
Class<?> cl = desc.forClass();
if (cl != null) {
try {
@SuppressWarnings("unchecked")
// 赋值给枚举变量具体的数值,数值来自读取的数据
Enum<?> en = Enum.valueOf((Class)cl, name);
result = en;
} catch (IllegalArgumentException ex) {
throw (IOException) new InvalidObjectException(
"enum constant " + name + " does not exist in " +
cl).initCause(ex);
}
if (!unshared) {
handles.setObject(enumHandle, result);
}
}

handles.finish(enumHandle);
passHandle = enumHandle;
//返回枚举数据
return result;
}
readOrdinaryObject:

readOrdinaryObject主要4个步骤:

  1. 根据TC_OBJECT判断是否是对象;
  2. 调用readClassDesc读取类元信息,用来创建实例;
  3. 调用Externalize/Serializable获取对象属性域信息并赋值给对象的属性;
  4. 如果是实现了readResolve,调用readResolve来直接替换前面第2-3步反序列化的对象。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
private Object readOrdinaryObject(boolean unshared)
throws IOException
{
/*** 1、首先根据TC_OBJECT判断是否是对象 **/
if (bin.readByte() != TC_OBJECT) {
throw new InternalError();
}

/*** 2、调用readClassDesc读取类元信息,并创建实例 **/
ObjectStreamClass desc = readClassDesc(false);
desc.checkDeserialize();

// 根据desc获取类名
Class<?> cl = desc.forClass();
if (cl == String.class || cl == Class.class
|| cl == ObjectStreamClass.class) {
throw new InvalidClassException("invalid class descriptor");
}

Object obj;
try {
// 类是否是能可实例化,可以的话调用newInstance无参构造创建实例
obj = desc.isInstantiable() ? desc.newInstance() : null;
} catch (Exception ex) {
throw (IOException) new InvalidClassException(
desc.forClass().getName(),
"unable to create instance").initCause(ex);
}

passHandle = handles.assign(unshared ? unsharedMarker : obj);
ClassNotFoundException resolveEx = desc.getResolveException();
if (resolveEx != null) {
handles.markException(passHandle, resolveEx);
}
/*** 3、调用Externalize或者Serializable获取对象属性域信息并赋值 **/
if (desc.isExternalizable()) { //实现了Externalizable,调用readExternalData
readExternalData((Externalizable) obj, desc);
} else { //实现了Serializable,调用readSerialData
readSerialData(obj, desc);
}

handles.finish(passHandle);

/*** 4、 判断是否实现了readResolve **/
if (obj != null &&
handles.lookupException(passHandle) == null &&
desc.hasReadResolveMethod())
{
// 实现了readResolve,调用readResolve来替换反序列化的对象
Object rep = desc.invokeReadResolve(obj);
if (unshared && rep.getClass().isArray()) {
rep = cloneArray(rep);
}
if (rep != obj) {
// Filter the replacement object
if (rep != null) {
if (rep.getClass().isArray()) {
filterCheck(rep.getClass(), Array.getLength(rep));
} else {
filterCheck(rep.getClass(), -1);
}
}
handles.setObject(passHandle, obj = rep);
}
}
// 返回对象
return obj;
}
ObjectStreamClass#readClassDesc和readNonProxy:读取类元信息

步骤:

  1. 获取类的类名;
  2. 获取类的serialVersionUID;
  3. 获取类使用的序列化方式;
  4. 获取类所有序列化的属性

这些信息将会被封装在ObjectStreamClass返回。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
private ObjectStreamClass readClassDesc(boolean unshared)
throws IOException
{
// 读取TC标识
byte tc = bin.peekByte();
ObjectStreamClass descriptor;
switch (tc) {
case TC_NULL:
descriptor = (ObjectStreamClass) readNull();
break;
case TC_REFERENCE:
descriptor = (ObjectStreamClass) readHandle(unshared);
break;
case TC_PROXYCLASSDESC:
descriptor = readProxyDesc(unshared);
break;
case TC_CLASSDESC:
// 如果TC是一个类描述符 意味着后面的数据是描述类信息,则调用readNonProxyDesc,readNonProxyDesc会调用readClassDescriptor,readClassDescriptor调用
descriptor = readNonProxyDesc(unshared);
break;
default:
throw new StreamCorruptedException(
String.format("invalid type code: %02X", tc));
}
if (descriptor != null) {
validateDescriptor(descriptor);
}
return descriptor;
}

void readNonProxy(ObjectInputStream in)
throws IOException, ClassNotFoundException
{
/*** 1、获取类的类名**/
name = in.readUTF(); // readUTF返回字符串,这里返回目标类名
/*** 2、获取类的serialVersionUID**/
suid = Long.valueOf(in.readLong()); // 接下来读取serialVersionUID
isProxy = false;

/*** 3、获取类使用的序列化方式 **/
byte flags = in.readByte(); //读取类使用的序列化方式
hasWriteObjectData =
((flags & ObjectStreamConstants.SC_WRITE_METHOD) != 0);
hasBlockExternalData =
((flags & ObjectStreamConstants.SC_BLOCK_DATA) != 0);
externalizable =
((flags & ObjectStreamConstants.SC_EXTERNALIZABLE) != 0);
boolean sflag =
((flags & ObjectStreamConstants.SC_SERIALIZABLE) != 0);
if (externalizable && sflag) {
throw new InvalidClassException(
name, "serializable and externalizable flags conflict");
}
serializable = externalizable || sflag;
isEnum = ((flags & ObjectStreamConstants.SC_ENUM) != 0);
if (isEnum && suid.longValue() != 0L) {
throw new InvalidClassException(name,
"enum descriptor has non-zero serialVersionUID: " + suid);
}
/*** 4、获取类序列化的属性 **/
int numFields = in.readShort(); //获取序列化的类属性个数
if (isEnum && numFields != 0) {
throw new InvalidClassException(name,
"enum descriptor has non-zero field count: " + numFields);
}
fields = (numFields > 0) ?
new ObjectStreamField[numFields] : NO_FIELDS;
// 循环获取每个类属性
for (int i = 0; i < numFields; i++) {
char tcode = (char) in.readByte();
String fname = in.readUTF();
String signature = ((tcode == 'L') || (tcode == '[')) ?
in.readTypeString() : new String(new char[] { tcode });
try {
fields[i] = new ObjectStreamField(fname, signature, false);
} catch (RuntimeException e) {
throw (IOException) new InvalidClassException(name,
"invalid descriptor for field " + fname).initCause(e);
}
}
computeFieldOffsets();
}
readExternalData/readSerialData获取对象属性域信息:

readSerialData未重写readObject会调用默认处理方法defaultReadFields

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
private void readSerialData(Object obj, ObjectStreamClass desc) throws IOException {
ObjectStreamClass.ClassDataSlot[] slots = desc.getClassDataLayout();
// 循环获取每个属性值数据
for (int i = 0; i < slots.length; i++) {
ObjectStreamClass slotDesc = slots[i].desc;

if (slots[i].hasData) {
if (obj == null || handles.lookupException(passHandle) != null) {
defaultReadFields(null, slotDesc); // skip field values
} else if (slotDesc.hasReadObjectMethod()) {
ThreadDeath t = null;
boolean reset = false;
SerialCallbackContext oldContext = curContext;
if (oldContext != null)
oldContext.check();
try {
curContext = new SerialCallbackContext(obj, slotDesc);

bin.setBlockDataMode(true);
// 重写readObject,调用重写方法
slotDesc.invokeReadObject(obj, this);
} catch (ClassNotFoundException ex) {
handles.markException(passHandle, ex);
} finally {
do {
try {
curContext.setUsed();
if (oldContext!= null)
oldContext.check();
curContext = oldContext;
reset = true;
} catch (ThreadDeath x) {
t = x; // defer until reset is true
}
} while (!reset);
if (t != null)
throw t;
}
defaultDataEnd = false;
} else {
//非重写,调用defaultReadFields
defaultReadFields(obj, slotDesc);
}

if (slotDesc.hasWriteObjectData()) {
skipCustomData();
} else {
bin.setBlockDataMode(false);
}
} else {
if (obj != null &&
slotDesc.hasReadObjectNoDataMethod() &&
handles.lookupException(passHandle) == null)
{
slotDesc.invokeReadObjectNoData(obj);
}
}
}
}

// 调用 类实现的readExternal方法获取
private void readExternalData(Externalizable obj, ObjectStreamClass desc) throws IOException {
SerialCallbackContext oldContext = curContext;
if (oldContext != null)
oldContext.check();
curContext = null;
try {
boolean blocked = desc.hasBlockExternalData();
if (blocked) {
bin.setBlockDataMode(true);
}
if (obj != null) {
try {
// 调用 实现的readExternal方法获取
obj.readExternal(this);
} catch (ClassNotFoundException ex) {
/*
* In most cases, the handle table has already propagated
* a CNFException to passHandle at this point; this mark
* call is included to address cases where the readExternal
* method has cons'ed and thrown a new CNFException of its
* own.
*/
handles.markException(passHandle, ex);
}
}
if (blocked) {
skipCustomData();
}
} finally {
if (oldContext != null)
oldContext.check();
curContext = oldContext;
}
}
defaultReadFields:属性值具体操作
  • 获取所有原生类型的属性数据并赋值对象属性:调用ObjectStreamClass#setPrimFieldValues给所有属性进行赋值
  • 获取所有非原生类型的属性数据并赋值:循环读取每个非原生属性,用readObject0递归调用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
private void defaultReadFields(Object obj, ObjectStreamClass desc) throws IOException {
// 获取类
Class<?> cl = desc.forClass();
if (cl != null && obj != null && !cl.isInstance(obj)) {
throw new ClassCastException();
}

/*** 1、获取所有原生类型的属性数据并赋值对象属性**/
// 获取属性是原生类型的属性数据个数
int primDataSize = desc.getPrimDataSize();
if (primVals == null || primVals.length < primDataSize) {
primVals = new byte[primDataSize];
}
// 读取所有原生类型的属性数据
bin.readFully(primVals, 0, primDataSize, false);
if (obj != null) {// 检查是否实例化
// 设置属性值 即赋值操作
desc.setPrimFieldValues(obj, primVals);
}

int objHandle = passHandle;
/*** 2、获取所有非原生类型的属性数据并赋值 **/
ObjectStreamField[] fields = desc.getFields(false);
// 给每个非原生类型属性进行实例化一个对象操作
Object[] objVals = new Object[desc.getNumObjFields()];
int numPrimFields = fields.length - objVals.length;
// 循环读取每个非原生属性,用readObject0递归调用
for (int i = 0; i < objVals.length; i++) {
ObjectStreamField f = fields[numPrimFields + i];
objVals[i] = readObject0(f.isUnshared());
if (f.getField() != null) {
handles.markDependency(objHandle, passHandle);
}
}
if (obj != null) {
desc.setObjFieldValues(obj, objVals);
}
passHandle = objHandle;
}

// ObjectStreamClass#setPrimFieldValues
// 原生属性类型赋值,将字节buf付给obj属性
void setPrimFieldValues(Object obj, byte[] buf) {
if (obj == null) {
throw new NullPointerException();
}
// 循环赋值
for (int i = 0; i < numPrimFields; i++) {
long key = writeKeys[i];
if (key == Unsafe.INVALID_FIELD_OFFSET) {
continue; // discard value
}
int off = offsets[i];
// 根据序列化后类型码进行赋值,如类型码为Z,则按照布尔类型给obj赋值
switch (typeCodes[i]) {
case 'Z':
unsafe.putBoolean(obj, key, Bits.getBoolean(buf, off));
break;

case 'B':
unsafe.putByte(obj, key, buf[off]);
break;

case 'C':
unsafe.putChar(obj, key, Bits.getChar(buf, off));
break;

case 'S':
unsafe.putShort(obj, key, Bits.getShort(buf, off));
break;

case 'I':
unsafe.putInt(obj, key, Bits.getInt(buf, off));
break;

case 'F':
unsafe.putFloat(obj, key, Bits.getFloat(buf, off));
break;

case 'J':
unsafe.putLong(obj, key, Bits.getLong(buf, off));
break;

case 'D':
unsafe.putDouble(obj, key, Bits.getDouble(buf, off));
break;

default:
throw new InternalError();
}
}
}

3、总结-对象输入流规则

跟对象输出规则一致,也是按照三部分进行,其实每部分也都一样的,不过一个是写入,一个是读取罢了。

ObjectInputStream读取规则如下:

  1. 第一部分获取序列化头信息:描述序列化协议的信息和版本;
  2. 第二部分获取类元信息包括序列化的类名称、以及哪些属性是被序列化的,获取到后用newInstance构造一个类的对象。
  3. 第三部分获取属性域的值信息,赋值对应序列化的每个属性。

更新-拓展:

weblogic经常打补丁的地方:resolveClass。

1、原生ObjectInputStream.resolveClass很简单,就是获取类名。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
protected Class<?> resolveClass(ObjectStreamClass desc)
throws IOException, ClassNotFoundException
{
String name = desc.getName();
try {
return Class.forName(name, false, latestUserDefinedLoader());
} catch (ClassNotFoundException ex) {
Class<?> cl = primClasses.get(name);
if (cl != null) {
return cl;
} else {
throw ex;
}
}
}

在readNonProxyDesc中进行了调用

image-20210809165619482

在readClassDesc中,当TC是一个类描述符,调用readNonProxyDesc获取类名。

image-20210809165722640

2、Weblogic的weblogic.rjvm.InboundMsgAbbrev中,对类进行了黑名单检测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
protected Class resolveClass(ObjectStreamClass descriptor) throws ClassNotFoundException, IOException {
try {
// 进行黑名单检测
this.checkLegacyBlacklistIfNeeded(descriptor.getName());
} catch (InvalidClassException var4) {
throw var4;
}

Class c = super.resolveClass(descriptor);
if(c == null) {
throw new ClassNotFoundException("super.resolveClass returns null.");
} else {
ObjectStreamClass localDesc = ObjectStreamClass.lookup(c);
if(localDesc != null && localDesc.getSerialVersionUID() != descriptor.getSerialVersionUID()) {
throw new ClassNotFoundException("different serialVersionUID. local: " + localDesc.getSerialVersionUID() + " remote: " + descriptor.getSerialVersionUID());
} else {
return c;
}
}
}