8
8
import java .util .List ;
9
9
10
10
import javax .xml .bind .UnmarshalException ;
11
- import javax .xml .transform .sax . SAXSource ;
11
+ import javax .xml .transform .Source ;
12
12
import javax .xml .transform .stream .StreamSource ;
13
13
14
- import org .dmg .pmml .Extension ;
15
14
import org .dmg .pmml .PMML ;
16
15
import org .junit .Test ;
17
- import org .xml .sax .InputSource ;
18
- import org .xml .sax .XMLReader ;
19
- import org .xml .sax .helpers .XMLReaderFactory ;
16
+ import org .xml .sax .SAXParseException ;
20
17
21
18
import static org .junit .Assert .assertEquals ;
19
+ import static org .junit .Assert .assertTrue ;
22
20
import static org .junit .Assert .fail ;
23
21
24
22
public class XXEAttackTest {
@@ -30,30 +28,31 @@ public void unmarshal() throws Exception {
30
28
System .setProperty ("javax.xml.accessExternalDTD" , "file" );
31
29
32
30
try (InputStream is = ResourceUtil .getStream (XXEAttackTest .class );){
33
- pmml = JAXBUtil .unmarshalPMML (new StreamSource (is ));
31
+ Source source = new StreamSource (is );
32
+
33
+ pmml = JAXBUtil .unmarshalPMML (source );
34
34
} finally {
35
35
System .clearProperty ("javax.xml.accessExternalDTD" );
36
36
}
37
37
38
- List <Extension > extensions = pmml .getExtensions ();
39
- assertEquals (1 , extensions .size ());
38
+ List <?> content = ExtensionUtil .getContent (pmml );
40
39
41
- Extension extension = extensions .get (0 );
42
- assertEquals (Arrays .asList ("lol" ), extension .getContent ());
40
+ assertEquals (Arrays .asList ("lol" ), content );
43
41
}
44
42
45
43
@ Test
46
44
public void unmarshalSecurely () throws Exception {
47
45
48
46
try (InputStream is = ResourceUtil .getStream (XXEAttackTest .class )){
49
- XMLReader reader = XMLReaderFactory .createXMLReader ();
50
- reader .setFeature ("http://apache.org/xml/features/disallow-doctype-decl" , true );
47
+ Source source = SAXUtil .createFilteredSource (is );
51
48
52
- JAXBUtil .unmarshalPMML (new SAXSource ( reader , new InputSource ( is )) );
49
+ JAXBUtil .unmarshalPMML (source );
53
50
54
51
fail ();
55
52
} catch (UnmarshalException ue ){
56
- // Ignored
53
+ Throwable cause = ue .getCause ();
54
+
55
+ assertTrue (cause instanceof SAXParseException );
57
56
}
58
57
}
59
58
}
0 commit comments