0x0

mia's file "the null pointer" hosting application -- l0bster h0sted
Log | Files | Refs | LICENSE

nsfw_detect.py (2258B)


      1 #!/usr/bin/env python3
      2 
      3 import numpy as np
      4 import os
      5 import sys
      6 from io import BytesIO
      7 from subprocess import run, PIPE, DEVNULL
      8 
      9 os.environ["GLOG_minloglevel"] = "2" # seriously :|
     10 import caffe
     11 
     12 class NSFWDetector:
     13     def __init__(self):
     14 
     15         npath = os.path.join(os.path.dirname(__file__), "nsfw_model")
     16         self.nsfw_net = caffe.Net(os.path.join(npath, "deploy.prototxt"),
     17                                   os.path.join(npath, "resnet_50_1by2_nsfw.caffemodel"),
     18                                   caffe.TEST)
     19         self.caffe_transformer = caffe.io.Transformer({'data': self.nsfw_net.blobs['data'].data.shape})
     20         self.caffe_transformer.set_transpose('data', (2, 0, 1))  # move image channels to outermost
     21         self.caffe_transformer.set_mean('data', np.array([104, 117, 123]))  # subtract the dataset-mean value in each channel
     22         self.caffe_transformer.set_raw_scale('data', 255)  # rescale from [0, 1] to [0, 255]
     23         self.caffe_transformer.set_channel_swap('data', (2, 1, 0))  # swap channels from RGB to BGR
     24 
     25     def _compute(self, img):
     26         image = caffe.io.load_image(BytesIO(img))
     27 
     28         H, W, _ = image.shape
     29         _, _, h, w = self.nsfw_net.blobs["data"].data.shape
     30         h_off = int(max((H - h) / 2, 0))
     31         w_off = int(max((W - w) / 2, 0))
     32         crop = image[h_off:h_off + h, w_off:w_off + w, :]
     33 
     34         transformed_image = self.caffe_transformer.preprocess('data', crop)
     35         transformed_image.shape = (1,) + transformed_image.shape
     36 
     37         input_name = self.nsfw_net.inputs[0]
     38         output_layers = ["prob"]
     39         all_outputs = self.nsfw_net.forward_all(blobs=output_layers,
     40                     **{input_name: transformed_image})
     41 
     42         outputs = all_outputs[output_layers[0]][0].astype(float)
     43 
     44         return outputs
     45 
     46     def detect(self, fpath):
     47         try:
     48             ff = run(["ffmpegthumbnailer", "-m", "-o-", "-s256", "-t50%", "-a", "-cpng", "-i", fpath], stdout=PIPE, stderr=DEVNULL, check=True)
     49             image_data = ff.stdout
     50         except:
     51             return -1.0
     52 
     53         scores = self._compute(image_data)
     54 
     55         return scores[1]
     56 
     57 if __name__ == "__main__":
     58     n = NSFWDetector()
     59 
     60     for inf in sys.argv[1:]:
     61         score = n.detect(inf)
     62         print(inf, score)