#!/usr/bin/env python
# Copyright 2016 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import bisect
import collections
import gzip
import json
import os
import sys

_SYMBOLS_PATH = os.path.abspath(os.path.join(
    os.path.dirname(__file__),
    '..',
    'third_party',
    'symbols'))
sys.path.append(_SYMBOLS_PATH)
# pylint: disable=import-error
import symbols.elf_symbolizer as elf_symbolizer


# Relevant trace event phases from Chromium's
# src/base/trace_event/common/trace_event_common.h.
TRACE_EVENT_PHASE_METADATA = 'M'
TRACE_EVENT_PHASE_MEMORY_DUMP = 'v'


class ProcessMemoryMaps(object):
  """Represents 'process_mmaps' trace file entry."""

  class Region(object):
    def __init__(self, start_address, size, file_name):
      self._start_address = start_address
      self._size = size
      self._file_name = file_name

    @property
    def start_address(self):
      return self._start_address

    @property
    def end_address(self):
      return self._start_address + self._size

    @property
    def size(self):
      return self._size

    @property
    def file_name(self):
      return self._file_name

    def __cmp__(self, other):
      if isinstance(other, type(self)):
        return long(self._start_address).__cmp__(long(other._start_address))
      elif isinstance(other, (long, int)):
        return long(self._start_address).__cmp__(long(other))
      else:
        raise Exception('Cannot compare with %s' % type(other))

    def __repr__(self):
      return 'Region(0x{:X} - 0x{:X}, {})'.format(
          self.start_address, self.end_address, self.file_name)

  def __init__(self, process_mmaps):
    """Parses 'process_mmaps' dictionary."""

    regions = []
    for region_value in process_mmaps['vm_regions']:
      regions.append(self.Region(
          long(region_value['sa'], 16),
          long(region_value['sz'], 16),
          region_value['mf']))
    regions.sort()

    # Copy regions without duplicates and check for overlaps.
    self._regions = []
    previous_region = None
    for region in regions:
      if previous_region is not None:
        if region == previous_region:
          continue
        assert region.start_address >= previous_region.end_address, \
            'Regions {} and {} overlap.'.format(previous_region, region)
      previous_region = region
      self._regions.append(region)

  @property
  def regions(self):
    return self._regions

  def FindRegion(self, address):
    """Finds region containing |address|. Returns None if none found."""

    region_index = bisect.bisect_right(self._regions, address) - 1
    if region_index >= 0:
      region = self._regions[region_index]
      if address >= region.start_address and address < region.end_address:
        return region
    return None


class StackFrames(object):
  """Represents 'stackFrames' trace file entry."""

  def __init__(self, stack_frames):
    """Constructs object using 'stackFrames' dictionary.

    Saves reference to the dictionary for later modification by SymbolizePCs().
    """
    self._stack_frames = stack_frames

  def CollectPCs(self):
    """Parses stack frames and returns list of PCs (program counters)."""
    return [pc for _, pc in self._IteratePCs()]

  def SymbolizePCs(self, pc_symbol_map):
    """Symbolizes PCs using the provided pc->symbol mapping.

    Function translates each PC using |pc_symbol_map| and updates stack_frames
    dictionary this object was created with.
    """
    symbolized = False
    for frame, pc in self._IteratePCs():
      if pc in pc_symbol_map:
        frame['name'] = pc_symbol_map[pc]
        symbolized = True
    return symbolized

  _PC_TAG = 'pc:'

  def _ParsePC(self, name):
    if not name.startswith(self._PC_TAG):
      return None
    return long(name[len(self._PC_TAG):], 16)

  def _IteratePCs(self):
    for frame in self._stack_frames.itervalues():
      pc = self._ParsePC(frame['name'])
      if pc is not None:
        yield (frame, pc)


class Symbolizer(object):
  """Convenience wrapper for ELFSymbolizer."""

  def __init__(self, binary_path, addr2line_path):
    self._pc_symbol_map = {}
    self._failed_pcs = set()
    self._queued_pcs = set()
    self._elf_symbolizer = elf_symbolizer.ELFSymbolizer(
        binary_path,
        addr2line_path,
        self._SymbolizerCallback)

  @property
  def pc_symbol_map(self):
    return self._pc_symbol_map

  @property
  def failed_pcs(self):
    return self._failed_pcs

  def SymbolizeAsync(self, pc, region):
    if pc in self._queued_pcs:
      return

    # SymbolizeAsync() asserts that the type of address is int. We operate
    # on longs (since they are raw pointers possibly from 64-bit processes).
    # It's OK to cast here because we're passing relative address, which
    # should always fit into int.
    self._elf_symbolizer.SymbolizeAsync(int(pc - region.start_address), pc)
    self._queued_pcs.add(pc)

  def Join(self):
    self._elf_symbolizer.Join()

  def _SymbolizerCallback(self, sym_info, pc):
    self._queued_pcs.remove(pc)
    if sym_info.name:
      self._pc_symbol_map[pc] = sym_info.name
    else:
      self._failed_pcs.add(pc)


