From 53f9acd9224fffd3a220ae3753b53f0876c707f6 Mon Sep 17 00:00:00 2001 From: ltiongku Date: Mon, 12 Aug 2024 23:55:08 +0800 Subject: [PATCH] ml init --- pom.xml | 21 +-- .../controller/QRCodeTypeController.java | 10 +- .../com/safeqr/app/qrcode/model/URLModel.java | 2 +- .../service/URLVerificationService.java | 29 +--- .../safeqr/app/spark/model/URLFeatures.java | 148 ++++++++++-------- .../app/spark/service/MLModelService.java | 22 +++ 6 files changed, 124 insertions(+), 108 deletions(-) create mode 100644 src/main/java/com/safeqr/app/spark/service/MLModelService.java diff --git a/pom.xml b/pom.xml index ced849d..0a8cddd 100644 --- a/pom.xml +++ b/pom.xml @@ -121,26 +121,7 @@ jackson-annotations 2.17.2 - - - org.apache.spark - spark-core_2.13 - 3.4.3 - - - - org.apache.spark - spark-sql_2.13 - 3.4.3 - provided - - - - org.apache.spark - spark-mllib_2.13 - 3.4.3 - provided - + diff --git a/src/main/java/com/safeqr/app/qrcode/controller/QRCodeTypeController.java b/src/main/java/com/safeqr/app/qrcode/controller/QRCodeTypeController.java index 2397e98..aff7a4d 100644 --- a/src/main/java/com/safeqr/app/qrcode/controller/QRCodeTypeController.java +++ b/src/main/java/com/safeqr/app/qrcode/controller/QRCodeTypeController.java @@ -61,10 +61,12 @@ public class QRCodeTypeController { return ResponseEntity.ok(qrCodeTypeService.detectType(payload).block()); } - @PostMapping(API_URL_QRCODE_VERIFY_URL) - public ResponseEntity verifyURL(@RequestBody QRCodePayload payload) { - URLVerificationResponse response = urlVerificationService.verifyURL(payload); - return ResponseEntity.ok(response); + @PostMapping(value = API_URL_QRCODE_VERIFY_URL, produces = MediaType.APPLICATION_JSON_VALUE) + public ResponseEntity verifyURL(@RequestBody QRCodePayload payload, + @RequestHeader(required = false, name = HEADER_USER_ID) String userId) { + logger.info("User Id Invoking verify url endpoint: {}", userId); + return ResponseEntity.ok(qrCodeTypeService.scanQRCode(userId, payload)); + } @PostMapping(API_URL_QRCODE_VIRUS_TOTAL_CHECK) diff --git a/src/main/java/com/safeqr/app/qrcode/model/URLModel.java b/src/main/java/com/safeqr/app/qrcode/model/URLModel.java index db5984a..c5e9721 100644 --- a/src/main/java/com/safeqr/app/qrcode/model/URLModel.java +++ b/src/main/java/com/safeqr/app/qrcode/model/URLModel.java @@ -47,6 +47,6 @@ public final class URLModel extends QRCodeModel { @Override public String retrieveClassification() { - return ""; + return urlVerificationService.getClassification(this); } } diff --git a/src/main/java/com/safeqr/app/qrcode/service/URLVerificationService.java b/src/main/java/com/safeqr/app/qrcode/service/URLVerificationService.java index b4897fd..85caae4 100644 --- a/src/main/java/com/safeqr/app/qrcode/service/URLVerificationService.java +++ b/src/main/java/com/safeqr/app/qrcode/service/URLVerificationService.java @@ -1,11 +1,10 @@ package com.safeqr.app.qrcode.service; import static com.safeqr.app.constants.CommonConstants.*; - -import com.safeqr.app.qrcode.dto.request.QRCodePayload; -import com.safeqr.app.qrcode.dto.URLVerificationResponse; import com.safeqr.app.qrcode.entity.URLEntity; +import com.safeqr.app.qrcode.model.URLModel; import com.safeqr.app.qrcode.repository.URLRepository; +import com.safeqr.app.spark.service.MLModelService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -29,9 +28,11 @@ public class URLVerificationService { private static final int READ_TIMEOUT_MS = 10000; private static final Logger logger = LoggerFactory.getLogger(URLVerificationService.class); private final URLRepository urlRepository; + private final MLModelService mlModelService; @Autowired - public URLVerificationService(URLRepository urlRepository) { + public URLVerificationService(URLRepository urlRepository, MLModelService mlModelService) { this.urlRepository = urlRepository; + this.mlModelService = mlModelService; } // Regular expression pattern for shortening services @@ -425,22 +426,8 @@ public class URLVerificationService { return INFO_NON_SECURE_CONNECTION; } - public URLVerificationResponse verifyURL(QRCodePayload payload) { - URLVerificationResponse response = new URLVerificationResponse(); - try { - java.net.URL url = new java.net.URL(payload.getData()); - String protocol = url.getProtocol(); - if ("https".equalsIgnoreCase(protocol)) { - response.setSecure(true); - response.setMessage("The connection is secure."); - } else { - response.setSecure(false); - response.setMessage("The connection is not secure."); - } - } catch (Exception e) { - response.setSecure(false); - response.setMessage("Invalid URL."); - } - return response; + // Get Classification using ML Model + public String getClassification(URLModel urlModel){ + return mlModelService.predict(urlModel); } } \ No newline at end of file diff --git a/src/main/java/com/safeqr/app/spark/model/URLFeatures.java b/src/main/java/com/safeqr/app/spark/model/URLFeatures.java index 9910919..2c281d3 100644 --- a/src/main/java/com/safeqr/app/spark/model/URLFeatures.java +++ b/src/main/java/com/safeqr/app/spark/model/URLFeatures.java @@ -1,134 +1,158 @@ package com.safeqr.app.spark.model; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; +import com.safeqr.app.qrcode.entity.URLEntity; +import com.safeqr.app.qrcode.model.URLModel; +import lombok.*; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; -@Data +@Getter @Builder @NoArgsConstructor @AllArgsConstructor public class URLFeatures { - private Long domain; - private Long subdomain; - private Long topLevelDomain; - private Long query; - private Long fragment; - private Long redirect; - private Long path; - private Long redirectChain; - private Long hstsHeader; - private Long sslStripping; - private Long hostnameEmbedding; - private Long javascriptCheck; - private Long shorteningService; - private Long hasIpAddress; - private Long trackingDescriptions; - private Long urlEncoding; - private Long hasExecutable; - private Long tls; - private Long contents; + private Double domain; + private Double subdomain; + private Double topLevelDomain; + private Double query; + private Double fragment; + private Double redirect; + private Double path; + private Double redirectChain; + private Double hstsHeader; + private Double sslStripping; + private Double hostnameEmbedding; + private Double javascriptCheck; + private Double shorteningService; + private Double hasIpAddress; + private Double trackingDescriptions; + private Double urlEncoding; + private Double hasExecutable; + private Double tls; + private Double contents; private String target; // This is the label, may be null if predicting + public static URLFeatures fromEntity(URLModel urlModel) { + URLFeatures features = URLFeatures.builder() + .build(); + features.setDomain(urlModel.getDetails().getDomain()); + features.setSubdomain(urlModel.getDetails().getSubdomain()); + features.setTopLevelDomain(urlModel.getDetails().getTopLevelDomain()); + features.setQuery(urlModel.getDetails().getQuery()); + features.setFragment(urlModel.getDetails().getFragment()); + features.setPath(urlModel.getDetails().getPath()); + features.setRedirectChain(urlModel.getDetails().getRedirectChain()); + features.setHstsHeader(urlModel.getDetails().getHstsHeader()); + features.setSslStripping(urlModel.getDetails().getSslStripping()); + features.setHostnameEmbedding(urlModel.getDetails().getHostnameEmbedding()); + features.setJavascriptCheck(urlModel.getDetails().getJavascriptCheck()); + features.setShorteningService(urlModel.getDetails().getShorteningService()); + features.setHasIpAddress(urlModel.getDetails().getHasIpAddress()); + features.setTrackingDescriptions(urlModel.getDetails().getTrackingDescriptions()); + features.setUrlEncoding(urlModel.getDetails().getUrlEncoding()); + features.setHasExecutable(urlModel.getDetails().getHasExecutable()); + features.setTls(Math.toIntExact(urlModel.getData().getInfo().getId())); + features.setContents(urlModel.getData().getContents()); + + return features; + } + // Custom setter for tls (qr_code_type_id) - public void setTls(Long tls) { + public void setTls(Integer tls) { if (tls != null) { - this.tls = tls == 1 ? 0 : tls == 9 ? 1 : tls; + this.tls = tls == 1 ? 0.0 : tls == 9 ? 1.0 : tls.doubleValue(); } else { - this.tls = 0L; + this.tls = 0.0; } } // Custom setter for hostnameEmbedding and other similar columns - public void setHostnameEmbedding(Long hostnameEmbedding) { - this.hostnameEmbedding = (hostnameEmbedding != null && hostnameEmbedding != 0) ? 1L : 0L; + public void setHostnameEmbedding(Integer hostnameEmbedding) { + this.hostnameEmbedding = (hostnameEmbedding != null && hostnameEmbedding != 0) ? 1.0 : 0.0; } - public void setJavascriptCheck(Long javascriptCheck) { - this.javascriptCheck = (javascriptCheck != null && javascriptCheck != 0) ? 1L : 0L; + public void setJavascriptCheck(String javascriptCheck) { + this.javascriptCheck = (javascriptCheck != null && !javascriptCheck.isEmpty()) ? 1.0 : 0.0; } - public void setShorteningService(Long shorteningService) { - this.shorteningService = (shorteningService != null && shorteningService != 0) ? 1L : 0L; + public void setShorteningService(String shorteningService) { + this.shorteningService = (shorteningService != null && !shorteningService.isEmpty()) ? 1.0 : 0.0; } - public void setHasIpAddress(Long hasIpAddress) { - this.hasIpAddress = (hasIpAddress != null && hasIpAddress != 0) ? 1L : 0L; + public void setHasIpAddress(String hasIpAddress) { + this.hasIpAddress = (hasIpAddress != null && !hasIpAddress.isEmpty()) ? 1.0 : 0.0; } - public void setUrlEncoding(Long urlEncoding) { - this.urlEncoding = (urlEncoding != null && urlEncoding != 0) ? 1L : 0L; + public void setUrlEncoding(String urlEncoding) { + this.urlEncoding = (urlEncoding != null && !urlEncoding.isEmpty()) ? 1.0 : 0.0; } - public void setHasExecutable(Long hasExecutable) { - this.hasExecutable = (hasExecutable != null && hasExecutable != 0) ? 1L : 0L; + public void setHasExecutable(String hasExecutable) { + this.hasExecutable = (hasExecutable != null && !hasExecutable.isEmpty()) ? 1.0 : 0.0; } - public void setTrackingDescriptions(Long trackingDescriptions) { - this.trackingDescriptions = (trackingDescriptions != null && trackingDescriptions != 0) ? 1L : 0L; + public void setTrackingDescriptions(List trackingDescriptions) { + this.trackingDescriptions = (trackingDescriptions != null && !trackingDescriptions.isEmpty()) ? 1.0 : 0.0; } // Custom setter for sslStripping - public void setSslStripping(String sslStripping) { - if (sslStripping != null && "true".equalsIgnoreCase(sslStripping)) { - this.sslStripping = 1L; + public void setSslStripping(List sslStripping) { + if (sslStripping != null && !sslStripping.isEmpty() && sslStripping.get(0) != null) { + this.sslStripping = sslStripping.get(0) ? 1.0 : 0.0; } else { - this.sslStripping = 0L; + this.sslStripping = 0.0; } } // Custom setter for hstsHeader - public void setHstsHeader(String hstsHeader) { - if (hstsHeader == null || "0".equals(hstsHeader)) { - this.hstsHeader = 0L; - } else if (hstsHeader.startsWith("{") && hstsHeader.endsWith("}")) { + public void setHstsHeader(List hstsHeader) { + if (hstsHeader == null || hstsHeader.isEmpty()) { + this.hstsHeader = 0.0; + } else if (hstsHeader.get(0).startsWith("{") && hstsHeader.get(0).endsWith("}")) { Pattern pattern = Pattern.compile("\"(.*?)\""); - Matcher matcher = pattern.matcher(hstsHeader); + Matcher matcher = pattern.matcher(hstsHeader.get(0)); if (matcher.find() && matcher.group(1).toLowerCase().contains("no")) { - this.hstsHeader = 0L; + this.hstsHeader = 0.0; } else { - this.hstsHeader = 1L; + this.hstsHeader = 1.0; } } else { - this.hstsHeader = 0L; + this.hstsHeader = 1.0; } } // Custom setters for calculating string lengths public void setDomain(String domain) { - this.domain = (domain != null) ? (long) domain.length() : 0L; + this.domain = (domain != null) ? (double) domain.length() : 0.0; } public void setSubdomain(String subdomain) { - this.subdomain = (subdomain != null) ? (long) subdomain.length() : 0L; + this.subdomain = (subdomain != null) ? (double) subdomain.length() : 0.0; } public void setTopLevelDomain(String topLevelDomain) { - this.topLevelDomain = (topLevelDomain != null) ? (long) topLevelDomain.length() : 0L; + this.topLevelDomain = (topLevelDomain != null) ? (double) topLevelDomain.length() : 0.0; } public void setQuery(String query) { - this.query = (query != null) ? (long) query.length() : 0L; + this.query = (query != null) ? (double) query.length() : 0.0; } public void setFragment(String fragment) { - this.fragment = (fragment != null) ? (long) fragment.length() : 0L; + this.fragment = (fragment != null) ? (double) fragment.length() : 0.0; } public void setPath(String path) { - this.path = (path != null) ? (long) path.length() : 0L; + this.path = (path != null) ? (double) path.length() : 0.0; } - public void setRedirectChain(String redirectChain) { - this.redirectChain = (redirectChain != null) ? (long) redirectChain.length() : 0L; + public void setRedirectChain(List redirectChain) { + this.redirectChain = (redirectChain != null) ? (double) redirectChain.size() : 0.0; } public void setContents(String contents) { - this.contents = (contents != null) ? (long) contents.length() : 0L; + this.contents = (contents != null) ? (double) contents.length() : 0.0; } } diff --git a/src/main/java/com/safeqr/app/spark/service/MLModelService.java b/src/main/java/com/safeqr/app/spark/service/MLModelService.java new file mode 100644 index 0000000..25e5f65 --- /dev/null +++ b/src/main/java/com/safeqr/app/spark/service/MLModelService.java @@ -0,0 +1,22 @@ +package com.safeqr.app.spark.service; + +import com.safeqr.app.qrcode.model.URLModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + + +@Service +public class MLModelService { + private static final Logger logger = LoggerFactory.getLogger(MLModelService.class); + + public MLModelService() { + + } + + + public String predict(URLModel urlModel) { + + return "haha"; + } +}