diff --git a/spring-ws-security/src/main/java/org/springframework/ws/soap/security/wss4j2/Wss4jSecurityInterceptor.java b/spring-ws-security/src/main/java/org/springframework/ws/soap/security/wss4j2/Wss4jSecurityInterceptor.java index 01dad04dd..92fdc1446 100644 --- a/spring-ws-security/src/main/java/org/springframework/ws/soap/security/wss4j2/Wss4jSecurityInterceptor.java +++ b/spring-ws-security/src/main/java/org/springframework/ws/soap/security/wss4j2/Wss4jSecurityInterceptor.java @@ -145,6 +145,8 @@ public class Wss4jSecurityInterceptor extends AbstractWsSecurityInterceptor impl private CallbackHandler samlCallbackHandler; + private CallbackHandler attachmentCallbackHandler; + // Allow RSA 15 to maintain default behavior private boolean allowRSA15KeyTransportAlgorithm = true; @@ -384,6 +386,14 @@ public void setSecurementUseDerivedKey(boolean securementUseDerivedKey) { public void setSecurementSamlCallbackHandler(CallbackHandler samlCallbackHandler) { this.samlCallbackHandler = samlCallbackHandler; } + + /** + * Sets the attachment callback handler used for SwA signature/encryption + * @param attachmentCallbackHandler + */ + public void setAttachmentCallbackHandler (CallbackHandler attachmentCallbackHandler) { + this.attachmentCallbackHandler = attachmentCallbackHandler; + } /** Sets the server-side time to live */ public void setValidationTimeToLive(int validationTimeToLive) { @@ -604,6 +614,8 @@ protected RequestData initializeRequestData(MessageContext messageContext) { requestData.setUseDerivedKeyForMAC(securementUseDerivedKey); requestData.setWssConfig(wssConfig); + + requestData.setAttachmentCallbackHandler(attachmentCallbackHandler); messageContext.setProperty(WSHandlerConstants.TTL_TIMESTAMP, Integer.toString(securementTimeToLive)); diff --git a/spring-ws-security/src/main/java/org/springframework/ws/soap/security/wss4j2/callback/SAAJAttachmentCallbackHandler.java b/spring-ws-security/src/main/java/org/springframework/ws/soap/security/wss4j2/callback/SAAJAttachmentCallbackHandler.java new file mode 100644 index 000000000..063080752 --- /dev/null +++ b/spring-ws-security/src/main/java/org/springframework/ws/soap/security/wss4j2/callback/SAAJAttachmentCallbackHandler.java @@ -0,0 +1,153 @@ +package org.springframework.ws.soap.security.wss4j2.callback; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import javax.activation.DataHandler; +import javax.activation.DataSource; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.xml.soap.AttachmentPart; +import javax.xml.soap.MimeHeader; +import javax.xml.soap.SOAPException; + +import org.apache.wss4j.common.ext.AttachmentRequestCallback; +import org.apache.wss4j.common.ext.AttachmentResultCallback; +import org.springframework.ws.soap.SoapMessage; +import org.springframework.ws.soap.saaj.SaajSoapMessage; + +/** + * A CallbackHandler to be used to sign/encrypt SAAJ SOAP Attachments. + */ +public class SAAJAttachmentCallbackHandler implements CallbackHandler { + + private SaajSoapMessage soapMessage; + + public SAAJAttachmentCallbackHandler(SoapMessage soapMessage) { + this.soapMessage = (SaajSoapMessage) soapMessage; + } + + @Override + public void handle (Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof AttachmentRequestCallback) { + AttachmentRequestCallback attachmentRequestCallback = (AttachmentRequestCallback) callback; + + List attachmentList = new ArrayList<>(); + attachmentRequestCallback.setAttachments(attachmentList); + + String attachmentId = attachmentRequestCallback.getAttachmentId(); + if ("Attachments".equals(attachmentId)) { + // Load all attachments + attachmentId = null; + } + loadAttachments(attachmentList, attachmentId, attachmentRequestCallback.isRemoveAttachments()); + } else if (callback instanceof AttachmentResultCallback) { + AttachmentResultCallback attachmentResultCallback = (AttachmentResultCallback) callback; + AttachmentPart attachmentPart = soapMessage.getSaajMessage() + .createAttachmentPart(new DataHandler( + new InputStreamDataSource( + attachmentResultCallback.getAttachment() + .getSourceStream(), + attachmentResultCallback.getAttachment() + .getMimeType()))); + attachmentPart.setContentId(attachmentResultCallback.getAttachmentId()); + + Map headers = attachmentResultCallback.getAttachment() + .getHeaders(); + for (Map.Entry entry : headers.entrySet()) { + attachmentPart.addMimeHeader(entry.getKey(), entry.getValue()); + } + + soapMessage.getSaajMessage() + .addAttachmentPart(attachmentPart); + + } else { + throw new UnsupportedCallbackException(callback, "Unsupported callback"); + } + } + } + + @SuppressWarnings("unchecked") + private void loadAttachments ( + List attachmentList, + String attachmentId, + boolean removeAttachments) + throws IOException { + // Calling LazyAttachmentCollection.size() here to force it to load the attachments + Iterator iterator = soapMessage.getSaajMessage() + .getAttachments(); + + while (iterator.hasNext()) { + AttachmentPart attachmentPart = iterator.next(); + if (attachmentId != null && !attachmentId.equals(attachmentPart.getContentId())) { + continue; + } + org.apache.wss4j.common.ext.Attachment att = new org.apache.wss4j.common.ext.Attachment(); + att.setMimeType(attachmentPart.getContentType()); + att.setId(attachmentPart.getContentId()); + try { + att.setSourceStream(attachmentPart.getDataHandler() + .getInputStream()); + } catch (SOAPException e) { + throw new IOException("Soap exception: " + e.getMessage()); + } + Iterator mimeHeaders = attachmentPart.getAllMimeHeaders(); + while (mimeHeaders.hasNext()) { + MimeHeader mimeHeader = mimeHeaders.next(); + att.addHeader(mimeHeader.getName(), mimeHeader.getValue()); + } + attachmentList.add(att); + + if (removeAttachments) { + iterator.remove(); + } + } + } + + /** + * Activation framework {@code DataSource} that wraps a Spring {@code InputStreamSource}. + * + * @author Arjen Poutsma + * @since 1.0.0 + */ + private static class InputStreamDataSource implements DataSource { + + private final InputStream inputStream; + + private final String contentType; + + public InputStreamDataSource(InputStream inputStream, String contentType) { + this.inputStream = inputStream; + this.contentType = contentType; + } + + @Override + public InputStream getInputStream () throws IOException { + return inputStream; + } + + @Override + public OutputStream getOutputStream () { + throw new UnsupportedOperationException("Read-only javax.activation.DataSource"); + } + + @Override + public String getContentType () { + return contentType; + } + + @Override + public String getName () { + throw new UnsupportedOperationException("DataSource name not available"); + } + + } + +} \ No newline at end of file diff --git a/spring-ws-security/src/test/java/org/springframework/ws/soap/security/wss4j2/callback/SAAJAttachmentCallbackHandlerTest.java b/spring-ws-security/src/test/java/org/springframework/ws/soap/security/wss4j2/callback/SAAJAttachmentCallbackHandlerTest.java new file mode 100644 index 000000000..f129ffbdb --- /dev/null +++ b/spring-ws-security/src/test/java/org/springframework/ws/soap/security/wss4j2/callback/SAAJAttachmentCallbackHandlerTest.java @@ -0,0 +1,135 @@ +package org.springframework.ws.soap.security.wss4j2.callback; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.ByteArrayInputStream; +import java.io.IOException; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; + +import org.apache.wss4j.common.ext.Attachment; +import org.apache.wss4j.common.ext.AttachmentRequestCallback; +import org.apache.wss4j.common.ext.AttachmentResultCallback; +import org.junit.Test; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.ws.soap.SoapMessage; +import org.springframework.ws.soap.saaj.SaajSoapMessageFactory; + +public class SAAJAttachmentCallbackHandlerTest { + + private static final String ATTACHMENTS = "Attachments"; + private static final String CONTENT_ID_TEST = "123456"; + private static final String CONTENT_TYPE_TEST = "application/xml"; + private static final String SAMPLE_XML_CONTENT = ""; + + @Test + public void requestCallbackNoDeleteTest () { + SaajSoapMessageFactory soapMessageFactory = createSAAJMessageFactory(); + SoapMessage soapMessage = soapMessageFactory.createWebServiceMessage(); + soapMessage.setSoapAction(null); + soapMessage.addAttachment(CONTENT_ID_TEST, new ByteArrayResource(SAMPLE_XML_CONTENT.getBytes()), CONTENT_TYPE_TEST); + + SAAJAttachmentCallbackHandler attachmentCallbackHandler = new SAAJAttachmentCallbackHandler(soapMessage); + AttachmentRequestCallback attachmentRequestCallback = new AttachmentRequestCallback(); + + attachmentRequestCallback.setRemoveAttachments(false); + attachmentRequestCallback.setAttachmentId(ATTACHMENTS); + Callback[] callbacks = new Callback[1]; + callbacks[0] = attachmentRequestCallback; + + try { + attachmentCallbackHandler.handle(callbacks); + } catch (IOException | UnsupportedCallbackException e) { + fail(e.getMessage()); + } + assertEquals("Callback must have an attachment element", 1, attachmentRequestCallback.getAttachments() + .size()); + assertEquals("Callback attachment must have the same content-ID", CONTENT_ID_TEST, attachmentRequestCallback.getAttachments() + .get(0) + .getId()); + + assertTrue("SoapMessage must conserve its attachment", soapMessage.getAttachments() + .hasNext()); + + } + + @Test + public void requestCallbackDeleteTest () { + SaajSoapMessageFactory soapMessageFactory = createSAAJMessageFactory(); + SoapMessage soapMessage = soapMessageFactory.createWebServiceMessage(); + soapMessage.setSoapAction(null); + soapMessage.addAttachment(CONTENT_ID_TEST, new ByteArrayResource(SAMPLE_XML_CONTENT.getBytes()), CONTENT_TYPE_TEST); + + SAAJAttachmentCallbackHandler attachmentCallbackHandler = new SAAJAttachmentCallbackHandler(soapMessage); + AttachmentRequestCallback attachmentRequestCallback = new AttachmentRequestCallback(); + + attachmentRequestCallback.setRemoveAttachments(true); + attachmentRequestCallback.setAttachmentId(CONTENT_ID_TEST); + Callback[] callbacks = new Callback[1]; + callbacks[0] = attachmentRequestCallback; + + try { + attachmentCallbackHandler.handle(callbacks); + } catch (IOException | UnsupportedCallbackException e) { + fail(e.getMessage()); + } + assertEquals("Callback must have an attachment element", 1, attachmentRequestCallback.getAttachments() + .size()); + assertEquals("Callback attachment must have the same content-ID", CONTENT_ID_TEST, attachmentRequestCallback.getAttachments() + .get(0) + .getId()); + + assertTrue("SoapMessage must not conserve its attachments", !soapMessage.getAttachments() + .hasNext()); + } + + @Test + public void responseCallbackTest () { + SaajSoapMessageFactory soapMessageFactory = createSAAJMessageFactory(); + SoapMessage soapMessage = soapMessageFactory.createWebServiceMessage(); + soapMessage.setSoapAction(null); + + SAAJAttachmentCallbackHandler attachmentCallbackHandler = new SAAJAttachmentCallbackHandler(soapMessage); + AttachmentResultCallback attachmentResultCallback = new AttachmentResultCallback(); + + Attachment attachment = new Attachment(); + attachment.setId(CONTENT_ID_TEST); + attachment.setMimeType(CONTENT_TYPE_TEST); + attachment.setSourceStream(new ByteArrayInputStream(SAMPLE_XML_CONTENT.getBytes())); + + attachmentResultCallback.setAttachment(attachment); + attachmentResultCallback.setAttachmentId(CONTENT_ID_TEST); + Callback[] callbacks = new Callback[1]; + callbacks[0] = attachmentResultCallback; + + assertTrue("SoapMessage must not have attachments at start", !soapMessage.getAttachments() + .hasNext()); + + try { + attachmentCallbackHandler.handle(callbacks); + } catch (IOException | UnsupportedCallbackException e) { + fail(e.getMessage()); + } + + assertTrue("SoapMessage must have attachments after result handle", soapMessage.getAttachments() + .hasNext()); + + org.springframework.ws.mime.Attachment resultAttachment = soapMessage.getAttachments() + .next(); + + assertEquals("SoapMessage attachment must have the same content-ID", CONTENT_ID_TEST, resultAttachment.getContentId()); + assertEquals("SoapMessage attachment must have the same content-type", CONTENT_TYPE_TEST, resultAttachment.getContentType()); + + } + + private SaajSoapMessageFactory createSAAJMessageFactory () { + SaajSoapMessageFactory soapMessageFactory = new SaajSoapMessageFactory(); + soapMessageFactory.setSoapVersion(org.springframework.ws.soap.SoapVersion.SOAP_12); + soapMessageFactory.afterPropertiesSet(); + return soapMessageFactory; + } + +}