class Process(object):
  """Holds various bits of information about a process in a trace file."""

  def __init__(self, pid):
    self.pid = pid
    self.name = None
    self.mmaps = None
    self.stack_frames = None


def CollectProcesses(trace):
  """Parses trace dictionary and returns pid->Process map of all processes
     suitable for symbolization (which have both mmaps and stack_frames).
  """

  process_map = {}

  for event in trace['traceEvents']:
    name = event.get('name')
    if not name:
      continue

    pid = event['pid']
    process = process_map.get(pid)
    if process is None:
      process = Process(pid)
      process_map[pid] = process

    phase = event['ph']
    if phase == TRACE_EVENT_PHASE_METADATA:
      if name == 'process_name':
        process.name = event['args']['name']
      elif name == 'stackFrames':
        process.stack_frames = StackFrames(event['args']['stackFrames'])
    elif phase == TRACE_EVENT_PHASE_MEMORY_DUMP:
      process_mmaps = event['args']['dumps'].get('process_mmaps')
      if process_mmaps:
        # TODO(dskiba): this parses all process_mmaps, but retains only the
        #               last one. We need to parse only once (lazy parsing?).
        process.mmaps = ProcessMemoryMaps(process_mmaps)

  return [p for p in process_map.itervalues() if p.mmaps and p.stack_frames]


def SymbolizeProcess(process, addr2line_path):
  """Symbolizes all PCs in the given process."""

  pcs = process.stack_frames.CollectPCs()
  if not pcs:
    return False

  print 'Symbolizing {} ({})...'.format(process.name, process.pid)

  def _SubPrintf(message, *args):
    print ('  ' + message).format(*args)

  unresolved_pcs_count = 0
  pending_pc_regions = collections.defaultdict(list)
  for pc in pcs:
    region = process.mmaps.FindRegion(pc)
    if region is None:
      unresolved_pcs_count += 1
      continue
    pc_regions = pending_pc_regions[region.file_name]
    pc_regions.append((pc, region))

  if unresolved_pcs_count != 0:
    _SubPrintf('{} PCs were not resolved.', unresolved_pcs_count)

  pc_symbol_map = {}
  for file_name, pc_regions in pending_pc_regions.iteritems():
    problem = None
    if not os.path.isabs(file_name):
      problem = 'not a file'
    elif not os.path.isfile(file_name):
      problem = "file doesn't exist"
    if problem:
      _SubPrintf('Not symbolizing {} PCs for "{}": {}.',
                 len(pc_regions), file_name, problem)
      continue

    symbolizer = Symbolizer(file_name, addr2line_path)
    for pc, region in pc_regions:
      symbolizer.SymbolizeAsync(pc, region)
    symbolizer.Join()
    _SubPrintf('{}: {} PCs symbolized ({} failed)',
               file_name,
               len(symbolizer.pc_symbol_map),
               len(symbolizer.failed_pcs))
    pc_symbol_map.update(symbolizer.pc_symbol_map)

  return process.stack_frames.SymbolizePCs(pc_symbol_map)


def FindInSystemPath(binary_name):
  paths = os.environ['PATH'].split(os.pathsep)
  for path in paths:
    binary_path = os.path.join(path, binary_name)
    if os.path.isfile(binary_path):
      return binary_path
  return None


# Suffix used for backup files.
BACKUP_FILE_TAG = '.BACKUP'

def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('file',
                      help='Trace file to symbolize (.json or .json.gz)')
  parser.add_argument('--no-backup',
                      dest='backup', default='true', action='store_false',
                      help="Don't create {} files".format(BACKUP_FILE_TAG))
  options = parser.parse_args()

  trace_file_path = options.file
  def _OpenTraceFile(mode):
    if trace_file_path.endswith('.gz'):
      return gzip.open(trace_file_path, mode + 'b')
    else:
      return open(trace_file_path, mode + 't')

  addr2line_path = FindInSystemPath('addr2line')
  if addr2line_path is None:
    sys.exit("Can't symbolize - no addr2line in PATH.")

  print 'Reading trace file...'
  with _OpenTraceFile('r') as trace_file:
    trace = json.load(trace_file)

  processes = CollectProcesses(trace)

  update_trace = False
  for process in processes:
    if SymbolizeProcess(process, addr2line_path):
      update_trace = True

  if update_trace:
    if options.backup:
      backup_file_path = trace_file_path + BACKUP_FILE_TAG
      print 'Backing up trace file to {}...'.format(backup_file_path)
      os.rename(trace_file_path, backup_file_path)

    print 'Updating trace file...'
    with _OpenTraceFile('w') as trace_file:
      json.dump(trace, trace_file)
  else:
    print 'No PCs symbolized - not updating trace file.'


if __name__ == '__main__':
  main()
