nsfw_detect.py (2258B)
1 #!/usr/bin/env python3 2 3 import numpy as np 4 import os 5 import sys 6 from io import BytesIO 7 from subprocess import run, PIPE, DEVNULL 8 9 os.environ["GLOG_minloglevel"] = "2" # seriously :| 10 import caffe 11 12 class NSFWDetector: 13 def __init__(self): 14 15 npath = os.path.join(os.path.dirname(__file__), "nsfw_model") 16 self.nsfw_net = caffe.Net(os.path.join(npath, "deploy.prototxt"), 17 os.path.join(npath, "resnet_50_1by2_nsfw.caffemodel"), 18 caffe.TEST) 19 self.caffe_transformer = caffe.io.Transformer({'data': self.nsfw_net.blobs['data'].data.shape}) 20 self.caffe_transformer.set_transpose('data', (2, 0, 1)) # move image channels to outermost 21 self.caffe_transformer.set_mean('data', np.array([104, 117, 123])) # subtract the dataset-mean value in each channel 22 self.caffe_transformer.set_raw_scale('data', 255) # rescale from [0, 1] to [0, 255] 23 self.caffe_transformer.set_channel_swap('data', (2, 1, 0)) # swap channels from RGB to BGR 24 25 def _compute(self, img): 26 image = caffe.io.load_image(BytesIO(img)) 27 28 H, W, _ = image.shape 29 _, _, h, w = self.nsfw_net.blobs["data"].data.shape 30 h_off = int(max((H - h) / 2, 0)) 31 w_off = int(max((W - w) / 2, 0)) 32 crop = image[h_off:h_off + h, w_off:w_off + w, :] 33 34 transformed_image = self.caffe_transformer.preprocess('data', crop) 35 transformed_image.shape = (1,) + transformed_image.shape 36 37 input_name = self.nsfw_net.inputs[0] 38 output_layers = ["prob"] 39 all_outputs = self.nsfw_net.forward_all(blobs=output_layers, 40 **{input_name: transformed_image}) 41 42 outputs = all_outputs[output_layers[0]][0].astype(float) 43 44 return outputs 45 46 def detect(self, fpath): 47 try: 48 ff = run(["ffmpegthumbnailer", "-m", "-o-", "-s256", "-t50%", "-a", "-cpng", "-i", fpath], stdout=PIPE, stderr=DEVNULL, check=True) 49 image_data = ff.stdout 50 except: 51 return -1.0 52 53 scores = self._compute(image_data) 54 55 return scores[1] 56 57 if __name__ == "__main__": 58 n = NSFWDetector() 59 60 for inf in sys.argv[1:]: 61 score = n.detect(inf) 62 print(inf, score)